diff options
Diffstat (limited to 'rust')
151 files changed, 5237 insertions, 4188 deletions
diff --git a/rust/Makefile b/rust/Makefile index 5d357dce1704..757974551359 100644 --- a/rust/Makefile +++ b/rust/Makefile @@ -117,6 +117,23 @@ syn-flags := \ --extern quote \ $(call cfgs-to-flags,$(syn-cfgs)) +pin_init_internal-cfgs := \ + kernel + +pin_init_internal-flags := \ + --extern proc_macro2 \ + --extern quote \ + --extern syn \ + $(call cfgs-to-flags,$(pin_init_internal-cfgs)) + +pin_init-cfgs := \ + kernel + +pin_init-flags := \ + --extern pin_init_internal \ + --extern macros \ + $(call cfgs-to-flags,$(pin_init-cfgs)) + # `rustdoc` did not save the target modifiers, thus workaround for # the time being (https://github.com/rust-lang/rust/issues/144521). rustdoc_modifiers_workaround := $(if $(call rustc-min-version,108800),-Cunsafe-allow-abi-mismatch=fixed-x18) @@ -211,15 +228,15 @@ rustdoc-ffi: $(src)/ffi.rs rustdoc-core FORCE +$(call if_changed,rustdoc) rustdoc-pin_init_internal: private rustdoc_host = yes -rustdoc-pin_init_internal: private rustc_target_flags = --cfg kernel \ +rustdoc-pin_init_internal: private rustc_target_flags = $(pin_init_internal-flags) \ --extern proc_macro --crate-type proc-macro rustdoc-pin_init_internal: $(src)/pin-init/internal/src/lib.rs \ - rustdoc-clean FORCE + rustdoc-clean rustdoc-proc_macro2 rustdoc-quote rustdoc-syn FORCE +$(call if_changed,rustdoc) rustdoc-pin_init: private rustdoc_host = yes -rustdoc-pin_init: private rustc_target_flags = --extern pin_init_internal \ - --extern macros --extern alloc --cfg kernel --cfg feature=\"alloc\" +rustdoc-pin_init: private rustc_target_flags = $(pin_init-flags) \ + --extern alloc --cfg feature=\"alloc\" rustdoc-pin_init: $(src)/pin-init/src/lib.rs rustdoc-pin_init_internal \ rustdoc-macros FORCE +$(call if_changed,rustdoc) @@ -272,14 +289,14 @@ rusttestlib-macros: $(src)/macros/lib.rs \ rusttestlib-proc_macro2 rusttestlib-quote rusttestlib-syn FORCE +$(call if_changed,rustc_test_library) -rusttestlib-pin_init_internal: private rustc_target_flags = --cfg kernel \ +rusttestlib-pin_init_internal: private rustc_target_flags = $(pin_init_internal-flags) \ --extern proc_macro rusttestlib-pin_init_internal: private rustc_test_library_proc = yes -rusttestlib-pin_init_internal: $(src)/pin-init/internal/src/lib.rs FORCE +rusttestlib-pin_init_internal: $(src)/pin-init/internal/src/lib.rs \ + rusttestlib-proc_macro2 rusttestlib-quote rusttestlib-syn FORCE +$(call if_changed,rustc_test_library) -rusttestlib-pin_init: private rustc_target_flags = --extern pin_init_internal \ - --extern macros --cfg kernel +rusttestlib-pin_init: private rustc_target_flags = $(pin_init-flags) rusttestlib-pin_init: $(src)/pin-init/src/lib.rs rusttestlib-macros \ rusttestlib-pin_init_internal $(obj)/$(libpin_init_internal_name) FORCE +$(call if_changed,rustc_test_library) @@ -383,6 +400,7 @@ bindgen_skip_c_flags := -mno-fp-ret-in-387 -mpreferred-stack-boundary=% \ -fno-inline-functions-called-once -fsanitize=bounds-strict \ -fstrict-flex-arrays=% -fmin-function-alignment=% \ -fzero-init-padding-bits=% -mno-fdpic \ + -fdiagnostics-show-context -fdiagnostics-show-context=% \ --param=% --param asan-% -fno-isolate-erroneous-paths-dereference # Derived from `scripts/Makefile.clang`. @@ -547,8 +565,9 @@ $(obj)/$(libmacros_name): $(src)/macros/lib.rs $(obj)/libproc_macro2.rlib \ $(obj)/libquote.rlib $(obj)/libsyn.rlib FORCE +$(call if_changed_dep,rustc_procmacro) -$(obj)/$(libpin_init_internal_name): private rustc_target_flags = --cfg kernel -$(obj)/$(libpin_init_internal_name): $(src)/pin-init/internal/src/lib.rs FORCE +$(obj)/$(libpin_init_internal_name): private rustc_target_flags = $(pin_init_internal-flags) +$(obj)/$(libpin_init_internal_name): $(src)/pin-init/internal/src/lib.rs \ + $(obj)/libproc_macro2.rlib $(obj)/libquote.rlib $(obj)/libsyn.rlib FORCE +$(call if_changed_dep,rustc_procmacro) quiet_cmd_rustc_library = $(if $(skip_clippy),RUSTC,$(RUSTC_OR_CLIPPY_QUIET)) L $@ @@ -642,8 +661,7 @@ $(obj)/compiler_builtins.o: $(src)/compiler_builtins.rs $(obj)/core.o FORCE +$(call if_changed_rule,rustc_library) $(obj)/pin_init.o: private skip_gendwarfksyms = 1 -$(obj)/pin_init.o: private rustc_target_flags = --extern pin_init_internal \ - --extern macros --cfg kernel +$(obj)/pin_init.o: private rustc_target_flags = $(pin_init-flags) $(obj)/pin_init.o: $(src)/pin-init/src/lib.rs $(obj)/compiler_builtins.o \ $(obj)/$(libpin_init_internal_name) $(obj)/$(libmacros_name) FORCE +$(call if_changed_rule,rustc_library) diff --git a/rust/bindings/bindings_helper.h b/rust/bindings/bindings_helper.h index a067038b4b42..083cc44aa952 100644 --- a/rust/bindings/bindings_helper.h +++ b/rust/bindings/bindings_helper.h @@ -56,9 +56,10 @@ #include <linux/fdtable.h> #include <linux/file.h> #include <linux/firmware.h> -#include <linux/interrupt.h> #include <linux/fs.h> #include <linux/i2c.h> +#include <linux/interrupt.h> +#include <linux/io-pgtable.h> #include <linux/ioport.h> #include <linux/jiffies.h> #include <linux/jump_label.h> @@ -80,6 +81,7 @@ #include <linux/sched.h> #include <linux/security.h> #include <linux/slab.h> +#include <linux/sys_soc.h> #include <linux/task_work.h> #include <linux/tracepoint.h> #include <linux/usb.h> diff --git a/rust/bindings/lib.rs b/rust/bindings/lib.rs index 0c57cf9b4004..19f57c5b2fa2 100644 --- a/rust/bindings/lib.rs +++ b/rust/bindings/lib.rs @@ -67,3 +67,16 @@ mod bindings_helper { } pub use bindings_raw::*; + +pub const compat_ptr_ioctl: Option< + unsafe extern "C" fn(*mut file, ffi::c_uint, ffi::c_ulong) -> ffi::c_long, +> = { + #[cfg(CONFIG_COMPAT)] + { + Some(bindings_raw::compat_ptr_ioctl) + } + #[cfg(not(CONFIG_COMPAT))] + { + None + } +}; diff --git a/rust/helpers/atomic.c b/rust/helpers/atomic.c index cf06b7ef9a1c..4b24eceef5fc 100644 --- a/rust/helpers/atomic.c +++ b/rust/helpers/atomic.c @@ -11,11 +11,6 @@ #include <linux/atomic.h> -// TODO: Remove this after INLINE_HELPERS support is added. -#ifndef __rust_helper -#define __rust_helper -#endif - __rust_helper int rust_helper_atomic_read(const atomic_t *v) { @@ -1037,4 +1032,4 @@ rust_helper_atomic64_dec_if_positive(atomic64_t *v) } #endif /* _RUST_ATOMIC_API_H */ -// 615a0e0c98b5973a47fe4fa65e92935051ca00ed +// e4edb6174dd42a265284958f00a7cea7ddb464b1 diff --git a/rust/helpers/atomic_ext.c b/rust/helpers/atomic_ext.c new file mode 100644 index 000000000000..7d0c2bd340da --- /dev/null +++ b/rust/helpers/atomic_ext.c @@ -0,0 +1,139 @@ +// SPDX-License-Identifier: GPL-2.0 + +#include <asm/barrier.h> +#include <asm/rwonce.h> +#include <linux/atomic.h> + +__rust_helper s8 rust_helper_atomic_i8_read(s8 *ptr) +{ + return READ_ONCE(*ptr); +} + +__rust_helper s8 rust_helper_atomic_i8_read_acquire(s8 *ptr) +{ + return smp_load_acquire(ptr); +} + +__rust_helper s16 rust_helper_atomic_i16_read(s16 *ptr) +{ + return READ_ONCE(*ptr); +} + +__rust_helper s16 rust_helper_atomic_i16_read_acquire(s16 *ptr) +{ + return smp_load_acquire(ptr); +} + +__rust_helper void rust_helper_atomic_i8_set(s8 *ptr, s8 val) +{ + WRITE_ONCE(*ptr, val); +} + +__rust_helper void rust_helper_atomic_i8_set_release(s8 *ptr, s8 val) +{ + smp_store_release(ptr, val); +} + +__rust_helper void rust_helper_atomic_i16_set(s16 *ptr, s16 val) +{ + WRITE_ONCE(*ptr, val); +} + +__rust_helper void rust_helper_atomic_i16_set_release(s16 *ptr, s16 val) +{ + smp_store_release(ptr, val); +} + +/* + * xchg helpers depend on ARCH_SUPPORTS_ATOMIC_RMW and on the + * architecture provding xchg() support for i8 and i16. + * + * The architectures that currently support Rust (x86_64, armv7, + * arm64, riscv, and loongarch) satisfy these requirements. + */ +__rust_helper s8 rust_helper_atomic_i8_xchg(s8 *ptr, s8 new) +{ + return xchg(ptr, new); +} + +__rust_helper s16 rust_helper_atomic_i16_xchg(s16 *ptr, s16 new) +{ + return xchg(ptr, new); +} + +__rust_helper s8 rust_helper_atomic_i8_xchg_acquire(s8 *ptr, s8 new) +{ + return xchg_acquire(ptr, new); +} + +__rust_helper s16 rust_helper_atomic_i16_xchg_acquire(s16 *ptr, s16 new) +{ + return xchg_acquire(ptr, new); +} + +__rust_helper s8 rust_helper_atomic_i8_xchg_release(s8 *ptr, s8 new) +{ + return xchg_release(ptr, new); +} + +__rust_helper s16 rust_helper_atomic_i16_xchg_release(s16 *ptr, s16 new) +{ + return xchg_release(ptr, new); +} + +__rust_helper s8 rust_helper_atomic_i8_xchg_relaxed(s8 *ptr, s8 new) +{ + return xchg_relaxed(ptr, new); +} + +__rust_helper s16 rust_helper_atomic_i16_xchg_relaxed(s16 *ptr, s16 new) +{ + return xchg_relaxed(ptr, new); +} + +/* + * try_cmpxchg helpers depend on ARCH_SUPPORTS_ATOMIC_RMW and on the + * architecture provding try_cmpxchg() support for i8 and i16. + * + * The architectures that currently support Rust (x86_64, armv7, + * arm64, riscv, and loongarch) satisfy these requirements. + */ +__rust_helper bool rust_helper_atomic_i8_try_cmpxchg(s8 *ptr, s8 *old, s8 new) +{ + return try_cmpxchg(ptr, old, new); +} + +__rust_helper bool rust_helper_atomic_i16_try_cmpxchg(s16 *ptr, s16 *old, s16 new) +{ + return try_cmpxchg(ptr, old, new); +} + +__rust_helper bool rust_helper_atomic_i8_try_cmpxchg_acquire(s8 *ptr, s8 *old, s8 new) +{ + return try_cmpxchg_acquire(ptr, old, new); +} + +__rust_helper bool rust_helper_atomic_i16_try_cmpxchg_acquire(s16 *ptr, s16 *old, s16 new) +{ + return try_cmpxchg_acquire(ptr, old, new); +} + +__rust_helper bool rust_helper_atomic_i8_try_cmpxchg_release(s8 *ptr, s8 *old, s8 new) +{ + return try_cmpxchg_release(ptr, old, new); +} + +__rust_helper bool rust_helper_atomic_i16_try_cmpxchg_release(s16 *ptr, s16 *old, s16 new) +{ + return try_cmpxchg_release(ptr, old, new); +} + +__rust_helper bool rust_helper_atomic_i8_try_cmpxchg_relaxed(s8 *ptr, s8 *old, s8 new) +{ + return try_cmpxchg_relaxed(ptr, old, new); +} + +__rust_helper bool rust_helper_atomic_i16_try_cmpxchg_relaxed(s16 *ptr, s16 *old, s16 new) +{ + return try_cmpxchg_relaxed(ptr, old, new); +} diff --git a/rust/helpers/auxiliary.c b/rust/helpers/auxiliary.c index 8b5e0fea4493..dd1c130843a0 100644 --- a/rust/helpers/auxiliary.c +++ b/rust/helpers/auxiliary.c @@ -2,12 +2,14 @@ #include <linux/auxiliary_bus.h> -void rust_helper_auxiliary_device_uninit(struct auxiliary_device *adev) +__rust_helper void +rust_helper_auxiliary_device_uninit(struct auxiliary_device *adev) { return auxiliary_device_uninit(adev); } -void rust_helper_auxiliary_device_delete(struct auxiliary_device *adev) +__rust_helper void +rust_helper_auxiliary_device_delete(struct auxiliary_device *adev) { return auxiliary_device_delete(adev); } diff --git a/rust/helpers/barrier.c b/rust/helpers/barrier.c index cdf28ce8e511..fed8853745c8 100644 --- a/rust/helpers/barrier.c +++ b/rust/helpers/barrier.c @@ -2,17 +2,17 @@ #include <asm/barrier.h> -void rust_helper_smp_mb(void) +__rust_helper void rust_helper_smp_mb(void) { smp_mb(); } -void rust_helper_smp_wmb(void) +__rust_helper void rust_helper_smp_wmb(void) { smp_wmb(); } -void rust_helper_smp_rmb(void) +__rust_helper void rust_helper_smp_rmb(void) { smp_rmb(); } diff --git a/rust/helpers/binder.c b/rust/helpers/binder.c index 224d38a92f1d..a2327f1b3c94 100644 --- a/rust/helpers/binder.c +++ b/rust/helpers/binder.c @@ -7,20 +7,21 @@ #include <linux/list_lru.h> #include <linux/task_work.h> -unsigned long rust_helper_list_lru_count(struct list_lru *lru) +__rust_helper unsigned long rust_helper_list_lru_count(struct list_lru *lru) { return list_lru_count(lru); } -unsigned long rust_helper_list_lru_walk(struct list_lru *lru, - list_lru_walk_cb isolate, void *cb_arg, - unsigned long nr_to_walk) +__rust_helper unsigned long rust_helper_list_lru_walk(struct list_lru *lru, + list_lru_walk_cb isolate, + void *cb_arg, + unsigned long nr_to_walk) { return list_lru_walk(lru, isolate, cb_arg, nr_to_walk); } -void rust_helper_init_task_work(struct callback_head *twork, - task_work_func_t func) +__rust_helper void rust_helper_init_task_work(struct callback_head *twork, + task_work_func_t func) { init_task_work(twork, func); } diff --git a/rust/helpers/bitmap.c b/rust/helpers/bitmap.c index a50e2f082e47..e4e9f4361270 100644 --- a/rust/helpers/bitmap.c +++ b/rust/helpers/bitmap.c @@ -2,6 +2,7 @@ #include <linux/bitmap.h> +__rust_helper void rust_helper_bitmap_copy_and_extend(unsigned long *to, const unsigned long *from, unsigned int count, unsigned int size) { diff --git a/rust/helpers/bitops.c b/rust/helpers/bitops.c index 5d0861d29d3f..271b8a712dee 100644 --- a/rust/helpers/bitops.c +++ b/rust/helpers/bitops.c @@ -1,23 +1,69 @@ // SPDX-License-Identifier: GPL-2.0 #include <linux/bitops.h> +#include <linux/find.h> +__rust_helper void rust_helper___set_bit(unsigned long nr, unsigned long *addr) { __set_bit(nr, addr); } +__rust_helper void rust_helper___clear_bit(unsigned long nr, unsigned long *addr) { __clear_bit(nr, addr); } +__rust_helper void rust_helper_set_bit(unsigned long nr, volatile unsigned long *addr) { set_bit(nr, addr); } +__rust_helper void rust_helper_clear_bit(unsigned long nr, volatile unsigned long *addr) { clear_bit(nr, addr); } + +/* + * The rust_helper_ prefix is intentionally omitted below so that the + * declarations in include/linux/find.h are compatible with these helpers. + * + * Note that the below #ifdefs mean that the helper is only created if C does + * not provide a definition. + */ +#ifdef find_first_zero_bit +__rust_helper +unsigned long _find_first_zero_bit(const unsigned long *p, unsigned long size) +{ + return find_first_zero_bit(p, size); +} +#endif /* find_first_zero_bit */ + +#ifdef find_next_zero_bit +__rust_helper +unsigned long _find_next_zero_bit(const unsigned long *addr, + unsigned long size, unsigned long offset) +{ + return find_next_zero_bit(addr, size, offset); +} +#endif /* find_next_zero_bit */ + +#ifdef find_first_bit +__rust_helper +unsigned long _find_first_bit(const unsigned long *addr, unsigned long size) +{ + return find_first_bit(addr, size); +} +#endif /* find_first_bit */ + +#ifdef find_next_bit +__rust_helper +unsigned long _find_next_bit(const unsigned long *addr, unsigned long size, + unsigned long offset) +{ + return find_next_bit(addr, size, offset); +} +#endif /* find_next_bit */ diff --git a/rust/helpers/blk.c b/rust/helpers/blk.c index cc9f4e6a2d23..20c512e46a7a 100644 --- a/rust/helpers/blk.c +++ b/rust/helpers/blk.c @@ -3,12 +3,12 @@ #include <linux/blk-mq.h> #include <linux/blkdev.h> -void *rust_helper_blk_mq_rq_to_pdu(struct request *rq) +__rust_helper void *rust_helper_blk_mq_rq_to_pdu(struct request *rq) { return blk_mq_rq_to_pdu(rq); } -struct request *rust_helper_blk_mq_rq_from_pdu(void *pdu) +__rust_helper struct request *rust_helper_blk_mq_rq_from_pdu(void *pdu) { return blk_mq_rq_from_pdu(pdu); } diff --git a/rust/helpers/bug.c b/rust/helpers/bug.c index a62c96f507d1..b51e60772578 100644 --- a/rust/helpers/bug.c +++ b/rust/helpers/bug.c @@ -2,12 +2,12 @@ #include <linux/bug.h> -__noreturn void rust_helper_BUG(void) +__rust_helper __noreturn void rust_helper_BUG(void) { BUG(); } -bool rust_helper_WARN_ON(bool cond) +__rust_helper bool rust_helper_WARN_ON(bool cond) { return WARN_ON(cond); } diff --git a/rust/helpers/build_bug.c b/rust/helpers/build_bug.c index 44e579488037..14dbc55bb539 100644 --- a/rust/helpers/build_bug.c +++ b/rust/helpers/build_bug.c @@ -2,7 +2,7 @@ #include <linux/errname.h> -const char *rust_helper_errname(int err) +__rust_helper const char *rust_helper_errname(int err) { return errname(err); } diff --git a/rust/helpers/completion.c b/rust/helpers/completion.c index b2443262a2ae..0126767cc3be 100644 --- a/rust/helpers/completion.c +++ b/rust/helpers/completion.c @@ -2,7 +2,7 @@ #include <linux/completion.h> -void rust_helper_init_completion(struct completion *x) +__rust_helper void rust_helper_init_completion(struct completion *x) { init_completion(x); } diff --git a/rust/helpers/cpu.c b/rust/helpers/cpu.c index 824e0adb19d4..5759349b2c88 100644 --- a/rust/helpers/cpu.c +++ b/rust/helpers/cpu.c @@ -2,7 +2,7 @@ #include <linux/smp.h> -unsigned int rust_helper_raw_smp_processor_id(void) +__rust_helper unsigned int rust_helper_raw_smp_processor_id(void) { return raw_smp_processor_id(); } diff --git a/rust/helpers/cpufreq.c b/rust/helpers/cpufreq.c index 7c1343c4d65e..0e16aeef2b5a 100644 --- a/rust/helpers/cpufreq.c +++ b/rust/helpers/cpufreq.c @@ -3,7 +3,8 @@ #include <linux/cpufreq.h> #ifdef CONFIG_CPU_FREQ -void rust_helper_cpufreq_register_em_with_opp(struct cpufreq_policy *policy) +__rust_helper void +rust_helper_cpufreq_register_em_with_opp(struct cpufreq_policy *policy) { cpufreq_register_em_with_opp(policy); } diff --git a/rust/helpers/cpumask.c b/rust/helpers/cpumask.c index eb10598a0242..5deced5b975e 100644 --- a/rust/helpers/cpumask.c +++ b/rust/helpers/cpumask.c @@ -2,67 +2,80 @@ #include <linux/cpumask.h> +__rust_helper void rust_helper_cpumask_set_cpu(unsigned int cpu, struct cpumask *dstp) { cpumask_set_cpu(cpu, dstp); } +__rust_helper void rust_helper___cpumask_set_cpu(unsigned int cpu, struct cpumask *dstp) { __cpumask_set_cpu(cpu, dstp); } +__rust_helper void rust_helper_cpumask_clear_cpu(int cpu, struct cpumask *dstp) { cpumask_clear_cpu(cpu, dstp); } +__rust_helper void rust_helper___cpumask_clear_cpu(int cpu, struct cpumask *dstp) { __cpumask_clear_cpu(cpu, dstp); } +__rust_helper bool rust_helper_cpumask_test_cpu(int cpu, struct cpumask *srcp) { return cpumask_test_cpu(cpu, srcp); } +__rust_helper void rust_helper_cpumask_setall(struct cpumask *dstp) { cpumask_setall(dstp); } +__rust_helper bool rust_helper_cpumask_empty(struct cpumask *srcp) { return cpumask_empty(srcp); } +__rust_helper bool rust_helper_cpumask_full(struct cpumask *srcp) { return cpumask_full(srcp); } +__rust_helper unsigned int rust_helper_cpumask_weight(struct cpumask *srcp) { return cpumask_weight(srcp); } +__rust_helper void rust_helper_cpumask_copy(struct cpumask *dstp, const struct cpumask *srcp) { cpumask_copy(dstp, srcp); } +__rust_helper bool rust_helper_alloc_cpumask_var(cpumask_var_t *mask, gfp_t flags) { return alloc_cpumask_var(mask, flags); } +__rust_helper bool rust_helper_zalloc_cpumask_var(cpumask_var_t *mask, gfp_t flags) { return zalloc_cpumask_var(mask, flags); } #ifndef CONFIG_CPUMASK_OFFSTACK +__rust_helper void rust_helper_free_cpumask_var(cpumask_var_t mask) { free_cpumask_var(mask); diff --git a/rust/helpers/cred.c b/rust/helpers/cred.c index fde7ae20cdd1..a56a7b753623 100644 --- a/rust/helpers/cred.c +++ b/rust/helpers/cred.c @@ -2,12 +2,12 @@ #include <linux/cred.h> -const struct cred *rust_helper_get_cred(const struct cred *cred) +__rust_helper const struct cred *rust_helper_get_cred(const struct cred *cred) { return get_cred(cred); } -void rust_helper_put_cred(const struct cred *cred) +__rust_helper void rust_helper_put_cred(const struct cred *cred) { put_cred(cred); } diff --git a/rust/helpers/device.c b/rust/helpers/device.c index 9a4316bafedf..a8ab931a9bd1 100644 --- a/rust/helpers/device.c +++ b/rust/helpers/device.c @@ -2,26 +2,26 @@ #include <linux/device.h> -int rust_helper_devm_add_action(struct device *dev, - void (*action)(void *), - void *data) +__rust_helper int rust_helper_devm_add_action(struct device *dev, + void (*action)(void *), + void *data) { return devm_add_action(dev, action, data); } -int rust_helper_devm_add_action_or_reset(struct device *dev, - void (*action)(void *), - void *data) +__rust_helper int rust_helper_devm_add_action_or_reset(struct device *dev, + void (*action)(void *), + void *data) { return devm_add_action_or_reset(dev, action, data); } -void *rust_helper_dev_get_drvdata(const struct device *dev) +__rust_helper void *rust_helper_dev_get_drvdata(const struct device *dev) { return dev_get_drvdata(dev); } -void rust_helper_dev_set_drvdata(struct device *dev, void *data) +__rust_helper void rust_helper_dev_set_drvdata(struct device *dev, void *data) { dev_set_drvdata(dev, data); } diff --git a/rust/helpers/dma.c b/rust/helpers/dma.c index 6e741c197242..9fbeb507b08c 100644 --- a/rust/helpers/dma.c +++ b/rust/helpers/dma.c @@ -2,20 +2,50 @@ #include <linux/dma-mapping.h> -void *rust_helper_dma_alloc_attrs(struct device *dev, size_t size, - dma_addr_t *dma_handle, gfp_t flag, - unsigned long attrs) +__rust_helper void *rust_helper_dma_alloc_attrs(struct device *dev, size_t size, + dma_addr_t *dma_handle, + gfp_t flag, unsigned long attrs) { return dma_alloc_attrs(dev, size, dma_handle, flag, attrs); } -void rust_helper_dma_free_attrs(struct device *dev, size_t size, void *cpu_addr, - dma_addr_t dma_handle, unsigned long attrs) +__rust_helper void rust_helper_dma_free_attrs(struct device *dev, size_t size, + void *cpu_addr, + dma_addr_t dma_handle, + unsigned long attrs) { dma_free_attrs(dev, size, cpu_addr, dma_handle, attrs); } -int rust_helper_dma_set_mask_and_coherent(struct device *dev, u64 mask) +__rust_helper int rust_helper_dma_set_mask_and_coherent(struct device *dev, + u64 mask) { return dma_set_mask_and_coherent(dev, mask); } + +__rust_helper int rust_helper_dma_set_mask(struct device *dev, u64 mask) +{ + return dma_set_mask(dev, mask); +} + +__rust_helper int rust_helper_dma_set_coherent_mask(struct device *dev, u64 mask) +{ + return dma_set_coherent_mask(dev, mask); +} + +__rust_helper int rust_helper_dma_map_sgtable(struct device *dev, struct sg_table *sgt, + enum dma_data_direction dir, unsigned long attrs) +{ + return dma_map_sgtable(dev, sgt, dir, attrs); +} + +__rust_helper size_t rust_helper_dma_max_mapping_size(struct device *dev) +{ + return dma_max_mapping_size(dev); +} + +__rust_helper void rust_helper_dma_set_max_seg_size(struct device *dev, + unsigned int size) +{ + dma_set_max_seg_size(dev, size); +} diff --git a/rust/helpers/drm.c b/rust/helpers/drm.c index 450b406c6f27..fe226f7b53ef 100644 --- a/rust/helpers/drm.c +++ b/rust/helpers/drm.c @@ -5,17 +5,18 @@ #ifdef CONFIG_DRM -void rust_helper_drm_gem_object_get(struct drm_gem_object *obj) +__rust_helper void rust_helper_drm_gem_object_get(struct drm_gem_object *obj) { drm_gem_object_get(obj); } -void rust_helper_drm_gem_object_put(struct drm_gem_object *obj) +__rust_helper void rust_helper_drm_gem_object_put(struct drm_gem_object *obj) { drm_gem_object_put(obj); } -__u64 rust_helper_drm_vma_node_offset_addr(struct drm_vma_offset_node *node) +__rust_helper __u64 +rust_helper_drm_vma_node_offset_addr(struct drm_vma_offset_node *node) { return drm_vma_node_offset_addr(node); } diff --git a/rust/helpers/err.c b/rust/helpers/err.c index 544c7cb86632..2872158e3793 100644 --- a/rust/helpers/err.c +++ b/rust/helpers/err.c @@ -2,17 +2,17 @@ #include <linux/err.h> -__force void *rust_helper_ERR_PTR(long err) +__rust_helper __force void *rust_helper_ERR_PTR(long err) { return ERR_PTR(err); } -bool rust_helper_IS_ERR(__force const void *ptr) +__rust_helper bool rust_helper_IS_ERR(__force const void *ptr) { return IS_ERR(ptr); } -long rust_helper_PTR_ERR(__force const void *ptr) +__rust_helper long rust_helper_PTR_ERR(__force const void *ptr) { return PTR_ERR(ptr); } diff --git a/rust/helpers/fs.c b/rust/helpers/fs.c index a75c96763372..789d60fb8908 100644 --- a/rust/helpers/fs.c +++ b/rust/helpers/fs.c @@ -6,7 +6,7 @@ #include <linux/fs.h> -struct file *rust_helper_get_file(struct file *f) +__rust_helper struct file *rust_helper_get_file(struct file *f) { return get_file(f); } diff --git a/rust/helpers/helpers.c b/rust/helpers/helpers.c index 79c72762ad9c..a3c42e51f00a 100644 --- a/rust/helpers/helpers.c +++ b/rust/helpers/helpers.c @@ -7,7 +7,10 @@ * Sorted alphabetically. */ +#define __rust_helper + #include "atomic.c" +#include "atomic_ext.c" #include "auxiliary.c" #include "barrier.c" #include "binder.c" diff --git a/rust/helpers/io.c b/rust/helpers/io.c index c475913c69e6..397810864a24 100644 --- a/rust/helpers/io.c +++ b/rust/helpers/io.c @@ -3,140 +3,144 @@ #include <linux/io.h> #include <linux/ioport.h> -void __iomem *rust_helper_ioremap(phys_addr_t offset, size_t size) +__rust_helper void __iomem *rust_helper_ioremap(phys_addr_t offset, size_t size) { return ioremap(offset, size); } -void __iomem *rust_helper_ioremap_np(phys_addr_t offset, size_t size) +__rust_helper void __iomem *rust_helper_ioremap_np(phys_addr_t offset, + size_t size) { return ioremap_np(offset, size); } -void rust_helper_iounmap(void __iomem *addr) +__rust_helper void rust_helper_iounmap(void __iomem *addr) { iounmap(addr); } -u8 rust_helper_readb(const void __iomem *addr) +__rust_helper u8 rust_helper_readb(const void __iomem *addr) { return readb(addr); } -u16 rust_helper_readw(const void __iomem *addr) +__rust_helper u16 rust_helper_readw(const void __iomem *addr) { return readw(addr); } -u32 rust_helper_readl(const void __iomem *addr) +__rust_helper u32 rust_helper_readl(const void __iomem *addr) { return readl(addr); } #ifdef CONFIG_64BIT -u64 rust_helper_readq(const void __iomem *addr) +__rust_helper u64 rust_helper_readq(const void __iomem *addr) { return readq(addr); } #endif -void rust_helper_writeb(u8 value, void __iomem *addr) +__rust_helper void rust_helper_writeb(u8 value, void __iomem *addr) { writeb(value, addr); } -void rust_helper_writew(u16 value, void __iomem *addr) +__rust_helper void rust_helper_writew(u16 value, void __iomem *addr) { writew(value, addr); } -void rust_helper_writel(u32 value, void __iomem *addr) +__rust_helper void rust_helper_writel(u32 value, void __iomem *addr) { writel(value, addr); } #ifdef CONFIG_64BIT -void rust_helper_writeq(u64 value, void __iomem *addr) +__rust_helper void rust_helper_writeq(u64 value, void __iomem *addr) { writeq(value, addr); } #endif -u8 rust_helper_readb_relaxed(const void __iomem *addr) +__rust_helper u8 rust_helper_readb_relaxed(const void __iomem *addr) { return readb_relaxed(addr); } -u16 rust_helper_readw_relaxed(const void __iomem *addr) +__rust_helper u16 rust_helper_readw_relaxed(const void __iomem *addr) { return readw_relaxed(addr); } -u32 rust_helper_readl_relaxed(const void __iomem *addr) +__rust_helper u32 rust_helper_readl_relaxed(const void __iomem *addr) { return readl_relaxed(addr); } #ifdef CONFIG_64BIT -u64 rust_helper_readq_relaxed(const void __iomem *addr) +__rust_helper u64 rust_helper_readq_relaxed(const void __iomem *addr) { return readq_relaxed(addr); } #endif -void rust_helper_writeb_relaxed(u8 value, void __iomem *addr) +__rust_helper void rust_helper_writeb_relaxed(u8 value, void __iomem *addr) { writeb_relaxed(value, addr); } -void rust_helper_writew_relaxed(u16 value, void __iomem *addr) +__rust_helper void rust_helper_writew_relaxed(u16 value, void __iomem *addr) { writew_relaxed(value, addr); } -void rust_helper_writel_relaxed(u32 value, void __iomem *addr) +__rust_helper void rust_helper_writel_relaxed(u32 value, void __iomem *addr) { writel_relaxed(value, addr); } #ifdef CONFIG_64BIT -void rust_helper_writeq_relaxed(u64 value, void __iomem *addr) +__rust_helper void rust_helper_writeq_relaxed(u64 value, void __iomem *addr) { writeq_relaxed(value, addr); } #endif -resource_size_t rust_helper_resource_size(struct resource *res) +__rust_helper resource_size_t rust_helper_resource_size(struct resource *res) { return resource_size(res); } -struct resource *rust_helper_request_mem_region(resource_size_t start, - resource_size_t n, - const char *name) +__rust_helper struct resource * +rust_helper_request_mem_region(resource_size_t start, resource_size_t n, + const char *name) { return request_mem_region(start, n, name); } -void rust_helper_release_mem_region(resource_size_t start, resource_size_t n) +__rust_helper void rust_helper_release_mem_region(resource_size_t start, + resource_size_t n) { release_mem_region(start, n); } -struct resource *rust_helper_request_region(resource_size_t start, - resource_size_t n, const char *name) +__rust_helper struct resource *rust_helper_request_region(resource_size_t start, + resource_size_t n, + const char *name) { return request_region(start, n, name); } -struct resource *rust_helper_request_muxed_region(resource_size_t start, - resource_size_t n, - const char *name) +__rust_helper struct resource * +rust_helper_request_muxed_region(resource_size_t start, resource_size_t n, + const char *name) { return request_muxed_region(start, n, name); } -void rust_helper_release_region(resource_size_t start, resource_size_t n) +__rust_helper void rust_helper_release_region(resource_size_t start, + resource_size_t n) { release_region(start, n); } diff --git a/rust/helpers/irq.c b/rust/helpers/irq.c index 1faca428e2c0..a61fad333866 100644 --- a/rust/helpers/irq.c +++ b/rust/helpers/irq.c @@ -2,8 +2,10 @@ #include <linux/interrupt.h> -int rust_helper_request_irq(unsigned int irq, irq_handler_t handler, - unsigned long flags, const char *name, void *dev) +__rust_helper int rust_helper_request_irq(unsigned int irq, + irq_handler_t handler, + unsigned long flags, const char *name, + void *dev) { return request_irq(irq, handler, flags, name, dev); } diff --git a/rust/helpers/kunit.c b/rust/helpers/kunit.c index b85a4d394c11..cafb94b6776c 100644 --- a/rust/helpers/kunit.c +++ b/rust/helpers/kunit.c @@ -2,7 +2,7 @@ #include <kunit/test-bug.h> -struct kunit *rust_helper_kunit_get_current_test(void) +__rust_helper struct kunit *rust_helper_kunit_get_current_test(void) { return kunit_get_current_test(); } diff --git a/rust/helpers/maple_tree.c b/rust/helpers/maple_tree.c index 1dd9ac84a13f..5586486a76e0 100644 --- a/rust/helpers/maple_tree.c +++ b/rust/helpers/maple_tree.c @@ -2,7 +2,8 @@ #include <linux/maple_tree.h> -void rust_helper_mt_init_flags(struct maple_tree *mt, unsigned int flags) +__rust_helper void rust_helper_mt_init_flags(struct maple_tree *mt, + unsigned int flags) { mt_init_flags(mt, flags); } diff --git a/rust/helpers/mm.c b/rust/helpers/mm.c index 81b510c96fd2..b5540997bd20 100644 --- a/rust/helpers/mm.c +++ b/rust/helpers/mm.c @@ -3,48 +3,48 @@ #include <linux/mm.h> #include <linux/sched/mm.h> -void rust_helper_mmgrab(struct mm_struct *mm) +__rust_helper void rust_helper_mmgrab(struct mm_struct *mm) { mmgrab(mm); } -void rust_helper_mmdrop(struct mm_struct *mm) +__rust_helper void rust_helper_mmdrop(struct mm_struct *mm) { mmdrop(mm); } -void rust_helper_mmget(struct mm_struct *mm) +__rust_helper void rust_helper_mmget(struct mm_struct *mm) { mmget(mm); } -bool rust_helper_mmget_not_zero(struct mm_struct *mm) +__rust_helper bool rust_helper_mmget_not_zero(struct mm_struct *mm) { return mmget_not_zero(mm); } -void rust_helper_mmap_read_lock(struct mm_struct *mm) +__rust_helper void rust_helper_mmap_read_lock(struct mm_struct *mm) { mmap_read_lock(mm); } -bool rust_helper_mmap_read_trylock(struct mm_struct *mm) +__rust_helper bool rust_helper_mmap_read_trylock(struct mm_struct *mm) { return mmap_read_trylock(mm); } -void rust_helper_mmap_read_unlock(struct mm_struct *mm) +__rust_helper void rust_helper_mmap_read_unlock(struct mm_struct *mm) { mmap_read_unlock(mm); } -struct vm_area_struct *rust_helper_vma_lookup(struct mm_struct *mm, - unsigned long addr) +__rust_helper struct vm_area_struct * +rust_helper_vma_lookup(struct mm_struct *mm, unsigned long addr) { return vma_lookup(mm, addr); } -void rust_helper_vma_end_read(struct vm_area_struct *vma) +__rust_helper void rust_helper_vma_end_read(struct vm_area_struct *vma) { vma_end_read(vma); } diff --git a/rust/helpers/mutex.c b/rust/helpers/mutex.c index e487819125f0..1b07d6e64299 100644 --- a/rust/helpers/mutex.c +++ b/rust/helpers/mutex.c @@ -2,28 +2,29 @@ #include <linux/mutex.h> -void rust_helper_mutex_lock(struct mutex *lock) +__rust_helper void rust_helper_mutex_lock(struct mutex *lock) { mutex_lock(lock); } -int rust_helper_mutex_trylock(struct mutex *lock) +__rust_helper int rust_helper_mutex_trylock(struct mutex *lock) { return mutex_trylock(lock); } -void rust_helper___mutex_init(struct mutex *mutex, const char *name, - struct lock_class_key *key) +__rust_helper void rust_helper___mutex_init(struct mutex *mutex, + const char *name, + struct lock_class_key *key) { __mutex_init(mutex, name, key); } -void rust_helper_mutex_assert_is_held(struct mutex *mutex) +__rust_helper void rust_helper_mutex_assert_is_held(struct mutex *mutex) { lockdep_assert_held(mutex); } -void rust_helper_mutex_destroy(struct mutex *lock) +__rust_helper void rust_helper_mutex_destroy(struct mutex *lock) { mutex_destroy(lock); } diff --git a/rust/helpers/of.c b/rust/helpers/of.c index 86b51167c913..8f62ca69e8ba 100644 --- a/rust/helpers/of.c +++ b/rust/helpers/of.c @@ -2,7 +2,7 @@ #include <linux/of.h> -bool rust_helper_is_of_node(const struct fwnode_handle *fwnode) +__rust_helper bool rust_helper_is_of_node(const struct fwnode_handle *fwnode) { return is_of_node(fwnode); } diff --git a/rust/helpers/page.c b/rust/helpers/page.c index 7144de5a61db..f8463fbed2a2 100644 --- a/rust/helpers/page.c +++ b/rust/helpers/page.c @@ -4,23 +4,24 @@ #include <linux/highmem.h> #include <linux/mm.h> -struct page *rust_helper_alloc_pages(gfp_t gfp_mask, unsigned int order) +__rust_helper struct page *rust_helper_alloc_pages(gfp_t gfp_mask, + unsigned int order) { return alloc_pages(gfp_mask, order); } -void *rust_helper_kmap_local_page(struct page *page) +__rust_helper void *rust_helper_kmap_local_page(struct page *page) { return kmap_local_page(page); } -void rust_helper_kunmap_local(const void *addr) +__rust_helper void rust_helper_kunmap_local(const void *addr) { kunmap_local(addr); } #ifndef NODE_NOT_IN_PAGE_FLAGS -int rust_helper_page_to_nid(const struct page *page) +__rust_helper int rust_helper_page_to_nid(const struct page *page) { return page_to_nid(page); } diff --git a/rust/helpers/pci.c b/rust/helpers/pci.c index bf8173979c5e..e44905317d75 100644 --- a/rust/helpers/pci.c +++ b/rust/helpers/pci.c @@ -2,41 +2,44 @@ #include <linux/pci.h> -u16 rust_helper_pci_dev_id(struct pci_dev *dev) +__rust_helper u16 rust_helper_pci_dev_id(struct pci_dev *dev) { return PCI_DEVID(dev->bus->number, dev->devfn); } -resource_size_t rust_helper_pci_resource_start(struct pci_dev *pdev, int bar) +__rust_helper resource_size_t +rust_helper_pci_resource_start(struct pci_dev *pdev, int bar) { return pci_resource_start(pdev, bar); } -resource_size_t rust_helper_pci_resource_len(struct pci_dev *pdev, int bar) +__rust_helper resource_size_t rust_helper_pci_resource_len(struct pci_dev *pdev, + int bar) { return pci_resource_len(pdev, bar); } -bool rust_helper_dev_is_pci(const struct device *dev) +__rust_helper bool rust_helper_dev_is_pci(const struct device *dev) { return dev_is_pci(dev); } #ifndef CONFIG_PCI_MSI -int rust_helper_pci_alloc_irq_vectors(struct pci_dev *dev, - unsigned int min_vecs, - unsigned int max_vecs, - unsigned int flags) +__rust_helper int rust_helper_pci_alloc_irq_vectors(struct pci_dev *dev, + unsigned int min_vecs, + unsigned int max_vecs, + unsigned int flags) { return pci_alloc_irq_vectors(dev, min_vecs, max_vecs, flags); } -void rust_helper_pci_free_irq_vectors(struct pci_dev *dev) +__rust_helper void rust_helper_pci_free_irq_vectors(struct pci_dev *dev) { pci_free_irq_vectors(dev); } -int rust_helper_pci_irq_vector(struct pci_dev *pdev, unsigned int nvec) +__rust_helper int rust_helper_pci_irq_vector(struct pci_dev *pdev, + unsigned int nvec) { return pci_irq_vector(pdev, nvec); } diff --git a/rust/helpers/pid_namespace.c b/rust/helpers/pid_namespace.c index f41482bdec9a..f46ab779b527 100644 --- a/rust/helpers/pid_namespace.c +++ b/rust/helpers/pid_namespace.c @@ -3,18 +3,20 @@ #include <linux/pid_namespace.h> #include <linux/cleanup.h> -struct pid_namespace *rust_helper_get_pid_ns(struct pid_namespace *ns) +__rust_helper struct pid_namespace * +rust_helper_get_pid_ns(struct pid_namespace *ns) { return get_pid_ns(ns); } -void rust_helper_put_pid_ns(struct pid_namespace *ns) +__rust_helper void rust_helper_put_pid_ns(struct pid_namespace *ns) { put_pid_ns(ns); } /* Get a reference on a task's pid namespace. */ -struct pid_namespace *rust_helper_task_get_pid_ns(struct task_struct *task) +__rust_helper struct pid_namespace * +rust_helper_task_get_pid_ns(struct task_struct *task) { struct pid_namespace *pid_ns; diff --git a/rust/helpers/platform.c b/rust/helpers/platform.c index 1ce89c1a36f7..188b3240f32d 100644 --- a/rust/helpers/platform.c +++ b/rust/helpers/platform.c @@ -2,7 +2,7 @@ #include <linux/platform_device.h> -bool rust_helper_dev_is_platform(const struct device *dev) +__rust_helper bool rust_helper_dev_is_platform(const struct device *dev) { return dev_is_platform(dev); } diff --git a/rust/helpers/poll.c b/rust/helpers/poll.c index 7e5b1751c2d5..78b3839b50f0 100644 --- a/rust/helpers/poll.c +++ b/rust/helpers/poll.c @@ -3,8 +3,9 @@ #include <linux/export.h> #include <linux/poll.h> -void rust_helper_poll_wait(struct file *filp, wait_queue_head_t *wait_address, - poll_table *p) +__rust_helper void rust_helper_poll_wait(struct file *filp, + wait_queue_head_t *wait_address, + poll_table *p) { poll_wait(filp, wait_address, p); } diff --git a/rust/helpers/processor.c b/rust/helpers/processor.c index d41355e14d6e..76fadbb647c5 100644 --- a/rust/helpers/processor.c +++ b/rust/helpers/processor.c @@ -2,7 +2,7 @@ #include <linux/processor.h> -void rust_helper_cpu_relax(void) +__rust_helper void rust_helper_cpu_relax(void) { cpu_relax(); } diff --git a/rust/helpers/property.c b/rust/helpers/property.c index 08f68e2dac4a..8fb9900533ef 100644 --- a/rust/helpers/property.c +++ b/rust/helpers/property.c @@ -2,7 +2,7 @@ #include <linux/property.h> -void rust_helper_fwnode_handle_put(struct fwnode_handle *fwnode) +__rust_helper void rust_helper_fwnode_handle_put(struct fwnode_handle *fwnode) { fwnode_handle_put(fwnode); } diff --git a/rust/helpers/pwm.c b/rust/helpers/pwm.c index d75c58886368..eb24d2ea8e74 100644 --- a/rust/helpers/pwm.c +++ b/rust/helpers/pwm.c @@ -4,17 +4,17 @@ #include <linux/pwm.h> -struct device *rust_helper_pwmchip_parent(const struct pwm_chip *chip) +__rust_helper struct device *rust_helper_pwmchip_parent(const struct pwm_chip *chip) { return pwmchip_parent(chip); } -void *rust_helper_pwmchip_get_drvdata(struct pwm_chip *chip) +__rust_helper void *rust_helper_pwmchip_get_drvdata(struct pwm_chip *chip) { return pwmchip_get_drvdata(chip); } -void rust_helper_pwmchip_set_drvdata(struct pwm_chip *chip, void *data) +__rust_helper void rust_helper_pwmchip_set_drvdata(struct pwm_chip *chip, void *data) { pwmchip_set_drvdata(chip, data); } diff --git a/rust/helpers/rbtree.c b/rust/helpers/rbtree.c index 2a0eabbb4160..a85defb22ff7 100644 --- a/rust/helpers/rbtree.c +++ b/rust/helpers/rbtree.c @@ -2,18 +2,19 @@ #include <linux/rbtree.h> -void rust_helper_rb_link_node(struct rb_node *node, struct rb_node *parent, - struct rb_node **rb_link) +__rust_helper void rust_helper_rb_link_node(struct rb_node *node, + struct rb_node *parent, + struct rb_node **rb_link) { rb_link_node(node, parent, rb_link); } -struct rb_node *rust_helper_rb_first(const struct rb_root *root) +__rust_helper struct rb_node *rust_helper_rb_first(const struct rb_root *root) { return rb_first(root); } -struct rb_node *rust_helper_rb_last(const struct rb_root *root) +__rust_helper struct rb_node *rust_helper_rb_last(const struct rb_root *root) { return rb_last(root); } diff --git a/rust/helpers/rcu.c b/rust/helpers/rcu.c index f1cec6583513..481274c05857 100644 --- a/rust/helpers/rcu.c +++ b/rust/helpers/rcu.c @@ -2,12 +2,12 @@ #include <linux/rcupdate.h> -void rust_helper_rcu_read_lock(void) +__rust_helper void rust_helper_rcu_read_lock(void) { rcu_read_lock(); } -void rust_helper_rcu_read_unlock(void) +__rust_helper void rust_helper_rcu_read_unlock(void) { rcu_read_unlock(); } diff --git a/rust/helpers/refcount.c b/rust/helpers/refcount.c index d175898ad7b8..36334a674ee4 100644 --- a/rust/helpers/refcount.c +++ b/rust/helpers/refcount.c @@ -2,27 +2,27 @@ #include <linux/refcount.h> -refcount_t rust_helper_REFCOUNT_INIT(int n) +__rust_helper refcount_t rust_helper_REFCOUNT_INIT(int n) { return (refcount_t)REFCOUNT_INIT(n); } -void rust_helper_refcount_set(refcount_t *r, int n) +__rust_helper void rust_helper_refcount_set(refcount_t *r, int n) { refcount_set(r, n); } -void rust_helper_refcount_inc(refcount_t *r) +__rust_helper void rust_helper_refcount_inc(refcount_t *r) { refcount_inc(r); } -void rust_helper_refcount_dec(refcount_t *r) +__rust_helper void rust_helper_refcount_dec(refcount_t *r) { refcount_dec(r); } -bool rust_helper_refcount_dec_and_test(refcount_t *r) +__rust_helper bool rust_helper_refcount_dec_and_test(refcount_t *r) { return refcount_dec_and_test(r); } diff --git a/rust/helpers/regulator.c b/rust/helpers/regulator.c index 11bc332443bd..9ec5237f449b 100644 --- a/rust/helpers/regulator.c +++ b/rust/helpers/regulator.c @@ -4,48 +4,52 @@ #ifndef CONFIG_REGULATOR -void rust_helper_regulator_put(struct regulator *regulator) +__rust_helper void rust_helper_regulator_put(struct regulator *regulator) { regulator_put(regulator); } -int rust_helper_regulator_set_voltage(struct regulator *regulator, int min_uV, - int max_uV) +__rust_helper int rust_helper_regulator_set_voltage(struct regulator *regulator, + int min_uV, int max_uV) { return regulator_set_voltage(regulator, min_uV, max_uV); } -int rust_helper_regulator_get_voltage(struct regulator *regulator) +__rust_helper int rust_helper_regulator_get_voltage(struct regulator *regulator) { return regulator_get_voltage(regulator); } -struct regulator *rust_helper_regulator_get(struct device *dev, const char *id) +__rust_helper struct regulator *rust_helper_regulator_get(struct device *dev, + const char *id) { return regulator_get(dev, id); } -int rust_helper_regulator_enable(struct regulator *regulator) +__rust_helper int rust_helper_regulator_enable(struct regulator *regulator) { return regulator_enable(regulator); } -int rust_helper_regulator_disable(struct regulator *regulator) +__rust_helper int rust_helper_regulator_disable(struct regulator *regulator) { return regulator_disable(regulator); } -int rust_helper_regulator_is_enabled(struct regulator *regulator) +__rust_helper int rust_helper_regulator_is_enabled(struct regulator *regulator) { return regulator_is_enabled(regulator); } -int rust_helper_devm_regulator_get_enable(struct device *dev, const char *id) +__rust_helper int rust_helper_devm_regulator_get_enable(struct device *dev, + const char *id) { return devm_regulator_get_enable(dev, id); } -int rust_helper_devm_regulator_get_enable_optional(struct device *dev, const char *id) +__rust_helper int +rust_helper_devm_regulator_get_enable_optional(struct device *dev, + const char *id) { return devm_regulator_get_enable_optional(dev, id); } diff --git a/rust/helpers/scatterlist.c b/rust/helpers/scatterlist.c index 80c956ee09ab..f3c41ea5e201 100644 --- a/rust/helpers/scatterlist.c +++ b/rust/helpers/scatterlist.c @@ -2,23 +2,25 @@ #include <linux/dma-direction.h> -dma_addr_t rust_helper_sg_dma_address(struct scatterlist *sg) +__rust_helper dma_addr_t rust_helper_sg_dma_address(struct scatterlist *sg) { return sg_dma_address(sg); } -unsigned int rust_helper_sg_dma_len(struct scatterlist *sg) +__rust_helper unsigned int rust_helper_sg_dma_len(struct scatterlist *sg) { return sg_dma_len(sg); } -struct scatterlist *rust_helper_sg_next(struct scatterlist *sg) +__rust_helper struct scatterlist *rust_helper_sg_next(struct scatterlist *sg) { return sg_next(sg); } -void rust_helper_dma_unmap_sgtable(struct device *dev, struct sg_table *sgt, - enum dma_data_direction dir, unsigned long attrs) +__rust_helper void rust_helper_dma_unmap_sgtable(struct device *dev, + struct sg_table *sgt, + enum dma_data_direction dir, + unsigned long attrs) { return dma_unmap_sgtable(dev, sgt, dir, attrs); } diff --git a/rust/helpers/security.c b/rust/helpers/security.c index ca22da09548d..8d0a25fcf931 100644 --- a/rust/helpers/security.c +++ b/rust/helpers/security.c @@ -3,41 +3,45 @@ #include <linux/security.h> #ifndef CONFIG_SECURITY -void rust_helper_security_cred_getsecid(const struct cred *c, u32 *secid) +__rust_helper void rust_helper_security_cred_getsecid(const struct cred *c, + u32 *secid) { security_cred_getsecid(c, secid); } -int rust_helper_security_secid_to_secctx(u32 secid, struct lsm_context *cp) +__rust_helper int rust_helper_security_secid_to_secctx(u32 secid, + struct lsm_context *cp) { return security_secid_to_secctx(secid, cp); } -void rust_helper_security_release_secctx(struct lsm_context *cp) +__rust_helper void rust_helper_security_release_secctx(struct lsm_context *cp) { security_release_secctx(cp); } -int rust_helper_security_binder_set_context_mgr(const struct cred *mgr) +__rust_helper int +rust_helper_security_binder_set_context_mgr(const struct cred *mgr) { return security_binder_set_context_mgr(mgr); } -int rust_helper_security_binder_transaction(const struct cred *from, - const struct cred *to) +__rust_helper int +rust_helper_security_binder_transaction(const struct cred *from, + const struct cred *to) { return security_binder_transaction(from, to); } -int rust_helper_security_binder_transfer_binder(const struct cred *from, - const struct cred *to) +__rust_helper int +rust_helper_security_binder_transfer_binder(const struct cred *from, + const struct cred *to) { return security_binder_transfer_binder(from, to); } -int rust_helper_security_binder_transfer_file(const struct cred *from, - const struct cred *to, - const struct file *file) +__rust_helper int rust_helper_security_binder_transfer_file( + const struct cred *from, const struct cred *to, const struct file *file) { return security_binder_transfer_file(from, to, file); } diff --git a/rust/helpers/signal.c b/rust/helpers/signal.c index 1a6bbe9438e2..85111186cf3d 100644 --- a/rust/helpers/signal.c +++ b/rust/helpers/signal.c @@ -2,7 +2,7 @@ #include <linux/sched/signal.h> -int rust_helper_signal_pending(struct task_struct *t) +__rust_helper int rust_helper_signal_pending(struct task_struct *t) { return signal_pending(t); } diff --git a/rust/helpers/slab.c b/rust/helpers/slab.c index 7fac958907b0..9279f082467d 100644 --- a/rust/helpers/slab.c +++ b/rust/helpers/slab.c @@ -2,14 +2,14 @@ #include <linux/slab.h> -void * __must_check __realloc_size(2) +__rust_helper void *__must_check __realloc_size(2) rust_helper_krealloc_node_align(const void *objp, size_t new_size, unsigned long align, gfp_t flags, int node) { return krealloc_node_align(objp, new_size, align, flags, node); } -void * __must_check __realloc_size(2) +__rust_helper void *__must_check __realloc_size(2) rust_helper_kvrealloc_node_align(const void *p, size_t size, unsigned long align, gfp_t flags, int node) { diff --git a/rust/helpers/spinlock.c b/rust/helpers/spinlock.c index 42c4bf01a23e..4d13062cf253 100644 --- a/rust/helpers/spinlock.c +++ b/rust/helpers/spinlock.c @@ -2,8 +2,9 @@ #include <linux/spinlock.h> -void rust_helper___spin_lock_init(spinlock_t *lock, const char *name, - struct lock_class_key *key) +__rust_helper void rust_helper___spin_lock_init(spinlock_t *lock, + const char *name, + struct lock_class_key *key) { #ifdef CONFIG_DEBUG_SPINLOCK # if defined(CONFIG_PREEMPT_RT) @@ -16,22 +17,22 @@ void rust_helper___spin_lock_init(spinlock_t *lock, const char *name, #endif /* CONFIG_DEBUG_SPINLOCK */ } -void rust_helper_spin_lock(spinlock_t *lock) +__rust_helper void rust_helper_spin_lock(spinlock_t *lock) { spin_lock(lock); } -void rust_helper_spin_unlock(spinlock_t *lock) +__rust_helper void rust_helper_spin_unlock(spinlock_t *lock) { spin_unlock(lock); } -int rust_helper_spin_trylock(spinlock_t *lock) +__rust_helper int rust_helper_spin_trylock(spinlock_t *lock) { return spin_trylock(lock); } -void rust_helper_spin_assert_is_held(spinlock_t *lock) +__rust_helper void rust_helper_spin_assert_is_held(spinlock_t *lock) { lockdep_assert_held(lock); } diff --git a/rust/helpers/sync.c b/rust/helpers/sync.c index ff7e68b48810..82d6aff73b04 100644 --- a/rust/helpers/sync.c +++ b/rust/helpers/sync.c @@ -2,12 +2,12 @@ #include <linux/lockdep.h> -void rust_helper_lockdep_register_key(struct lock_class_key *k) +__rust_helper void rust_helper_lockdep_register_key(struct lock_class_key *k) { lockdep_register_key(k); } -void rust_helper_lockdep_unregister_key(struct lock_class_key *k) +__rust_helper void rust_helper_lockdep_unregister_key(struct lock_class_key *k) { lockdep_unregister_key(k); } diff --git a/rust/helpers/task.c b/rust/helpers/task.c index 2c85bbc2727e..c0e1a06ede78 100644 --- a/rust/helpers/task.c +++ b/rust/helpers/task.c @@ -3,60 +3,60 @@ #include <linux/kernel.h> #include <linux/sched/task.h> -void rust_helper_might_resched(void) +__rust_helper void rust_helper_might_resched(void) { might_resched(); } -struct task_struct *rust_helper_get_current(void) +__rust_helper struct task_struct *rust_helper_get_current(void) { return current; } -void rust_helper_get_task_struct(struct task_struct *t) +__rust_helper void rust_helper_get_task_struct(struct task_struct *t) { get_task_struct(t); } -void rust_helper_put_task_struct(struct task_struct *t) +__rust_helper void rust_helper_put_task_struct(struct task_struct *t) { put_task_struct(t); } -kuid_t rust_helper_task_uid(struct task_struct *task) +__rust_helper kuid_t rust_helper_task_uid(struct task_struct *task) { return task_uid(task); } -kuid_t rust_helper_task_euid(struct task_struct *task) +__rust_helper kuid_t rust_helper_task_euid(struct task_struct *task) { return task_euid(task); } #ifndef CONFIG_USER_NS -uid_t rust_helper_from_kuid(struct user_namespace *to, kuid_t uid) +__rust_helper uid_t rust_helper_from_kuid(struct user_namespace *to, kuid_t uid) { return from_kuid(to, uid); } #endif /* CONFIG_USER_NS */ -bool rust_helper_uid_eq(kuid_t left, kuid_t right) +__rust_helper bool rust_helper_uid_eq(kuid_t left, kuid_t right) { return uid_eq(left, right); } -kuid_t rust_helper_current_euid(void) +__rust_helper kuid_t rust_helper_current_euid(void) { return current_euid(); } -struct user_namespace *rust_helper_current_user_ns(void) +__rust_helper struct user_namespace *rust_helper_current_user_ns(void) { return current_user_ns(); } -pid_t rust_helper_task_tgid_nr_ns(struct task_struct *tsk, - struct pid_namespace *ns) +__rust_helper pid_t rust_helper_task_tgid_nr_ns(struct task_struct *tsk, + struct pid_namespace *ns) { return task_tgid_nr_ns(tsk, ns); } diff --git a/rust/helpers/time.c b/rust/helpers/time.c index 67a36ccc3ec4..32f495970493 100644 --- a/rust/helpers/time.c +++ b/rust/helpers/time.c @@ -4,37 +4,37 @@ #include <linux/ktime.h> #include <linux/timekeeping.h> -void rust_helper_fsleep(unsigned long usecs) +__rust_helper void rust_helper_fsleep(unsigned long usecs) { fsleep(usecs); } -ktime_t rust_helper_ktime_get_real(void) +__rust_helper ktime_t rust_helper_ktime_get_real(void) { return ktime_get_real(); } -ktime_t rust_helper_ktime_get_boottime(void) +__rust_helper ktime_t rust_helper_ktime_get_boottime(void) { return ktime_get_boottime(); } -ktime_t rust_helper_ktime_get_clocktai(void) +__rust_helper ktime_t rust_helper_ktime_get_clocktai(void) { return ktime_get_clocktai(); } -s64 rust_helper_ktime_to_us(const ktime_t kt) +__rust_helper s64 rust_helper_ktime_to_us(const ktime_t kt) { return ktime_to_us(kt); } -s64 rust_helper_ktime_to_ms(const ktime_t kt) +__rust_helper s64 rust_helper_ktime_to_ms(const ktime_t kt) { return ktime_to_ms(kt); } -void rust_helper_udelay(unsigned long usec) +__rust_helper void rust_helper_udelay(unsigned long usec) { udelay(usec); } diff --git a/rust/helpers/uaccess.c b/rust/helpers/uaccess.c index 4629b2d15529..d9625b9ee046 100644 --- a/rust/helpers/uaccess.c +++ b/rust/helpers/uaccess.c @@ -2,24 +2,26 @@ #include <linux/uaccess.h> -unsigned long rust_helper_copy_from_user(void *to, const void __user *from, - unsigned long n) +__rust_helper unsigned long +rust_helper_copy_from_user(void *to, const void __user *from, unsigned long n) { return copy_from_user(to, from, n); } -unsigned long rust_helper_copy_to_user(void __user *to, const void *from, - unsigned long n) +__rust_helper unsigned long +rust_helper_copy_to_user(void __user *to, const void *from, unsigned long n) { return copy_to_user(to, from, n); } #ifdef INLINE_COPY_FROM_USER +__rust_helper unsigned long rust_helper__copy_from_user(void *to, const void __user *from, unsigned long n) { return _inline_copy_from_user(to, from, n); } +__rust_helper unsigned long rust_helper__copy_to_user(void __user *to, const void *from, unsigned long n) { return _inline_copy_to_user(to, from, n); diff --git a/rust/helpers/usb.c b/rust/helpers/usb.c index fb2aad0cbf4d..eff1cf7be3c2 100644 --- a/rust/helpers/usb.c +++ b/rust/helpers/usb.c @@ -2,7 +2,8 @@ #include <linux/usb.h> -struct usb_device *rust_helper_interface_to_usbdev(struct usb_interface *intf) +__rust_helper struct usb_device * +rust_helper_interface_to_usbdev(struct usb_interface *intf) { return interface_to_usbdev(intf); } diff --git a/rust/helpers/vmalloc.c b/rust/helpers/vmalloc.c index 7d7f7336b3d2..326b030487a2 100644 --- a/rust/helpers/vmalloc.c +++ b/rust/helpers/vmalloc.c @@ -2,7 +2,7 @@ #include <linux/vmalloc.h> -void * __must_check __realloc_size(2) +__rust_helper void *__must_check __realloc_size(2) rust_helper_vrealloc_node_align(const void *p, size_t size, unsigned long align, gfp_t flags, int node) { diff --git a/rust/helpers/wait.c b/rust/helpers/wait.c index ae48e33d9da3..2dde1e451780 100644 --- a/rust/helpers/wait.c +++ b/rust/helpers/wait.c @@ -2,7 +2,7 @@ #include <linux/wait.h> -void rust_helper_init_wait(struct wait_queue_entry *wq_entry) +__rust_helper void rust_helper_init_wait(struct wait_queue_entry *wq_entry) { init_wait(wq_entry); } diff --git a/rust/helpers/workqueue.c b/rust/helpers/workqueue.c index b2b82753509b..ce1c3a5b2150 100644 --- a/rust/helpers/workqueue.c +++ b/rust/helpers/workqueue.c @@ -2,9 +2,11 @@ #include <linux/workqueue.h> -void rust_helper_init_work_with_key(struct work_struct *work, work_func_t func, - bool onstack, const char *name, - struct lock_class_key *key) +__rust_helper void rust_helper_init_work_with_key(struct work_struct *work, + work_func_t func, + bool onstack, + const char *name, + struct lock_class_key *key) { __init_work(work, onstack); work->data = (atomic_long_t)WORK_DATA_INIT(); diff --git a/rust/helpers/xarray.c b/rust/helpers/xarray.c index 60b299f11451..08979b304341 100644 --- a/rust/helpers/xarray.c +++ b/rust/helpers/xarray.c @@ -2,27 +2,27 @@ #include <linux/xarray.h> -int rust_helper_xa_err(void *entry) +__rust_helper int rust_helper_xa_err(void *entry) { return xa_err(entry); } -void rust_helper_xa_init_flags(struct xarray *xa, gfp_t flags) +__rust_helper void rust_helper_xa_init_flags(struct xarray *xa, gfp_t flags) { return xa_init_flags(xa, flags); } -int rust_helper_xa_trylock(struct xarray *xa) +__rust_helper int rust_helper_xa_trylock(struct xarray *xa) { return xa_trylock(xa); } -void rust_helper_xa_lock(struct xarray *xa) +__rust_helper void rust_helper_xa_lock(struct xarray *xa) { return xa_lock(xa); } -void rust_helper_xa_unlock(struct xarray *xa) +__rust_helper void rust_helper_xa_unlock(struct xarray *xa) { return xa_unlock(xa); } diff --git a/rust/kernel/auxiliary.rs b/rust/kernel/auxiliary.rs index 56f3c180e8f6..93c0db1f6655 100644 --- a/rust/kernel/auxiliary.rs +++ b/rust/kernel/auxiliary.rs @@ -5,31 +5,51 @@ //! C header: [`include/linux/auxiliary_bus.h`](srctree/include/linux/auxiliary_bus.h) use crate::{ - bindings, container_of, device, - device_id::{RawDeviceId, RawDeviceIdIndex}, + bindings, + container_of, + device, + device_id::{ + RawDeviceId, + RawDeviceIdIndex, // + }, devres::Devres, driver, - error::{from_result, to_result, Result}, + error::{ + from_result, + to_result, // + }, prelude::*, types::Opaque, - ThisModule, + ThisModule, // }; use core::{ marker::PhantomData, mem::offset_of, - ptr::{addr_of_mut, NonNull}, + ptr::{ + addr_of_mut, + NonNull, // + }, }; /// An adapter for the registration of auxiliary drivers. pub struct Adapter<T: Driver>(T); -// SAFETY: A call to `unregister` for a given instance of `RegType` is guaranteed to be valid if +// SAFETY: +// - `bindings::auxiliary_driver` is a C type declared as `repr(C)`. +// - `T` is the type of the driver's device private data. +// - `struct auxiliary_driver` embeds a `struct device_driver`. +// - `DEVICE_DRIVER_OFFSET` is the correct byte offset to the embedded `struct device_driver`. +unsafe impl<T: Driver + 'static> driver::DriverLayout for Adapter<T> { + type DriverType = bindings::auxiliary_driver; + type DriverData = T; + const DEVICE_DRIVER_OFFSET: usize = core::mem::offset_of!(Self::DriverType, driver); +} + +// SAFETY: A call to `unregister` for a given instance of `DriverType` is guaranteed to be valid if // a preceding call to `register` has been successful. unsafe impl<T: Driver + 'static> driver::RegistrationOps for Adapter<T> { - type RegType = bindings::auxiliary_driver; - unsafe fn register( - adrv: &Opaque<Self::RegType>, + adrv: &Opaque<Self::DriverType>, name: &'static CStr, module: &'static ThisModule, ) -> Result { @@ -41,14 +61,14 @@ unsafe impl<T: Driver + 'static> driver::RegistrationOps for Adapter<T> { (*adrv.get()).id_table = T::ID_TABLE.as_ptr(); } - // SAFETY: `adrv` is guaranteed to be a valid `RegType`. + // SAFETY: `adrv` is guaranteed to be a valid `DriverType`. to_result(unsafe { bindings::__auxiliary_driver_register(adrv.get(), module.0, name.as_char_ptr()) }) } - unsafe fn unregister(adrv: &Opaque<Self::RegType>) { - // SAFETY: `adrv` is guaranteed to be a valid `RegType`. + unsafe fn unregister(adrv: &Opaque<Self::DriverType>) { + // SAFETY: `adrv` is guaranteed to be a valid `DriverType`. unsafe { bindings::auxiliary_driver_unregister(adrv.get()) } } } @@ -81,13 +101,15 @@ impl<T: Driver + 'static> Adapter<T> { // SAFETY: The auxiliary bus only ever calls the probe callback with a valid pointer to a // `struct auxiliary_device`. // - // INVARIANT: `adev` is valid for the duration of `probe_callback()`. + // INVARIANT: `adev` is valid for the duration of `remove_callback()`. let adev = unsafe { &*adev.cast::<Device<device::CoreInternal>>() }; // SAFETY: `remove_callback` is only ever called after a successful call to // `probe_callback`, hence it's guaranteed that `Device::set_drvdata()` has been called // and stored a `Pin<KBox<T>>`. - drop(unsafe { adev.as_ref().drvdata_obtain::<T>() }); + let data = unsafe { adev.as_ref().drvdata_borrow::<T>() }; + + T::unbind(adev, data); } } @@ -110,12 +132,7 @@ impl DeviceId { let name = name.to_bytes_with_nul(); let modname = modname.to_bytes_with_nul(); - // TODO: Replace with `bindings::auxiliary_device_id::default()` once stabilized for - // `const`. - // - // SAFETY: FFI type is valid to be zero-initialized. - let mut id: bindings::auxiliary_device_id = unsafe { core::mem::zeroed() }; - + let mut id: bindings::auxiliary_device_id = pin_init::zeroed(); let mut i = 0; while i < modname.len() { id.name[i] = modname[i]; @@ -187,6 +204,20 @@ pub trait Driver { /// /// Called when an auxiliary device is matches a corresponding driver. fn probe(dev: &Device<device::Core>, id_info: &Self::IdInfo) -> impl PinInit<Self, Error>; + + /// Auxiliary driver unbind. + /// + /// Called when a [`Device`] is unbound from its bound [`Driver`]. Implementing this callback + /// is optional. + /// + /// This callback serves as a place for drivers to perform teardown operations that require a + /// `&Device<Core>` or `&Device<Bound>` reference. For instance, drivers may try to perform I/O + /// operations to gracefully tear down the device. + /// + /// Otherwise, release operations for driver resources should be performed in `Self::drop`. + fn unbind(dev: &Device<device::Core>, this: Pin<&Self>) { + let _ = (dev, this); + } } /// The auxiliary device representation. diff --git a/rust/kernel/bits.rs b/rust/kernel/bits.rs index 553d50265883..2daead125626 100644 --- a/rust/kernel/bits.rs +++ b/rust/kernel/bits.rs @@ -27,7 +27,8 @@ macro_rules! impl_bit_fn { /// /// This version is the default and should be used if `n` is known at /// compile time. - #[inline] + // Always inline to optimize out error path of `build_assert`. + #[inline(always)] pub const fn [<bit_ $ty>](n: u32) -> $ty { build_assert!(n < <$ty>::BITS); (1 as $ty) << n @@ -75,7 +76,8 @@ macro_rules! impl_genmask_fn { /// This version is the default and should be used if the range is known /// at compile time. $(#[$genmask_ex])* - #[inline] + // Always inline to optimize out error path of `build_assert`. + #[inline(always)] pub const fn [<genmask_ $ty>](range: RangeInclusive<u32>) -> $ty { let start = *range.start(); let end = *range.end(); diff --git a/rust/kernel/block/mq/gen_disk.rs b/rust/kernel/block/mq/gen_disk.rs index 1ce815c8cdab..c8b0ecb17082 100644 --- a/rust/kernel/block/mq/gen_disk.rs +++ b/rust/kernel/block/mq/gen_disk.rs @@ -107,8 +107,7 @@ impl GenDiskBuilder { drop(unsafe { T::QueueData::from_foreign(data) }); }); - // SAFETY: `bindings::queue_limits` contain only fields that are valid when zeroed. - let mut lim: bindings::queue_limits = unsafe { core::mem::zeroed() }; + let mut lim: bindings::queue_limits = pin_init::zeroed(); lim.logical_block_size = self.logical_block_size; lim.physical_block_size = self.physical_block_size; diff --git a/rust/kernel/block/mq/tag_set.rs b/rust/kernel/block/mq/tag_set.rs index c3cf56d52bee..dae9df408a86 100644 --- a/rust/kernel/block/mq/tag_set.rs +++ b/rust/kernel/block/mq/tag_set.rs @@ -38,9 +38,7 @@ impl<T: Operations> TagSet<T> { num_tags: u32, num_maps: u32, ) -> impl PinInit<Self, error::Error> { - // SAFETY: `blk_mq_tag_set` only contains integers and pointers, which - // all are allowed to be 0. - let tag_set: bindings::blk_mq_tag_set = unsafe { core::mem::zeroed() }; + let tag_set: bindings::blk_mq_tag_set = pin_init::zeroed(); let tag_set: Result<_> = core::mem::size_of::<RequestDataWrapper>() .try_into() .map(|cmd_size| { diff --git a/rust/kernel/bug.rs b/rust/kernel/bug.rs index 36aef43e5ebe..ed943960f851 100644 --- a/rust/kernel/bug.rs +++ b/rust/kernel/bug.rs @@ -11,9 +11,9 @@ #[cfg(all(CONFIG_BUG, not(CONFIG_UML), not(CONFIG_LOONGARCH), not(CONFIG_ARM)))] #[cfg(CONFIG_DEBUG_BUGVERBOSE)] macro_rules! warn_flags { - ($flags:expr) => { + ($file:expr, $flags:expr) => { const FLAGS: u32 = $crate::bindings::BUGFLAG_WARNING | $flags; - const _FILE: &[u8] = file!().as_bytes(); + const _FILE: &[u8] = $file.as_bytes(); // Plus one for null-terminator. static FILE: [u8; _FILE.len() + 1] = { let mut bytes = [0; _FILE.len() + 1]; @@ -50,7 +50,7 @@ macro_rules! warn_flags { #[cfg(all(CONFIG_BUG, not(CONFIG_UML), not(CONFIG_LOONGARCH), not(CONFIG_ARM)))] #[cfg(not(CONFIG_DEBUG_BUGVERBOSE))] macro_rules! warn_flags { - ($flags:expr) => { + ($file:expr, $flags:expr) => { const FLAGS: u32 = $crate::bindings::BUGFLAG_WARNING | $flags; // SAFETY: @@ -75,7 +75,7 @@ macro_rules! warn_flags { #[doc(hidden)] #[cfg(all(CONFIG_BUG, CONFIG_UML))] macro_rules! warn_flags { - ($flags:expr) => { + ($file:expr, $flags:expr) => { // SAFETY: It is always safe to call `warn_slowpath_fmt()` // with a valid null-terminated string. unsafe { @@ -93,7 +93,7 @@ macro_rules! warn_flags { #[doc(hidden)] #[cfg(all(CONFIG_BUG, any(CONFIG_LOONGARCH, CONFIG_ARM)))] macro_rules! warn_flags { - ($flags:expr) => { + ($file:expr, $flags:expr) => { // SAFETY: It is always safe to call `WARN_ON()`. unsafe { $crate::bindings::WARN_ON(true) } }; @@ -103,7 +103,7 @@ macro_rules! warn_flags { #[doc(hidden)] #[cfg(not(CONFIG_BUG))] macro_rules! warn_flags { - ($flags:expr) => {}; + ($file:expr, $flags:expr) => {}; } #[doc(hidden)] @@ -116,10 +116,16 @@ pub const fn bugflag_taint(value: u32) -> u32 { macro_rules! warn_on { ($cond:expr) => {{ let cond = $cond; + + #[cfg(CONFIG_DEBUG_BUGVERBOSE_DETAILED)] + const _COND_STR: &str = concat!("[", stringify!($cond), "] ", file!()); + #[cfg(not(CONFIG_DEBUG_BUGVERBOSE_DETAILED))] + const _COND_STR: &str = file!(); + if cond { const WARN_ON_FLAGS: u32 = $crate::bug::bugflag_taint($crate::bindings::TAINT_WARN); - $crate::warn_flags!(WARN_ON_FLAGS); + $crate::warn_flags!(_COND_STR, WARN_ON_FLAGS); } cond }}; diff --git a/rust/kernel/build_assert.rs b/rust/kernel/build_assert.rs index 6331b15d7c4d..f8124dbc663f 100644 --- a/rust/kernel/build_assert.rs +++ b/rust/kernel/build_assert.rs @@ -61,8 +61,13 @@ macro_rules! build_error { /// build_assert!(N > 1); // Build-time check /// assert!(N > 1); // Run-time check /// } +/// ``` /// -/// #[inline] +/// When a condition depends on a function argument, the function must be annotated with +/// `#[inline(always)]`. Without this attribute, the compiler may choose to not inline the +/// function, preventing it from optimizing out the error path. +/// ``` +/// #[inline(always)] /// fn bar(n: usize) { /// // `static_assert!(n > 1);` is not allowed /// build_assert!(n > 1); // Build-time check diff --git a/rust/kernel/clk.rs b/rust/kernel/clk.rs index c1cfaeaa36a2..4059aff34d09 100644 --- a/rust/kernel/clk.rs +++ b/rust/kernel/clk.rs @@ -94,7 +94,7 @@ mod common_clk { /// # Invariants /// /// A [`Clk`] instance holds either a pointer to a valid [`struct clk`] created by the C - /// portion of the kernel or a NULL pointer. + /// portion of the kernel or a `NULL` pointer. /// /// Instances of this type are reference-counted. Calling [`Clk::get`] ensures that the /// allocation remains valid for the lifetime of the [`Clk`]. @@ -104,13 +104,12 @@ mod common_clk { /// The following example demonstrates how to obtain and configure a clock for a device. /// /// ``` - /// use kernel::c_str; /// use kernel::clk::{Clk, Hertz}; /// use kernel::device::Device; /// use kernel::error::Result; /// /// fn configure_clk(dev: &Device) -> Result { - /// let clk = Clk::get(dev, Some(c_str!("apb_clk")))?; + /// let clk = Clk::get(dev, Some(c"apb_clk"))?; /// /// clk.prepare_enable()?; /// @@ -272,13 +271,12 @@ mod common_clk { /// device. The code functions correctly whether or not the clock is available. /// /// ``` - /// use kernel::c_str; /// use kernel::clk::{OptionalClk, Hertz}; /// use kernel::device::Device; /// use kernel::error::Result; /// /// fn configure_clk(dev: &Device) -> Result { - /// let clk = OptionalClk::get(dev, Some(c_str!("apb_clk")))?; + /// let clk = OptionalClk::get(dev, Some(c"apb_clk"))?; /// /// clk.prepare_enable()?; /// diff --git a/rust/kernel/configfs.rs b/rust/kernel/configfs.rs index 466fb7f40762..2339c6467325 100644 --- a/rust/kernel/configfs.rs +++ b/rust/kernel/configfs.rs @@ -21,7 +21,6 @@ //! //! ```ignore //! use kernel::alloc::flags; -//! use kernel::c_str; //! use kernel::configfs_attrs; //! use kernel::configfs; //! use kernel::new_mutex; @@ -50,7 +49,7 @@ //! //! try_pin_init!(Self { //! config <- configfs::Subsystem::new( -//! c_str!("rust_configfs"), item_type, Configuration::new() +//! c"rust_configfs", item_type, Configuration::new() //! ), //! }) //! } @@ -66,7 +65,7 @@ //! impl Configuration { //! fn new() -> impl PinInit<Self, Error> { //! try_pin_init!(Self { -//! message: c_str!("Hello World\n"), +//! message: c"Hello World\n", //! bar <- new_mutex!((KBox::new([0; PAGE_SIZE], flags::GFP_KERNEL)?, 0)), //! }) //! } @@ -1000,7 +999,9 @@ macro_rules! configfs_attrs { static [< $data:upper _ $name:upper _ATTR >]: $crate::configfs::Attribute<$attr, $data, $data> = unsafe { - $crate::configfs::Attribute::new(c_str!(::core::stringify!($name))) + $crate::configfs::Attribute::new( + $crate::c_str!(::core::stringify!($name)), + ) }; )* diff --git a/rust/kernel/cpufreq.rs b/rust/kernel/cpufreq.rs index f968fbd22890..76faa1ac8501 100644 --- a/rust/kernel/cpufreq.rs +++ b/rust/kernel/cpufreq.rs @@ -840,7 +840,6 @@ pub trait Driver { /// ``` /// use kernel::{ /// cpufreq, -/// c_str, /// device::{Core, Device}, /// macros::vtable, /// of, platform, @@ -853,7 +852,7 @@ pub trait Driver { /// /// #[vtable] /// impl cpufreq::Driver for SampleDriver { -/// const NAME: &'static CStr = c_str!("cpufreq-sample"); +/// const NAME: &'static CStr = c"cpufreq-sample"; /// const FLAGS: u16 = cpufreq::flags::NEED_INITIAL_FREQ_CHECK | cpufreq::flags::IS_COOLING_DEV; /// const BOOST_ENABLED: bool = true; /// @@ -1015,6 +1014,8 @@ impl<T: Driver> Registration<T> { ..pin_init::zeroed() }; + // Always inline to optimize out error path of `build_assert`. + #[inline(always)] const fn copy_name(name: &'static CStr) -> [c_char; CPUFREQ_NAME_LEN] { let src = name.to_bytes_with_nul(); let mut dst = [0; CPUFREQ_NAME_LEN]; diff --git a/rust/kernel/cpumask.rs b/rust/kernel/cpumask.rs index c1d17826ae7b..44bb36636ee3 100644 --- a/rust/kernel/cpumask.rs +++ b/rust/kernel/cpumask.rs @@ -39,7 +39,7 @@ use core::ops::{Deref, DerefMut}; /// fn set_clear_cpu(ptr: *mut bindings::cpumask, set_cpu: CpuId, clear_cpu: CpuId) { /// // SAFETY: The `ptr` is valid for writing and remains valid for the lifetime of the /// // returned reference. -/// let mask = unsafe { Cpumask::as_mut_ref(ptr) }; +/// let mask = unsafe { Cpumask::from_raw_mut(ptr) }; /// /// mask.set(set_cpu); /// mask.clear(clear_cpu); @@ -49,13 +49,13 @@ use core::ops::{Deref, DerefMut}; pub struct Cpumask(Opaque<bindings::cpumask>); impl Cpumask { - /// Creates a mutable reference to an existing `struct cpumask` pointer. + /// Creates a mutable reference from an existing `struct cpumask` pointer. /// /// # Safety /// /// The caller must ensure that `ptr` is valid for writing and remains valid for the lifetime /// of the returned reference. - pub unsafe fn as_mut_ref<'a>(ptr: *mut bindings::cpumask) -> &'a mut Self { + pub unsafe fn from_raw_mut<'a>(ptr: *mut bindings::cpumask) -> &'a mut Self { // SAFETY: Guaranteed by the safety requirements of the function. // // INVARIANT: The caller ensures that `ptr` is valid for writing and remains valid for the @@ -63,13 +63,13 @@ impl Cpumask { unsafe { &mut *ptr.cast() } } - /// Creates a reference to an existing `struct cpumask` pointer. + /// Creates a reference from an existing `struct cpumask` pointer. /// /// # Safety /// /// The caller must ensure that `ptr` is valid for reading and remains valid for the lifetime /// of the returned reference. - pub unsafe fn as_ref<'a>(ptr: *const bindings::cpumask) -> &'a Self { + pub unsafe fn from_raw<'a>(ptr: *const bindings::cpumask) -> &'a Self { // SAFETY: Guaranteed by the safety requirements of the function. // // INVARIANT: The caller ensures that `ptr` is valid for reading and remains valid for the diff --git a/rust/kernel/debugfs.rs b/rust/kernel/debugfs.rs index facad81e8290..d7b8014a6474 100644 --- a/rust/kernel/debugfs.rs +++ b/rust/kernel/debugfs.rs @@ -8,28 +8,52 @@ // When DebugFS is disabled, many parameters are dead. Linting for this isn't helpful. #![cfg_attr(not(CONFIG_DEBUG_FS), allow(unused_variables))] -use crate::fmt; -use crate::prelude::*; -use crate::str::CStr; #[cfg(CONFIG_DEBUG_FS)] use crate::sync::Arc; -use crate::uaccess::UserSliceReader; -use core::marker::PhantomData; -use core::marker::PhantomPinned; +use crate::{ + fmt, + prelude::*, + str::CStr, + uaccess::UserSliceReader, // +}; + #[cfg(CONFIG_DEBUG_FS)] use core::mem::ManuallyDrop; -use core::ops::Deref; +use core::{ + marker::{ + PhantomData, + PhantomPinned, // + }, + ops::Deref, +}; mod traits; -pub use traits::{BinaryReader, BinaryReaderMut, BinaryWriter, Reader, Writer}; +pub use traits::{ + BinaryReader, + BinaryReaderMut, + BinaryWriter, + Reader, + Writer, // +}; mod callback_adapters; -use callback_adapters::{FormatAdapter, NoWriter, WritableAdapter}; +use callback_adapters::{ + FormatAdapter, + NoWriter, + WritableAdapter, // +}; + mod file_ops; use file_ops::{ - BinaryReadFile, BinaryReadWriteFile, BinaryWriteFile, FileOps, ReadFile, ReadWriteFile, - WriteFile, + BinaryReadFile, + BinaryReadWriteFile, + BinaryWriteFile, + FileOps, + ReadFile, + ReadWriteFile, + WriteFile, // }; + #[cfg(CONFIG_DEBUG_FS)] mod entry; #[cfg(CONFIG_DEBUG_FS)] @@ -102,9 +126,8 @@ impl Dir { /// # Examples /// /// ``` - /// # use kernel::c_str; /// # use kernel::debugfs::Dir; - /// let debugfs = Dir::new(c_str!("parent")); + /// let debugfs = Dir::new(c"parent"); /// ``` pub fn new(name: &CStr) -> Self { Dir::create(name, None) @@ -115,10 +138,9 @@ impl Dir { /// # Examples /// /// ``` - /// # use kernel::c_str; /// # use kernel::debugfs::Dir; - /// let parent = Dir::new(c_str!("parent")); - /// let child = parent.subdir(c_str!("child")); + /// let parent = Dir::new(c"parent"); + /// let child = parent.subdir(c"child"); /// ``` pub fn subdir(&self, name: &CStr) -> Self { Dir::create(name, Some(self)) @@ -132,11 +154,10 @@ impl Dir { /// # Examples /// /// ``` - /// # use kernel::c_str; /// # use kernel::debugfs::Dir; /// # use kernel::prelude::*; - /// # let dir = Dir::new(c_str!("my_debugfs_dir")); - /// let file = KBox::pin_init(dir.read_only_file(c_str!("foo"), 200), GFP_KERNEL)?; + /// # let dir = Dir::new(c"my_debugfs_dir"); + /// let file = KBox::pin_init(dir.read_only_file(c"foo", 200), GFP_KERNEL)?; /// // "my_debugfs_dir/foo" now contains the number 200. /// // The file is removed when `file` is dropped. /// # Ok::<(), Error>(()) @@ -161,11 +182,10 @@ impl Dir { /// # Examples /// /// ``` - /// # use kernel::c_str; /// # use kernel::debugfs::Dir; /// # use kernel::prelude::*; - /// # let dir = Dir::new(c_str!("my_debugfs_dir")); - /// let file = KBox::pin_init(dir.read_binary_file(c_str!("foo"), [0x1, 0x2]), GFP_KERNEL)?; + /// # let dir = Dir::new(c"my_debugfs_dir"); + /// let file = KBox::pin_init(dir.read_binary_file(c"foo", [0x1, 0x2]), GFP_KERNEL)?; /// # Ok::<(), Error>(()) /// ``` pub fn read_binary_file<'a, T, E: 'a>( @@ -187,21 +207,25 @@ impl Dir { /// # Examples /// /// ``` - /// # use core::sync::atomic::{AtomicU32, Ordering}; - /// # use kernel::c_str; - /// # use kernel::debugfs::Dir; - /// # use kernel::prelude::*; - /// # let dir = Dir::new(c_str!("foo")); + /// # use kernel::{ + /// # debugfs::Dir, + /// # prelude::*, + /// # sync::atomic::{ + /// # Atomic, + /// # Relaxed, + /// # }, + /// # }; + /// # let dir = Dir::new(c"foo"); /// let file = KBox::pin_init( - /// dir.read_callback_file(c_str!("bar"), - /// AtomicU32::new(3), + /// dir.read_callback_file(c"bar", + /// Atomic::<u32>::new(3), /// &|val, f| { - /// let out = val.load(Ordering::Relaxed); + /// let out = val.load(Relaxed); /// writeln!(f, "{out:#010x}") /// }), /// GFP_KERNEL)?; /// // Reading "foo/bar" will show "0x00000003". - /// file.store(10, Ordering::Relaxed); + /// file.store(10, Relaxed); /// // Reading "foo/bar" will now show "0x0000000a". /// # Ok::<(), Error>(()) /// ``` diff --git a/rust/kernel/debugfs/callback_adapters.rs b/rust/kernel/debugfs/callback_adapters.rs index a260d8dee051..dee7d021e18c 100644 --- a/rust/kernel/debugfs/callback_adapters.rs +++ b/rust/kernel/debugfs/callback_adapters.rs @@ -4,12 +4,21 @@ //! Adapters which allow the user to supply a write or read implementation as a value rather //! than a trait implementation. If provided, it will override the trait implementation. -use super::{Reader, Writer}; -use crate::fmt; -use crate::prelude::*; -use crate::uaccess::UserSliceReader; -use core::marker::PhantomData; -use core::ops::Deref; +use super::{ + Reader, + Writer, // +}; + +use crate::{ + fmt, + prelude::*, + uaccess::UserSliceReader, // +}; + +use core::{ + marker::PhantomData, + ops::Deref, // +}; /// # Safety /// diff --git a/rust/kernel/debugfs/entry.rs b/rust/kernel/debugfs/entry.rs index 706cb7f73d6c..46aad64896ec 100644 --- a/rust/kernel/debugfs/entry.rs +++ b/rust/kernel/debugfs/entry.rs @@ -1,10 +1,16 @@ // SPDX-License-Identifier: GPL-2.0 // Copyright (C) 2025 Google LLC. -use crate::debugfs::file_ops::FileOps; -use crate::ffi::c_void; -use crate::str::{CStr, CStrExt as _}; -use crate::sync::Arc; +use crate::{ + debugfs::file_ops::FileOps, + prelude::*, + str::{ + CStr, + CStrExt as _, // + }, + sync::Arc, +}; + use core::marker::PhantomData; /// Owning handle to a DebugFS entry. @@ -148,7 +154,7 @@ impl Entry<'_> { /// # Guarantees /// /// Due to the type invariant, the value returned from this function will always be an error - /// code, NULL, or a live DebugFS directory. If it is live, it will remain live at least as + /// code, `NULL`, or a live DebugFS directory. If it is live, it will remain live at least as /// long as this entry lives. pub(crate) fn as_ptr(&self) -> *mut bindings::dentry { self.entry diff --git a/rust/kernel/debugfs/file_ops.rs b/rust/kernel/debugfs/file_ops.rs index 8a0442d6dd7a..f15908f71c4a 100644 --- a/rust/kernel/debugfs/file_ops.rs +++ b/rust/kernel/debugfs/file_ops.rs @@ -1,14 +1,23 @@ // SPDX-License-Identifier: GPL-2.0 // Copyright (C) 2025 Google LLC. -use super::{BinaryReader, BinaryWriter, Reader, Writer}; -use crate::debugfs::callback_adapters::Adapter; -use crate::fmt; -use crate::fs::file; -use crate::prelude::*; -use crate::seq_file::SeqFile; -use crate::seq_print; -use crate::uaccess::UserSlice; +use super::{ + BinaryReader, + BinaryWriter, + Reader, + Writer, // +}; + +use crate::{ + debugfs::callback_adapters::Adapter, + fmt, + fs::file, + prelude::*, + seq_file::SeqFile, + seq_print, + uaccess::UserSlice, // +}; + use core::marker::PhantomData; #[cfg(CONFIG_DEBUG_FS)] @@ -126,8 +135,7 @@ impl<T: Writer + Sync> ReadFile<T> for T { llseek: Some(bindings::seq_lseek), release: Some(bindings::single_release), open: Some(writer_open::<Self>), - // SAFETY: `file_operations` supports zeroes in all fields. - ..unsafe { core::mem::zeroed() } + ..pin_init::zeroed() }; // SAFETY: `operations` is all stock `seq_file` implementations except for `writer_open`. // `open`'s only requirement beyond what is provided to all open functions is that the @@ -179,8 +187,7 @@ impl<T: Writer + Reader + Sync> ReadWriteFile<T> for T { write: Some(write::<T>), llseek: Some(bindings::seq_lseek), release: Some(bindings::single_release), - // SAFETY: `file_operations` supports zeroes in all fields. - ..unsafe { core::mem::zeroed() } + ..pin_init::zeroed() }; // SAFETY: `operations` is all stock `seq_file` implementations except for `writer_open` // and `write`. @@ -235,8 +242,7 @@ impl<T: Reader + Sync> WriteFile<T> for T { open: Some(write_only_open), write: Some(write_only_write::<T>), llseek: Some(bindings::noop_llseek), - // SAFETY: `file_operations` supports zeroes in all fields. - ..unsafe { core::mem::zeroed() } + ..pin_init::zeroed() }; // SAFETY: // * `write_only_open` populates the file private data with the inode private data @@ -288,8 +294,7 @@ impl<T: BinaryWriter + Sync> BinaryReadFile<T> for T { read: Some(blob_read::<T>), llseek: Some(bindings::default_llseek), open: Some(bindings::simple_open), - // SAFETY: `file_operations` supports zeroes in all fields. - ..unsafe { core::mem::zeroed() } + ..pin_init::zeroed() }; // SAFETY: @@ -343,8 +348,7 @@ impl<T: BinaryReader + Sync> BinaryWriteFile<T> for T { write: Some(blob_write::<T>), llseek: Some(bindings::default_llseek), open: Some(bindings::simple_open), - // SAFETY: `file_operations` supports zeroes in all fields. - ..unsafe { core::mem::zeroed() } + ..pin_init::zeroed() }; // SAFETY: @@ -369,8 +373,7 @@ impl<T: BinaryWriter + BinaryReader + Sync> BinaryReadWriteFile<T> for T { write: Some(blob_write::<T>), llseek: Some(bindings::default_llseek), open: Some(bindings::simple_open), - // SAFETY: `file_operations` supports zeroes in all fields. - ..unsafe { core::mem::zeroed() } + ..pin_init::zeroed() }; // SAFETY: diff --git a/rust/kernel/debugfs/traits.rs b/rust/kernel/debugfs/traits.rs index 3eee60463fd5..8c39524b6a99 100644 --- a/rust/kernel/debugfs/traits.rs +++ b/rust/kernel/debugfs/traits.rs @@ -3,17 +3,38 @@ //! Traits for rendering or updating values exported to DebugFS. -use crate::alloc::Allocator; -use crate::fmt; -use crate::fs::file; -use crate::prelude::*; -use crate::sync::atomic::{Atomic, AtomicBasicOps, AtomicType, Relaxed}; -use crate::sync::Arc; -use crate::sync::Mutex; -use crate::transmute::{AsBytes, FromBytes}; -use crate::uaccess::{UserSliceReader, UserSliceWriter}; -use core::ops::{Deref, DerefMut}; -use core::str::FromStr; +use crate::{ + alloc::Allocator, + fmt, + fs::file, + prelude::*, + sync::{ + atomic::{ + Atomic, + AtomicBasicOps, + AtomicType, + Relaxed, // + }, + Arc, + Mutex, // + }, + transmute::{ + AsBytes, + FromBytes, // + }, + uaccess::{ + UserSliceReader, + UserSliceWriter, // + }, +}; + +use core::{ + ops::{ + Deref, + DerefMut, // + }, + str::FromStr, +}; /// A trait for types that can be written into a string. /// diff --git a/rust/kernel/device.rs b/rust/kernel/device.rs index c79be2e2bfe3..94e0548e7687 100644 --- a/rust/kernel/device.rs +++ b/rust/kernel/device.rs @@ -5,16 +5,20 @@ //! C header: [`include/linux/device.h`](srctree/include/linux/device.h) use crate::{ - bindings, fmt, + bindings, + fmt, prelude::*, sync::aref::ARef, - types::{ForeignOwnable, Opaque}, + types::{ + ForeignOwnable, + Opaque, // + }, // +}; +use core::{ + any::TypeId, + marker::PhantomData, + ptr, // }; -use core::{any::TypeId, marker::PhantomData, ptr}; - -#[cfg(CONFIG_PRINTK)] -use crate::c_str; -use crate::str::CStrExt as _; pub mod property; @@ -67,8 +71,9 @@ static_assert!(core::mem::size_of::<bindings::driver_type>() >= core::mem::size_ /// /// # Implementing Bus Devices /// -/// This section provides a guideline to implement bus specific devices, such as [`pci::Device`] or -/// [`platform::Device`]. +/// This section provides a guideline to implement bus specific devices, such as: +#[cfg_attr(CONFIG_PCI, doc = "* [`pci::Device`](kernel::pci::Device)")] +/// * [`platform::Device`] /// /// A bus specific device should be defined as follows. /// @@ -158,9 +163,8 @@ static_assert!(core::mem::size_of::<bindings::driver_type>() >= core::mem::size_ /// `bindings::device::release` is valid to be called from any thread, hence `ARef<Device>` can be /// dropped from any thread. /// -/// [`AlwaysRefCounted`]: kernel::types::AlwaysRefCounted +/// [`AlwaysRefCounted`]: kernel::sync::aref::AlwaysRefCounted /// [`impl_device_context_deref`]: kernel::impl_device_context_deref -/// [`pci::Device`]: kernel::pci::Device /// [`platform::Device`]: kernel::platform::Device #[repr(transparent)] pub struct Device<Ctx: DeviceContext = Normal>(Opaque<bindings::device>, PhantomData<Ctx>); @@ -233,30 +237,32 @@ impl Device<CoreInternal> { /// /// # Safety /// - /// - Must only be called once after a preceding call to [`Device::set_drvdata`]. /// - The type `T` must match the type of the `ForeignOwnable` previously stored by /// [`Device::set_drvdata`]. - pub unsafe fn drvdata_obtain<T: 'static>(&self) -> Pin<KBox<T>> { + pub(crate) unsafe fn drvdata_obtain<T: 'static>(&self) -> Option<Pin<KBox<T>>> { // SAFETY: By the type invariants, `self.as_raw()` is a valid pointer to a `struct device`. let ptr = unsafe { bindings::dev_get_drvdata(self.as_raw()) }; // SAFETY: By the type invariants, `self.as_raw()` is a valid pointer to a `struct device`. unsafe { bindings::dev_set_drvdata(self.as_raw(), core::ptr::null_mut()) }; + if ptr.is_null() { + return None; + } + // SAFETY: - // - By the safety requirements of this function, `ptr` comes from a previous call to - // `into_foreign()`. + // - If `ptr` is not NULL, it comes from a previous call to `into_foreign()`. // - `dev_get_drvdata()` guarantees to return the same pointer given to `dev_set_drvdata()` // in `into_foreign()`. - unsafe { Pin::<KBox<T>>::from_foreign(ptr.cast()) } + Some(unsafe { Pin::<KBox<T>>::from_foreign(ptr.cast()) }) } /// Borrow the driver's private data bound to this [`Device`]. /// /// # Safety /// - /// - Must only be called after a preceding call to [`Device::set_drvdata`] and before - /// [`Device::drvdata_obtain`]. + /// - Must only be called after a preceding call to [`Device::set_drvdata`] and before the + /// device is fully unbound. /// - The type `T` must match the type of the `ForeignOwnable` previously stored by /// [`Device::set_drvdata`]. pub unsafe fn drvdata_borrow<T: 'static>(&self) -> Pin<&T> { @@ -272,7 +278,7 @@ impl Device<Bound> { /// # Safety /// /// - Must only be called after a preceding call to [`Device::set_drvdata`] and before - /// [`Device::drvdata_obtain`]. + /// the device is fully unbound. /// - The type `T` must match the type of the `ForeignOwnable` previously stored by /// [`Device::set_drvdata`]. unsafe fn drvdata_unchecked<T: 'static>(&self) -> Pin<&T> { @@ -321,7 +327,7 @@ impl Device<Bound> { // SAFETY: // - The above check of `dev_get_drvdata()` guarantees that we are called after - // `set_drvdata()` and before `drvdata_obtain()`. + // `set_drvdata()`. // - We've just checked that the type of the driver's private data is in fact `T`. Ok(unsafe { self.drvdata_unchecked() }) } @@ -463,7 +469,7 @@ impl<Ctx: DeviceContext> Device<Ctx> { bindings::_dev_printk( klevel.as_ptr().cast::<crate::ffi::c_char>(), self.as_raw(), - c_str!("%pA").as_char_ptr(), + c"%pA".as_char_ptr(), core::ptr::from_ref(&msg).cast::<crate::ffi::c_void>(), ) }; @@ -540,7 +546,7 @@ pub trait DeviceContext: private::Sealed {} /// [`Device<Normal>`]. It is the only [`DeviceContext`] for which it is valid to implement /// [`AlwaysRefCounted`] for. /// -/// [`AlwaysRefCounted`]: kernel::types::AlwaysRefCounted +/// [`AlwaysRefCounted`]: kernel::sync::aref::AlwaysRefCounted pub struct Normal; /// The [`Core`] context is the context of a bus specific device when it appears as argument of @@ -595,6 +601,13 @@ impl DeviceContext for Core {} impl DeviceContext for CoreInternal {} impl DeviceContext for Normal {} +impl<Ctx: DeviceContext> AsRef<Device<Ctx>> for Device<Ctx> { + #[inline] + fn as_ref(&self) -> &Device<Ctx> { + self + } +} + /// Convert device references to bus device references. /// /// Bus devices can implement this trait to allow abstractions to provide the bus device in @@ -714,7 +727,7 @@ macro_rules! impl_device_context_into_aref { macro_rules! dev_printk { ($method:ident, $dev:expr, $($f:tt)*) => { { - ($dev).$method($crate::prelude::fmt!($($f)*)); + $crate::device::Device::$method($dev.as_ref(), $crate::prelude::fmt!($($f)*)) } } } diff --git a/rust/kernel/device/property.rs b/rust/kernel/device/property.rs index 3a332a8c53a9..5aead835fbbc 100644 --- a/rust/kernel/device/property.rs +++ b/rust/kernel/device/property.rs @@ -14,7 +14,8 @@ use crate::{ fmt, prelude::*, str::{CStr, CString}, - types::{ARef, Opaque}, + sync::aref::ARef, + types::Opaque, }; /// A reference-counted fwnode_handle. @@ -178,11 +179,11 @@ impl FwNode { /// # Examples /// /// ``` - /// # use kernel::{c_str, device::{Device, property::FwNode}, str::CString}; + /// # use kernel::{device::{Device, property::FwNode}, str::CString}; /// fn examples(dev: &Device) -> Result { /// let fwnode = dev.fwnode().ok_or(ENOENT)?; - /// let b: u32 = fwnode.property_read(c_str!("some-number")).required_by(dev)?; - /// if let Some(s) = fwnode.property_read::<CString>(c_str!("some-str")).optional() { + /// let b: u32 = fwnode.property_read(c"some-number").required_by(dev)?; + /// if let Some(s) = fwnode.property_read::<CString>(c"some-str").optional() { /// // ... /// } /// Ok(()) @@ -359,7 +360,7 @@ impl fmt::Debug for FwNodeReferenceArgs { } // SAFETY: Instances of `FwNode` are always reference-counted. -unsafe impl crate::types::AlwaysRefCounted for FwNode { +unsafe impl crate::sync::aref::AlwaysRefCounted for FwNode { fn inc_ref(&self) { // SAFETY: The existence of a shared reference guarantees that the // refcount is non-zero. diff --git a/rust/kernel/device_id.rs b/rust/kernel/device_id.rs index 62c42da12e9d..8e9721446014 100644 --- a/rust/kernel/device_id.rs +++ b/rust/kernel/device_id.rs @@ -15,7 +15,7 @@ use core::mem::MaybeUninit; /// # Safety /// /// Implementers must ensure that `Self` is layout-compatible with [`RawDeviceId::RawType`]; -/// i.e. it's safe to transmute to `RawDeviceId`. +/// i.e. it's safe to transmute to `RawType`. /// /// This requirement is needed so `IdArray::new` can convert `Self` to `RawType` when building /// the ID table. diff --git a/rust/kernel/devres.rs b/rust/kernel/devres.rs index 835d9c11948e..6afe196be42c 100644 --- a/rust/kernel/devres.rs +++ b/rust/kernel/devres.rs @@ -8,30 +8,24 @@ use crate::{ alloc::Flags, bindings, - device::{Bound, Device}, - error::{to_result, Error, Result}, - ffi::c_void, + device::{ + Bound, + Device, // + }, + error::to_result, prelude::*, - revocable::{Revocable, RevocableGuard}, - sync::{aref::ARef, rcu, Completion}, - types::{ForeignOwnable, Opaque, ScopeGuard}, + revocable::{ + Revocable, + RevocableGuard, // + }, + sync::{ + aref::ARef, + rcu, + Arc, // + }, + types::ForeignOwnable, }; -use pin_init::Wrapper; - -/// [`Devres`] inner data accessed from [`Devres::callback`]. -#[pin_data] -struct Inner<T: Send> { - #[pin] - data: Revocable<T>, - /// Tracks whether [`Devres::callback`] has been completed. - #[pin] - devm: Completion, - /// Tracks whether revoking [`Self::data`] has been completed. - #[pin] - revoke: Completion, -} - /// This abstraction is meant to be used by subsystems to containerize [`Device`] bound resources to /// manage their lifetime. /// @@ -61,14 +55,17 @@ struct Inner<T: Send> { /// devres::Devres, /// io::{ /// Io, -/// IoRaw, -/// PhysAddr, +/// IoKnownSize, +/// Mmio, +/// MmioRaw, +/// PhysAddr, // /// }, +/// prelude::*, /// }; /// use core::ops::Deref; /// /// // See also [`pci::Bar`] for a real example. -/// struct IoMem<const SIZE: usize>(IoRaw<SIZE>); +/// struct IoMem<const SIZE: usize>(MmioRaw<SIZE>); /// /// impl<const SIZE: usize> IoMem<SIZE> { /// /// # Safety @@ -83,7 +80,7 @@ struct Inner<T: Send> { /// return Err(ENOMEM); /// } /// -/// Ok(IoMem(IoRaw::new(addr as usize, SIZE)?)) +/// Ok(IoMem(MmioRaw::new(addr as usize, SIZE)?)) /// } /// } /// @@ -95,28 +92,23 @@ struct Inner<T: Send> { /// } /// /// impl<const SIZE: usize> Deref for IoMem<SIZE> { -/// type Target = Io<SIZE>; +/// type Target = Mmio<SIZE>; /// /// fn deref(&self) -> &Self::Target { /// // SAFETY: The memory range stored in `self` has been properly mapped in `Self::new`. -/// unsafe { Io::from_raw(&self.0) } +/// unsafe { Mmio::from_raw(&self.0) } /// } /// } /// # fn no_run(dev: &Device<Bound>) -> Result<(), Error> { /// // SAFETY: Invalid usage for example purposes. /// let iomem = unsafe { IoMem::<{ core::mem::size_of::<u32>() }>::new(0xBAAAAAAD)? }; -/// let devres = KBox::pin_init(Devres::new(dev, iomem), GFP_KERNEL)?; +/// let devres = Devres::new(dev, iomem)?; /// /// let res = devres.try_access().ok_or(ENXIO)?; /// res.write8(0x42, 0x0); /// # Ok(()) /// # } /// ``` -/// -/// # Invariants -/// -/// `Self::inner` is guaranteed to be initialized and is always accessed read-only. -#[pin_data(PinnedDrop)] pub struct Devres<T: Send> { dev: ARef<Device>, /// Pointer to [`Self::devres_callback`]. @@ -124,14 +116,7 @@ pub struct Devres<T: Send> { /// Has to be stored, since Rust does not guarantee to always return the same address for a /// function. However, the C API uses the address as a key. callback: unsafe extern "C" fn(*mut c_void), - /// Contains all the fields shared with [`Self::callback`]. - // TODO: Replace with `UnsafePinned`, once available. - // - // Subsequently, the `drop_in_place()` in `Devres::drop` and `Devres::new` as well as the - // explicit `Send` and `Sync' impls can be removed. - #[pin] - inner: Opaque<Inner<T>>, - _add_action: (), + data: Arc<Revocable<T>>, } impl<T: Send> Devres<T> { @@ -139,74 +124,48 @@ impl<T: Send> Devres<T> { /// /// The `data` encapsulated within the returned `Devres` instance' `data` will be /// (revoked)[`Revocable`] once the device is detached. - pub fn new<'a, E>( - dev: &'a Device<Bound>, - data: impl PinInit<T, E> + 'a, - ) -> impl PinInit<Self, Error> + 'a + pub fn new<E>(dev: &Device<Bound>, data: impl PinInit<T, E>) -> Result<Self> where - T: 'a, Error: From<E>, { - try_pin_init!(&this in Self { - dev: dev.into(), - callback: Self::devres_callback, - // INVARIANT: `inner` is properly initialized. - inner <- Opaque::pin_init(try_pin_init!(Inner { - devm <- Completion::new(), - revoke <- Completion::new(), - data <- Revocable::new(data), - })), - // TODO: Replace with "initializer code blocks" [1] once available. - // - // [1] https://github.com/Rust-for-Linux/pin-init/pull/69 - _add_action: { - // SAFETY: `this` is a valid pointer to uninitialized memory. - let inner = unsafe { &raw mut (*this.as_ptr()).inner }; + let callback = Self::devres_callback; + let data = Arc::pin_init(Revocable::new(data), GFP_KERNEL)?; + let devres_data = data.clone(); - // SAFETY: - // - `dev.as_raw()` is a pointer to a valid bound device. - // - `inner` is guaranteed to be a valid for the duration of the lifetime of `Self`. - // - `devm_add_action()` is guaranteed not to call `callback` until `this` has been - // properly initialized, because we require `dev` (i.e. the *bound* device) to - // live at least as long as the returned `impl PinInit<Self, Error>`. - to_result(unsafe { - bindings::devm_add_action(dev.as_raw(), Some(*callback), inner.cast()) - }).inspect_err(|_| { - let inner = Opaque::cast_into(inner); + // SAFETY: + // - `dev.as_raw()` is a pointer to a valid bound device. + // - `data` is guaranteed to be a valid for the duration of the lifetime of `Self`. + // - `devm_add_action()` is guaranteed not to call `callback` for the entire lifetime of + // `dev`. + to_result(unsafe { + bindings::devm_add_action( + dev.as_raw(), + Some(callback), + Arc::as_ptr(&data).cast_mut().cast(), + ) + })?; - // SAFETY: `inner` is a valid pointer to an `Inner<T>` and valid for both reads - // and writes. - unsafe { core::ptr::drop_in_place(inner) }; - })?; - }, - }) - } + // `devm_add_action()` was successful and has consumed the reference count. + core::mem::forget(devres_data); - fn inner(&self) -> &Inner<T> { - // SAFETY: By the type invairants of `Self`, `inner` is properly initialized and always - // accessed read-only. - unsafe { &*self.inner.get() } + Ok(Self { + dev: dev.into(), + callback, + data, + }) } fn data(&self) -> &Revocable<T> { - &self.inner().data + &self.data } #[allow(clippy::missing_safety_doc)] unsafe extern "C" fn devres_callback(ptr: *mut kernel::ffi::c_void) { - // SAFETY: In `Self::new` we've passed a valid pointer to `Inner` to `devm_add_action()`, - // hence `ptr` must be a valid pointer to `Inner`. - let inner = unsafe { &*ptr.cast::<Inner<T>>() }; - - // Ensure that `inner` can't be used anymore after we signal completion of this callback. - let inner = ScopeGuard::new_with_data(inner, |inner| inner.devm.complete_all()); + // SAFETY: In `Self::new` we've passed a valid pointer of `Revocable<T>` to + // `devm_add_action()`, hence `ptr` must be a valid pointer to `Revocable<T>`. + let data = unsafe { Arc::from_raw(ptr.cast::<Revocable<T>>()) }; - if !inner.data.revoke() { - // If `revoke()` returns false, it means that `Devres::drop` already started revoking - // `data` for us. Hence we have to wait until `Devres::drop` signals that it - // completed revoking `data`. - inner.revoke.wait_for_completion(); - } + data.revoke(); } fn remove_action(&self) -> bool { @@ -218,7 +177,7 @@ impl<T: Send> Devres<T> { bindings::devm_remove_action_nowarn( self.dev.as_raw(), Some(self.callback), - core::ptr::from_ref(self.inner()).cast_mut().cast(), + core::ptr::from_ref(self.data()).cast_mut().cast(), ) } == 0) } @@ -241,8 +200,16 @@ impl<T: Send> Devres<T> { /// # Examples /// /// ```no_run - /// # #![cfg(CONFIG_PCI)] - /// # use kernel::{device::Core, devres::Devres, pci}; + /// #![cfg(CONFIG_PCI)] + /// use kernel::{ + /// device::Core, + /// devres::Devres, + /// io::{ + /// Io, + /// IoKnownSize, // + /// }, + /// pci, // + /// }; /// /// fn from_core(dev: &pci::Device<Core>, devres: Devres<pci::Bar<0x4>>) -> Result { /// let bar = devres.access(dev.as_ref())?; @@ -289,31 +256,19 @@ unsafe impl<T: Send> Send for Devres<T> {} // SAFETY: `Devres` can be shared with any task, if `T: Sync`. unsafe impl<T: Send + Sync> Sync for Devres<T> {} -#[pinned_drop] -impl<T: Send> PinnedDrop for Devres<T> { - fn drop(self: Pin<&mut Self>) { +impl<T: Send> Drop for Devres<T> { + fn drop(&mut self) { // SAFETY: When `drop` runs, it is guaranteed that nobody is accessing the revocable data // anymore, hence it is safe not to wait for the grace period to finish. if unsafe { self.data().revoke_nosync() } { // We revoked `self.data` before the devres action did, hence try to remove it. - if !self.remove_action() { - // We could not remove the devres action, which means that it now runs concurrently, - // hence signal that `self.data` has been revoked by us successfully. - self.inner().revoke.complete_all(); - - // Wait for `Self::devres_callback` to be done using this object. - self.inner().devm.wait_for_completion(); + if self.remove_action() { + // SAFETY: In `Self::new` we have taken an additional reference count of `self.data` + // for `devm_add_action()`. Since `remove_action()` was successful, we have to drop + // this additional reference count. + drop(unsafe { Arc::from_raw(Arc::as_ptr(&self.data)) }); } - } else { - // `Self::devres_callback` revokes `self.data` for us, hence wait for it to be done - // using this object. - self.inner().devm.wait_for_completion(); } - - // INVARIANT: At this point it is guaranteed that `inner` can't be accessed any more. - // - // SAFETY: `inner` is valid for dropping. - unsafe { core::ptr::drop_in_place(self.inner.get()) }; } } @@ -345,7 +300,13 @@ where /// # Examples /// /// ```no_run -/// use kernel::{device::{Bound, Device}, devres}; +/// use kernel::{ +/// device::{ +/// Bound, +/// Device, // +/// }, +/// devres, // +/// }; /// /// /// Registration of e.g. a class device, IRQ, etc. /// struct Registration; diff --git a/rust/kernel/dma.rs b/rust/kernel/dma.rs index 84d3c67269e8..909d56fd5118 100644 --- a/rust/kernel/dma.rs +++ b/rust/kernel/dma.rs @@ -27,8 +27,9 @@ pub type DmaAddress = bindings::dma_addr_t; /// Trait to be implemented by DMA capable bus devices. /// /// The [`dma::Device`](Device) trait should be implemented by bus specific device representations, -/// where the underlying bus is DMA capable, such as [`pci::Device`](::kernel::pci::Device) or -/// [`platform::Device`](::kernel::platform::Device). +/// where the underlying bus is DMA capable, such as: +#[cfg_attr(CONFIG_PCI, doc = "* [`pci::Device`](kernel::pci::Device)")] +/// * [`platform::Device`](::kernel::platform::Device) pub trait Device: AsRef<device::Device<Core>> { /// Set up the device's DMA streaming addressing capabilities. /// @@ -84,6 +85,23 @@ pub trait Device: AsRef<device::Device<Core>> { bindings::dma_set_mask_and_coherent(self.as_ref().as_raw(), mask.value()) }) } + + /// Set the maximum size of a single DMA segment the device may request. + /// + /// This method is usually called once from `probe()` as soon as the device capabilities are + /// known. + /// + /// # Safety + /// + /// This method must not be called concurrently with any DMA allocation or mapping primitives, + /// such as [`CoherentAllocation::alloc_attrs`]. + unsafe fn dma_set_max_seg_size(&self, size: u32) { + // SAFETY: + // - By the type invariant of `device::Device`, `self.as_ref().as_raw()` is valid. + // - The safety requirement of this function guarantees that there are no concurrent calls + // to DMA allocation and mapping primitives using this parameter. + unsafe { bindings::dma_set_max_seg_size(self.as_ref().as_raw(), size) } + } } /// A DMA mask that holds a bitmask with the lowest `n` bits set. @@ -532,8 +550,6 @@ impl<T: AsBytes + FromBytes> CoherentAllocation<T> { /// /// # Safety /// - /// * Callers must ensure that the device does not read/write to/from memory while the returned - /// slice is live. /// * Callers must ensure that this call does not race with a read or write to the same region /// that overlaps with this write. /// diff --git a/rust/kernel/driver.rs b/rust/kernel/driver.rs index 9beae2e3d57e..36de8098754d 100644 --- a/rust/kernel/driver.rs +++ b/rust/kernel/driver.rs @@ -33,7 +33,14 @@ //! } //! ``` //! -//! For specific examples see [`auxiliary::Driver`], [`pci::Driver`] and [`platform::Driver`]. +//! For specific examples see: +//! +//! * [`platform::Driver`](kernel::platform::Driver) +#" +)] +#")] //! //! The `probe()` callback should return a `impl PinInit<Self, Error>`, i.e. the driver's private //! data. The bus abstraction should store the pointer in the corresponding bus device. The generic @@ -79,7 +86,6 @@ //! //! For this purpose the generic infrastructure in [`device_id`] should be used. //! -//! [`auxiliary::Driver`]: kernel::auxiliary::Driver //! [`Core`]: device::Core //! [`Device`]: device::Device //! [`Device<Core>`]: device::Device<device::Core> @@ -87,31 +93,53 @@ //! [`DeviceContext`]: device::DeviceContext //! [`device_id`]: kernel::device_id //! [`module_driver`]: kernel::module_driver -//! [`pci::Driver`]: kernel::pci::Driver -//! [`platform::Driver`]: kernel::platform::Driver -use crate::error::{Error, Result}; -use crate::{acpi, device, of, str::CStr, try_pin_init, types::Opaque, ThisModule}; -use core::pin::Pin; -use pin_init::{pin_data, pinned_drop, PinInit}; +use crate::{ + acpi, + device, + of, + prelude::*, + types::Opaque, + ThisModule, // +}; + +/// Trait describing the layout of a specific device driver. +/// +/// This trait describes the layout of a specific driver structure, such as `struct pci_driver` or +/// `struct platform_driver`. +/// +/// # Safety +/// +/// Implementors must guarantee that: +/// - `DriverType` is `repr(C)`, +/// - `DriverData` is the type of the driver's device private data. +/// - `DriverType` embeds a valid `struct device_driver` at byte offset `DEVICE_DRIVER_OFFSET`. +pub unsafe trait DriverLayout { + /// The specific driver type embedding a `struct device_driver`. + type DriverType: Default; + + /// The type of the driver's device private data. + type DriverData; + + /// Byte offset of the embedded `struct device_driver` within `DriverType`. + /// + /// This must correspond exactly to the location of the embedded `struct device_driver` field. + const DEVICE_DRIVER_OFFSET: usize; +} /// The [`RegistrationOps`] trait serves as generic interface for subsystems (e.g., PCI, Platform, /// Amba, etc.) to provide the corresponding subsystem specific implementation to register / -/// unregister a driver of the particular type (`RegType`). +/// unregister a driver of the particular type (`DriverType`). /// -/// For instance, the PCI subsystem would set `RegType` to `bindings::pci_driver` and call +/// For instance, the PCI subsystem would set `DriverType` to `bindings::pci_driver` and call /// `bindings::__pci_register_driver` from `RegistrationOps::register` and /// `bindings::pci_unregister_driver` from `RegistrationOps::unregister`. /// /// # Safety /// -/// A call to [`RegistrationOps::unregister`] for a given instance of `RegType` is only valid if a -/// preceding call to [`RegistrationOps::register`] has been successful. -pub unsafe trait RegistrationOps { - /// The type that holds information about the registration. This is typically a struct defined - /// by the C portion of the kernel. - type RegType: Default; - +/// A call to [`RegistrationOps::unregister`] for a given instance of `DriverType` is only valid if +/// a preceding call to [`RegistrationOps::register`] has been successful. +pub unsafe trait RegistrationOps: DriverLayout { /// Registers a driver. /// /// # Safety @@ -119,7 +147,7 @@ pub unsafe trait RegistrationOps { /// On success, `reg` must remain pinned and valid until the matching call to /// [`RegistrationOps::unregister`]. unsafe fn register( - reg: &Opaque<Self::RegType>, + reg: &Opaque<Self::DriverType>, name: &'static CStr, module: &'static ThisModule, ) -> Result; @@ -130,7 +158,7 @@ pub unsafe trait RegistrationOps { /// /// Must only be called after a preceding successful call to [`RegistrationOps::register`] for /// the same `reg`. - unsafe fn unregister(reg: &Opaque<Self::RegType>); + unsafe fn unregister(reg: &Opaque<Self::DriverType>); } /// A [`Registration`] is a generic type that represents the registration of some driver type (e.g. @@ -142,7 +170,7 @@ pub unsafe trait RegistrationOps { #[pin_data(PinnedDrop)] pub struct Registration<T: RegistrationOps> { #[pin] - reg: Opaque<T::RegType>, + reg: Opaque<T::DriverType>, } // SAFETY: `Registration` has no fields or methods accessible via `&Registration`, so it is safe to @@ -153,17 +181,51 @@ unsafe impl<T: RegistrationOps> Sync for Registration<T> {} // any thread, so `Registration` is `Send`. unsafe impl<T: RegistrationOps> Send for Registration<T> {} -impl<T: RegistrationOps> Registration<T> { +impl<T: RegistrationOps + 'static> Registration<T> { + extern "C" fn post_unbind_callback(dev: *mut bindings::device) { + // SAFETY: The driver core only ever calls the post unbind callback with a valid pointer to + // a `struct device`. + // + // INVARIANT: `dev` is valid for the duration of the `post_unbind_callback()`. + let dev = unsafe { &*dev.cast::<device::Device<device::CoreInternal>>() }; + + // `remove()` and all devres callbacks have been completed at this point, hence drop the + // driver's device private data. + // + // SAFETY: By the safety requirements of the `Driver` trait, `T::DriverData` is the + // driver's device private data type. + drop(unsafe { dev.drvdata_obtain::<T::DriverData>() }); + } + + /// Attach generic `struct device_driver` callbacks. + fn callbacks_attach(drv: &Opaque<T::DriverType>) { + let ptr = drv.get().cast::<u8>(); + + // SAFETY: + // - `drv.get()` yields a valid pointer to `Self::DriverType`. + // - Adding `DEVICE_DRIVER_OFFSET` yields the address of the embedded `struct device_driver` + // as guaranteed by the safety requirements of the `Driver` trait. + let base = unsafe { ptr.add(T::DEVICE_DRIVER_OFFSET) }; + + // CAST: `base` points to the offset of the embedded `struct device_driver`. + let base = base.cast::<bindings::device_driver>(); + + // SAFETY: It is safe to set the fields of `struct device_driver` on initialization. + unsafe { (*base).p_cb.post_unbind_rust = Some(Self::post_unbind_callback) }; + } + /// Creates a new instance of the registration object. pub fn new(name: &'static CStr, module: &'static ThisModule) -> impl PinInit<Self, Error> { try_pin_init!(Self { - reg <- Opaque::try_ffi_init(|ptr: *mut T::RegType| { + reg <- Opaque::try_ffi_init(|ptr: *mut T::DriverType| { // SAFETY: `try_ffi_init` guarantees that `ptr` is valid for write. - unsafe { ptr.write(T::RegType::default()) }; + unsafe { ptr.write(T::DriverType::default()) }; // SAFETY: `try_ffi_init` guarantees that `ptr` is valid for write, and it has // just been initialised above, so it's also valid for read. - let drv = unsafe { &*(ptr as *const Opaque<T::RegType>) }; + let drv = unsafe { &*(ptr as *const Opaque<T::DriverType>) }; + + Self::callbacks_attach(drv); // SAFETY: `drv` is guaranteed to be pinned until `T::unregister`. unsafe { T::register(drv, name, module) } diff --git a/rust/kernel/drm/driver.rs b/rust/kernel/drm/driver.rs index f30ee4c6245c..e09f977b5b51 100644 --- a/rust/kernel/drm/driver.rs +++ b/rust/kernel/drm/driver.rs @@ -121,7 +121,6 @@ pub trait Driver { pub struct Registration<T: Driver>(ARef<drm::Device<T>>); impl<T: Driver> Registration<T> { - /// Creates a new [`Registration`] and registers it. fn new(drm: &drm::Device<T>, flags: usize) -> Result<Self> { // SAFETY: `drm.as_raw()` is valid by the invariants of `drm::Device`. to_result(unsafe { bindings::drm_dev_register(drm.as_raw(), flags) })?; @@ -129,8 +128,9 @@ impl<T: Driver> Registration<T> { Ok(Self(drm.into())) } - /// Same as [`Registration::new`}, but transfers ownership of the [`Registration`] to - /// [`devres::register`]. + /// Registers a new [`Device`](drm::Device) with userspace. + /// + /// Ownership of the [`Registration`] object is passed to [`devres::register`]. pub fn new_foreign_owned( drm: &drm::Device<T>, dev: &device::Device<device::Bound>, diff --git a/rust/kernel/drm/gem/mod.rs b/rust/kernel/drm/gem/mod.rs index a7f682e95c01..d49a9ba02635 100644 --- a/rust/kernel/drm/gem/mod.rs +++ b/rust/kernel/drm/gem/mod.rs @@ -210,7 +210,7 @@ impl<T: DriverObject> Object<T> { // SAFETY: The arguments are all valid per the type invariants. to_result(unsafe { bindings::drm_gem_object_init(dev.as_raw(), obj.obj.get(), size) })?; - // SAFETY: We never move out of `Self`. + // SAFETY: We will never move out of `Self` as `ARef<Self>` is always treated as pinned. let ptr = KBox::into_raw(unsafe { Pin::into_inner_unchecked(obj) }); // SAFETY: `ptr` comes from `KBox::into_raw` and hence can't be NULL. @@ -253,7 +253,7 @@ impl<T: DriverObject> Object<T> { } // SAFETY: Instances of `Object<T>` are always reference-counted. -unsafe impl<T: DriverObject> crate::types::AlwaysRefCounted for Object<T> { +unsafe impl<T: DriverObject> crate::sync::aref::AlwaysRefCounted for Object<T> { fn inc_ref(&self) { // SAFETY: The existence of a shared reference guarantees that the refcount is non-zero. unsafe { bindings::drm_gem_object_get(self.as_raw()) }; @@ -293,9 +293,7 @@ impl<T: DriverObject> AllocImpl for Object<T> { } pub(super) const fn create_fops() -> bindings::file_operations { - // SAFETY: As by the type invariant, it is safe to initialize `bindings::file_operations` - // zeroed. - let mut fops: bindings::file_operations = unsafe { core::mem::zeroed() }; + let mut fops: bindings::file_operations = pin_init::zeroed(); fops.owner = core::ptr::null_mut(); fops.open = Some(bindings::drm_open); diff --git a/rust/kernel/faux.rs b/rust/kernel/faux.rs index 7fe2dd197e37..43b4974f48cd 100644 --- a/rust/kernel/faux.rs +++ b/rust/kernel/faux.rs @@ -6,8 +6,17 @@ //! //! C header: [`include/linux/device/faux.h`](srctree/include/linux/device/faux.h) -use crate::{bindings, device, error::code::*, prelude::*}; -use core::ptr::{addr_of_mut, null, null_mut, NonNull}; +use crate::{ + bindings, + device, + prelude::*, // +}; +use core::ptr::{ + addr_of_mut, + null, + null_mut, + NonNull, // +}; /// The registration of a faux device. /// diff --git a/rust/kernel/fmt.rs b/rust/kernel/fmt.rs index 84d634201d90..1e8725eb44ed 100644 --- a/rust/kernel/fmt.rs +++ b/rust/kernel/fmt.rs @@ -6,7 +6,7 @@ pub use core::fmt::{Arguments, Debug, Error, Formatter, Result, Write}; -/// Internal adapter used to route allow implementations of formatting traits for foreign types. +/// Internal adapter used to route and allow implementations of formatting traits for foreign types. /// /// It is inserted automatically by the [`fmt!`] macro and is not meant to be used directly. /// diff --git a/rust/kernel/i2c.rs b/rust/kernel/i2c.rs index 491e6cc25cf4..bb5b830f48c3 100644 --- a/rust/kernel/i2c.rs +++ b/rust/kernel/i2c.rs @@ -92,13 +92,22 @@ macro_rules! i2c_device_table { /// An adapter for the registration of I2C drivers. pub struct Adapter<T: Driver>(T); -// SAFETY: A call to `unregister` for a given instance of `RegType` is guaranteed to be valid if +// SAFETY: +// - `bindings::i2c_driver` is a C type declared as `repr(C)`. +// - `T` is the type of the driver's device private data. +// - `struct i2c_driver` embeds a `struct device_driver`. +// - `DEVICE_DRIVER_OFFSET` is the correct byte offset to the embedded `struct device_driver`. +unsafe impl<T: Driver + 'static> driver::DriverLayout for Adapter<T> { + type DriverType = bindings::i2c_driver; + type DriverData = T; + const DEVICE_DRIVER_OFFSET: usize = core::mem::offset_of!(Self::DriverType, driver); +} + +// SAFETY: A call to `unregister` for a given instance of `DriverType` is guaranteed to be valid if // a preceding call to `register` has been successful. unsafe impl<T: Driver + 'static> driver::RegistrationOps for Adapter<T> { - type RegType = bindings::i2c_driver; - unsafe fn register( - idrv: &Opaque<Self::RegType>, + idrv: &Opaque<Self::DriverType>, name: &'static CStr, module: &'static ThisModule, ) -> Result { @@ -133,12 +142,12 @@ unsafe impl<T: Driver + 'static> driver::RegistrationOps for Adapter<T> { (*idrv.get()).driver.acpi_match_table = acpi_table; } - // SAFETY: `idrv` is guaranteed to be a valid `RegType`. + // SAFETY: `idrv` is guaranteed to be a valid `DriverType`. to_result(unsafe { bindings::i2c_register_driver(module.0, idrv.get()) }) } - unsafe fn unregister(idrv: &Opaque<Self::RegType>) { - // SAFETY: `idrv` is guaranteed to be a valid `RegType`. + unsafe fn unregister(idrv: &Opaque<Self::DriverType>) { + // SAFETY: `idrv` is guaranteed to be a valid `DriverType`. unsafe { bindings::i2c_del_driver(idrv.get()) } } } @@ -169,9 +178,9 @@ impl<T: Driver + 'static> Adapter<T> { // SAFETY: `remove_callback` is only ever called after a successful call to // `probe_callback`, hence it's guaranteed that `I2cClient::set_drvdata()` has been called // and stored a `Pin<KBox<T>>`. - let data = unsafe { idev.as_ref().drvdata_obtain::<T>() }; + let data = unsafe { idev.as_ref().drvdata_borrow::<T>() }; - T::unbind(idev, data.as_ref()); + T::unbind(idev, data); } extern "C" fn shutdown_callback(idev: *mut bindings::i2c_client) { @@ -181,9 +190,9 @@ impl<T: Driver + 'static> Adapter<T> { // SAFETY: `shutdown_callback` is only ever called after a successful call to // `probe_callback`, hence it's guaranteed that `Device::set_drvdata()` has been called // and stored a `Pin<KBox<T>>`. - let data = unsafe { idev.as_ref().drvdata_obtain::<T>() }; + let data = unsafe { idev.as_ref().drvdata_borrow::<T>() }; - T::shutdown(idev, data.as_ref()); + T::shutdown(idev, data); } /// The [`i2c::IdTable`] of the corresponding driver. @@ -253,7 +262,7 @@ macro_rules! module_i2c_driver { /// # Example /// ///``` -/// # use kernel::{acpi, bindings, c_str, device::Core, i2c, of}; +/// # use kernel::{acpi, bindings, device::Core, i2c, of}; /// /// struct MyDriver; /// @@ -262,7 +271,7 @@ macro_rules! module_i2c_driver { /// MODULE_ACPI_TABLE, /// <MyDriver as i2c::Driver>::IdInfo, /// [ -/// (acpi::DeviceId::new(c_str!("LNUXBEEF")), ()) +/// (acpi::DeviceId::new(c"LNUXBEEF"), ()) /// ] /// ); /// @@ -271,7 +280,7 @@ macro_rules! module_i2c_driver { /// MODULE_I2C_TABLE, /// <MyDriver as i2c::Driver>::IdInfo, /// [ -/// (i2c::DeviceId::new(c_str!("rust_driver_i2c")), ()) +/// (i2c::DeviceId::new(c"rust_driver_i2c"), ()) /// ] /// ); /// @@ -280,7 +289,7 @@ macro_rules! module_i2c_driver { /// MODULE_OF_TABLE, /// <MyDriver as i2c::Driver>::IdInfo, /// [ -/// (of::DeviceId::new(c_str!("test,device")), ()) +/// (of::DeviceId::new(c"test,device"), ()) /// ] /// ); /// diff --git a/rust/kernel/impl_flags.rs b/rust/kernel/impl_flags.rs new file mode 100644 index 000000000000..e2bd7639da12 --- /dev/null +++ b/rust/kernel/impl_flags.rs @@ -0,0 +1,272 @@ +// SPDX-License-Identifier: GPL-2.0 + +//! Bitflag type generator. + +/// Common helper for declaring bitflag and bitmask types. +/// +/// This macro takes as input: +/// - A struct declaration representing a bitmask type +/// (e.g., `pub struct Permissions(u32)`). +/// - An enumeration declaration representing individual bit flags +/// (e.g., `pub enum Permission { ... }`). +/// +/// And generates: +/// - The struct and enum types with appropriate `#[repr]` attributes. +/// - Implementations of common bitflag operators +/// ([`::core::ops::BitOr`], [`::core::ops::BitAnd`], etc.). +/// - Utility methods such as `.contains()` to check flags. +/// +/// # Examples +/// +/// ``` +/// use kernel::impl_flags; +/// +/// impl_flags!( +/// /// Represents multiple permissions. +/// #[derive(Debug, Clone, Default, Copy, PartialEq, Eq)] +/// pub struct Permissions(u32); +/// +/// /// Represents a single permission. +/// #[derive(Debug, Clone, Copy, PartialEq, Eq)] +/// pub enum Permission { +/// /// Read permission. +/// Read = 1 << 0, +/// +/// /// Write permission. +/// Write = 1 << 1, +/// +/// /// Execute permission. +/// Execute = 1 << 2, +/// } +/// ); +/// +/// // Combine multiple permissions using the bitwise OR (`|`) operator. +/// let mut read_write: Permissions = Permission::Read | Permission::Write; +/// assert!(read_write.contains(Permission::Read)); +/// assert!(read_write.contains(Permission::Write)); +/// assert!(!read_write.contains(Permission::Execute)); +/// assert!(read_write.contains_any(Permission::Read | Permission::Execute)); +/// assert!(read_write.contains_all(Permission::Read | Permission::Write)); +/// +/// // Using the bitwise OR assignment (`|=`) operator. +/// read_write |= Permission::Execute; +/// assert!(read_write.contains(Permission::Execute)); +/// +/// // Masking a permission with the bitwise AND (`&`) operator. +/// let read_only: Permissions = read_write & Permission::Read; +/// assert!(read_only.contains(Permission::Read)); +/// assert!(!read_only.contains(Permission::Write)); +/// +/// // Toggling permissions with the bitwise XOR (`^`) operator. +/// let toggled: Permissions = read_only ^ Permission::Read; +/// assert!(!toggled.contains(Permission::Read)); +/// +/// // Inverting permissions with the bitwise NOT (`!`) operator. +/// let negated = !read_only; +/// assert!(negated.contains(Permission::Write)); +/// assert!(!negated.contains(Permission::Read)); +/// ``` +#[macro_export] +macro_rules! impl_flags { + ( + $(#[$outer_flags:meta])* + $vis_flags:vis struct $flags:ident($ty:ty); + + $(#[$outer_flag:meta])* + $vis_flag:vis enum $flag:ident { + $( + $(#[$inner_flag:meta])* + $name:ident = $value:expr + ),+ $( , )? + } + ) => { + $(#[$outer_flags])* + #[repr(transparent)] + $vis_flags struct $flags($ty); + + $(#[$outer_flag])* + #[repr($ty)] + $vis_flag enum $flag { + $( + $(#[$inner_flag])* + $name = $value + ),+ + } + + impl ::core::convert::From<$flag> for $flags { + #[inline] + fn from(value: $flag) -> Self { + Self(value as $ty) + } + } + + impl ::core::convert::From<$flags> for $ty { + #[inline] + fn from(value: $flags) -> Self { + value.0 + } + } + + impl ::core::ops::BitOr for $flags { + type Output = Self; + #[inline] + fn bitor(self, rhs: Self) -> Self::Output { + Self(self.0 | rhs.0) + } + } + + impl ::core::ops::BitOrAssign for $flags { + #[inline] + fn bitor_assign(&mut self, rhs: Self) { + *self = *self | rhs; + } + } + + impl ::core::ops::BitOr<$flag> for $flags { + type Output = Self; + #[inline] + fn bitor(self, rhs: $flag) -> Self::Output { + self | Self::from(rhs) + } + } + + impl ::core::ops::BitOrAssign<$flag> for $flags { + #[inline] + fn bitor_assign(&mut self, rhs: $flag) { + *self = *self | rhs; + } + } + + impl ::core::ops::BitAnd for $flags { + type Output = Self; + #[inline] + fn bitand(self, rhs: Self) -> Self::Output { + Self(self.0 & rhs.0) + } + } + + impl ::core::ops::BitAndAssign for $flags { + #[inline] + fn bitand_assign(&mut self, rhs: Self) { + *self = *self & rhs; + } + } + + impl ::core::ops::BitAnd<$flag> for $flags { + type Output = Self; + #[inline] + fn bitand(self, rhs: $flag) -> Self::Output { + self & Self::from(rhs) + } + } + + impl ::core::ops::BitAndAssign<$flag> for $flags { + #[inline] + fn bitand_assign(&mut self, rhs: $flag) { + *self = *self & rhs; + } + } + + impl ::core::ops::BitXor for $flags { + type Output = Self; + #[inline] + fn bitxor(self, rhs: Self) -> Self::Output { + Self((self.0 ^ rhs.0) & Self::all_bits()) + } + } + + impl ::core::ops::BitXorAssign for $flags { + #[inline] + fn bitxor_assign(&mut self, rhs: Self) { + *self = *self ^ rhs; + } + } + + impl ::core::ops::BitXor<$flag> for $flags { + type Output = Self; + #[inline] + fn bitxor(self, rhs: $flag) -> Self::Output { + self ^ Self::from(rhs) + } + } + + impl ::core::ops::BitXorAssign<$flag> for $flags { + #[inline] + fn bitxor_assign(&mut self, rhs: $flag) { + *self = *self ^ rhs; + } + } + + impl ::core::ops::Not for $flags { + type Output = Self; + #[inline] + fn not(self) -> Self::Output { + Self((!self.0) & Self::all_bits()) + } + } + + impl ::core::ops::BitOr for $flag { + type Output = $flags; + #[inline] + fn bitor(self, rhs: Self) -> Self::Output { + $flags(self as $ty | rhs as $ty) + } + } + + impl ::core::ops::BitAnd for $flag { + type Output = $flags; + #[inline] + fn bitand(self, rhs: Self) -> Self::Output { + $flags(self as $ty & rhs as $ty) + } + } + + impl ::core::ops::BitXor for $flag { + type Output = $flags; + #[inline] + fn bitxor(self, rhs: Self) -> Self::Output { + $flags((self as $ty ^ rhs as $ty) & $flags::all_bits()) + } + } + + impl ::core::ops::Not for $flag { + type Output = $flags; + #[inline] + fn not(self) -> Self::Output { + $flags((!(self as $ty)) & $flags::all_bits()) + } + } + + impl $flags { + /// Returns an empty instance where no flags are set. + #[inline] + pub const fn empty() -> Self { + Self(0) + } + + /// Returns a mask containing all valid flag bits. + #[inline] + pub const fn all_bits() -> $ty { + 0 $( | $value )+ + } + + /// Checks if a specific flag is set. + #[inline] + pub fn contains(self, flag: $flag) -> bool { + (self.0 & flag as $ty) == flag as $ty + } + + /// Checks if at least one of the provided flags is set. + #[inline] + pub fn contains_any(self, flags: $flags) -> bool { + (self.0 & flags.0) != 0 + } + + /// Checks if all of the provided flags are set. + #[inline] + pub fn contains_all(self, flags: $flags) -> bool { + (self.0 & flags.0) == flags.0 + } + } + }; +} diff --git a/rust/kernel/init.rs b/rust/kernel/init.rs index 899b9a962762..7a0d4559d7b5 100644 --- a/rust/kernel/init.rs +++ b/rust/kernel/init.rs @@ -219,20 +219,12 @@ pub trait InPlaceInit<T>: Sized { /// [`Error`]: crate::error::Error #[macro_export] macro_rules! try_init { - ($(&$this:ident in)? $t:ident $(::<$($generics:ty),* $(,)?>)? { - $($fields:tt)* - }) => { - ::pin_init::try_init!($(&$this in)? $t $(::<$($generics),*>)? { - $($fields)* - }? $crate::error::Error) - }; - ($(&$this:ident in)? $t:ident $(::<$($generics:ty),* $(,)?>)? { - $($fields:tt)* - }? $err:ty) => { - ::pin_init::try_init!($(&$this in)? $t $(::<$($generics),*>)? { - $($fields)* - }? $err) - }; + ($($args:tt)*) => { + ::pin_init::init!( + #[default_error($crate::error::Error)] + $($args)* + ) + } } /// Construct an in-place, fallible pinned initializer for `struct`s. @@ -279,18 +271,10 @@ macro_rules! try_init { /// [`Error`]: crate::error::Error #[macro_export] macro_rules! try_pin_init { - ($(&$this:ident in)? $t:ident $(::<$($generics:ty),* $(,)?>)? { - $($fields:tt)* - }) => { - ::pin_init::try_pin_init!($(&$this in)? $t $(::<$($generics),*>)? { - $($fields)* - }? $crate::error::Error) - }; - ($(&$this:ident in)? $t:ident $(::<$($generics:ty),* $(,)?>)? { - $($fields:tt)* - }? $err:ty) => { - ::pin_init::try_pin_init!($(&$this in)? $t $(::<$($generics),*>)? { - $($fields)* - }? $err) - }; + ($($args:tt)*) => { + ::pin_init::pin_init!( + #[default_error($crate::error::Error)] + $($args)* + ) + } } diff --git a/rust/kernel/io.rs b/rust/kernel/io.rs index 98e8b84e68d1..c1cca7b438c3 100644 --- a/rust/kernel/io.rs +++ b/rust/kernel/io.rs @@ -32,16 +32,16 @@ pub type ResourceSize = bindings::resource_size_t; /// By itself, the existence of an instance of this structure does not provide any guarantees that /// the represented MMIO region does exist or is properly mapped. /// -/// Instead, the bus specific MMIO implementation must convert this raw representation into an `Io` -/// instance providing the actual memory accessors. Only by the conversion into an `Io` structure -/// any guarantees are given. -pub struct IoRaw<const SIZE: usize = 0> { +/// Instead, the bus specific MMIO implementation must convert this raw representation into an +/// `Mmio` instance providing the actual memory accessors. Only by the conversion into an `Mmio` +/// structure any guarantees are given. +pub struct MmioRaw<const SIZE: usize = 0> { addr: usize, maxsize: usize, } -impl<const SIZE: usize> IoRaw<SIZE> { - /// Returns a new `IoRaw` instance on success, an error otherwise. +impl<const SIZE: usize> MmioRaw<SIZE> { + /// Returns a new `MmioRaw` instance on success, an error otherwise. pub fn new(addr: usize, maxsize: usize) -> Result<Self> { if maxsize < SIZE { return Err(EINVAL); @@ -81,14 +81,16 @@ impl<const SIZE: usize> IoRaw<SIZE> { /// ffi::c_void, /// io::{ /// Io, -/// IoRaw, +/// IoKnownSize, +/// Mmio, +/// MmioRaw, /// PhysAddr, /// }, /// }; /// use core::ops::Deref; /// -/// // See also [`pci::Bar`] for a real example. -/// struct IoMem<const SIZE: usize>(IoRaw<SIZE>); +/// // See also `pci::Bar` for a real example. +/// struct IoMem<const SIZE: usize>(MmioRaw<SIZE>); /// /// impl<const SIZE: usize> IoMem<SIZE> { /// /// # Safety @@ -103,7 +105,7 @@ impl<const SIZE: usize> IoRaw<SIZE> { /// return Err(ENOMEM); /// } /// -/// Ok(IoMem(IoRaw::new(addr as usize, SIZE)?)) +/// Ok(IoMem(MmioRaw::new(addr as usize, SIZE)?)) /// } /// } /// @@ -115,11 +117,11 @@ impl<const SIZE: usize> IoRaw<SIZE> { /// } /// /// impl<const SIZE: usize> Deref for IoMem<SIZE> { -/// type Target = Io<SIZE>; +/// type Target = Mmio<SIZE>; /// /// fn deref(&self) -> &Self::Target { /// // SAFETY: The memory range stored in `self` has been properly mapped in `Self::new`. -/// unsafe { Io::from_raw(&self.0) } +/// unsafe { Mmio::from_raw(&self.0) } /// } /// } /// @@ -133,104 +135,183 @@ impl<const SIZE: usize> IoRaw<SIZE> { /// # } /// ``` #[repr(transparent)] -pub struct Io<const SIZE: usize = 0>(IoRaw<SIZE>); +pub struct Mmio<const SIZE: usize = 0>(MmioRaw<SIZE>); + +/// Internal helper macros used to invoke C MMIO read functions. +/// +/// This macro is intended to be used by higher-level MMIO access macros (define_read) and provides +/// a unified expansion for infallible vs. fallible read semantics. It emits a direct call into the +/// corresponding C helper and performs the required cast to the Rust return type. +/// +/// # Parameters +/// +/// * `$c_fn` – The C function performing the MMIO read. +/// * `$self` – The I/O backend object. +/// * `$ty` – The type of the value to be read. +/// * `$addr` – The MMIO address to read. +/// +/// This macro does not perform any validation; all invariants must be upheld by the higher-level +/// abstraction invoking it. +macro_rules! call_mmio_read { + (infallible, $c_fn:ident, $self:ident, $type:ty, $addr:expr) => { + // SAFETY: By the type invariant `addr` is a valid address for MMIO operations. + unsafe { bindings::$c_fn($addr as *const c_void) as $type } + }; + + (fallible, $c_fn:ident, $self:ident, $type:ty, $addr:expr) => {{ + // SAFETY: By the type invariant `addr` is a valid address for MMIO operations. + Ok(unsafe { bindings::$c_fn($addr as *const c_void) as $type }) + }}; +} + +/// Internal helper macros used to invoke C MMIO write functions. +/// +/// This macro is intended to be used by higher-level MMIO access macros (define_write) and provides +/// a unified expansion for infallible vs. fallible write semantics. It emits a direct call into the +/// corresponding C helper and performs the required cast to the Rust return type. +/// +/// # Parameters +/// +/// * `$c_fn` – The C function performing the MMIO write. +/// * `$self` – The I/O backend object. +/// * `$ty` – The type of the written value. +/// * `$addr` – The MMIO address to write. +/// * `$value` – The value to write. +/// +/// This macro does not perform any validation; all invariants must be upheld by the higher-level +/// abstraction invoking it. +macro_rules! call_mmio_write { + (infallible, $c_fn:ident, $self:ident, $ty:ty, $addr:expr, $value:expr) => { + // SAFETY: By the type invariant `addr` is a valid address for MMIO operations. + unsafe { bindings::$c_fn($value, $addr as *mut c_void) } + }; + + (fallible, $c_fn:ident, $self:ident, $ty:ty, $addr:expr, $value:expr) => {{ + // SAFETY: By the type invariant `addr` is a valid address for MMIO operations. + unsafe { bindings::$c_fn($value, $addr as *mut c_void) }; + Ok(()) + }}; +} macro_rules! define_read { - ($(#[$attr:meta])* $name:ident, $try_name:ident, $c_fn:ident -> $type_name:ty) => { + (infallible, $(#[$attr:meta])* $vis:vis $name:ident, $call_macro:ident($c_fn:ident) -> + $type_name:ty) => { /// Read IO data from a given offset known at compile time. /// /// Bound checks are performed on compile time, hence if the offset is not known at compile /// time, the build will fail. $(#[$attr])* - #[inline] - pub fn $name(&self, offset: usize) -> $type_name { + // Always inline to optimize out error path of `io_addr_assert`. + #[inline(always)] + $vis fn $name(&self, offset: usize) -> $type_name { let addr = self.io_addr_assert::<$type_name>(offset); - // SAFETY: By the type invariant `addr` is a valid address for MMIO operations. - unsafe { bindings::$c_fn(addr as *const c_void) } + // SAFETY: By the type invariant `addr` is a valid address for IO operations. + $call_macro!(infallible, $c_fn, self, $type_name, addr) } + }; + (fallible, $(#[$attr:meta])* $vis:vis $try_name:ident, $call_macro:ident($c_fn:ident) -> + $type_name:ty) => { /// Read IO data from a given offset. /// /// Bound checks are performed on runtime, it fails if the offset (plus the type size) is /// out of bounds. $(#[$attr])* - pub fn $try_name(&self, offset: usize) -> Result<$type_name> { + $vis fn $try_name(&self, offset: usize) -> Result<$type_name> { let addr = self.io_addr::<$type_name>(offset)?; - // SAFETY: By the type invariant `addr` is a valid address for MMIO operations. - Ok(unsafe { bindings::$c_fn(addr as *const c_void) }) + // SAFETY: By the type invariant `addr` is a valid address for IO operations. + $call_macro!(fallible, $c_fn, self, $type_name, addr) } }; } +pub(crate) use define_read; macro_rules! define_write { - ($(#[$attr:meta])* $name:ident, $try_name:ident, $c_fn:ident <- $type_name:ty) => { + (infallible, $(#[$attr:meta])* $vis:vis $name:ident, $call_macro:ident($c_fn:ident) <- + $type_name:ty) => { /// Write IO data from a given offset known at compile time. /// /// Bound checks are performed on compile time, hence if the offset is not known at compile /// time, the build will fail. $(#[$attr])* - #[inline] - pub fn $name(&self, value: $type_name, offset: usize) { + // Always inline to optimize out error path of `io_addr_assert`. + #[inline(always)] + $vis fn $name(&self, value: $type_name, offset: usize) { let addr = self.io_addr_assert::<$type_name>(offset); - // SAFETY: By the type invariant `addr` is a valid address for MMIO operations. - unsafe { bindings::$c_fn(value, addr as *mut c_void) } + $call_macro!(infallible, $c_fn, self, $type_name, addr, value); } + }; + (fallible, $(#[$attr:meta])* $vis:vis $try_name:ident, $call_macro:ident($c_fn:ident) <- + $type_name:ty) => { /// Write IO data from a given offset. /// /// Bound checks are performed on runtime, it fails if the offset (plus the type size) is /// out of bounds. $(#[$attr])* - pub fn $try_name(&self, value: $type_name, offset: usize) -> Result { + $vis fn $try_name(&self, value: $type_name, offset: usize) -> Result { let addr = self.io_addr::<$type_name>(offset)?; - // SAFETY: By the type invariant `addr` is a valid address for MMIO operations. - unsafe { bindings::$c_fn(value, addr as *mut c_void) } - Ok(()) + $call_macro!(fallible, $c_fn, self, $type_name, addr, value) } }; } - -impl<const SIZE: usize> Io<SIZE> { - /// Converts an `IoRaw` into an `Io` instance, providing the accessors to the MMIO mapping. - /// - /// # Safety - /// - /// Callers must ensure that `addr` is the start of a valid I/O mapped memory region of size - /// `maxsize`. - pub unsafe fn from_raw(raw: &IoRaw<SIZE>) -> &Self { - // SAFETY: `Io` is a transparent wrapper around `IoRaw`. - unsafe { &*core::ptr::from_ref(raw).cast() } +pub(crate) use define_write; + +/// Checks whether an access of type `U` at the given `offset` +/// is valid within this region. +#[inline] +const fn offset_valid<U>(offset: usize, size: usize) -> bool { + let type_size = core::mem::size_of::<U>(); + if let Some(end) = offset.checked_add(type_size) { + end <= size && offset % type_size == 0 + } else { + false } +} + +/// Marker trait indicating that an I/O backend supports operations of a certain type. +/// +/// Different I/O backends can implement this trait to expose only the operations they support. +/// +/// For example, a PCI configuration space may implement `IoCapable<u8>`, `IoCapable<u16>`, +/// and `IoCapable<u32>`, but not `IoCapable<u64>`, while an MMIO region on a 64-bit +/// system might implement all four. +pub trait IoCapable<T> {} +/// Types implementing this trait (e.g. MMIO BARs or PCI config regions) +/// can perform I/O operations on regions of memory. +/// +/// This is an abstract representation to be implemented by arbitrary I/O +/// backends (e.g. MMIO, PCI config space, etc.). +/// +/// The [`Io`] trait provides: +/// - Base address and size information +/// - Helper methods for offset validation and address calculation +/// - Fallible (runtime checked) accessors for different data widths +/// +/// Which I/O methods are available depends on which [`IoCapable<T>`] traits +/// are implemented for the type. +/// +/// # Examples +/// +/// For MMIO regions, all widths (u8, u16, u32, and u64 on 64-bit systems) are typically +/// supported. For PCI configuration space, u8, u16, and u32 are supported but u64 is not. +pub trait Io { /// Returns the base address of this mapping. - #[inline] - pub fn addr(&self) -> usize { - self.0.addr() - } + fn addr(&self) -> usize; /// Returns the maximum size of this mapping. - #[inline] - pub fn maxsize(&self) -> usize { - self.0.maxsize() - } - - #[inline] - const fn offset_valid<U>(offset: usize, size: usize) -> bool { - let type_size = core::mem::size_of::<U>(); - if let Some(end) = offset.checked_add(type_size) { - end <= size && offset % type_size == 0 - } else { - false - } - } + fn maxsize(&self) -> usize; + /// Returns the absolute I/O address for a given `offset`, + /// performing runtime bound checks. #[inline] fn io_addr<U>(&self, offset: usize) -> Result<usize> { - if !Self::offset_valid::<U>(offset, self.maxsize()) { + if !offset_valid::<U>(offset, self.maxsize()) { return Err(EINVAL); } @@ -239,50 +320,289 @@ impl<const SIZE: usize> Io<SIZE> { self.addr().checked_add(offset).ok_or(EINVAL) } - #[inline] + /// Fallible 8-bit read with runtime bounds check. + #[inline(always)] + fn try_read8(&self, _offset: usize) -> Result<u8> + where + Self: IoCapable<u8>, + { + build_error!("Backend does not support fallible 8-bit read") + } + + /// Fallible 16-bit read with runtime bounds check. + #[inline(always)] + fn try_read16(&self, _offset: usize) -> Result<u16> + where + Self: IoCapable<u16>, + { + build_error!("Backend does not support fallible 16-bit read") + } + + /// Fallible 32-bit read with runtime bounds check. + #[inline(always)] + fn try_read32(&self, _offset: usize) -> Result<u32> + where + Self: IoCapable<u32>, + { + build_error!("Backend does not support fallible 32-bit read") + } + + /// Fallible 64-bit read with runtime bounds check. + #[inline(always)] + fn try_read64(&self, _offset: usize) -> Result<u64> + where + Self: IoCapable<u64>, + { + build_error!("Backend does not support fallible 64-bit read") + } + + /// Fallible 8-bit write with runtime bounds check. + #[inline(always)] + fn try_write8(&self, _value: u8, _offset: usize) -> Result + where + Self: IoCapable<u8>, + { + build_error!("Backend does not support fallible 8-bit write") + } + + /// Fallible 16-bit write with runtime bounds check. + #[inline(always)] + fn try_write16(&self, _value: u16, _offset: usize) -> Result + where + Self: IoCapable<u16>, + { + build_error!("Backend does not support fallible 16-bit write") + } + + /// Fallible 32-bit write with runtime bounds check. + #[inline(always)] + fn try_write32(&self, _value: u32, _offset: usize) -> Result + where + Self: IoCapable<u32>, + { + build_error!("Backend does not support fallible 32-bit write") + } + + /// Fallible 64-bit write with runtime bounds check. + #[inline(always)] + fn try_write64(&self, _value: u64, _offset: usize) -> Result + where + Self: IoCapable<u64>, + { + build_error!("Backend does not support fallible 64-bit write") + } + + /// Infallible 8-bit read with compile-time bounds check. + #[inline(always)] + fn read8(&self, _offset: usize) -> u8 + where + Self: IoKnownSize + IoCapable<u8>, + { + build_error!("Backend does not support infallible 8-bit read") + } + + /// Infallible 16-bit read with compile-time bounds check. + #[inline(always)] + fn read16(&self, _offset: usize) -> u16 + where + Self: IoKnownSize + IoCapable<u16>, + { + build_error!("Backend does not support infallible 16-bit read") + } + + /// Infallible 32-bit read with compile-time bounds check. + #[inline(always)] + fn read32(&self, _offset: usize) -> u32 + where + Self: IoKnownSize + IoCapable<u32>, + { + build_error!("Backend does not support infallible 32-bit read") + } + + /// Infallible 64-bit read with compile-time bounds check. + #[inline(always)] + fn read64(&self, _offset: usize) -> u64 + where + Self: IoKnownSize + IoCapable<u64>, + { + build_error!("Backend does not support infallible 64-bit read") + } + + /// Infallible 8-bit write with compile-time bounds check. + #[inline(always)] + fn write8(&self, _value: u8, _offset: usize) + where + Self: IoKnownSize + IoCapable<u8>, + { + build_error!("Backend does not support infallible 8-bit write") + } + + /// Infallible 16-bit write with compile-time bounds check. + #[inline(always)] + fn write16(&self, _value: u16, _offset: usize) + where + Self: IoKnownSize + IoCapable<u16>, + { + build_error!("Backend does not support infallible 16-bit write") + } + + /// Infallible 32-bit write with compile-time bounds check. + #[inline(always)] + fn write32(&self, _value: u32, _offset: usize) + where + Self: IoKnownSize + IoCapable<u32>, + { + build_error!("Backend does not support infallible 32-bit write") + } + + /// Infallible 64-bit write with compile-time bounds check. + #[inline(always)] + fn write64(&self, _value: u64, _offset: usize) + where + Self: IoKnownSize + IoCapable<u64>, + { + build_error!("Backend does not support infallible 64-bit write") + } +} + +/// Trait for types with a known size at compile time. +/// +/// This trait is implemented by I/O backends that have a compile-time known size, +/// enabling the use of infallible I/O accessors with compile-time bounds checking. +/// +/// Types implementing this trait can use the infallible methods in [`Io`] trait +/// (e.g., `read8`, `write32`), which require `Self: IoKnownSize` bound. +pub trait IoKnownSize: Io { + /// Minimum usable size of this region. + const MIN_SIZE: usize; + + /// Returns the absolute I/O address for a given `offset`, + /// performing compile-time bound checks. + // Always inline to optimize out error path of `build_assert`. + #[inline(always)] fn io_addr_assert<U>(&self, offset: usize) -> usize { - build_assert!(Self::offset_valid::<U>(offset, SIZE)); + build_assert!(offset_valid::<U>(offset, Self::MIN_SIZE)); self.addr() + offset } +} + +// MMIO regions support 8, 16, and 32-bit accesses. +impl<const SIZE: usize> IoCapable<u8> for Mmio<SIZE> {} +impl<const SIZE: usize> IoCapable<u16> for Mmio<SIZE> {} +impl<const SIZE: usize> IoCapable<u32> for Mmio<SIZE> {} + +// MMIO regions on 64-bit systems also support 64-bit accesses. +#[cfg(CONFIG_64BIT)] +impl<const SIZE: usize> IoCapable<u64> for Mmio<SIZE> {} + +impl<const SIZE: usize> Io for Mmio<SIZE> { + /// Returns the base address of this mapping. + #[inline] + fn addr(&self) -> usize { + self.0.addr() + } + + /// Returns the maximum size of this mapping. + #[inline] + fn maxsize(&self) -> usize { + self.0.maxsize() + } - define_read!(read8, try_read8, readb -> u8); - define_read!(read16, try_read16, readw -> u16); - define_read!(read32, try_read32, readl -> u32); + define_read!(fallible, try_read8, call_mmio_read(readb) -> u8); + define_read!(fallible, try_read16, call_mmio_read(readw) -> u16); + define_read!(fallible, try_read32, call_mmio_read(readl) -> u32); define_read!( + fallible, #[cfg(CONFIG_64BIT)] - read64, try_read64, - readq -> u64 + call_mmio_read(readq) -> u64 + ); + + define_write!(fallible, try_write8, call_mmio_write(writeb) <- u8); + define_write!(fallible, try_write16, call_mmio_write(writew) <- u16); + define_write!(fallible, try_write32, call_mmio_write(writel) <- u32); + define_write!( + fallible, + #[cfg(CONFIG_64BIT)] + try_write64, + call_mmio_write(writeq) <- u64 ); - define_read!(read8_relaxed, try_read8_relaxed, readb_relaxed -> u8); - define_read!(read16_relaxed, try_read16_relaxed, readw_relaxed -> u16); - define_read!(read32_relaxed, try_read32_relaxed, readl_relaxed -> u32); + define_read!(infallible, read8, call_mmio_read(readb) -> u8); + define_read!(infallible, read16, call_mmio_read(readw) -> u16); + define_read!(infallible, read32, call_mmio_read(readl) -> u32); define_read!( + infallible, #[cfg(CONFIG_64BIT)] - read64_relaxed, - try_read64_relaxed, - readq_relaxed -> u64 + read64, + call_mmio_read(readq) -> u64 ); - define_write!(write8, try_write8, writeb <- u8); - define_write!(write16, try_write16, writew <- u16); - define_write!(write32, try_write32, writel <- u32); + define_write!(infallible, write8, call_mmio_write(writeb) <- u8); + define_write!(infallible, write16, call_mmio_write(writew) <- u16); + define_write!(infallible, write32, call_mmio_write(writel) <- u32); define_write!( + infallible, #[cfg(CONFIG_64BIT)] write64, - try_write64, - writeq <- u64 + call_mmio_write(writeq) <- u64 + ); +} + +impl<const SIZE: usize> IoKnownSize for Mmio<SIZE> { + const MIN_SIZE: usize = SIZE; +} + +impl<const SIZE: usize> Mmio<SIZE> { + /// Converts an `MmioRaw` into an `Mmio` instance, providing the accessors to the MMIO mapping. + /// + /// # Safety + /// + /// Callers must ensure that `addr` is the start of a valid I/O mapped memory region of size + /// `maxsize`. + pub unsafe fn from_raw(raw: &MmioRaw<SIZE>) -> &Self { + // SAFETY: `Mmio` is a transparent wrapper around `MmioRaw`. + unsafe { &*core::ptr::from_ref(raw).cast() } + } + + define_read!(infallible, pub read8_relaxed, call_mmio_read(readb_relaxed) -> u8); + define_read!(infallible, pub read16_relaxed, call_mmio_read(readw_relaxed) -> u16); + define_read!(infallible, pub read32_relaxed, call_mmio_read(readl_relaxed) -> u32); + define_read!( + infallible, + #[cfg(CONFIG_64BIT)] + pub read64_relaxed, + call_mmio_read(readq_relaxed) -> u64 + ); + + define_read!(fallible, pub try_read8_relaxed, call_mmio_read(readb_relaxed) -> u8); + define_read!(fallible, pub try_read16_relaxed, call_mmio_read(readw_relaxed) -> u16); + define_read!(fallible, pub try_read32_relaxed, call_mmio_read(readl_relaxed) -> u32); + define_read!( + fallible, + #[cfg(CONFIG_64BIT)] + pub try_read64_relaxed, + call_mmio_read(readq_relaxed) -> u64 + ); + + define_write!(infallible, pub write8_relaxed, call_mmio_write(writeb_relaxed) <- u8); + define_write!(infallible, pub write16_relaxed, call_mmio_write(writew_relaxed) <- u16); + define_write!(infallible, pub write32_relaxed, call_mmio_write(writel_relaxed) <- u32); + define_write!( + infallible, + #[cfg(CONFIG_64BIT)] + pub write64_relaxed, + call_mmio_write(writeq_relaxed) <- u64 ); - define_write!(write8_relaxed, try_write8_relaxed, writeb_relaxed <- u8); - define_write!(write16_relaxed, try_write16_relaxed, writew_relaxed <- u16); - define_write!(write32_relaxed, try_write32_relaxed, writel_relaxed <- u32); + define_write!(fallible, pub try_write8_relaxed, call_mmio_write(writeb_relaxed) <- u8); + define_write!(fallible, pub try_write16_relaxed, call_mmio_write(writew_relaxed) <- u16); + define_write!(fallible, pub try_write32_relaxed, call_mmio_write(writel_relaxed) <- u32); define_write!( + fallible, #[cfg(CONFIG_64BIT)] - write64_relaxed, - try_write64_relaxed, - writeq_relaxed <- u64 + pub try_write64_relaxed, + call_mmio_write(writeq_relaxed) <- u64 ); } diff --git a/rust/kernel/io/mem.rs b/rust/kernel/io/mem.rs index b03b82cd531b..620022cff401 100644 --- a/rust/kernel/io/mem.rs +++ b/rust/kernel/io/mem.rs @@ -5,7 +5,6 @@ use core::ops::Deref; use crate::{ - c_str, device::{ Bound, Device, // @@ -17,8 +16,8 @@ use crate::{ Region, Resource, // }, - Io, - IoRaw, // + Mmio, + MmioRaw, // }, prelude::*, }; @@ -52,7 +51,12 @@ impl<'a> IoRequest<'a> { /// illustration purposes. /// /// ```no_run - /// use kernel::{bindings, c_str, platform, of, device::Core}; + /// use kernel::{ + /// bindings, + /// device::Core, + /// of, + /// platform, + /// }; /// struct SampleDriver; /// /// impl platform::Driver for SampleDriver { @@ -110,7 +114,12 @@ impl<'a> IoRequest<'a> { /// illustration purposes. /// /// ```no_run - /// use kernel::{bindings, c_str, platform, of, device::Core}; + /// use kernel::{ + /// bindings, + /// device::Core, + /// of, + /// platform, + /// }; /// struct SampleDriver; /// /// impl platform::Driver for SampleDriver { @@ -172,7 +181,7 @@ impl<const SIZE: usize> ExclusiveIoMem<SIZE> { fn ioremap(resource: &Resource) -> Result<Self> { let start = resource.start(); let size = resource.size(); - let name = resource.name().unwrap_or(c_str!("")); + let name = resource.name().unwrap_or_default(); let region = resource .request_region( @@ -203,7 +212,7 @@ impl<const SIZE: usize> ExclusiveIoMem<SIZE> { } impl<const SIZE: usize> Deref for ExclusiveIoMem<SIZE> { - type Target = Io<SIZE>; + type Target = Mmio<SIZE>; fn deref(&self) -> &Self::Target { &self.iomem @@ -217,10 +226,10 @@ impl<const SIZE: usize> Deref for ExclusiveIoMem<SIZE> { /// /// # Invariants /// -/// [`IoMem`] always holds an [`IoRaw`] instance that holds a valid pointer to the +/// [`IoMem`] always holds an [`MmioRaw`] instance that holds a valid pointer to the /// start of the I/O memory mapped region. pub struct IoMem<const SIZE: usize = 0> { - io: IoRaw<SIZE>, + io: MmioRaw<SIZE>, } impl<const SIZE: usize> IoMem<SIZE> { @@ -255,7 +264,7 @@ impl<const SIZE: usize> IoMem<SIZE> { return Err(ENOMEM); } - let io = IoRaw::new(addr as usize, size)?; + let io = MmioRaw::new(addr as usize, size)?; let io = IoMem { io }; Ok(io) @@ -278,10 +287,10 @@ impl<const SIZE: usize> Drop for IoMem<SIZE> { } impl<const SIZE: usize> Deref for IoMem<SIZE> { - type Target = Io<SIZE>; + type Target = Mmio<SIZE>; fn deref(&self) -> &Self::Target { // SAFETY: Safe as by the invariant of `IoMem`. - unsafe { Io::from_raw(&self.io) } + unsafe { Mmio::from_raw(&self.io) } } } diff --git a/rust/kernel/io/poll.rs b/rust/kernel/io/poll.rs index b1a2570364f4..75d1b3e8596c 100644 --- a/rust/kernel/io/poll.rs +++ b/rust/kernel/io/poll.rs @@ -45,12 +45,16 @@ use crate::{ /// # Examples /// /// ```no_run -/// use kernel::io::{Io, poll::read_poll_timeout}; +/// use kernel::io::{ +/// Io, +/// Mmio, +/// poll::read_poll_timeout, // +/// }; /// use kernel::time::Delta; /// /// const HW_READY: u16 = 0x01; /// -/// fn wait_for_hardware<const SIZE: usize>(io: &Io<SIZE>) -> Result { +/// fn wait_for_hardware<const SIZE: usize>(io: &Mmio<SIZE>) -> Result { /// read_poll_timeout( /// // The `op` closure reads the value of a specific status register. /// || io.try_read16(0x1000), @@ -128,12 +132,16 @@ where /// # Examples /// /// ```no_run -/// use kernel::io::{poll::read_poll_timeout_atomic, Io}; +/// use kernel::io::{ +/// Io, +/// Mmio, +/// poll::read_poll_timeout_atomic, // +/// }; /// use kernel::time::Delta; /// /// const HW_READY: u16 = 0x01; /// -/// fn wait_for_hardware<const SIZE: usize>(io: &Io<SIZE>) -> Result { +/// fn wait_for_hardware<const SIZE: usize>(io: &Mmio<SIZE>) -> Result { /// read_poll_timeout_atomic( /// // The `op` closure reads the value of a specific status register. /// || io.try_read16(0x1000), diff --git a/rust/kernel/io/resource.rs b/rust/kernel/io/resource.rs index 56cfde97ce87..b7ac9faf141d 100644 --- a/rust/kernel/io/resource.rs +++ b/rust/kernel/io/resource.rs @@ -226,6 +226,8 @@ impl Flags { /// Resource represents a memory region that must be ioremaped using `ioremap_np`. pub const IORESOURCE_MEM_NONPOSTED: Flags = Flags::new(bindings::IORESOURCE_MEM_NONPOSTED); + // Always inline to optimize out error path of `build_assert`. + #[inline(always)] const fn new(value: u32) -> Self { crate::build_assert!(value as u64 <= c_ulong::MAX as u64); Flags(value as c_ulong) diff --git a/rust/kernel/iommu/mod.rs b/rust/kernel/iommu/mod.rs new file mode 100644 index 000000000000..1423d7b19b57 --- /dev/null +++ b/rust/kernel/iommu/mod.rs @@ -0,0 +1,5 @@ +// SPDX-License-Identifier: GPL-2.0 + +//! Rust support related to IOMMU. + +pub mod pgtable; diff --git a/rust/kernel/iommu/pgtable.rs b/rust/kernel/iommu/pgtable.rs new file mode 100644 index 000000000000..c88e38fd938a --- /dev/null +++ b/rust/kernel/iommu/pgtable.rs @@ -0,0 +1,279 @@ +// SPDX-License-Identifier: GPL-2.0 + +//! IOMMU page table management. +//! +//! C header: [`include/linux/io-pgtable.h`](srctree/include/linux/io-pgtable.h) + +use core::{ + marker::PhantomData, + ptr::NonNull, // +}; + +use crate::{ + alloc, + bindings, + device::{ + Bound, + Device, // + }, + devres::Devres, + error::to_result, + io::PhysAddr, + prelude::*, // +}; + +use bindings::io_pgtable_fmt; + +/// Protection flags used with IOMMU mappings. +pub mod prot { + /// Read access. + pub const READ: u32 = bindings::IOMMU_READ; + /// Write access. + pub const WRITE: u32 = bindings::IOMMU_WRITE; + /// Request cache coherency. + pub const CACHE: u32 = bindings::IOMMU_CACHE; + /// Request no-execute permission. + pub const NOEXEC: u32 = bindings::IOMMU_NOEXEC; + /// MMIO peripheral mapping. + pub const MMIO: u32 = bindings::IOMMU_MMIO; + /// Privileged mapping. + pub const PRIVILEGED: u32 = bindings::IOMMU_PRIV; +} + +/// Represents a requested `io_pgtable` configuration. +pub struct Config { + /// Quirk bitmask (type-specific). + pub quirks: usize, + /// Valid page sizes, as a bitmask of powers of two. + pub pgsize_bitmap: usize, + /// Input address space size in bits. + pub ias: u32, + /// Output address space size in bits. + pub oas: u32, + /// IOMMU uses coherent accesses for page table walks. + pub coherent_walk: bool, +} + +/// An io page table using a specific format. +/// +/// # Invariants +/// +/// The pointer references a valid io page table. +pub struct IoPageTable<F: IoPageTableFmt> { + ptr: NonNull<bindings::io_pgtable_ops>, + _marker: PhantomData<F>, +} + +// SAFETY: `struct io_pgtable_ops` is not restricted to a single thread. +unsafe impl<F: IoPageTableFmt> Send for IoPageTable<F> {} +// SAFETY: `struct io_pgtable_ops` may be accessed concurrently. +unsafe impl<F: IoPageTableFmt> Sync for IoPageTable<F> {} + +/// The format used by this page table. +pub trait IoPageTableFmt: 'static { + /// The value representing this format. + const FORMAT: io_pgtable_fmt; +} + +impl<F: IoPageTableFmt> IoPageTable<F> { + /// Create a new `IoPageTable` as a device resource. + #[inline] + pub fn new( + dev: &Device<Bound>, + config: Config, + ) -> impl PinInit<Devres<IoPageTable<F>>, Error> + '_ { + // SAFETY: Devres ensures that the value is dropped during device unbind. + Devres::new(dev, unsafe { Self::new_raw(dev, config) }) + } + + /// Create a new `IoPageTable`. + /// + /// # Safety + /// + /// If successful, then the returned `IoPageTable` must be dropped before the device is + /// unbound. + #[inline] + pub unsafe fn new_raw(dev: &Device<Bound>, config: Config) -> Result<IoPageTable<F>> { + let mut raw_cfg = bindings::io_pgtable_cfg { + quirks: config.quirks, + pgsize_bitmap: config.pgsize_bitmap, + ias: config.ias, + oas: config.oas, + coherent_walk: config.coherent_walk, + tlb: &raw const NOOP_FLUSH_OPS, + iommu_dev: dev.as_raw(), + // SAFETY: All zeroes is a valid value for `struct io_pgtable_cfg`. + ..unsafe { core::mem::zeroed() } + }; + + // SAFETY: + // * The raw_cfg pointer is valid for the duration of this call. + // * The provided `FLUSH_OPS` contains valid function pointers that accept a null pointer + // as cookie. + // * The caller ensures that the io pgtable does not outlive the device. + let ops = unsafe { + bindings::alloc_io_pgtable_ops(F::FORMAT, &mut raw_cfg, core::ptr::null_mut()) + }; + + // INVARIANT: We successfully created a valid page table. + Ok(IoPageTable { + ptr: NonNull::new(ops).ok_or(ENOMEM)?, + _marker: PhantomData, + }) + } + + /// Obtain a raw pointer to the underlying `struct io_pgtable_ops`. + #[inline] + pub fn raw_ops(&self) -> *mut bindings::io_pgtable_ops { + self.ptr.as_ptr() + } + + /// Obtain a raw pointer to the underlying `struct io_pgtable`. + #[inline] + pub fn raw_pgtable(&self) -> *mut bindings::io_pgtable { + // SAFETY: The io_pgtable_ops of an io-pgtable is always the ops field of a io_pgtable. + unsafe { kernel::container_of!(self.raw_ops(), bindings::io_pgtable, ops) } + } + + /// Obtain a raw pointer to the underlying `struct io_pgtable_cfg`. + #[inline] + pub fn raw_cfg(&self) -> *mut bindings::io_pgtable_cfg { + // SAFETY: The `raw_pgtable()` method returns a valid pointer. + unsafe { &raw mut (*self.raw_pgtable()).cfg } + } + + /// Map a physically contiguous range of pages of the same size. + /// + /// Even if successful, this operation may not map the entire range. In that case, only a + /// prefix of the range is mapped, and the returned integer indicates its length in bytes. In + /// this case, the caller will usually call `map_pages` again for the remaining range. + /// + /// The returned [`Result`] indicates whether an error was encountered while mapping pages. + /// Note that this may return a non-zero length even if an error was encountered. The caller + /// will usually [unmap the relevant pages](Self::unmap_pages) on error. + /// + /// The caller must flush the TLB before using the pgtable to access the newly created mapping. + /// + /// # Safety + /// + /// * No other io-pgtable operation may access the range `iova .. iova+pgsize*pgcount` while + /// this `map_pages` operation executes. + /// * This page table must not contain any mapping that overlaps with the mapping created by + /// this call. + /// * If this page table is live, then the caller must ensure that it's okay to access the + /// physical address being mapped for the duration in which it is mapped. + #[inline] + pub unsafe fn map_pages( + &self, + iova: usize, + paddr: PhysAddr, + pgsize: usize, + pgcount: usize, + prot: u32, + flags: alloc::Flags, + ) -> (usize, Result) { + let mut mapped: usize = 0; + + // SAFETY: The `map_pages` function in `io_pgtable_ops` is never null. + let map_pages = unsafe { (*self.raw_ops()).map_pages.unwrap_unchecked() }; + + // SAFETY: The safety requirements of this method are sufficient to call `map_pages`. + let ret = to_result(unsafe { + (map_pages)( + self.raw_ops(), + iova, + paddr, + pgsize, + pgcount, + prot as i32, + flags.as_raw(), + &mut mapped, + ) + }); + + (mapped, ret) + } + + /// Unmap a range of virtually contiguous pages of the same size. + /// + /// This may not unmap the entire range, and returns the length of the unmapped prefix in + /// bytes. + /// + /// # Safety + /// + /// * No other io-pgtable operation may access the range `iova .. iova+pgsize*pgcount` while + /// this `unmap_pages` operation executes. + /// * This page table must contain one or more consecutive mappings starting at `iova` whose + /// total size is `pgcount * pgsize`. + #[inline] + #[must_use] + pub unsafe fn unmap_pages(&self, iova: usize, pgsize: usize, pgcount: usize) -> usize { + // SAFETY: The `unmap_pages` function in `io_pgtable_ops` is never null. + let unmap_pages = unsafe { (*self.raw_ops()).unmap_pages.unwrap_unchecked() }; + + // SAFETY: The safety requirements of this method are sufficient to call `unmap_pages`. + unsafe { (unmap_pages)(self.raw_ops(), iova, pgsize, pgcount, core::ptr::null_mut()) } + } +} + +// For the initial users of these rust bindings, the GPU FW is managing the IOTLB and performs all +// required invalidations using a range. There is no need for it get ARM style invalidation +// instructions from the page table code. +// +// Support for flushing the TLB with ARM style invalidation instructions may be added in the +// future. +static NOOP_FLUSH_OPS: bindings::iommu_flush_ops = bindings::iommu_flush_ops { + tlb_flush_all: Some(rust_tlb_flush_all_noop), + tlb_flush_walk: Some(rust_tlb_flush_walk_noop), + tlb_add_page: None, +}; + +#[no_mangle] +extern "C" fn rust_tlb_flush_all_noop(_cookie: *mut core::ffi::c_void) {} + +#[no_mangle] +extern "C" fn rust_tlb_flush_walk_noop( + _iova: usize, + _size: usize, + _granule: usize, + _cookie: *mut core::ffi::c_void, +) { +} + +impl<F: IoPageTableFmt> Drop for IoPageTable<F> { + fn drop(&mut self) { + // SAFETY: The caller of `Self::ttbr()` promised that the page table is not live when this + // destructor runs. + unsafe { bindings::free_io_pgtable_ops(self.raw_ops()) }; + } +} + +/// The `ARM_64_LPAE_S1` page table format. +pub enum ARM64LPAES1 {} + +impl IoPageTableFmt for ARM64LPAES1 { + const FORMAT: io_pgtable_fmt = bindings::io_pgtable_fmt_ARM_64_LPAE_S1 as io_pgtable_fmt; +} + +impl IoPageTable<ARM64LPAES1> { + /// Access the `ttbr` field of the configuration. + /// + /// This is the physical address of the page table, which may be passed to the device that + /// needs to use it. + /// + /// # Safety + /// + /// The caller must ensure that the device stops using the page table before dropping it. + #[inline] + pub unsafe fn ttbr(&self) -> u64 { + // SAFETY: `arm_lpae_s1_cfg` is the right cfg type for `ARM64LPAES1`. + unsafe { (*self.raw_cfg()).__bindgen_anon_1.arm_lpae_s1_cfg.ttbr } + } + + /// Access the `mair` field of the configuration. + #[inline] + pub fn mair(&self) -> u64 { + // SAFETY: `arm_lpae_s1_cfg` is the right cfg type for `ARM64LPAES1`. + unsafe { (*self.raw_cfg()).__bindgen_anon_1.arm_lpae_s1_cfg.mair } + } +} diff --git a/rust/kernel/irq/flags.rs b/rust/kernel/irq/flags.rs index adfde96ec47c..d26e25af06ee 100644 --- a/rust/kernel/irq/flags.rs +++ b/rust/kernel/irq/flags.rs @@ -96,6 +96,8 @@ impl Flags { self.0 } + // Always inline to optimize out error path of `build_assert`. + #[inline(always)] const fn new(value: u32) -> Self { build_assert!(value as u64 <= c_ulong::MAX as u64); Self(value as c_ulong) diff --git a/rust/kernel/irq/request.rs b/rust/kernel/irq/request.rs index b150563fdef8..67769800117c 100644 --- a/rust/kernel/irq/request.rs +++ b/rust/kernel/irq/request.rs @@ -139,7 +139,6 @@ impl<'a> IrqRequest<'a> { /// [`Completion::wait_for_completion()`]: kernel::sync::Completion::wait_for_completion /// /// ``` -/// use kernel::c_str; /// use kernel::device::{Bound, Device}; /// use kernel::irq::{self, Flags, IrqRequest, IrqReturn, Registration}; /// use kernel::prelude::*; @@ -167,7 +166,7 @@ impl<'a> IrqRequest<'a> { /// handler: impl PinInit<Data, Error>, /// request: IrqRequest<'_>, /// ) -> Result<Arc<Registration<Data>>> { -/// let registration = Registration::new(request, Flags::SHARED, c_str!("my_device"), handler); +/// let registration = Registration::new(request, Flags::SHARED, c"my_device", handler); /// /// let registration = Arc::pin_init(registration, GFP_KERNEL)?; /// @@ -340,7 +339,6 @@ impl<T: ?Sized + ThreadedHandler, A: Allocator> ThreadedHandler for Box<T, A> { /// [`Mutex`](kernel::sync::Mutex) to provide interior mutability. /// /// ``` -/// use kernel::c_str; /// use kernel::device::{Bound, Device}; /// use kernel::irq::{ /// self, Flags, IrqRequest, IrqReturn, ThreadedHandler, ThreadedIrqReturn, @@ -381,7 +379,7 @@ impl<T: ?Sized + ThreadedHandler, A: Allocator> ThreadedHandler for Box<T, A> { /// request: IrqRequest<'_>, /// ) -> Result<Arc<ThreadedRegistration<Data>>> { /// let registration = -/// ThreadedRegistration::new(request, Flags::SHARED, c_str!("my_device"), handler); +/// ThreadedRegistration::new(request, Flags::SHARED, c"my_device", handler); /// /// let registration = Arc::pin_init(registration, GFP_KERNEL)?; /// diff --git a/rust/kernel/kunit.rs b/rust/kernel/kunit.rs index 79436509dd73..f93f24a60bdd 100644 --- a/rust/kernel/kunit.rs +++ b/rust/kernel/kunit.rs @@ -9,9 +9,6 @@ use crate::fmt; use crate::prelude::*; -#[cfg(CONFIG_PRINTK)] -use crate::c_str; - /// Prints a KUnit error-level message. /// /// Public but hidden since it should only be used from KUnit generated code. @@ -22,7 +19,7 @@ pub fn err(args: fmt::Arguments<'_>) { #[cfg(CONFIG_PRINTK)] unsafe { bindings::_printk( - c_str!("\x013%pA").as_char_ptr(), + c"\x013%pA".as_char_ptr(), core::ptr::from_ref(&args).cast::<c_void>(), ); } @@ -38,7 +35,7 @@ pub fn info(args: fmt::Arguments<'_>) { #[cfg(CONFIG_PRINTK)] unsafe { bindings::_printk( - c_str!("\x016%pA").as_char_ptr(), + c"\x016%pA".as_char_ptr(), core::ptr::from_ref(&args).cast::<c_void>(), ); } @@ -60,7 +57,7 @@ macro_rules! kunit_assert { break 'out; } - static FILE: &'static $crate::str::CStr = $crate::c_str!($file); + static FILE: &'static $crate::str::CStr = $file; static LINE: i32 = ::core::line!() as i32 - $diff; static CONDITION: &'static $crate::str::CStr = $crate::c_str!(stringify!($condition)); @@ -192,9 +189,6 @@ pub fn is_test_result_ok(t: impl TestResult) -> bool { } /// Represents an individual test case. -/// -/// The [`kunit_unsafe_test_suite!`] macro expects a NULL-terminated list of valid test cases. -/// Use [`kunit_case_null`] to generate such a delimiter. #[doc(hidden)] pub const fn kunit_case( name: &'static kernel::str::CStr, @@ -215,32 +209,11 @@ pub const fn kunit_case( } } -/// Represents the NULL test case delimiter. -/// -/// The [`kunit_unsafe_test_suite!`] macro expects a NULL-terminated list of test cases. This -/// function returns such a delimiter. -#[doc(hidden)] -pub const fn kunit_case_null() -> kernel::bindings::kunit_case { - kernel::bindings::kunit_case { - run_case: None, - name: core::ptr::null_mut(), - generate_params: None, - attr: kernel::bindings::kunit_attributes { - speed: kernel::bindings::kunit_speed_KUNIT_SPEED_NORMAL, - }, - status: kernel::bindings::kunit_status_KUNIT_SUCCESS, - module_name: core::ptr::null_mut(), - log: core::ptr::null_mut(), - param_init: None, - param_exit: None, - } -} - /// Registers a KUnit test suite. /// /// # Safety /// -/// `test_cases` must be a NULL terminated array of valid test cases, +/// `test_cases` must be a `NULL` terminated array of valid test cases, /// whose lifetime is at least that of the test suite (i.e., static). /// /// # Examples @@ -253,8 +226,8 @@ pub const fn kunit_case_null() -> kernel::bindings::kunit_case { /// } /// /// static mut KUNIT_TEST_CASES: [kernel::bindings::kunit_case; 2] = [ -/// kernel::kunit::kunit_case(kernel::c_str!("name"), test_fn), -/// kernel::kunit::kunit_case_null(), +/// kernel::kunit::kunit_case(c"name", test_fn), +/// pin_init::zeroed(), /// ]; /// kernel::kunit_unsafe_test_suite!(suite_name, KUNIT_TEST_CASES); /// ``` diff --git a/rust/kernel/lib.rs b/rust/kernel/lib.rs index f812cf120042..3da92f18f4ee 100644 --- a/rust/kernel/lib.rs +++ b/rust/kernel/lib.rs @@ -100,9 +100,12 @@ pub mod fs; #[cfg(CONFIG_I2C = "y")] pub mod i2c; pub mod id_pool; +#[doc(hidden)] +pub mod impl_flags; pub mod init; pub mod io; pub mod ioctl; +pub mod iommu; pub mod iov; pub mod irq; pub mod jump_label; @@ -133,11 +136,14 @@ pub mod pwm; pub mod rbtree; pub mod regulator; pub mod revocable; +pub mod safety; pub mod scatterlist; pub mod security; pub mod seq_file; pub mod sizes; pub mod slice; +#[cfg(CONFIG_SOC_BUS)] +pub mod soc; mod static_assert; #[doc(hidden)] pub mod std_vendor; diff --git a/rust/kernel/list/arc.rs b/rust/kernel/list/arc.rs index d92bcf665c89..2282f33913ee 100644 --- a/rust/kernel/list/arc.rs +++ b/rust/kernel/list/arc.rs @@ -6,11 +6,11 @@ use crate::alloc::{AllocError, Flags}; use crate::prelude::*; +use crate::sync::atomic::{ordering, Atomic}; use crate::sync::{Arc, ArcBorrow, UniqueArc}; use core::marker::PhantomPinned; use core::ops::Deref; use core::pin::Pin; -use core::sync::atomic::{AtomicBool, Ordering}; /// Declares that this type has some way to ensure that there is exactly one `ListArc` instance for /// this id. @@ -469,7 +469,7 @@ where /// If the boolean is `false`, then there is no [`ListArc`] for this value. #[repr(transparent)] pub struct AtomicTracker<const ID: u64 = 0> { - inner: AtomicBool, + inner: Atomic<bool>, // This value needs to be pinned to justify the INVARIANT: comment in `AtomicTracker::new`. _pin: PhantomPinned, } @@ -480,12 +480,12 @@ impl<const ID: u64> AtomicTracker<ID> { // INVARIANT: Pin-init initializers can't be used on an existing `Arc`, so this value will // not be constructed in an `Arc` that already has a `ListArc`. Self { - inner: AtomicBool::new(false), + inner: Atomic::new(false), _pin: PhantomPinned, } } - fn project_inner(self: Pin<&mut Self>) -> &mut AtomicBool { + fn project_inner(self: Pin<&mut Self>) -> &mut Atomic<bool> { // SAFETY: The `inner` field is not structurally pinned, so we may obtain a mutable // reference to it even if we only have a pinned reference to `self`. unsafe { &mut Pin::into_inner_unchecked(self).inner } @@ -500,7 +500,7 @@ impl<const ID: u64> ListArcSafe<ID> for AtomicTracker<ID> { unsafe fn on_drop_list_arc(&self) { // INVARIANT: We just dropped a ListArc, so the boolean should be false. - self.inner.store(false, Ordering::Release); + self.inner.store(false, ordering::Release); } } @@ -514,8 +514,6 @@ unsafe impl<const ID: u64> TryNewListArc<ID> for AtomicTracker<ID> { fn try_new_list_arc(&self) -> bool { // INVARIANT: If this method returns true, then the boolean used to be false, and is no // longer false, so it is okay for the caller to create a new [`ListArc`]. - self.inner - .compare_exchange(false, true, Ordering::Acquire, Ordering::Relaxed) - .is_ok() + self.inner.cmpxchg(false, true, ordering::Acquire).is_ok() } } diff --git a/rust/kernel/maple_tree.rs b/rust/kernel/maple_tree.rs index e72eec56bf57..265d6396a78a 100644 --- a/rust/kernel/maple_tree.rs +++ b/rust/kernel/maple_tree.rs @@ -265,7 +265,16 @@ impl<T: ForeignOwnable> MapleTree<T> { loop { // This uses the raw accessor because we're destroying pointers without removing them // from the maple tree, which is only valid because this is the destructor. - let ptr = ma_state.mas_find_raw(usize::MAX); + // + // Take the rcu lock because mas_find_raw() requires that you hold either the spinlock + // or the rcu read lock. This is only really required if memory reclaim might + // reallocate entries in the tree, as we otherwise have exclusive access. That feature + // doesn't exist yet, so for now, taking the rcu lock only serves the purpose of + // silencing lockdep. + let ptr = { + let _rcu = kernel::sync::rcu::Guard::new(); + ma_state.mas_find_raw(usize::MAX) + }; if ptr.is_null() { break; } diff --git a/rust/kernel/miscdevice.rs b/rust/kernel/miscdevice.rs index d698cddcb4a5..c3c2052c9206 100644 --- a/rust/kernel/miscdevice.rs +++ b/rust/kernel/miscdevice.rs @@ -20,7 +20,7 @@ use crate::{ seq_file::SeqFile, types::{ForeignOwnable, Opaque}, }; -use core::{marker::PhantomData, mem::MaybeUninit, pin::Pin}; +use core::{marker::PhantomData, pin::Pin}; /// Options for creating a misc device. #[derive(Copy, Clone)] @@ -32,8 +32,7 @@ pub struct MiscDeviceOptions { impl MiscDeviceOptions { /// Create a raw `struct miscdev` ready for registration. pub const fn into_raw<T: MiscDevice>(self) -> bindings::miscdevice { - // SAFETY: All zeros is valid for this C type. - let mut result: bindings::miscdevice = unsafe { MaybeUninit::zeroed().assume_init() }; + let mut result: bindings::miscdevice = pin_init::zeroed(); result.minor = bindings::MISC_DYNAMIC_MINOR as ffi::c_int; result.name = crate::str::as_char_ptr_in_const_context(self.name); result.fops = MiscdeviceVTable::<T>::build(); @@ -411,7 +410,7 @@ impl<T: MiscDevice> MiscdeviceVTable<T> { compat_ioctl: if T::HAS_COMPAT_IOCTL { Some(Self::compat_ioctl) } else if T::HAS_IOCTL { - Some(bindings::compat_ptr_ioctl) + bindings::compat_ptr_ioctl } else { None }, @@ -420,8 +419,7 @@ impl<T: MiscDevice> MiscdeviceVTable<T> { } else { None }, - // SAFETY: All zeros is a valid value for `bindings::file_operations`. - ..unsafe { MaybeUninit::zeroed().assume_init() } + ..pin_init::zeroed() }; const fn build() -> &'static bindings::file_operations { diff --git a/rust/kernel/net/phy.rs b/rust/kernel/net/phy.rs index bf6272d87a7b..3ca99db5cccf 100644 --- a/rust/kernel/net/phy.rs +++ b/rust/kernel/net/phy.rs @@ -777,7 +777,6 @@ impl DeviceMask { /// /// ``` /// # mod module_phy_driver_sample { -/// use kernel::c_str; /// use kernel::net::phy::{self, DeviceId}; /// use kernel::prelude::*; /// @@ -796,7 +795,7 @@ impl DeviceMask { /// /// #[vtable] /// impl phy::Driver for PhySample { -/// const NAME: &'static CStr = c_str!("PhySample"); +/// const NAME: &'static CStr = c"PhySample"; /// const PHY_DEVICE_ID: phy::DeviceId = phy::DeviceId::new_with_exact_mask(0x00000001); /// } /// # } @@ -805,7 +804,6 @@ impl DeviceMask { /// This expands to the following code: /// /// ```ignore -/// use kernel::c_str; /// use kernel::net::phy::{self, DeviceId}; /// use kernel::prelude::*; /// @@ -825,7 +823,7 @@ impl DeviceMask { /// /// #[vtable] /// impl phy::Driver for PhySample { -/// const NAME: &'static CStr = c_str!("PhySample"); +/// const NAME: &'static CStr = c"PhySample"; /// const PHY_DEVICE_ID: phy::DeviceId = phy::DeviceId::new_with_exact_mask(0x00000001); /// } /// diff --git a/rust/kernel/num/bounded.rs b/rust/kernel/num/bounded.rs index f870080af8ac..fa81acbdc8c2 100644 --- a/rust/kernel/num/bounded.rs +++ b/rust/kernel/num/bounded.rs @@ -40,11 +40,11 @@ fn fits_within<T: Integer>(value: T, num_bits: u32) -> bool { fits_within!(value, T, num_bits) } -/// An integer value that requires only the `N` less significant bits of the wrapped type to be +/// An integer value that requires only the `N` least significant bits of the wrapped type to be /// encoded. /// /// This limits the number of usable bits in the wrapped integer type, and thus the stored value to -/// a narrower range, which provides guarantees that can be useful when working with in e.g. +/// a narrower range, which provides guarantees that can be useful when working within e.g. /// bitfields. /// /// # Invariants @@ -56,7 +56,7 @@ fn fits_within<T: Integer>(value: T, num_bits: u32) -> bool { /// # Examples /// /// The preferred way to create values is through constants and the [`Bounded::new`] family of -/// constructors, as they trigger a build error if the type invariants cannot be withheld. +/// constructors, as they trigger a build error if the type invariants cannot be upheld. /// /// ``` /// use kernel::num::Bounded; @@ -82,7 +82,7 @@ fn fits_within<T: Integer>(value: T, num_bits: u32) -> bool { /// ``` /// use kernel::num::Bounded; /// -/// // This succeeds because `15` can be represented with 4 unsigned bits. +/// // This succeeds because `15` can be represented with 4 unsigned bits. /// assert!(Bounded::<u8, 4>::try_new(15).is_some()); /// /// // This fails because `16` cannot be represented with 4 unsigned bits. @@ -221,7 +221,7 @@ fn fits_within<T: Integer>(value: T, num_bits: u32) -> bool { /// let v: Option<Bounded<u16, 8>> = 128u32.try_into_bounded(); /// assert_eq!(v.as_deref().copied(), Some(128)); /// -/// // Fails because `128` doesn't fits into 6 bits. +/// // Fails because `128` doesn't fit into 6 bits. /// let v: Option<Bounded<u16, 6>> = 128u32.try_into_bounded(); /// assert_eq!(v, None); /// ``` @@ -259,9 +259,9 @@ macro_rules! impl_const_new { assert!(fits_within!(VALUE, $type, N)); } - // INVARIANT: `fits_within` confirmed that `VALUE` can be represented within + // SAFETY: `fits_within` confirmed that `VALUE` can be represented within // `N` bits. - Self::__new(VALUE) + unsafe { Self::__new(VALUE) } } } )* @@ -282,9 +282,10 @@ where /// All instances of [`Bounded`] must be created through this method as it enforces most of the /// type invariants. /// - /// The caller remains responsible for checking, either statically or dynamically, that `value` - /// can be represented as a `T` using at most `N` bits. - const fn __new(value: T) -> Self { + /// # Safety + /// + /// The caller must ensure that `value` can be represented within `N` bits. + const unsafe fn __new(value: T) -> Self { // Enforce the type invariants. const { // `N` cannot be zero. @@ -293,6 +294,7 @@ where assert!(N <= T::BITS); } + // INVARIANT: The caller ensures `value` fits within `N` bits. Self(value) } @@ -328,8 +330,8 @@ where /// ``` pub fn try_new(value: T) -> Option<Self> { fits_within(value, N).then(|| { - // INVARIANT: `fits_within` confirmed that `value` can be represented within `N` bits. - Self::__new(value) + // SAFETY: `fits_within` confirmed that `value` can be represented within `N` bits. + unsafe { Self::__new(value) } }) } @@ -363,6 +365,7 @@ where /// assert_eq!(Bounded::<u8, 1>::from_expr(1).get(), 1); /// assert_eq!(Bounded::<u16, 8>::from_expr(0xff).get(), 0xff); /// ``` + // Always inline to optimize out error path of `build_assert`. #[inline(always)] pub fn from_expr(expr: T) -> Self { crate::build_assert!( @@ -370,8 +373,8 @@ where "Requested value larger than maximal representable value." ); - // INVARIANT: `fits_within` confirmed that `expr` can be represented within `N` bits. - Self::__new(expr) + // SAFETY: `fits_within` confirmed that `expr` can be represented within `N` bits. + unsafe { Self::__new(expr) } } /// Returns the wrapped value as the backing type. @@ -410,9 +413,9 @@ where ); } - // INVARIANT: The value did fit within `N` bits, so it will all the more fit within + // SAFETY: The value did fit within `N` bits, so it will all the more fit within // the larger `M` bits. - Bounded::__new(self.0) + unsafe { Bounded::__new(self.0) } } /// Attempts to shrink the number of bits usable for `self`. @@ -466,9 +469,9 @@ where // `U` and `T` have the same sign, hence this conversion cannot fail. let value = unsafe { U::try_from(self.get()).unwrap_unchecked() }; - // INVARIANT: Although the backing type has changed, the value is still represented within + // SAFETY: Although the backing type has changed, the value is still represented within // `N` bits, and with the same signedness. - Bounded::__new(value) + unsafe { Bounded::__new(value) } } } @@ -501,7 +504,7 @@ where /// let v: Option<Bounded<u16, 8>> = 128u32.try_into_bounded(); /// assert_eq!(v.as_deref().copied(), Some(128)); /// -/// // Fails because `128` doesn't fits into 6 bits. +/// // Fails because `128` doesn't fit into 6 bits. /// let v: Option<Bounded<u16, 6>> = 128u32.try_into_bounded(); /// assert_eq!(v, None); /// ``` @@ -944,9 +947,9 @@ macro_rules! impl_from_primitive { Self: AtLeastXBits<{ <$type as Integer>::BITS as usize }>, { fn from(value: $type) -> Self { - // INVARIANT: The trait bound on `Self` guarantees that `N` bits is + // SAFETY: The trait bound on `Self` guarantees that `N` bits is // enough to hold any value of the source type. - Self::__new(T::from(value)) + unsafe { Self::__new(T::from(value)) } } } )* @@ -1051,8 +1054,8 @@ where T: Integer + From<bool>, { fn from(value: bool) -> Self { - // INVARIANT: A boolean can be represented using a single bit, and thus fits within any + // SAFETY: A boolean can be represented using a single bit, and thus fits within any // integer type for any `N` > 0. - Self::__new(T::from(value)) + unsafe { Self::__new(T::from(value)) } } } diff --git a/rust/kernel/page.rs b/rust/kernel/page.rs index 432fc0297d4a..adecb200c654 100644 --- a/rust/kernel/page.rs +++ b/rust/kernel/page.rs @@ -25,14 +25,36 @@ pub const PAGE_SIZE: usize = bindings::PAGE_SIZE; /// A bitmask that gives the page containing a given address. pub const PAGE_MASK: usize = !(PAGE_SIZE - 1); -/// Round up the given number to the next multiple of [`PAGE_SIZE`]. +/// Rounds up to the next multiple of [`PAGE_SIZE`]. /// -/// It is incorrect to pass an address where the next multiple of [`PAGE_SIZE`] doesn't fit in a -/// [`usize`]. -pub const fn page_align(addr: usize) -> usize { - // Parentheses around `PAGE_SIZE - 1` to avoid triggering overflow sanitizers in the wrong - // cases. - (addr + (PAGE_SIZE - 1)) & PAGE_MASK +/// Returns [`None`] on integer overflow. +/// +/// # Examples +/// +/// ``` +/// use kernel::page::{ +/// page_align, +/// PAGE_SIZE, +/// }; +/// +/// // Requested address is already aligned. +/// assert_eq!(page_align(0x0), Some(0x0)); +/// assert_eq!(page_align(PAGE_SIZE), Some(PAGE_SIZE)); +/// +/// // Requested address needs alignment up. +/// assert_eq!(page_align(0x1), Some(PAGE_SIZE)); +/// assert_eq!(page_align(PAGE_SIZE + 1), Some(2 * PAGE_SIZE)); +/// +/// // Requested address causes overflow (returns `None`). +/// let overflow_addr = usize::MAX - (PAGE_SIZE / 2); +/// assert_eq!(page_align(overflow_addr), None); +/// ``` +#[inline(always)] +pub const fn page_align(addr: usize) -> Option<usize> { + let Some(sum) = addr.checked_add(PAGE_SIZE - 1) else { + return None; + }; + Some(sum & PAGE_MASK) } /// Representation of a non-owning reference to a [`Page`]. diff --git a/rust/kernel/pci.rs b/rust/kernel/pci.rs index 82e128431f08..af74ddff6114 100644 --- a/rust/kernel/pci.rs +++ b/rust/kernel/pci.rs @@ -40,7 +40,14 @@ pub use self::id::{ ClassMask, Vendor, // }; -pub use self::io::Bar; +pub use self::io::{ + Bar, + ConfigSpace, + ConfigSpaceKind, + ConfigSpaceSize, + Extended, + Normal, // +}; pub use self::irq::{ IrqType, IrqTypes, @@ -50,13 +57,22 @@ pub use self::irq::{ /// An adapter for the registration of PCI drivers. pub struct Adapter<T: Driver>(T); -// SAFETY: A call to `unregister` for a given instance of `RegType` is guaranteed to be valid if +// SAFETY: +// - `bindings::pci_driver` is a C type declared as `repr(C)`. +// - `T` is the type of the driver's device private data. +// - `struct pci_driver` embeds a `struct device_driver`. +// - `DEVICE_DRIVER_OFFSET` is the correct byte offset to the embedded `struct device_driver`. +unsafe impl<T: Driver + 'static> driver::DriverLayout for Adapter<T> { + type DriverType = bindings::pci_driver; + type DriverData = T; + const DEVICE_DRIVER_OFFSET: usize = core::mem::offset_of!(Self::DriverType, driver); +} + +// SAFETY: A call to `unregister` for a given instance of `DriverType` is guaranteed to be valid if // a preceding call to `register` has been successful. unsafe impl<T: Driver + 'static> driver::RegistrationOps for Adapter<T> { - type RegType = bindings::pci_driver; - unsafe fn register( - pdrv: &Opaque<Self::RegType>, + pdrv: &Opaque<Self::DriverType>, name: &'static CStr, module: &'static ThisModule, ) -> Result { @@ -68,14 +84,14 @@ unsafe impl<T: Driver + 'static> driver::RegistrationOps for Adapter<T> { (*pdrv.get()).id_table = T::ID_TABLE.as_ptr(); } - // SAFETY: `pdrv` is guaranteed to be a valid `RegType`. + // SAFETY: `pdrv` is guaranteed to be a valid `DriverType`. to_result(unsafe { bindings::__pci_register_driver(pdrv.get(), module.0, name.as_char_ptr()) }) } - unsafe fn unregister(pdrv: &Opaque<Self::RegType>) { - // SAFETY: `pdrv` is guaranteed to be a valid `RegType`. + unsafe fn unregister(pdrv: &Opaque<Self::DriverType>) { + // SAFETY: `pdrv` is guaranteed to be a valid `DriverType`. unsafe { bindings::pci_unregister_driver(pdrv.get()) } } } @@ -114,9 +130,9 @@ impl<T: Driver + 'static> Adapter<T> { // SAFETY: `remove_callback` is only ever called after a successful call to // `probe_callback`, hence it's guaranteed that `Device::set_drvdata()` has been called // and stored a `Pin<KBox<T>>`. - let data = unsafe { pdev.as_ref().drvdata_obtain::<T>() }; + let data = unsafe { pdev.as_ref().drvdata_borrow::<T>() }; - T::unbind(pdev, data.as_ref()); + T::unbind(pdev, data); } } @@ -342,7 +358,7 @@ impl Device { /// // Get an instance of `Vendor`. /// let vendor = pdev.vendor_id(); /// dev_info!( - /// pdev.as_ref(), + /// pdev, /// "Device: Vendor={}, Device=0x{:x}\n", /// vendor, /// pdev.device_id() diff --git a/rust/kernel/pci/id.rs b/rust/kernel/pci/id.rs index c09125946d9e..50005d176561 100644 --- a/rust/kernel/pci/id.rs +++ b/rust/kernel/pci/id.rs @@ -22,7 +22,7 @@ use crate::{ /// fn probe_device(pdev: &pci::Device<Core>) -> Result { /// let pci_class = pdev.pci_class(); /// dev_info!( -/// pdev.as_ref(), +/// pdev, /// "Detected PCI class: {}\n", /// pci_class /// ); @@ -416,7 +416,6 @@ define_all_pci_vendors! { MICROSEMI = bindings::PCI_VENDOR_ID_MICROSEMI, // 0x11f8 RP = bindings::PCI_VENDOR_ID_RP, // 0x11fe CYCLADES = bindings::PCI_VENDOR_ID_CYCLADES, // 0x120e - ESSENTIAL = bindings::PCI_VENDOR_ID_ESSENTIAL, // 0x120f O2 = bindings::PCI_VENDOR_ID_O2, // 0x1217 THREEDX = bindings::PCI_VENDOR_ID_3DFX, // 0x121a AVM = bindings::PCI_VENDOR_ID_AVM, // 0x1244 diff --git a/rust/kernel/pci/io.rs b/rust/kernel/pci/io.rs index 0d55c3139b6f..6ca4cf75594c 100644 --- a/rust/kernel/pci/io.rs +++ b/rust/kernel/pci/io.rs @@ -8,23 +8,186 @@ use crate::{ device, devres::Devres, io::{ + define_read, + define_write, Io, - IoRaw, // + IoCapable, + IoKnownSize, + Mmio, + MmioRaw, // }, prelude::*, sync::aref::ARef, // }; -use core::ops::Deref; +use core::{ + marker::PhantomData, + ops::Deref, // +}; + +/// Represents the size of a PCI configuration space. +/// +/// PCI devices can have either a *normal* (legacy) configuration space of 256 bytes, +/// or an *extended* configuration space of 4096 bytes as defined in the PCI Express +/// specification. +#[repr(usize)] +#[derive(Eq, PartialEq)] +pub enum ConfigSpaceSize { + /// 256-byte legacy PCI configuration space. + Normal = 256, + + /// 4096-byte PCIe extended configuration space. + Extended = 4096, +} + +impl ConfigSpaceSize { + /// Get the raw value of this enum. + #[inline(always)] + pub const fn into_raw(self) -> usize { + // CAST: PCI configuration space size is at most 4096 bytes, so the value always fits + // within `usize` without truncation or sign change. + self as usize + } +} + +/// Marker type for normal (256-byte) PCI configuration space. +pub struct Normal; + +/// Marker type for extended (4096-byte) PCIe configuration space. +pub struct Extended; + +/// Trait for PCI configuration space size markers. +/// +/// This trait is implemented by [`Normal`] and [`Extended`] to provide +/// compile-time knowledge of the configuration space size. +pub trait ConfigSpaceKind { + /// The size of this configuration space in bytes. + const SIZE: usize; +} + +impl ConfigSpaceKind for Normal { + const SIZE: usize = 256; +} + +impl ConfigSpaceKind for Extended { + const SIZE: usize = 4096; +} + +/// The PCI configuration space of a device. +/// +/// Provides typed read and write accessors for configuration registers +/// using the standard `pci_read_config_*` and `pci_write_config_*` helpers. +/// +/// The generic parameter `S` indicates the maximum size of the configuration space. +/// Use [`Normal`] for 256-byte legacy configuration space or [`Extended`] for +/// 4096-byte PCIe extended configuration space (default). +pub struct ConfigSpace<'a, S: ConfigSpaceKind = Extended> { + pub(crate) pdev: &'a Device<device::Bound>, + _marker: PhantomData<S>, +} + +/// Internal helper macros used to invoke C PCI configuration space read functions. +/// +/// This macro is intended to be used by higher-level PCI configuration space access macros +/// (define_read) and provides a unified expansion for infallible vs. fallible read semantics. It +/// emits a direct call into the corresponding C helper and performs the required cast to the Rust +/// return type. +/// +/// # Parameters +/// +/// * `$c_fn` – The C function performing the PCI configuration space write. +/// * `$self` – The I/O backend object. +/// * `$ty` – The type of the value to read. +/// * `$addr` – The PCI configuration space offset to read. +/// +/// This macro does not perform any validation; all invariants must be upheld by the higher-level +/// abstraction invoking it. +macro_rules! call_config_read { + (infallible, $c_fn:ident, $self:ident, $ty:ty, $addr:expr) => {{ + let mut val: $ty = 0; + // SAFETY: By the type invariant `$self.pdev` is a valid address. + // CAST: The offset is cast to `i32` because the C functions expect a 32-bit signed offset + // parameter. PCI configuration space size is at most 4096 bytes, so the value always fits + // within `i32` without truncation or sign change. + // Return value from C function is ignored in infallible accessors. + let _ret = unsafe { bindings::$c_fn($self.pdev.as_raw(), $addr as i32, &mut val) }; + val + }}; +} + +/// Internal helper macros used to invoke C PCI configuration space write functions. +/// +/// This macro is intended to be used by higher-level PCI configuration space access macros +/// (define_write) and provides a unified expansion for infallible vs. fallible read semantics. It +/// emits a direct call into the corresponding C helper and performs the required cast to the Rust +/// return type. +/// +/// # Parameters +/// +/// * `$c_fn` – The C function performing the PCI configuration space write. +/// * `$self` – The I/O backend object. +/// * `$ty` – The type of the written value. +/// * `$addr` – The configuration space offset to write. +/// * `$value` – The value to write. +/// +/// This macro does not perform any validation; all invariants must be upheld by the higher-level +/// abstraction invoking it. +macro_rules! call_config_write { + (infallible, $c_fn:ident, $self:ident, $ty:ty, $addr:expr, $value:expr) => { + // SAFETY: By the type invariant `$self.pdev` is a valid address. + // CAST: The offset is cast to `i32` because the C functions expect a 32-bit signed offset + // parameter. PCI configuration space size is at most 4096 bytes, so the value always fits + // within `i32` without truncation or sign change. + // Return value from C function is ignored in infallible accessors. + let _ret = unsafe { bindings::$c_fn($self.pdev.as_raw(), $addr as i32, $value) }; + }; +} + +// PCI configuration space supports 8, 16, and 32-bit accesses. +impl<'a, S: ConfigSpaceKind> IoCapable<u8> for ConfigSpace<'a, S> {} +impl<'a, S: ConfigSpaceKind> IoCapable<u16> for ConfigSpace<'a, S> {} +impl<'a, S: ConfigSpaceKind> IoCapable<u32> for ConfigSpace<'a, S> {} + +impl<'a, S: ConfigSpaceKind> Io for ConfigSpace<'a, S> { + /// Returns the base address of the I/O region. It is always 0 for configuration space. + #[inline] + fn addr(&self) -> usize { + 0 + } + + /// Returns the maximum size of the configuration space. + #[inline] + fn maxsize(&self) -> usize { + self.pdev.cfg_size().into_raw() + } + + // PCI configuration space does not support fallible operations. + // The default implementations from the Io trait are not used. + + define_read!(infallible, read8, call_config_read(pci_read_config_byte) -> u8); + define_read!(infallible, read16, call_config_read(pci_read_config_word) -> u16); + define_read!(infallible, read32, call_config_read(pci_read_config_dword) -> u32); + + define_write!(infallible, write8, call_config_write(pci_write_config_byte) <- u8); + define_write!(infallible, write16, call_config_write(pci_write_config_word) <- u16); + define_write!(infallible, write32, call_config_write(pci_write_config_dword) <- u32); +} + +impl<'a, S: ConfigSpaceKind> IoKnownSize for ConfigSpace<'a, S> { + const MIN_SIZE: usize = S::SIZE; +} /// A PCI BAR to perform I/O-Operations on. /// +/// I/O backend assumes that the device is little-endian and will automatically +/// convert from little-endian to CPU endianness. +/// /// # Invariants /// -/// `Bar` always holds an `IoRaw` inststance that holds a valid pointer to the start of the I/O +/// `Bar` always holds an `IoRaw` instance that holds a valid pointer to the start of the I/O /// memory mapped PCI BAR and its size. pub struct Bar<const SIZE: usize = 0> { pdev: ARef<Device>, - io: IoRaw<SIZE>, + io: MmioRaw<SIZE>, num: i32, } @@ -54,13 +217,13 @@ impl<const SIZE: usize> Bar<SIZE> { let ioptr: usize = unsafe { bindings::pci_iomap(pdev.as_raw(), num, 0) } as usize; if ioptr == 0 { // SAFETY: - // `pdev` valid by the invariants of `Device`. + // `pdev` is valid by the invariants of `Device`. // `num` is checked for validity by a previous call to `Device::resource_len`. unsafe { bindings::pci_release_region(pdev.as_raw(), num) }; return Err(ENOMEM); } - let io = match IoRaw::new(ioptr, len as usize) { + let io = match MmioRaw::new(ioptr, len as usize) { Ok(io) => io, Err(err) => { // SAFETY: @@ -114,11 +277,11 @@ impl<const SIZE: usize> Drop for Bar<SIZE> { } impl<const SIZE: usize> Deref for Bar<SIZE> { - type Target = Io<SIZE>; + type Target = Mmio<SIZE>; fn deref(&self) -> &Self::Target { // SAFETY: By the type invariant of `Self`, the MMIO range in `self.io` is properly mapped. - unsafe { Io::from_raw(&self.io) } + unsafe { Mmio::from_raw(&self.io) } } } @@ -141,4 +304,39 @@ impl Device<device::Bound> { ) -> impl PinInit<Devres<Bar>, Error> + 'a { self.iomap_region_sized::<0>(bar, name) } + + /// Returns the size of configuration space. + pub fn cfg_size(&self) -> ConfigSpaceSize { + // SAFETY: `self.as_raw` is a valid pointer to a `struct pci_dev`. + let size = unsafe { (*self.as_raw()).cfg_size }; + match size { + 256 => ConfigSpaceSize::Normal, + 4096 => ConfigSpaceSize::Extended, + _ => { + // PANIC: The PCI subsystem only ever reports the configuration space size as either + // `ConfigSpaceSize::Normal` or `ConfigSpaceSize::Extended`. + unreachable!(); + } + } + } + + /// Return an initialized normal (256-byte) config space object. + pub fn config_space<'a>(&'a self) -> ConfigSpace<'a, Normal> { + ConfigSpace { + pdev: self, + _marker: PhantomData, + } + } + + /// Return an initialized extended (4096-byte) config space object. + pub fn config_space_extended<'a>(&'a self) -> Result<ConfigSpace<'a, Extended>> { + if self.cfg_size() != ConfigSpaceSize::Extended { + return Err(EINVAL); + } + + Ok(ConfigSpace { + pdev: self, + _marker: PhantomData, + }) + } } diff --git a/rust/kernel/platform.rs b/rust/kernel/platform.rs index ed889f079cab..8917d4ee499f 100644 --- a/rust/kernel/platform.rs +++ b/rust/kernel/platform.rs @@ -5,34 +5,60 @@ //! C header: [`include/linux/platform_device.h`](srctree/include/linux/platform_device.h) use crate::{ - acpi, bindings, container_of, - device::{self, Bound}, + acpi, + bindings, + container_of, + device::{ + self, + Bound, // + }, driver, - error::{from_result, to_result, Result}, - io::{mem::IoRequest, Resource}, - irq::{self, IrqRequest}, + error::{ + from_result, + to_result, // + }, + io::{ + mem::IoRequest, + Resource, // + }, + irq::{ + self, + IrqRequest, // + }, of, prelude::*, types::Opaque, - ThisModule, + ThisModule, // }; use core::{ marker::PhantomData, mem::offset_of, - ptr::{addr_of_mut, NonNull}, + ptr::{ + addr_of_mut, + NonNull, // + }, }; /// An adapter for the registration of platform drivers. pub struct Adapter<T: Driver>(T); -// SAFETY: A call to `unregister` for a given instance of `RegType` is guaranteed to be valid if +// SAFETY: +// - `bindings::platform_driver` is a C type declared as `repr(C)`. +// - `T` is the type of the driver's device private data. +// - `struct platform_driver` embeds a `struct device_driver`. +// - `DEVICE_DRIVER_OFFSET` is the correct byte offset to the embedded `struct device_driver`. +unsafe impl<T: Driver + 'static> driver::DriverLayout for Adapter<T> { + type DriverType = bindings::platform_driver; + type DriverData = T; + const DEVICE_DRIVER_OFFSET: usize = core::mem::offset_of!(Self::DriverType, driver); +} + +// SAFETY: A call to `unregister` for a given instance of `DriverType` is guaranteed to be valid if // a preceding call to `register` has been successful. unsafe impl<T: Driver + 'static> driver::RegistrationOps for Adapter<T> { - type RegType = bindings::platform_driver; - unsafe fn register( - pdrv: &Opaque<Self::RegType>, + pdrv: &Opaque<Self::DriverType>, name: &'static CStr, module: &'static ThisModule, ) -> Result { @@ -55,12 +81,12 @@ unsafe impl<T: Driver + 'static> driver::RegistrationOps for Adapter<T> { (*pdrv.get()).driver.acpi_match_table = acpi_table; } - // SAFETY: `pdrv` is guaranteed to be a valid `RegType`. + // SAFETY: `pdrv` is guaranteed to be a valid `DriverType`. to_result(unsafe { bindings::__platform_driver_register(pdrv.get(), module.0) }) } - unsafe fn unregister(pdrv: &Opaque<Self::RegType>) { - // SAFETY: `pdrv` is guaranteed to be a valid `RegType`. + unsafe fn unregister(pdrv: &Opaque<Self::DriverType>) { + // SAFETY: `pdrv` is guaranteed to be a valid `DriverType`. unsafe { bindings::platform_driver_unregister(pdrv.get()) }; } } @@ -86,15 +112,15 @@ impl<T: Driver + 'static> Adapter<T> { // SAFETY: The platform bus only ever calls the remove callback with a valid pointer to a // `struct platform_device`. // - // INVARIANT: `pdev` is valid for the duration of `probe_callback()`. + // INVARIANT: `pdev` is valid for the duration of `remove_callback()`. let pdev = unsafe { &*pdev.cast::<Device<device::CoreInternal>>() }; // SAFETY: `remove_callback` is only ever called after a successful call to // `probe_callback`, hence it's guaranteed that `Device::set_drvdata()` has been called // and stored a `Pin<KBox<T>>`. - let data = unsafe { pdev.as_ref().drvdata_obtain::<T>() }; + let data = unsafe { pdev.as_ref().drvdata_borrow::<T>() }; - T::unbind(pdev, data.as_ref()); + T::unbind(pdev, data); } } @@ -137,8 +163,13 @@ macro_rules! module_platform_driver { /// # Examples /// ///``` -/// # use kernel::{acpi, bindings, c_str, device::Core, of, platform}; -/// +/// # use kernel::{ +/// # acpi, +/// # bindings, +/// # device::Core, +/// # of, +/// # platform, +/// # }; /// struct MyDriver; /// /// kernel::of_device_table!( @@ -146,7 +177,7 @@ macro_rules! module_platform_driver { /// MODULE_OF_TABLE, /// <MyDriver as platform::Driver>::IdInfo, /// [ -/// (of::DeviceId::new(c_str!("test,device")), ()) +/// (of::DeviceId::new(c"test,device"), ()) /// ] /// ); /// @@ -155,7 +186,7 @@ macro_rules! module_platform_driver { /// MODULE_ACPI_TABLE, /// <MyDriver as platform::Driver>::IdInfo, /// [ -/// (acpi::DeviceId::new(c_str!("LNUXBEEF")), ()) +/// (acpi::DeviceId::new(c"LNUXBEEF"), ()) /// ] /// ); /// diff --git a/rust/kernel/print.rs b/rust/kernel/print.rs index 2d743d78d220..6fd84389a858 100644 --- a/rust/kernel/print.rs +++ b/rust/kernel/print.rs @@ -11,6 +11,11 @@ use crate::{ fmt, prelude::*, str::RawFormatter, + sync::atomic::{ + Atomic, + AtomicType, + Relaxed, // + }, }; // Called from `vsprintf` with format specifier `%pA`. @@ -423,3 +428,151 @@ macro_rules! pr_cont ( $crate::print_macro!($crate::print::format_strings::CONT, true, $($arg)*) ) ); + +/// A lightweight `call_once` primitive. +/// +/// This structure provides the Rust equivalent of the kernel's `DO_ONCE_LITE` macro. +/// While it would be possible to implement the feature entirely as a Rust macro, +/// the functionality that can be implemented as regular functions has been +/// extracted and implemented as the `OnceLite` struct for better code maintainability. +pub struct OnceLite(Atomic<State>); + +#[derive(Clone, Copy, PartialEq, Eq)] +#[repr(i32)] +enum State { + Incomplete = 0, + Complete = 1, +} + +// SAFETY: `State` and `i32` has the same size and alignment, and it's round-trip +// transmutable to `i32`. +unsafe impl AtomicType for State { + type Repr = i32; +} + +impl OnceLite { + /// Creates a new [`OnceLite`] in the incomplete state. + #[inline(always)] + #[allow(clippy::new_without_default)] + pub const fn new() -> Self { + OnceLite(Atomic::new(State::Incomplete)) + } + + /// Calls the provided function exactly once. + /// + /// There is no other synchronization between two `call_once()`s + /// except that only one will execute `f`, in other words, callers + /// should not use a failed `call_once()` as a proof that another + /// `call_once()` has already finished and the effect is observable + /// to this thread. + pub fn call_once<F>(&self, f: F) -> bool + where + F: FnOnce(), + { + // Avoid expensive cmpxchg if already completed. + // ORDERING: `Relaxed` is used here since no synchronization is required. + let old = self.0.load(Relaxed); + if old == State::Complete { + return false; + } + + // ORDERING: `Relaxed` is used here since no synchronization is required. + let old = self.0.xchg(State::Complete, Relaxed); + if old == State::Complete { + return false; + } + + f(); + true + } +} + +/// Run the given function exactly once. +/// +/// This is equivalent to the kernel's `DO_ONCE_LITE` macro. +/// +/// # Examples +/// +/// ``` +/// kernel::do_once_lite! { +/// kernel::pr_info!("This will be printed only once\n"); +/// }; +/// ``` +#[macro_export] +macro_rules! do_once_lite { + { $($e:tt)* } => {{ + #[link_section = ".data..once"] + static ONCE: $crate::print::OnceLite = $crate::print::OnceLite::new(); + ONCE.call_once(|| { $($e)* }); + }}; +} + +/// Prints an emergency-level message (level 0) only once. +/// +/// Equivalent to the kernel's `pr_emerg_once` macro. +#[macro_export] +macro_rules! pr_emerg_once ( + ($($arg:tt)*) => ( + $crate::do_once_lite! { $crate::pr_emerg!($($arg)*) } + ) +); + +/// Prints an alert-level message (level 1) only once. +/// +/// Equivalent to the kernel's `pr_alert_once` macro. +#[macro_export] +macro_rules! pr_alert_once ( + ($($arg:tt)*) => ( + $crate::do_once_lite! { $crate::pr_alert!($($arg)*) } + ) +); + +/// Prints a critical-level message (level 2) only once. +/// +/// Equivalent to the kernel's `pr_crit_once` macro. +#[macro_export] +macro_rules! pr_crit_once ( + ($($arg:tt)*) => ( + $crate::do_once_lite! { $crate::pr_crit!($($arg)*) } + ) +); + +/// Prints an error-level message (level 3) only once. +/// +/// Equivalent to the kernel's `pr_err_once` macro. +#[macro_export] +macro_rules! pr_err_once ( + ($($arg:tt)*) => ( + $crate::do_once_lite! { $crate::pr_err!($($arg)*) } + ) +); + +/// Prints a warning-level message (level 4) only once. +/// +/// Equivalent to the kernel's `pr_warn_once` macro. +#[macro_export] +macro_rules! pr_warn_once ( + ($($arg:tt)*) => ( + $crate::do_once_lite! { $crate::pr_warn!($($arg)*) } + ) +); + +/// Prints a notice-level message (level 5) only once. +/// +/// Equivalent to the kernel's `pr_notice_once` macro. +#[macro_export] +macro_rules! pr_notice_once ( + ($($arg:tt)*) => ( + $crate::do_once_lite! { $crate::pr_notice!($($arg)*) } + ) +); + +/// Prints an info-level message (level 6) only once. +/// +/// Equivalent to the kernel's `pr_info_once` macro. +#[macro_export] +macro_rules! pr_info_once ( + ($($arg:tt)*) => ( + $crate::do_once_lite! { $crate::pr_info!($($arg)*) } + ) +); diff --git a/rust/kernel/ptr.rs b/rust/kernel/ptr.rs index e3893ed04049..5b6a382637fe 100644 --- a/rust/kernel/ptr.rs +++ b/rust/kernel/ptr.rs @@ -5,8 +5,6 @@ use core::mem::align_of; use core::num::NonZero; -use crate::build_assert; - /// Type representing an alignment, which is always a power of two. /// /// It is used to validate that a given value is a valid alignment, and to perform masking and @@ -40,10 +38,12 @@ impl Alignment { /// ``` #[inline(always)] pub const fn new<const ALIGN: usize>() -> Self { - build_assert!( - ALIGN.is_power_of_two(), - "Provided alignment is not a power of two." - ); + const { + assert!( + ALIGN.is_power_of_two(), + "Provided alignment is not a power of two." + ); + } // INVARIANT: `align` is a power of two. // SAFETY: `align` is a power of two, and thus non-zero. diff --git a/rust/kernel/pwm.rs b/rust/kernel/pwm.rs index cb00f8a8765c..6c9d667009ef 100644 --- a/rust/kernel/pwm.rs +++ b/rust/kernel/pwm.rs @@ -13,9 +13,14 @@ use crate::{ devres, error::{self, to_result}, prelude::*, - types::{ARef, AlwaysRefCounted, Opaque}, // + sync::aref::{ARef, AlwaysRefCounted}, + types::Opaque, // +}; +use core::{ + marker::PhantomData, + ops::Deref, + ptr::NonNull, // }; -use core::{marker::PhantomData, ptr::NonNull}; /// Represents a PWM waveform configuration. /// Mirrors struct [`struct pwm_waveform`](srctree/include/linux/pwm.h). @@ -124,8 +129,7 @@ impl Device { // SAFETY: `self.as_raw()` provides a valid `*mut pwm_device` pointer. // `&c_wf` is a valid pointer to a `pwm_waveform` struct. The C function // handles all necessary internal locking. - let ret = unsafe { bindings::pwm_set_waveform_might_sleep(self.as_raw(), &c_wf, exact) }; - to_result(ret) + to_result(unsafe { bindings::pwm_set_waveform_might_sleep(self.as_raw(), &c_wf, exact) }) } /// Queries the hardware for the configuration it would apply for a given @@ -155,9 +159,7 @@ impl Device { // SAFETY: `self.as_raw()` is a valid pointer. We provide a valid pointer // to a stack-allocated `pwm_waveform` struct for the kernel to fill. - let ret = unsafe { bindings::pwm_get_waveform_might_sleep(self.as_raw(), &mut c_wf) }; - - to_result(ret)?; + to_result(unsafe { bindings::pwm_get_waveform_might_sleep(self.as_raw(), &mut c_wf) })?; Ok(Waveform::from(c_wf)) } @@ -173,7 +175,7 @@ pub struct RoundedWaveform<WfHw> { } /// Trait defining the operations for a PWM driver. -pub trait PwmOps: 'static + Sized { +pub trait PwmOps: 'static + Send + Sync + Sized { /// The driver-specific hardware representation of a waveform. /// /// This type must be [`Copy`], [`Default`], and fit within `PWM_WFHWSIZE`. @@ -258,8 +260,8 @@ impl<T: PwmOps> Adapter<T> { core::ptr::from_ref::<T::WfHw>(wfhw).cast::<u8>(), wfhw_ptr.cast::<u8>(), size, - ); - } + ) + }; Ok(()) } @@ -279,8 +281,8 @@ impl<T: PwmOps> Adapter<T> { wfhw_ptr.cast::<u8>(), core::ptr::from_mut::<T::WfHw>(&mut wfhw).cast::<u8>(), size, - ); - } + ) + }; Ok(wfhw) } @@ -306,9 +308,7 @@ impl<T: PwmOps> Adapter<T> { // Now, call the original release function to free the `pwm_chip` itself. // SAFETY: `dev` is the valid pointer passed into this callback, which is // the expected argument for `pwmchip_release`. - unsafe { - bindings::pwmchip_release(dev); - } + unsafe { bindings::pwmchip_release(dev) }; } /// # Safety @@ -408,9 +408,7 @@ impl<T: PwmOps> Adapter<T> { match T::round_waveform_fromhw(chip, pwm, &wfhw, &mut rust_wf) { Ok(()) => { // SAFETY: `wf_ptr` is guaranteed valid by the C caller. - unsafe { - *wf_ptr = rust_wf.into(); - }; + unsafe { *wf_ptr = rust_wf.into() }; 0 } Err(e) => e.to_errno(), @@ -584,11 +582,12 @@ impl<T: PwmOps> Chip<T> { /// /// Returns an [`ARef<Chip>`] managing the chip's lifetime via refcounting /// on its embedded `struct device`. - pub fn new( - parent_dev: &device::Device, + #[allow(clippy::new_ret_no_self)] + pub fn new<'a>( + parent_dev: &'a device::Device<Bound>, num_channels: u32, data: impl pin_init::PinInit<T, Error>, - ) -> Result<ARef<Self>> { + ) -> Result<UnregisteredChip<'a, T>> { let sizeof_priv = core::mem::size_of::<T>(); // SAFETY: `pwmchip_alloc` allocates memory for the C struct and our private data. let c_chip_ptr_raw = @@ -601,19 +600,19 @@ impl<T: PwmOps> Chip<T> { let drvdata_ptr = unsafe { bindings::pwmchip_get_drvdata(c_chip_ptr) }; // SAFETY: We construct the `T` object in-place in the allocated private memory. - unsafe { data.__pinned_init(drvdata_ptr.cast())? }; + unsafe { data.__pinned_init(drvdata_ptr.cast()) }.inspect_err(|_| { + // SAFETY: It is safe to call `pwmchip_put()` with a valid pointer obtained + // from `pwmchip_alloc()`. We will not use pointer after this. + unsafe { bindings::pwmchip_put(c_chip_ptr) } + })?; // SAFETY: `c_chip_ptr` points to a valid chip. - unsafe { - (*c_chip_ptr).dev.release = Some(Adapter::<T>::release_callback); - } + unsafe { (*c_chip_ptr).dev.release = Some(Adapter::<T>::release_callback) }; // SAFETY: `c_chip_ptr` points to a valid chip. // The `Adapter`'s `VTABLE` has a 'static lifetime, so the pointer // returned by `as_raw()` is always valid. - unsafe { - (*c_chip_ptr).ops = Adapter::<T>::VTABLE.as_raw(); - } + unsafe { (*c_chip_ptr).ops = Adapter::<T>::VTABLE.as_raw() }; // Cast the `*mut bindings::pwm_chip` to `*mut Chip`. This is valid because // `Chip` is `repr(transparent)` over `Opaque<bindings::pwm_chip>`, and @@ -623,7 +622,9 @@ impl<T: PwmOps> Chip<T> { // SAFETY: `chip_ptr_as_self` points to a valid `Chip` (layout-compatible with // `bindings::pwm_chip`) whose embedded device has refcount 1. // `ARef::from_raw` takes this pointer and manages it via `AlwaysRefCounted`. - Ok(unsafe { ARef::from_raw(NonNull::new_unchecked(chip_ptr_as_self)) }) + let chip = unsafe { ARef::from_raw(NonNull::new_unchecked(chip_ptr_as_self)) }; + + Ok(UnregisteredChip { chip, parent_dev }) } } @@ -633,9 +634,7 @@ unsafe impl<T: PwmOps> AlwaysRefCounted for Chip<T> { fn inc_ref(&self) { // SAFETY: `self.0.get()` points to a valid `pwm_chip` because `self` exists. // The embedded `dev` is valid. `get_device` increments its refcount. - unsafe { - bindings::get_device(&raw mut (*self.0.get()).dev); - } + unsafe { bindings::get_device(&raw mut (*self.0.get()).dev) }; } #[inline] @@ -644,9 +643,7 @@ unsafe impl<T: PwmOps> AlwaysRefCounted for Chip<T> { // SAFETY: `obj` is a valid pointer to a `Chip` (and thus `bindings::pwm_chip`) // with a non-zero refcount. `put_device` handles decrement and final release. - unsafe { - bindings::put_device(&raw mut (*c_chip_ptr).dev); - } + unsafe { bindings::put_device(&raw mut (*c_chip_ptr).dev) }; } } @@ -654,50 +651,61 @@ unsafe impl<T: PwmOps> AlwaysRefCounted for Chip<T> { // structure's state is managed and synchronized by the kernel's device model // and PWM core locking mechanisms. Therefore, it is safe to move the `Chip` // wrapper (and the pointer it contains) across threads. -unsafe impl<T: PwmOps + Send> Send for Chip<T> {} +unsafe impl<T: PwmOps> Send for Chip<T> {} // SAFETY: It is safe for multiple threads to have shared access (`&Chip`) because // the `Chip` data is immutable from the Rust side without holding the appropriate // kernel locks, which the C core is responsible for. Any interior mutability is // handled and synchronized by the C kernel code. -unsafe impl<T: PwmOps + Sync> Sync for Chip<T> {} +unsafe impl<T: PwmOps> Sync for Chip<T> {} -/// A resource guard that ensures `pwmchip_remove` is called on drop. -/// -/// This struct is intended to be managed by the `devres` framework by transferring its ownership -/// via [`devres::register`]. This ties the lifetime of the PWM chip registration -/// to the lifetime of the underlying device. -pub struct Registration<T: PwmOps> { +/// A wrapper around `ARef<Chip<T>>` that ensures that `register` can only be called once. +pub struct UnregisteredChip<'a, T: PwmOps> { chip: ARef<Chip<T>>, + parent_dev: &'a device::Device<Bound>, } -impl<T: 'static + PwmOps + Send + Sync> Registration<T> { +impl<T: PwmOps> UnregisteredChip<'_, T> { /// Registers a PWM chip with the PWM subsystem. /// /// Transfers its ownership to the `devres` framework, which ties its lifetime /// to the parent device. /// On unbind of the parent device, the `devres` entry will be dropped, automatically /// calling `pwmchip_remove`. This function should be called from the driver's `probe`. - pub fn register(dev: &device::Device<Bound>, chip: ARef<Chip<T>>) -> Result { - let chip_parent = chip.device().parent().ok_or(EINVAL)?; - if dev.as_raw() != chip_parent.as_raw() { - return Err(EINVAL); - } - - let c_chip_ptr = chip.as_raw(); + pub fn register(self) -> Result<ARef<Chip<T>>> { + let c_chip_ptr = self.chip.as_raw(); // SAFETY: `c_chip_ptr` points to a valid chip with its ops initialized. // `__pwmchip_add` is the C function to register the chip with the PWM core. - unsafe { - to_result(bindings::__pwmchip_add(c_chip_ptr, core::ptr::null_mut()))?; - } + to_result(unsafe { bindings::__pwmchip_add(c_chip_ptr, core::ptr::null_mut()) })?; + + let registration = Registration { + chip: ARef::clone(&self.chip), + }; + + devres::register(self.parent_dev, registration, GFP_KERNEL)?; + + Ok(self.chip) + } +} - let registration = Registration { chip }; +impl<T: PwmOps> Deref for UnregisteredChip<'_, T> { + type Target = Chip<T>; - devres::register(dev, registration, GFP_KERNEL) + fn deref(&self) -> &Self::Target { + &self.chip } } +/// A resource guard that ensures `pwmchip_remove` is called on drop. +/// +/// This struct is intended to be managed by the `devres` framework by transferring its ownership +/// via [`devres::register`]. This ties the lifetime of the PWM chip registration +/// to the lifetime of the underlying device. +struct Registration<T: PwmOps> { + chip: ARef<Chip<T>>, +} + impl<T: PwmOps> Drop for Registration<T> { fn drop(&mut self) { let chip_raw = self.chip.as_raw(); @@ -705,9 +713,7 @@ impl<T: PwmOps> Drop for Registration<T> { // SAFETY: `chip_raw` points to a chip that was successfully registered. // `bindings::pwmchip_remove` is the correct C function to unregister it. // This `drop` implementation is called automatically by `devres` on driver unbind. - unsafe { - bindings::pwmchip_remove(chip_raw); - } + unsafe { bindings::pwmchip_remove(chip_raw) }; } } diff --git a/rust/kernel/rbtree.rs b/rust/kernel/rbtree.rs index 4729eb56827a..6fbd579d4a43 100644 --- a/rust/kernel/rbtree.rs +++ b/rust/kernel/rbtree.rs @@ -414,14 +414,17 @@ where // SAFETY: By the type invariant of `Self`, all non-null `rb_node` pointers stored in `self` // point to the links field of `Node<K, V>` objects. let this = unsafe { container_of!(node, Node<K, V>, links) }; + // SAFETY: `this` is a non-null node so it is valid by the type invariants. - node = match key.cmp(unsafe { &(*this).key }) { - // SAFETY: `node` is a non-null node so it is valid by the type invariants. - Ordering::Less => unsafe { (*node).rb_left }, - // SAFETY: `node` is a non-null node so it is valid by the type invariants. - Ordering::Greater => unsafe { (*node).rb_right }, - // SAFETY: `node` is a non-null node so it is valid by the type invariants. - Ordering::Equal => return Some(unsafe { &(*this).value }), + let this_ref = unsafe { &*this }; + + // SAFETY: `node` is a non-null node so it is valid by the type invariants. + let node_ref = unsafe { &*node }; + + node = match key.cmp(&this_ref.key) { + Ordering::Less => node_ref.rb_left, + Ordering::Greater => node_ref.rb_right, + Ordering::Equal => return Some(&this_ref.value), } } None @@ -498,10 +501,10 @@ where let this = unsafe { container_of!(node, Node<K, V>, links) }; // SAFETY: `this` is a non-null node so it is valid by the type invariants. let this_key = unsafe { &(*this).key }; + // SAFETY: `node` is a non-null node so it is valid by the type invariants. - let left_child = unsafe { (*node).rb_left }; - // SAFETY: `node` is a non-null node so it is valid by the type invariants. - let right_child = unsafe { (*node).rb_right }; + let node_ref = unsafe { &*node }; + match key.cmp(this_key) { Ordering::Equal => { // SAFETY: `this` is a non-null node so it is valid by the type invariants. @@ -509,7 +512,7 @@ where break; } Ordering::Greater => { - node = right_child; + node = node_ref.rb_right; } Ordering::Less => { let is_better_match = match best_key { @@ -521,7 +524,7 @@ where // SAFETY: `this` is a non-null node so it is valid by the type invariants. best_links = Some(unsafe { NonNull::new_unchecked(&mut (*this).links) }); } - node = left_child; + node = node_ref.rb_left; } }; } @@ -985,7 +988,7 @@ impl<'a, K, V> CursorMut<'a, K, V> { self.peek(Direction::Prev) } - /// Access the previous node without moving the cursor. + /// Access the next node without moving the cursor. pub fn peek_next(&self) -> Option<(&K, &V)> { self.peek(Direction::Next) } @@ -1130,7 +1133,7 @@ pub struct IterMut<'a, K, V> { } // SAFETY: The [`IterMut`] has exclusive access to both `K` and `V`, so it is sufficient to require them to be `Send`. -// The iterator only gives out immutable references to the keys, but since the iterator has excusive access to those same +// The iterator only gives out immutable references to the keys, but since the iterator has exclusive access to those same // keys, `Send` is sufficient. `Sync` would be okay, but it is more restrictive to the user. unsafe impl<'a, K: Send, V: Send> Send for IterMut<'a, K, V> {} diff --git a/rust/kernel/regulator.rs b/rust/kernel/regulator.rs index 2c44827ad0b7..4f7837c7e53a 100644 --- a/rust/kernel/regulator.rs +++ b/rust/kernel/regulator.rs @@ -122,12 +122,11 @@ pub fn devm_enable_optional(dev: &Device<Bound>, name: &CStr) -> Result { /// /// ``` /// # use kernel::prelude::*; -/// # use kernel::c_str; /// # use kernel::device::Device; /// # use kernel::regulator::{Voltage, Regulator, Disabled, Enabled}; /// fn enable(dev: &Device, min_voltage: Voltage, max_voltage: Voltage) -> Result { /// // Obtain a reference to a (fictitious) regulator. -/// let regulator: Regulator<Disabled> = Regulator::<Disabled>::get(dev, c_str!("vcc"))?; +/// let regulator: Regulator<Disabled> = Regulator::<Disabled>::get(dev, c"vcc")?; /// /// // The voltage can be set before enabling the regulator if needed, e.g.: /// regulator.set_voltage(min_voltage, max_voltage)?; @@ -166,12 +165,11 @@ pub fn devm_enable_optional(dev: &Device<Bound>, name: &CStr) -> Result { /// /// ``` /// # use kernel::prelude::*; -/// # use kernel::c_str; /// # use kernel::device::Device; /// # use kernel::regulator::{Voltage, Regulator, Enabled}; /// fn enable(dev: &Device) -> Result { /// // Obtain a reference to a (fictitious) regulator and enable it. -/// let regulator: Regulator<Enabled> = Regulator::<Enabled>::get(dev, c_str!("vcc"))?; +/// let regulator: Regulator<Enabled> = Regulator::<Enabled>::get(dev, c"vcc")?; /// /// // Dropping an enabled regulator will disable it. The refcount will be /// // decremented. @@ -193,13 +191,12 @@ pub fn devm_enable_optional(dev: &Device<Bound>, name: &CStr) -> Result { /// /// ``` /// # use kernel::prelude::*; -/// # use kernel::c_str; /// # use kernel::device::{Bound, Device}; /// # use kernel::regulator; /// fn enable(dev: &Device<Bound>) -> Result { /// // Obtain a reference to a (fictitious) regulator and enable it. This /// // call only returns whether the operation succeeded. -/// regulator::devm_enable(dev, c_str!("vcc"))?; +/// regulator::devm_enable(dev, c"vcc")?; /// /// // The regulator will be disabled and put when `dev` is unbound. /// Ok(()) diff --git a/rust/kernel/safety.rs b/rust/kernel/safety.rs new file mode 100644 index 000000000000..c1c6bd0fa2cc --- /dev/null +++ b/rust/kernel/safety.rs @@ -0,0 +1,53 @@ +// SPDX-License-Identifier: GPL-2.0 + +//! Safety related APIs. + +/// Checks that a precondition of an unsafe function is followed. +/// +/// The check is enabled at runtime if debug assertions (`CONFIG_RUST_DEBUG_ASSERTIONS`) +/// are enabled. Otherwise, this macro is a no-op. +/// +/// # Examples +/// +/// ```no_run +/// use kernel::unsafe_precondition_assert; +/// +/// struct RawBuffer<T: Copy, const N: usize> { +/// data: [T; N], +/// } +/// +/// impl<T: Copy, const N: usize> RawBuffer<T, N> { +/// /// # Safety +/// /// +/// /// The caller must ensure that `index` is less than `N`. +/// unsafe fn set_unchecked(&mut self, index: usize, value: T) { +/// unsafe_precondition_assert!( +/// index < N, +/// "RawBuffer::set_unchecked() requires index ({index}) < N ({N})" +/// ); +/// +/// // SAFETY: By the safety requirements of this function, `index` is valid. +/// unsafe { +/// *self.data.get_unchecked_mut(index) = value; +/// } +/// } +/// } +/// ``` +/// +/// # Panics +/// +/// Panics if the expression is evaluated to [`false`] at runtime. +#[macro_export] +macro_rules! unsafe_precondition_assert { + ($cond:expr $(,)?) => { + $crate::unsafe_precondition_assert!(@inner $cond, ::core::stringify!($cond)) + }; + + ($cond:expr, $($arg:tt)+) => { + $crate::unsafe_precondition_assert!(@inner $cond, $crate::prelude::fmt!($($arg)+)) + }; + + (@inner $cond:expr, $msg:expr) => { + ::core::debug_assert!($cond, "unsafe precondition violated: {}", $msg) + }; +} diff --git a/rust/kernel/scatterlist.rs b/rust/kernel/scatterlist.rs index 196fdb9a75e7..b83c468b5c63 100644 --- a/rust/kernel/scatterlist.rs +++ b/rust/kernel/scatterlist.rs @@ -38,7 +38,8 @@ use crate::{ io::ResourceSize, page, prelude::*, - types::{ARef, Opaque}, + sync::aref::ARef, + types::Opaque, }; use core::{ops::Deref, ptr::NonNull}; diff --git a/rust/kernel/seq_file.rs b/rust/kernel/seq_file.rs index 855e533813a6..518265558d66 100644 --- a/rust/kernel/seq_file.rs +++ b/rust/kernel/seq_file.rs @@ -4,7 +4,7 @@ //! //! C header: [`include/linux/seq_file.h`](srctree/include/linux/seq_file.h) -use crate::{bindings, c_str, fmt, str::CStrExt as _, types::NotThreadSafe, types::Opaque}; +use crate::{bindings, fmt, str::CStrExt as _, types::NotThreadSafe, types::Opaque}; /// A utility for generating the contents of a seq file. #[repr(transparent)] @@ -36,7 +36,7 @@ impl SeqFile { unsafe { bindings::seq_printf( self.inner.get(), - c_str!("%pA").as_char_ptr(), + c"%pA".as_char_ptr(), core::ptr::from_ref(&args).cast::<crate::ffi::c_void>(), ); } diff --git a/rust/kernel/soc.rs b/rust/kernel/soc.rs new file mode 100644 index 000000000000..0d6a36c83cb6 --- /dev/null +++ b/rust/kernel/soc.rs @@ -0,0 +1,135 @@ +// SPDX-License-Identifier: GPL-2.0 + +// Copyright (C) 2025 Google LLC. + +//! SoC Driver Abstraction. +//! +//! C header: [`include/linux/sys_soc.h`](srctree/include/linux/sys_soc.h) + +use crate::{ + bindings, + error, + prelude::*, + str::CString, + types::Opaque, // +}; +use core::ptr::NonNull; + +/// Attributes for a SoC device. +/// +/// These are both exported to userspace under /sys/devices/socX and provided to other drivers to +/// match against via `soc_device_match` (not yet available in Rust) to enable quirks or +/// device-specific support where necessary. +/// +/// All fields are freeform - they have no specific formatting, just defined meanings. +/// For example, the [`machine`](`Attributes::machine`) field could be "DB8500" or +/// "Qualcomm Technologies, Inc. SM8560 HDK", but regardless it should identify a board or product. +pub struct Attributes { + /// Should generally be a board ID or product ID. Examples + /// include DB8500 (ST-Ericsson) or "Qualcomm Technologies, inc. SM8560 HDK". + /// + /// If this field is not populated, the SoC infrastructure will try to populate it from + /// `/model` in the device tree. + pub machine: Option<CString>, + /// The broader class this SoC belongs to. Examples include ux500 + /// (for DB8500) or Snapdragon (for SM8650). + /// + /// On chips with ARM firmware supporting SMCCC v1.2+, this may be a JEDEC JEP106 manufacturer + /// identification. + pub family: Option<CString>, + /// The manufacturing revision of the part. Frequently this is MAJOR.MINOR, but not always. + pub revision: Option<CString>, + /// Serial Number - uniquely identifies a specific SoC. If present, should be unique (buying a + /// replacement part should change it if present). This field cannot be matched on and is + /// solely present to export through /sys. + pub serial_number: Option<CString>, + /// SoC ID - identifies a specific SoC kind in question, sometimes more specifically than + /// `machine` if the same SoC is used in multiple products. Some devices use this to specify a + /// SoC name, e.g. "I.MX??", and others just print an ID number (e.g. Tegra and Qualcomm). + /// + /// On chips with ARM firmware supporting SMCCC v1.2+, this may be a JEDEC JEP106 manufacturer + /// identification (the family value) followed by a colon and then a 4-digit ID value. + pub soc_id: Option<CString>, +} + +struct BuiltAttributes { + // While `inner` has pointers to `_backing`, it is to the interior of the `CStrings`, not + // `backing` itself, so it does not need to be pinned. + _backing: Attributes, + // `Opaque` makes us `!Unpin`, as the registration holds a pointer to `inner` when used. + inner: Opaque<bindings::soc_device_attribute>, +} + +fn cstring_to_c(mcs: &Option<CString>) -> *const kernel::ffi::c_char { + mcs.as_ref() + .map(|cs| cs.as_char_ptr()) + .unwrap_or(core::ptr::null()) +} + +impl BuiltAttributes { + fn as_mut_ptr(&self) -> *mut bindings::soc_device_attribute { + self.inner.get() + } +} + +impl Attributes { + fn build(self) -> BuiltAttributes { + BuiltAttributes { + inner: Opaque::new(bindings::soc_device_attribute { + machine: cstring_to_c(&self.machine), + family: cstring_to_c(&self.family), + revision: cstring_to_c(&self.revision), + serial_number: cstring_to_c(&self.serial_number), + soc_id: cstring_to_c(&self.soc_id), + data: core::ptr::null(), + custom_attr_group: core::ptr::null(), + }), + _backing: self, + } + } +} + +#[pin_data(PinnedDrop)] +/// Registration handle for your soc_dev. If you let it go out of scope, your soc_dev will be +/// unregistered. +pub struct Registration { + #[pin] + attr: BuiltAttributes, + soc_dev: NonNull<bindings::soc_device>, +} + +// SAFETY: We provide no operations through `&Registration`. +unsafe impl Sync for Registration {} + +// SAFETY: All pointers are normal allocations, not thread-specific. +unsafe impl Send for Registration {} + +#[pinned_drop] +impl PinnedDrop for Registration { + fn drop(self: Pin<&mut Self>) { + // SAFETY: Device always contains a live pointer to a soc_device that can be unregistered + unsafe { bindings::soc_device_unregister(self.soc_dev.as_ptr()) } + } +} + +impl Registration { + /// Register a new SoC device + pub fn new(attr: Attributes) -> impl PinInit<Self, Error> { + try_pin_init!(Self { + attr: attr.build(), + soc_dev: { + // SAFETY: + // * The struct provided through attr is backed by pinned data next to it, + // so as long as attr lives, the strings pointed to by the struct will too. + // * `attr` is pinned, so the pinned data won't move. + // * If it returns a device, and so others may try to read this data, by + // caller invariant, `attr` won't be released until the device is. + let raw_soc = error::from_err_ptr(unsafe { + bindings::soc_device_register(attr.as_mut_ptr()) + })?; + + NonNull::new(raw_soc).ok_or(EINVAL)? + }, + }? Error) + } +} diff --git a/rust/kernel/sync.rs b/rust/kernel/sync.rs index 5df87e2bd212..993dbf2caa0e 100644 --- a/rust/kernel/sync.rs +++ b/rust/kernel/sync.rs @@ -32,7 +32,9 @@ pub use locked_by::LockedBy; pub use refcount::Refcount; pub use set_once::SetOnce; -/// Represents a lockdep class. It's a wrapper around C's `lock_class_key`. +/// Represents a lockdep class. +/// +/// Wraps the kernel's `struct lock_class_key`. #[repr(transparent)] #[pin_data(PinnedDrop)] pub struct LockClassKey { @@ -40,20 +42,42 @@ pub struct LockClassKey { inner: Opaque<bindings::lock_class_key>, } +// SAFETY: Unregistering a lock class key from a different thread than where it was registered is +// allowed. +unsafe impl Send for LockClassKey {} + // SAFETY: `bindings::lock_class_key` is designed to be used concurrently from multiple threads and // provides its own synchronization. unsafe impl Sync for LockClassKey {} impl LockClassKey { - /// Initializes a dynamically allocated lock class key. In the common case of using a - /// statically allocated lock class key, the static_lock_class! macro should be used instead. + /// Initializes a statically allocated lock class key. + /// + /// This is usually used indirectly through the [`static_lock_class!`] macro. See its + /// documentation for more information. + /// + /// # Safety + /// + /// * Before using the returned value, it must be pinned in a static memory location. + /// * The destructor must never run on the returned `LockClassKey`. + pub const unsafe fn new_static() -> Self { + LockClassKey { + inner: Opaque::uninit(), + } + } + + /// Initializes a dynamically allocated lock class key. + /// + /// In the common case of using a statically allocated lock class key, the + /// [`static_lock_class!`] macro should be used instead. /// /// # Examples + /// /// ``` - /// # use kernel::alloc::KBox; - /// # use kernel::types::ForeignOwnable; - /// # use kernel::sync::{LockClassKey, SpinLock}; - /// # use pin_init::stack_pin_init; + /// use kernel::alloc::KBox; + /// use kernel::types::ForeignOwnable; + /// use kernel::sync::{LockClassKey, SpinLock}; + /// use pin_init::stack_pin_init; /// /// let key = KBox::pin_init(LockClassKey::new_dynamic(), GFP_KERNEL)?; /// let key_ptr = key.into_foreign(); @@ -71,7 +95,6 @@ impl LockClassKey { /// // SAFETY: We dropped `num`, the only use of the key, so the result of the previous /// // `borrow` has also been dropped. Thus, it's safe to use from_foreign. /// unsafe { drop(<Pin<KBox<LockClassKey>> as ForeignOwnable>::from_foreign(key_ptr)) }; - /// /// # Ok::<(), Error>(()) /// ``` pub fn new_dynamic() -> impl PinInit<Self> { @@ -81,7 +104,10 @@ impl LockClassKey { }) } - pub(crate) fn as_ptr(&self) -> *mut bindings::lock_class_key { + /// Returns a raw pointer to the inner C struct. + /// + /// It is up to the caller to use the raw pointer correctly. + pub fn as_ptr(&self) -> *mut bindings::lock_class_key { self.inner.get() } } @@ -89,27 +115,38 @@ impl LockClassKey { #[pinned_drop] impl PinnedDrop for LockClassKey { fn drop(self: Pin<&mut Self>) { - // SAFETY: self.as_ptr was registered with lockdep and self is pinned, so the address - // hasn't changed. Thus, it's safe to pass to unregister. + // SAFETY: `self.as_ptr()` was registered with lockdep and `self` is pinned, so the address + // hasn't changed. Thus, it's safe to pass it to unregister. unsafe { bindings::lockdep_unregister_key(self.as_ptr()) } } } /// Defines a new static lock class and returns a pointer to it. -#[doc(hidden)] +/// +/// # Examples +/// +/// ``` +/// use kernel::sync::{static_lock_class, Arc, SpinLock}; +/// +/// fn new_locked_int() -> Result<Arc<SpinLock<u32>>> { +/// Arc::pin_init(SpinLock::new( +/// 42, +/// c"new_locked_int", +/// static_lock_class!(), +/// ), GFP_KERNEL) +/// } +/// ``` #[macro_export] macro_rules! static_lock_class { () => {{ static CLASS: $crate::sync::LockClassKey = - // Lockdep expects uninitialized memory when it's handed a statically allocated `struct - // lock_class_key`. - // - // SAFETY: `LockClassKey` transparently wraps `Opaque` which permits uninitialized - // memory. - unsafe { ::core::mem::MaybeUninit::uninit().assume_init() }; + // SAFETY: The returned `LockClassKey` is stored in static memory and we pin it. Drop + // never runs on a static global. + unsafe { $crate::sync::LockClassKey::new_static() }; $crate::prelude::Pin::static_ref(&CLASS) }}; } +pub use static_lock_class; /// Returns the given string, if one is provided, otherwise generates one based on the source code /// location. diff --git a/rust/kernel/sync/arc.rs b/rust/kernel/sync/arc.rs index 289f77abf415..921e19333b89 100644 --- a/rust/kernel/sync/arc.rs +++ b/rust/kernel/sync/arc.rs @@ -240,6 +240,9 @@ impl<T> Arc<T> { // `Arc` object. Ok(unsafe { Self::from_inner(inner) }) } + + /// The offset that the value is stored at. + pub const DATA_OFFSET: usize = core::mem::offset_of!(ArcInner<T>, data); } impl<T: ?Sized> Arc<T> { diff --git a/rust/kernel/sync/aref.rs b/rust/kernel/sync/aref.rs index 0d24a0432015..0616c0353c2b 100644 --- a/rust/kernel/sync/aref.rs +++ b/rust/kernel/sync/aref.rs @@ -83,6 +83,9 @@ unsafe impl<T: AlwaysRefCounted + Sync + Send> Send for ARef<T> {} // example, when the reference count reaches zero and `T` is dropped. unsafe impl<T: AlwaysRefCounted + Sync + Send> Sync for ARef<T> {} +// Even if `T` is pinned, pointers to `T` can still move. +impl<T: AlwaysRefCounted> Unpin for ARef<T> {} + impl<T: AlwaysRefCounted> ARef<T> { /// Creates a new instance of [`ARef`]. /// diff --git a/rust/kernel/sync/atomic/internal.rs b/rust/kernel/sync/atomic/internal.rs index 6fdd8e59f45b..0dac58bca2b3 100644 --- a/rust/kernel/sync/atomic/internal.rs +++ b/rust/kernel/sync/atomic/internal.rs @@ -13,17 +13,22 @@ mod private { pub trait Sealed {} } -// `i32` and `i64` are only supported atomic implementations. +// The C side supports atomic primitives only for `i32` and `i64` (`atomic_t` and `atomic64_t`), +// while the Rust side also layers provides atomic support for `i8` and `i16` +// on top of lower-level C primitives. +impl private::Sealed for i8 {} +impl private::Sealed for i16 {} impl private::Sealed for i32 {} impl private::Sealed for i64 {} /// A marker trait for types that implement atomic operations with C side primitives. /// -/// This trait is sealed, and only types that have directly mapping to the C side atomics should -/// impl this: +/// This trait is sealed, and only types that map directly to the C side atomics +/// or can be implemented with lower-level C primitives are allowed to implement this: /// -/// - `i32` maps to `atomic_t`. -/// - `i64` maps to `atomic64_t`. +/// - `i8` and `i16` are implemented with lower-level C primitives. +/// - `i32` map to `atomic_t` +/// - `i64` map to `atomic64_t` pub trait AtomicImpl: Sized + Send + Copy + private::Sealed { /// The type of the delta in arithmetic or logical operations. /// @@ -32,6 +37,20 @@ pub trait AtomicImpl: Sized + Send + Copy + private::Sealed { type Delta; } +// The current helpers of load/store uses `{WRITE,READ}_ONCE()` hence the atomicity is only +// guaranteed against read-modify-write operations if the architecture supports native atomic RmW. +#[cfg(CONFIG_ARCH_SUPPORTS_ATOMIC_RMW)] +impl AtomicImpl for i8 { + type Delta = Self; +} + +// The current helpers of load/store uses `{WRITE,READ}_ONCE()` hence the atomicity is only +// guaranteed against read-modify-write operations if the architecture supports native atomic RmW. +#[cfg(CONFIG_ARCH_SUPPORTS_ATOMIC_RMW)] +impl AtomicImpl for i16 { + type Delta = Self; +} + // `atomic_t` implements atomic operations on `i32`. impl AtomicImpl for i32 { type Delta = Self; @@ -156,16 +175,17 @@ macro_rules! impl_atomic_method { } } -// Delcares $ops trait with methods and implements the trait for `i32` and `i64`. -macro_rules! declare_and_impl_atomic_methods { - ($(#[$attr:meta])* $pub:vis trait $ops:ident { - $( - $(#[doc=$doc:expr])* - fn $func:ident [$($variant:ident),*]($($arg_sig:tt)*) $( -> $ret:ty)? { - $unsafe:tt { bindings::#call($($arg:tt)*) } - } - )* - }) => { +macro_rules! declare_atomic_ops_trait { + ( + $(#[$attr:meta])* $pub:vis trait $ops:ident { + $( + $(#[doc=$doc:expr])* + fn $func:ident [$($variant:ident),*]($($arg_sig:tt)*) $( -> $ret:ty)? { + $unsafe:tt { bindings::#call($($arg:tt)*) } + } + )* + } + ) => { $(#[$attr])* $pub trait $ops: AtomicImpl { $( @@ -175,21 +195,25 @@ macro_rules! declare_and_impl_atomic_methods { ); )* } + } +} - impl $ops for i32 { +macro_rules! impl_atomic_ops_for_one { + ( + $ty:ty => $ctype:ident, + $(#[$attr:meta])* $pub:vis trait $ops:ident { $( - impl_atomic_method!( - (atomic) $func[$($variant)*]($($arg_sig)*) $(-> $ret)? { - $unsafe { call($($arg)*) } - } - ); + $(#[doc=$doc:expr])* + fn $func:ident [$($variant:ident),*]($($arg_sig:tt)*) $( -> $ret:ty)? { + $unsafe:tt { bindings::#call($($arg:tt)*) } + } )* } - - impl $ops for i64 { + ) => { + impl $ops for $ty { $( impl_atomic_method!( - (atomic64) $func[$($variant)*]($($arg_sig)*) $(-> $ret)? { + ($ctype) $func[$($variant)*]($($arg_sig)*) $(-> $ret)? { $unsafe { call($($arg)*) } } ); @@ -198,7 +222,47 @@ macro_rules! declare_and_impl_atomic_methods { } } +// Declares $ops trait with methods and implements the trait. +macro_rules! declare_and_impl_atomic_methods { + ( + [ $($map:tt)* ] + $(#[$attr:meta])* $pub:vis trait $ops:ident { $($body:tt)* } + ) => { + declare_and_impl_atomic_methods!( + @with_ops_def + [ $($map)* ] + ( $(#[$attr])* $pub trait $ops { $($body)* } ) + ); + }; + + (@with_ops_def [ $($map:tt)* ] ( $($ops_def:tt)* )) => { + declare_atomic_ops_trait!( $($ops_def)* ); + + declare_and_impl_atomic_methods!( + @munch + [ $($map)* ] + ( $($ops_def)* ) + ); + }; + + (@munch [] ( $($ops_def:tt)* )) => {}; + + (@munch [ $ty:ty => $ctype:ident $(, $($rest:tt)*)? ] ( $($ops_def:tt)* )) => { + impl_atomic_ops_for_one!( + $ty => $ctype, + $($ops_def)* + ); + + declare_and_impl_atomic_methods!( + @munch + [ $($($rest)*)? ] + ( $($ops_def)* ) + ); + }; +} + declare_and_impl_atomic_methods!( + [ i8 => atomic_i8, i16 => atomic_i16, i32 => atomic, i64 => atomic64 ] /// Basic atomic operations pub trait AtomicBasicOps { /// Atomic read (load). @@ -216,6 +280,7 @@ declare_and_impl_atomic_methods!( ); declare_and_impl_atomic_methods!( + [ i8 => atomic_i8, i16 => atomic_i16, i32 => atomic, i64 => atomic64 ] /// Exchange and compare-and-exchange atomic operations pub trait AtomicExchangeOps { /// Atomic exchange. @@ -243,6 +308,7 @@ declare_and_impl_atomic_methods!( ); declare_and_impl_atomic_methods!( + [ i32 => atomic, i64 => atomic64 ] /// Atomic arithmetic operations pub trait AtomicArithmeticOps { /// Atomic add (wrapping). diff --git a/rust/kernel/sync/atomic/predefine.rs b/rust/kernel/sync/atomic/predefine.rs index 45a17985cda4..67a0406d3ea4 100644 --- a/rust/kernel/sync/atomic/predefine.rs +++ b/rust/kernel/sync/atomic/predefine.rs @@ -5,6 +5,29 @@ use crate::static_assert; use core::mem::{align_of, size_of}; +// Ensure size and alignment requirements are checked. +static_assert!(size_of::<bool>() == size_of::<i8>()); +static_assert!(align_of::<bool>() == align_of::<i8>()); + +// SAFETY: `bool` has the same size and alignment as `i8`, and Rust guarantees that `bool` has +// only two valid bit patterns: 0 (false) and 1 (true). Those are valid `i8` values, so `bool` is +// round-trip transmutable to `i8`. +unsafe impl super::AtomicType for bool { + type Repr = i8; +} + +// SAFETY: `i8` has the same size and alignment with itself, and is round-trip transmutable to +// itself. +unsafe impl super::AtomicType for i8 { + type Repr = i8; +} + +// SAFETY: `i16` has the same size and alignment with itself, and is round-trip transmutable to +// itself. +unsafe impl super::AtomicType for i16 { + type Repr = i16; +} + // SAFETY: `i32` has the same size and alignment with itself, and is round-trip transmutable to // itself. unsafe impl super::AtomicType for i32 { @@ -35,12 +58,23 @@ unsafe impl super::AtomicAdd<i64> for i64 { // as `isize` and `usize`, and `isize` and `usize` are always bi-directional transmutable to // `isize_atomic_repr`, which also always implements `AtomicImpl`. #[allow(non_camel_case_types)] +#[cfg(not(testlib))] #[cfg(not(CONFIG_64BIT))] type isize_atomic_repr = i32; #[allow(non_camel_case_types)] +#[cfg(not(testlib))] #[cfg(CONFIG_64BIT)] type isize_atomic_repr = i64; +#[allow(non_camel_case_types)] +#[cfg(testlib)] +#[cfg(target_pointer_width = "32")] +type isize_atomic_repr = i32; +#[allow(non_camel_case_types)] +#[cfg(testlib)] +#[cfg(target_pointer_width = "64")] +type isize_atomic_repr = i64; + // Ensure size and alignment requirements are checked. static_assert!(size_of::<isize>() == size_of::<isize_atomic_repr>()); static_assert!(align_of::<isize>() == align_of::<isize_atomic_repr>()); @@ -118,7 +152,7 @@ mod tests { #[test] fn atomic_basic_tests() { - for_each_type!(42 in [i32, i64, u32, u64, isize, usize] |v| { + for_each_type!(42 in [i8, i16, i32, i64, u32, u64, isize, usize] |v| { let x = Atomic::new(v); assert_eq!(v, x.load(Relaxed)); @@ -126,8 +160,18 @@ mod tests { } #[test] + fn atomic_acquire_release_tests() { + for_each_type!(42 in [i8, i16, i32, i64, u32, u64, isize, usize] |v| { + let x = Atomic::new(0); + + x.store(v, Release); + assert_eq!(v, x.load(Acquire)); + }); + } + + #[test] fn atomic_xchg_tests() { - for_each_type!(42 in [i32, i64, u32, u64, isize, usize] |v| { + for_each_type!(42 in [i8, i16, i32, i64, u32, u64, isize, usize] |v| { let x = Atomic::new(v); let old = v; @@ -140,7 +184,7 @@ mod tests { #[test] fn atomic_cmpxchg_tests() { - for_each_type!(42 in [i32, i64, u32, u64, isize, usize] |v| { + for_each_type!(42 in [i8, i16, i32, i64, u32, u64, isize, usize] |v| { let x = Atomic::new(v); let old = v; @@ -166,4 +210,20 @@ mod tests { assert_eq!(v + 25, x.load(Relaxed)); }); } + + #[test] + fn atomic_bool_tests() { + let x = Atomic::new(false); + + assert_eq!(false, x.load(Relaxed)); + x.store(true, Relaxed); + assert_eq!(true, x.load(Relaxed)); + + assert_eq!(true, x.xchg(false, Relaxed)); + assert_eq!(false, x.load(Relaxed)); + + assert_eq!(Err(false), x.cmpxchg(true, true, Relaxed)); + assert_eq!(false, x.load(Relaxed)); + assert_eq!(Ok(false), x.cmpxchg(false, true, Full)); + } } diff --git a/rust/kernel/sync/lock.rs b/rust/kernel/sync/lock.rs index 46a57d1fc309..10b6b5e9b024 100644 --- a/rust/kernel/sync/lock.rs +++ b/rust/kernel/sync/lock.rs @@ -156,6 +156,7 @@ impl<B: Backend> Lock<(), B> { /// the whole lifetime of `'a`. /// /// [`State`]: Backend::State + #[inline] pub unsafe fn from_raw<'a>(ptr: *mut B::State) -> &'a Self { // SAFETY: // - By the safety contract `ptr` must point to a valid initialised instance of `B::State` @@ -169,6 +170,7 @@ impl<B: Backend> Lock<(), B> { impl<T: ?Sized, B: Backend> Lock<T, B> { /// Acquires the lock and gives the caller access to the data protected by it. + #[inline] pub fn lock(&self) -> Guard<'_, T, B> { // SAFETY: The constructor of the type calls `init`, so the existence of the object proves // that `init` was called. @@ -182,6 +184,7 @@ impl<T: ?Sized, B: Backend> Lock<T, B> { /// Returns a guard that can be used to access the data protected by the lock if successful. // `Option<T>` is not `#[must_use]` even if `T` is, thus the attribute is needed here. #[must_use = "if unused, the lock will be immediately unlocked"] + #[inline] pub fn try_lock(&self) -> Option<Guard<'_, T, B>> { // SAFETY: The constructor of the type calls `init`, so the existence of the object proves // that `init` was called. @@ -275,6 +278,7 @@ impl<'a, T: ?Sized, B: Backend> Guard<'a, T, B> { impl<T: ?Sized, B: Backend> core::ops::Deref for Guard<'_, T, B> { type Target = T; + #[inline] fn deref(&self) -> &Self::Target { // SAFETY: The caller owns the lock, so it is safe to deref the protected data. unsafe { &*self.lock.data.get() } @@ -285,6 +289,7 @@ impl<T: ?Sized, B: Backend> core::ops::DerefMut for Guard<'_, T, B> where T: Unpin, { + #[inline] fn deref_mut(&mut self) -> &mut Self::Target { // SAFETY: The caller owns the lock, so it is safe to deref the protected data. unsafe { &mut *self.lock.data.get() } @@ -292,6 +297,7 @@ where } impl<T: ?Sized, B: Backend> Drop for Guard<'_, T, B> { + #[inline] fn drop(&mut self) { // SAFETY: The caller owns the lock, so it is safe to unlock it. unsafe { B::unlock(self.lock.state.get(), &self.state) }; @@ -304,6 +310,7 @@ impl<'a, T: ?Sized, B: Backend> Guard<'a, T, B> { /// # Safety /// /// The caller must ensure that it owns the lock. + #[inline] pub unsafe fn new(lock: &'a Lock<T, B>, state: B::GuardState) -> Self { // SAFETY: The caller can only hold the lock if `Backend::init` has already been called. unsafe { B::assert_is_held(lock.state.get()) }; diff --git a/rust/kernel/sync/lock/global.rs b/rust/kernel/sync/lock/global.rs index eab48108a4ae..aecbdc34738f 100644 --- a/rust/kernel/sync/lock/global.rs +++ b/rust/kernel/sync/lock/global.rs @@ -77,6 +77,7 @@ impl<B: GlobalLockBackend> GlobalLock<B> { } /// Lock this global lock. + #[inline] pub fn lock(&'static self) -> GlobalGuard<B> { GlobalGuard { inner: self.inner.lock(), @@ -84,6 +85,7 @@ impl<B: GlobalLockBackend> GlobalLock<B> { } /// Try to lock this global lock. + #[inline] pub fn try_lock(&'static self) -> Option<GlobalGuard<B>> { Some(GlobalGuard { inner: self.inner.try_lock()?, diff --git a/rust/kernel/sync/lock/mutex.rs b/rust/kernel/sync/lock/mutex.rs index 581cee7ab842..cda0203efefb 100644 --- a/rust/kernel/sync/lock/mutex.rs +++ b/rust/kernel/sync/lock/mutex.rs @@ -102,6 +102,7 @@ unsafe impl super::Backend for MutexBackend { type State = bindings::mutex; type GuardState = (); + #[inline] unsafe fn init( ptr: *mut Self::State, name: *const crate::ffi::c_char, @@ -112,18 +113,21 @@ unsafe impl super::Backend for MutexBackend { unsafe { bindings::__mutex_init(ptr, name, key) } } + #[inline] unsafe fn lock(ptr: *mut Self::State) -> Self::GuardState { // SAFETY: The safety requirements of this function ensure that `ptr` points to valid // memory, and that it has been initialised before. unsafe { bindings::mutex_lock(ptr) }; } + #[inline] unsafe fn unlock(ptr: *mut Self::State, _guard_state: &Self::GuardState) { // SAFETY: The safety requirements of this function ensure that `ptr` is valid and that the // caller is the owner of the mutex. unsafe { bindings::mutex_unlock(ptr) }; } + #[inline] unsafe fn try_lock(ptr: *mut Self::State) -> Option<Self::GuardState> { // SAFETY: The `ptr` pointer is guaranteed to be valid and initialized before use. let result = unsafe { bindings::mutex_trylock(ptr) }; @@ -135,6 +139,7 @@ unsafe impl super::Backend for MutexBackend { } } + #[inline] unsafe fn assert_is_held(ptr: *mut Self::State) { // SAFETY: The `ptr` pointer is guaranteed to be valid and initialized before use. unsafe { bindings::mutex_assert_is_held(ptr) } diff --git a/rust/kernel/sync/lock/spinlock.rs b/rust/kernel/sync/lock/spinlock.rs index d7be38ccbdc7..ef76fa07ca3a 100644 --- a/rust/kernel/sync/lock/spinlock.rs +++ b/rust/kernel/sync/lock/spinlock.rs @@ -101,6 +101,7 @@ unsafe impl super::Backend for SpinLockBackend { type State = bindings::spinlock_t; type GuardState = (); + #[inline] unsafe fn init( ptr: *mut Self::State, name: *const crate::ffi::c_char, @@ -111,18 +112,21 @@ unsafe impl super::Backend for SpinLockBackend { unsafe { bindings::__spin_lock_init(ptr, name, key) } } + #[inline] unsafe fn lock(ptr: *mut Self::State) -> Self::GuardState { // SAFETY: The safety requirements of this function ensure that `ptr` points to valid // memory, and that it has been initialised before. unsafe { bindings::spin_lock(ptr) } } + #[inline] unsafe fn unlock(ptr: *mut Self::State, _guard_state: &Self::GuardState) { // SAFETY: The safety requirements of this function ensure that `ptr` is valid and that the // caller is the owner of the spinlock. unsafe { bindings::spin_unlock(ptr) } } + #[inline] unsafe fn try_lock(ptr: *mut Self::State) -> Option<Self::GuardState> { // SAFETY: The `ptr` pointer is guaranteed to be valid and initialized before use. let result = unsafe { bindings::spin_trylock(ptr) }; @@ -134,6 +138,7 @@ unsafe impl super::Backend for SpinLockBackend { } } + #[inline] unsafe fn assert_is_held(ptr: *mut Self::State) { // SAFETY: The `ptr` pointer is guaranteed to be valid and initialized before use. unsafe { bindings::spin_assert_is_held(ptr) } diff --git a/rust/kernel/sync/refcount.rs b/rust/kernel/sync/refcount.rs index 19236a5bccde..6c7ae8b05a0b 100644 --- a/rust/kernel/sync/refcount.rs +++ b/rust/kernel/sync/refcount.rs @@ -23,7 +23,8 @@ impl Refcount { /// Construct a new [`Refcount`] from an initial value. /// /// The initial value should be non-saturated. - #[inline] + // Always inline to optimize out error path of `build_assert`. + #[inline(always)] pub fn new(value: i32) -> Self { build_assert!(value >= 0, "initial value saturated"); // SAFETY: There are no safety requirements for this FFI call. diff --git a/rust/kernel/sync/set_once.rs b/rust/kernel/sync/set_once.rs index bdba601807d8..139cef05e935 100644 --- a/rust/kernel/sync/set_once.rs +++ b/rust/kernel/sync/set_once.rs @@ -123,3 +123,11 @@ impl<T> Drop for SetOnce<T> { } } } + +// SAFETY: `SetOnce` can be transferred across thread boundaries iff the data it contains can. +unsafe impl<T: Send> Send for SetOnce<T> {} + +// SAFETY: `SetOnce` synchronises access to the inner value via atomic operations, +// so shared references are safe when `T: Sync`. Since the inner `T` may be dropped +// on any thread, we also require `T: Send`. +unsafe impl<T: Send + Sync> Sync for SetOnce<T> {} diff --git a/rust/kernel/task.rs b/rust/kernel/task.rs index 49fad6de0674..cc907fb531bc 100644 --- a/rust/kernel/task.rs +++ b/rust/kernel/task.rs @@ -204,18 +204,6 @@ impl Task { self.0.get() } - /// Returns the group leader of the given task. - pub fn group_leader(&self) -> &Task { - // SAFETY: The group leader of a task never changes after initialization, so reading this - // field is not a data race. - let ptr = unsafe { *ptr::addr_of!((*self.as_ptr()).group_leader) }; - - // SAFETY: The lifetime of the returned task reference is tied to the lifetime of `self`, - // and given that a task has a reference to its group leader, we know it must be valid for - // the lifetime of the returned task reference. - unsafe { &*ptr.cast() } - } - /// Returns the PID of the given task. pub fn pid(&self) -> Pid { // SAFETY: The pid of a task never changes after initialization, so reading this field is @@ -345,6 +333,18 @@ impl CurrentTask { // `release_task()` call. Some(unsafe { PidNamespace::from_ptr(active_ns) }) } + + /// Returns the group leader of the current task. + pub fn group_leader(&self) -> &Task { + // SAFETY: The group leader of a task never changes while the task is running, and `self` + // is the current task, which is guaranteed running. + let ptr = unsafe { (*self.as_ptr()).group_leader }; + + // SAFETY: `current->group_leader` stays valid for at least the duration in which `current` + // is running, and the signature of this function ensures that the returned `&Task` can + // only be used while `current` is still valid, thus still running. + unsafe { &*ptr.cast() } + } } // SAFETY: The type invariants guarantee that `Task` is always refcounted. diff --git a/rust/kernel/transmute.rs b/rust/kernel/transmute.rs index be5dbf3829e2..5711580c9f9b 100644 --- a/rust/kernel/transmute.rs +++ b/rust/kernel/transmute.rs @@ -170,6 +170,10 @@ macro_rules! impl_frombytes { } impl_frombytes! { + // SAFETY: Inhabited ZSTs only have one possible bit pattern, and these two have no invariant. + (), + {<T>} core::marker::PhantomData<T>, + // SAFETY: All bit patterns are acceptable values of the types below. u8, u16, u32, u64, usize, i8, i16, i32, i64, isize, @@ -230,6 +234,10 @@ macro_rules! impl_asbytes { } impl_asbytes! { + // SAFETY: Inhabited ZSTs only have one possible bit pattern, and these two have no invariant. + (), + {<T>} core::marker::PhantomData<T>, + // SAFETY: Instances of the following types have no uninitialized portions. u8, u16, u32, u64, usize, i8, i16, i32, i64, isize, diff --git a/rust/kernel/usb.rs b/rust/kernel/usb.rs index d10b65e9fb6a..0e1b9a88f4f1 100644 --- a/rust/kernel/usb.rs +++ b/rust/kernel/usb.rs @@ -6,14 +6,23 @@ //! C header: [`include/linux/usb.h`](srctree/include/linux/usb.h) use crate::{ - bindings, device, - device_id::{RawDeviceId, RawDeviceIdIndex}, + bindings, + device, + device_id::{ + RawDeviceId, + RawDeviceIdIndex, // + }, driver, - error::{from_result, to_result, Result}, + error::{ + from_result, + to_result, // + }, prelude::*, - str::CStr, - types::{AlwaysRefCounted, Opaque}, - ThisModule, + types::{ + AlwaysRefCounted, + Opaque, // + }, + ThisModule, // }; use core::{ marker::PhantomData, @@ -27,13 +36,22 @@ use core::{ /// An adapter for the registration of USB drivers. pub struct Adapter<T: Driver>(T); -// SAFETY: A call to `unregister` for a given instance of `RegType` is guaranteed to be valid if +// SAFETY: +// - `bindings::usb_driver` is a C type declared as `repr(C)`. +// - `T` is the type of the driver's device private data. +// - `struct usb_driver` embeds a `struct device_driver`. +// - `DEVICE_DRIVER_OFFSET` is the correct byte offset to the embedded `struct device_driver`. +unsafe impl<T: Driver + 'static> driver::DriverLayout for Adapter<T> { + type DriverType = bindings::usb_driver; + type DriverData = T; + const DEVICE_DRIVER_OFFSET: usize = core::mem::offset_of!(Self::DriverType, driver); +} + +// SAFETY: A call to `unregister` for a given instance of `DriverType` is guaranteed to be valid if // a preceding call to `register` has been successful. unsafe impl<T: Driver + 'static> driver::RegistrationOps for Adapter<T> { - type RegType = bindings::usb_driver; - unsafe fn register( - udrv: &Opaque<Self::RegType>, + udrv: &Opaque<Self::DriverType>, name: &'static CStr, module: &'static ThisModule, ) -> Result { @@ -45,14 +63,14 @@ unsafe impl<T: Driver + 'static> driver::RegistrationOps for Adapter<T> { (*udrv.get()).id_table = T::ID_TABLE.as_ptr(); } - // SAFETY: `udrv` is guaranteed to be a valid `RegType`. + // SAFETY: `udrv` is guaranteed to be a valid `DriverType`. to_result(unsafe { bindings::usb_register_driver(udrv.get(), module.0, name.as_char_ptr()) }) } - unsafe fn unregister(udrv: &Opaque<Self::RegType>) { - // SAFETY: `udrv` is guaranteed to be a valid `RegType`. + unsafe fn unregister(udrv: &Opaque<Self::DriverType>) { + // SAFETY: `udrv` is guaranteed to be a valid `DriverType`. unsafe { bindings::usb_deregister(udrv.get()) }; } } @@ -94,9 +112,9 @@ impl<T: Driver + 'static> Adapter<T> { // SAFETY: `disconnect_callback` is only ever called after a successful call to // `probe_callback`, hence it's guaranteed that `Device::set_drvdata()` has been called // and stored a `Pin<KBox<T>>`. - let data = unsafe { dev.drvdata_obtain::<T>() }; + let data = unsafe { dev.drvdata_borrow::<T>() }; - T::disconnect(intf, data.as_ref()); + T::disconnect(intf, data); } } diff --git a/rust/macros/concat_idents.rs b/rust/macros/concat_idents.rs index 7e4b450f3a50..47b6add378d2 100644 --- a/rust/macros/concat_idents.rs +++ b/rust/macros/concat_idents.rs @@ -1,23 +1,36 @@ // SPDX-License-Identifier: GPL-2.0 -use proc_macro::{token_stream, Ident, TokenStream, TokenTree}; +use proc_macro2::{ + Ident, + TokenStream, + TokenTree, // +}; +use syn::{ + parse::{ + Parse, + ParseStream, // + }, + Result, + Token, // +}; -use crate::helpers::expect_punct; +pub(crate) struct Input { + a: Ident, + _comma: Token![,], + b: Ident, +} -fn expect_ident(it: &mut token_stream::IntoIter) -> Ident { - if let Some(TokenTree::Ident(ident)) = it.next() { - ident - } else { - panic!("Expected Ident") +impl Parse for Input { + fn parse(input: ParseStream<'_>) -> Result<Self> { + Ok(Self { + a: input.parse()?, + _comma: input.parse()?, + b: input.parse()?, + }) } } -pub(crate) fn concat_idents(ts: TokenStream) -> TokenStream { - let mut it = ts.into_iter(); - let a = expect_ident(&mut it); - assert_eq!(expect_punct(&mut it), ','); - let b = expect_ident(&mut it); - assert!(it.next().is_none(), "only two idents can be concatenated"); +pub(crate) fn concat_idents(Input { a, b, .. }: Input) -> TokenStream { let res = Ident::new(&format!("{a}{b}"), b.span()); TokenStream::from_iter([TokenTree::Ident(res)]) } diff --git a/rust/macros/export.rs b/rust/macros/export.rs index a08f6337d5c8..6d53521f62fc 100644 --- a/rust/macros/export.rs +++ b/rust/macros/export.rs @@ -1,19 +1,16 @@ // SPDX-License-Identifier: GPL-2.0 -use crate::helpers::function_name; -use proc_macro::TokenStream; +use proc_macro2::TokenStream; +use quote::quote; /// Please see [`crate::export`] for documentation. -pub(crate) fn export(_attr: TokenStream, ts: TokenStream) -> TokenStream { - let Some(name) = function_name(ts.clone()) else { - return "::core::compile_error!(\"The #[export] attribute must be used on a function.\");" - .parse::<TokenStream>() - .unwrap(); - }; +pub(crate) fn export(f: syn::ItemFn) -> TokenStream { + let name = &f.sig.ident; - // This verifies that the function has the same signature as the declaration generated by - // bindgen. It makes use of the fact that all branches of an if/else must have the same type. - let signature_check = quote!( + quote! { + // This verifies that the function has the same signature as the declaration generated by + // bindgen. It makes use of the fact that all branches of an if/else must have the same + // type. const _: () = { if true { ::kernel::bindings::#name @@ -21,9 +18,8 @@ pub(crate) fn export(_attr: TokenStream, ts: TokenStream) -> TokenStream { #name }; }; - ); - let no_mangle = quote!(#[no_mangle]); - - TokenStream::from_iter([signature_check, no_mangle, ts]) + #[no_mangle] + #f + } } diff --git a/rust/macros/fmt.rs b/rust/macros/fmt.rs index 2f4b9f6e2211..ce6c7249305a 100644 --- a/rust/macros/fmt.rs +++ b/rust/macros/fmt.rs @@ -1,8 +1,10 @@ // SPDX-License-Identifier: GPL-2.0 -use proc_macro::{Ident, TokenStream, TokenTree}; use std::collections::BTreeSet; +use proc_macro2::{Ident, TokenStream, TokenTree}; +use quote::quote_spanned; + /// Please see [`crate::fmt`] for documentation. pub(crate) fn fmt(input: TokenStream) -> TokenStream { let mut input = input.into_iter(); @@ -67,7 +69,7 @@ pub(crate) fn fmt(input: TokenStream) -> TokenStream { } (None, acc) })(); - args.extend(quote_spanned!(first_span => #lhs #adapter(&#rhs))); + args.extend(quote_spanned!(first_span => #lhs #adapter(&(#rhs)))); } }; diff --git a/rust/macros/helpers.rs b/rust/macros/helpers.rs index 365d7eb499c0..37ef6a6f2c85 100644 --- a/rust/macros/helpers.rs +++ b/rust/macros/helpers.rs @@ -1,101 +1,41 @@ // SPDX-License-Identifier: GPL-2.0 -use proc_macro::{token_stream, Group, Ident, TokenStream, TokenTree}; - -pub(crate) fn try_ident(it: &mut token_stream::IntoIter) -> Option<String> { - if let Some(TokenTree::Ident(ident)) = it.next() { - Some(ident.to_string()) - } else { - None - } -} - -pub(crate) fn try_sign(it: &mut token_stream::IntoIter) -> Option<char> { - let peek = it.clone().next(); - match peek { - Some(TokenTree::Punct(punct)) if punct.as_char() == '-' => { - let _ = it.next(); - Some(punct.as_char()) +use proc_macro2::TokenStream; +use quote::ToTokens; +use syn::{ + parse::{ + Parse, + ParseStream, // + }, + Attribute, + Error, + LitStr, + Result, // +}; + +/// A string literal that is required to have ASCII value only. +pub(crate) struct AsciiLitStr(LitStr); + +impl Parse for AsciiLitStr { + fn parse(input: ParseStream<'_>) -> Result<Self> { + let s: LitStr = input.parse()?; + if !s.value().is_ascii() { + return Err(Error::new_spanned(s, "expected ASCII-only string literal")); } - _ => None, - } -} - -pub(crate) fn try_literal(it: &mut token_stream::IntoIter) -> Option<String> { - if let Some(TokenTree::Literal(literal)) = it.next() { - Some(literal.to_string()) - } else { - None + Ok(Self(s)) } } -pub(crate) fn try_string(it: &mut token_stream::IntoIter) -> Option<String> { - try_literal(it).and_then(|string| { - if string.starts_with('\"') && string.ends_with('\"') { - let content = &string[1..string.len() - 1]; - if content.contains('\\') { - panic!("Escape sequences in string literals not yet handled"); - } - Some(content.to_string()) - } else if string.starts_with("r\"") { - panic!("Raw string literals are not yet handled"); - } else { - None - } - }) -} - -pub(crate) fn expect_ident(it: &mut token_stream::IntoIter) -> String { - try_ident(it).expect("Expected Ident") -} - -pub(crate) fn expect_punct(it: &mut token_stream::IntoIter) -> char { - if let TokenTree::Punct(punct) = it.next().expect("Reached end of token stream for Punct") { - punct.as_char() - } else { - panic!("Expected Punct"); +impl ToTokens for AsciiLitStr { + fn to_tokens(&self, ts: &mut TokenStream) { + self.0.to_tokens(ts); } } -pub(crate) fn expect_string(it: &mut token_stream::IntoIter) -> String { - try_string(it).expect("Expected string") -} - -pub(crate) fn expect_string_ascii(it: &mut token_stream::IntoIter) -> String { - let string = try_string(it).expect("Expected string"); - assert!(string.is_ascii(), "Expected ASCII string"); - string -} - -pub(crate) fn expect_group(it: &mut token_stream::IntoIter) -> Group { - if let TokenTree::Group(group) = it.next().expect("Reached end of token stream for Group") { - group - } else { - panic!("Expected Group"); - } -} - -pub(crate) fn expect_end(it: &mut token_stream::IntoIter) { - if it.next().is_some() { - panic!("Expected end"); - } -} - -/// Given a function declaration, finds the name of the function. -pub(crate) fn function_name(input: TokenStream) -> Option<Ident> { - let mut input = input.into_iter(); - while let Some(token) = input.next() { - match token { - TokenTree::Ident(i) if i.to_string() == "fn" => { - if let Some(TokenTree::Ident(i)) = input.next() { - return Some(i); - } - return None; - } - _ => continue, - } +impl AsciiLitStr { + pub(crate) fn value(&self) -> String { + self.0.value() } - None } pub(crate) fn file() -> String { @@ -115,16 +55,7 @@ pub(crate) fn file() -> String { } } -/// Parse a token stream of the form `expected_name: "value",` and return the -/// string in the position of "value". -/// -/// # Panics -/// -/// - On parse error. -pub(crate) fn expect_string_field(it: &mut token_stream::IntoIter, expected_name: &str) -> String { - assert_eq!(expect_ident(it), expected_name); - assert_eq!(expect_punct(it), ':'); - let string = expect_string(it); - assert_eq!(expect_punct(it), ','); - string +/// Obtain all `#[cfg]` attributes. +pub(crate) fn gather_cfg_attrs(attr: &[Attribute]) -> impl Iterator<Item = &Attribute> + '_ { + attr.iter().filter(|a| a.path().is_ident("cfg")) } diff --git a/rust/macros/kunit.rs b/rust/macros/kunit.rs index b395bb053695..6be880d634e2 100644 --- a/rust/macros/kunit.rs +++ b/rust/macros/kunit.rs @@ -4,80 +4,50 @@ //! //! Copyright (c) 2023 José Expósito <jose.exposito89@gmail.com> -use proc_macro::{Delimiter, Group, TokenStream, TokenTree}; -use std::collections::HashMap; -use std::fmt::Write; - -pub(crate) fn kunit_tests(attr: TokenStream, ts: TokenStream) -> TokenStream { - let attr = attr.to_string(); - - if attr.is_empty() { - panic!("Missing test name in `#[kunit_tests(test_name)]` macro") - } - - if attr.len() > 255 { - panic!("The test suite name `{attr}` exceeds the maximum length of 255 bytes") +use std::ffi::CString; + +use proc_macro2::TokenStream; +use quote::{ + format_ident, + quote, + ToTokens, // +}; +use syn::{ + parse_quote, + Error, + Ident, + Item, + ItemMod, + LitCStr, + Result, // +}; + +pub(crate) fn kunit_tests(test_suite: Ident, mut module: ItemMod) -> Result<TokenStream> { + if test_suite.to_string().len() > 255 { + return Err(Error::new_spanned( + test_suite, + "test suite names cannot exceed the maximum length of 255 bytes", + )); } - let mut tokens: Vec<_> = ts.into_iter().collect(); - - // Scan for the `mod` keyword. - tokens - .iter() - .find_map(|token| match token { - TokenTree::Ident(ident) => match ident.to_string().as_str() { - "mod" => Some(true), - _ => None, - }, - _ => None, - }) - .expect("`#[kunit_tests(test_name)]` attribute should only be applied to modules"); - - // Retrieve the main body. The main body should be the last token tree. - let body = match tokens.pop() { - Some(TokenTree::Group(group)) if group.delimiter() == Delimiter::Brace => group, - _ => panic!("Cannot locate main body of module"), + // We cannot handle modules that defer to another file (e.g. `mod foo;`). + let Some((module_brace, module_items)) = module.content.take() else { + Err(Error::new_spanned( + module, + "`#[kunit_tests(test_name)]` attribute should only be applied to inline modules", + ))? }; - // Get the functions set as tests. Search for `[test]` -> `fn`. - let mut body_it = body.stream().into_iter(); - let mut tests = Vec::new(); - let mut attributes: HashMap<String, TokenStream> = HashMap::new(); - while let Some(token) = body_it.next() { - match token { - TokenTree::Punct(ref p) if p.as_char() == '#' => match body_it.next() { - Some(TokenTree::Group(g)) if g.delimiter() == Delimiter::Bracket => { - if let Some(TokenTree::Ident(name)) = g.stream().into_iter().next() { - // Collect attributes because we need to find which are tests. We also - // need to copy `cfg` attributes so tests can be conditionally enabled. - attributes - .entry(name.to_string()) - .or_default() - .extend([token, TokenTree::Group(g)]); - } - continue; - } - _ => (), - }, - TokenTree::Ident(i) if i.to_string() == "fn" && attributes.contains_key("test") => { - if let Some(TokenTree::Ident(test_name)) = body_it.next() { - tests.push((test_name, attributes.remove("cfg").unwrap_or_default())) - } - } - - _ => (), - } - attributes.clear(); - } + // Make the entire module gated behind `CONFIG_KUNIT`. + module + .attrs + .insert(0, parse_quote!(#[cfg(CONFIG_KUNIT="y")])); - // Add `#[cfg(CONFIG_KUNIT="y")]` before the module declaration. - let config_kunit = "#[cfg(CONFIG_KUNIT=\"y\")]".to_owned().parse().unwrap(); - tokens.insert( - 0, - TokenTree::Group(Group::new(Delimiter::None, config_kunit)), - ); + let mut processed_items = Vec::new(); + let mut test_cases = Vec::new(); // Generate the test KUnit test suite and a test case for each `#[test]`. + // // The code generated for the following test module: // // ``` @@ -102,105 +72,100 @@ pub(crate) fn kunit_tests(attr: TokenStream, ts: TokenStream) -> TokenStream { // unsafe extern "C" fn kunit_rust_wrapper_bar(_test: *mut ::kernel::bindings::kunit) { bar(); } // // static mut TEST_CASES: [::kernel::bindings::kunit_case; 3] = [ - // ::kernel::kunit::kunit_case(::kernel::c_str!("foo"), kunit_rust_wrapper_foo), - // ::kernel::kunit::kunit_case(::kernel::c_str!("bar"), kunit_rust_wrapper_bar), - // ::kernel::kunit::kunit_case_null(), + // ::kernel::kunit::kunit_case(c"foo", kunit_rust_wrapper_foo), + // ::kernel::kunit::kunit_case(c"bar", kunit_rust_wrapper_bar), + // ::pin_init::zeroed(), // ]; // // ::kernel::kunit_unsafe_test_suite!(kunit_test_suit_name, TEST_CASES); // ``` - let mut kunit_macros = "".to_owned(); - let mut test_cases = "".to_owned(); - let mut assert_macros = "".to_owned(); - let path = crate::helpers::file(); - let num_tests = tests.len(); - for (test, cfg_attr) in tests { - let kunit_wrapper_fn_name = format!("kunit_rust_wrapper_{test}"); - // Append any `cfg` attributes the user might have written on their tests so we don't - // attempt to call them when they are `cfg`'d out. An extra `use` is used here to reduce - // the length of the assert message. - let kunit_wrapper = format!( - r#"unsafe extern "C" fn {kunit_wrapper_fn_name}(_test: *mut ::kernel::bindings::kunit) - {{ - (*_test).status = ::kernel::bindings::kunit_status_KUNIT_SKIPPED; - {cfg_attr} {{ - (*_test).status = ::kernel::bindings::kunit_status_KUNIT_SUCCESS; - use ::kernel::kunit::is_test_result_ok; - assert!(is_test_result_ok({test}())); + // + // Non-function items (e.g. imports) are preserved. + for item in module_items { + let Item::Fn(mut f) = item else { + processed_items.push(item); + continue; + }; + + // TODO: Replace below with `extract_if` when MSRV is bumped above 1.85. + let before_len = f.attrs.len(); + f.attrs.retain(|attr| !attr.path().is_ident("test")); + if f.attrs.len() == before_len { + processed_items.push(Item::Fn(f)); + continue; + } + + let test = f.sig.ident.clone(); + + // Retrieve `#[cfg]` applied on the function which needs to be present on derived items too. + let cfg_attrs: Vec<_> = f + .attrs + .iter() + .filter(|attr| attr.path().is_ident("cfg")) + .cloned() + .collect(); + + // Before the test, override usual `assert!` and `assert_eq!` macros with ones that call + // KUnit instead. + let test_str = test.to_string(); + let path = CString::new(crate::helpers::file()).expect("file path cannot contain NUL"); + processed_items.push(parse_quote! { + #[allow(unused)] + macro_rules! assert { + ($cond:expr $(,)?) => {{ + kernel::kunit_assert!(#test_str, #path, 0, $cond); + }} + } + }); + processed_items.push(parse_quote! { + #[allow(unused)] + macro_rules! assert_eq { + ($left:expr, $right:expr $(,)?) => {{ + kernel::kunit_assert_eq!(#test_str, #path, 0, $left, $right); }} - }}"#, + } + }); + + // Add back the test item. + processed_items.push(Item::Fn(f)); + + let kunit_wrapper_fn_name = format_ident!("kunit_rust_wrapper_{test}"); + let test_cstr = LitCStr::new( + &CString::new(test_str.as_str()).expect("identifier cannot contain NUL"), + test.span(), ); - writeln!(kunit_macros, "{kunit_wrapper}").unwrap(); - writeln!( - test_cases, - " ::kernel::kunit::kunit_case(::kernel::c_str!(\"{test}\"), {kunit_wrapper_fn_name})," - ) - .unwrap(); - writeln!( - assert_macros, - r#" -/// Overrides the usual [`assert!`] macro with one that calls KUnit instead. -#[allow(unused)] -macro_rules! assert {{ - ($cond:expr $(,)?) => {{{{ - kernel::kunit_assert!("{test}", "{path}", 0, $cond); - }}}} -}} - -/// Overrides the usual [`assert_eq!`] macro with one that calls KUnit instead. -#[allow(unused)] -macro_rules! assert_eq {{ - ($left:expr, $right:expr $(,)?) => {{{{ - kernel::kunit_assert_eq!("{test}", "{path}", 0, $left, $right); - }}}} -}} - "# - ) - .unwrap(); - } + processed_items.push(parse_quote! { + unsafe extern "C" fn #kunit_wrapper_fn_name(_test: *mut ::kernel::bindings::kunit) { + (*_test).status = ::kernel::bindings::kunit_status_KUNIT_SKIPPED; - writeln!(kunit_macros).unwrap(); - writeln!( - kunit_macros, - "static mut TEST_CASES: [::kernel::bindings::kunit_case; {}] = [\n{test_cases} ::kernel::kunit::kunit_case_null(),\n];", - num_tests + 1 - ) - .unwrap(); - - writeln!( - kunit_macros, - "::kernel::kunit_unsafe_test_suite!({attr}, TEST_CASES);" - ) - .unwrap(); - - // Remove the `#[test]` macros. - // We do this at a token level, in order to preserve span information. - let mut new_body = vec![]; - let mut body_it = body.stream().into_iter(); - - while let Some(token) = body_it.next() { - match token { - TokenTree::Punct(ref c) if c.as_char() == '#' => match body_it.next() { - Some(TokenTree::Group(group)) if group.to_string() == "[test]" => (), - Some(next) => { - new_body.extend([token, next]); - } - _ => { - new_body.push(token); + // Append any `cfg` attributes the user might have written on their tests so we + // don't attempt to call them when they are `cfg`'d out. An extra `use` is used + // here to reduce the length of the assert message. + #(#cfg_attrs)* + { + (*_test).status = ::kernel::bindings::kunit_status_KUNIT_SUCCESS; + use ::kernel::kunit::is_test_result_ok; + assert!(is_test_result_ok(#test())); } - }, - _ => { - new_body.push(token); } - } - } - - let mut final_body = TokenStream::new(); - final_body.extend::<TokenStream>(assert_macros.parse().unwrap()); - final_body.extend(new_body); - final_body.extend::<TokenStream>(kunit_macros.parse().unwrap()); + }); - tokens.push(TokenTree::Group(Group::new(Delimiter::Brace, final_body))); + test_cases.push(quote!( + ::kernel::kunit::kunit_case(#test_cstr, #kunit_wrapper_fn_name) + )); + } - tokens.into_iter().collect() + let num_tests_plus_1 = test_cases.len() + 1; + processed_items.push(parse_quote! { + static mut TEST_CASES: [::kernel::bindings::kunit_case; #num_tests_plus_1] = [ + #(#test_cases,)* + ::pin_init::zeroed(), + ]; + }); + processed_items.push(parse_quote! { + ::kernel::kunit_unsafe_test_suite!(#test_suite, TEST_CASES); + }); + + module.content = Some((module_brace, processed_items)); + Ok(module.to_token_stream()) } diff --git a/rust/macros/lib.rs b/rust/macros/lib.rs index b38002151871..0c36194d9971 100644 --- a/rust/macros/lib.rs +++ b/rust/macros/lib.rs @@ -11,8 +11,6 @@ // to avoid depending on the full `proc_macro_span` on Rust >= 1.88.0. #![cfg_attr(not(CONFIG_RUSTC_HAS_SPAN_FILE), feature(proc_macro_span))] -#[macro_use] -mod quote; mod concat_idents; mod export; mod fmt; @@ -24,6 +22,8 @@ mod vtable; use proc_macro::TokenStream; +use syn::parse_macro_input; + /// Declares a kernel module. /// /// The `type` argument should be a type which implements the [`Module`] @@ -59,7 +59,7 @@ use proc_macro::TokenStream; /// /// # Examples /// -/// ``` +/// ```ignore /// use kernel::prelude::*; /// /// module!{ @@ -131,8 +131,10 @@ use proc_macro::TokenStream; /// - `firmware`: array of ASCII string literals of the firmware files of /// the kernel module. #[proc_macro] -pub fn module(ts: TokenStream) -> TokenStream { - module::module(ts) +pub fn module(input: TokenStream) -> TokenStream { + module::module(parse_macro_input!(input)) + .unwrap_or_else(|e| e.into_compile_error()) + .into() } /// Declares or implements a vtable trait. @@ -154,7 +156,7 @@ pub fn module(ts: TokenStream) -> TokenStream { /// case the default implementation will never be executed. The reason for this /// is that the functions will be called through function pointers installed in /// C side vtables. When an optional method is not implemented on a `#[vtable]` -/// trait, a NULL entry is installed in the vtable. Thus the default +/// trait, a `NULL` entry is installed in the vtable. Thus the default /// implementation is never called. Since these traits are not designed to be /// used on the Rust side, it should not be possible to call the default /// implementation. This is done to ensure that we call the vtable methods @@ -206,8 +208,11 @@ pub fn module(ts: TokenStream) -> TokenStream { /// /// [`kernel::error::VTABLE_DEFAULT_ERROR`]: ../kernel/error/constant.VTABLE_DEFAULT_ERROR.html #[proc_macro_attribute] -pub fn vtable(attr: TokenStream, ts: TokenStream) -> TokenStream { - vtable::vtable(attr, ts) +pub fn vtable(attr: TokenStream, input: TokenStream) -> TokenStream { + parse_macro_input!(attr as syn::parse::Nothing); + vtable::vtable(parse_macro_input!(input)) + .unwrap_or_else(|e| e.into_compile_error()) + .into() } /// Export a function so that C code can call it via a header file. @@ -229,8 +234,9 @@ pub fn vtable(attr: TokenStream, ts: TokenStream) -> TokenStream { /// This macro is *not* the same as the C macros `EXPORT_SYMBOL_*`. All Rust symbols are currently /// automatically exported with `EXPORT_SYMBOL_GPL`. #[proc_macro_attribute] -pub fn export(attr: TokenStream, ts: TokenStream) -> TokenStream { - export::export(attr, ts) +pub fn export(attr: TokenStream, input: TokenStream) -> TokenStream { + parse_macro_input!(attr as syn::parse::Nothing); + export::export(parse_macro_input!(input)).into() } /// Like [`core::format_args!`], but automatically wraps arguments in [`kernel::fmt::Adapter`]. @@ -248,7 +254,7 @@ pub fn export(attr: TokenStream, ts: TokenStream) -> TokenStream { /// [`pr_info!`]: ../kernel/macro.pr_info.html #[proc_macro] pub fn fmt(input: TokenStream) -> TokenStream { - fmt::fmt(input) + fmt::fmt(input.into()).into() } /// Concatenate two identifiers. @@ -305,8 +311,8 @@ pub fn fmt(input: TokenStream) -> TokenStream { /// assert_eq!(BR_OK, binder_driver_return_protocol_BR_OK); /// ``` #[proc_macro] -pub fn concat_idents(ts: TokenStream) -> TokenStream { - concat_idents::concat_idents(ts) +pub fn concat_idents(input: TokenStream) -> TokenStream { + concat_idents::concat_idents(parse_macro_input!(input)).into() } /// Paste identifiers together. @@ -444,9 +450,12 @@ pub fn concat_idents(ts: TokenStream) -> TokenStream { /// [`paste`]: https://docs.rs/paste/ #[proc_macro] pub fn paste(input: TokenStream) -> TokenStream { - let mut tokens = input.into_iter().collect(); + let mut tokens = proc_macro2::TokenStream::from(input).into_iter().collect(); paste::expand(&mut tokens); - tokens.into_iter().collect() + tokens + .into_iter() + .collect::<proc_macro2::TokenStream>() + .into() } /// Registers a KUnit test suite and its test cases using a user-space like syntax. @@ -472,6 +481,8 @@ pub fn paste(input: TokenStream) -> TokenStream { /// } /// ``` #[proc_macro_attribute] -pub fn kunit_tests(attr: TokenStream, ts: TokenStream) -> TokenStream { - kunit::kunit_tests(attr, ts) +pub fn kunit_tests(attr: TokenStream, input: TokenStream) -> TokenStream { + kunit::kunit_tests(parse_macro_input!(attr), parse_macro_input!(input)) + .unwrap_or_else(|e| e.into_compile_error()) + .into() } diff --git a/rust/macros/module.rs b/rust/macros/module.rs index 80cb9b16f5aa..e16298e520c7 100644 --- a/rust/macros/module.rs +++ b/rust/macros/module.rs @@ -1,32 +1,42 @@ // SPDX-License-Identifier: GPL-2.0 +use std::ffi::CString; + +use proc_macro2::{ + Literal, + TokenStream, // +}; +use quote::{ + format_ident, + quote, // +}; +use syn::{ + braced, + bracketed, + ext::IdentExt, + parse::{ + Parse, + ParseStream, // + }, + parse_quote, + punctuated::Punctuated, + Error, + Expr, + Ident, + LitStr, + Path, + Result, + Token, + Type, // +}; + use crate::helpers::*; -use proc_macro::{token_stream, Delimiter, Literal, TokenStream, TokenTree}; -use std::fmt::Write; - -fn expect_string_array(it: &mut token_stream::IntoIter) -> Vec<String> { - let group = expect_group(it); - assert_eq!(group.delimiter(), Delimiter::Bracket); - let mut values = Vec::new(); - let mut it = group.stream().into_iter(); - - while let Some(val) = try_string(&mut it) { - assert!(val.is_ascii(), "Expected ASCII string"); - values.push(val); - match it.next() { - Some(TokenTree::Punct(punct)) => assert_eq!(punct.as_char(), ','), - None => break, - _ => panic!("Expected ',' or end of array"), - } - } - values -} struct ModInfoBuilder<'a> { module: &'a str, counter: usize, - buffer: String, - param_buffer: String, + ts: TokenStream, + param_ts: TokenStream, } impl<'a> ModInfoBuilder<'a> { @@ -34,8 +44,8 @@ impl<'a> ModInfoBuilder<'a> { ModInfoBuilder { module, counter: 0, - buffer: String::new(), - param_buffer: String::new(), + ts: TokenStream::new(), + param_ts: TokenStream::new(), } } @@ -52,33 +62,31 @@ impl<'a> ModInfoBuilder<'a> { // Loadable modules' modinfo strings go as-is. format!("{field}={content}\0") }; - - let buffer = if param { - &mut self.param_buffer + let length = string.len(); + let string = Literal::byte_string(string.as_bytes()); + let cfg = if builtin { + quote!(#[cfg(not(MODULE))]) } else { - &mut self.buffer + quote!(#[cfg(MODULE)]) }; - write!( - buffer, - " - {cfg} - #[doc(hidden)] - #[cfg_attr(not(target_os = \"macos\"), link_section = \".modinfo\")] - #[used(compiler)] - pub static __{module}_{counter}: [u8; {length}] = *{string}; - ", - cfg = if builtin { - "#[cfg(not(MODULE))]" - } else { - "#[cfg(MODULE)]" - }, + let counter = format_ident!( + "__{module}_{counter}", module = self.module.to_uppercase(), - counter = self.counter, - length = string.len(), - string = Literal::byte_string(string.as_bytes()), - ) - .unwrap(); + counter = self.counter + ); + let item = quote! { + #cfg + #[cfg_attr(not(target_os = "macos"), link_section = ".modinfo")] + #[used(compiler)] + pub static #counter: [u8; #length] = *#string; + }; + + if param { + self.param_ts.extend(item); + } else { + self.ts.extend(item); + } self.counter += 1; } @@ -111,201 +119,160 @@ impl<'a> ModInfoBuilder<'a> { }; for param in params { - let ops = param_ops_path(¶m.ptype); + let param_name_str = param.name.to_string(); + let param_type_str = param.ptype.to_string(); + + let ops = param_ops_path(¶m_type_str); // Note: The spelling of these fields is dictated by the user space // tool `modinfo`. - self.emit_param("parmtype", ¶m.name, ¶m.ptype); - self.emit_param("parm", ¶m.name, ¶m.description); - - write!( - self.param_buffer, - " - pub(crate) static {param_name}: - ::kernel::module_param::ModuleParamAccess<{param_type}> = - ::kernel::module_param::ModuleParamAccess::new({param_default}); - - const _: () = {{ - #[link_section = \"__param\"] - #[used] - static __{module_name}_{param_name}_struct: + self.emit_param("parmtype", ¶m_name_str, ¶m_type_str); + self.emit_param("parm", ¶m_name_str, ¶m.description.value()); + + let static_name = format_ident!("__{}_{}_struct", self.module, param.name); + let param_name_cstr = + CString::new(param_name_str).expect("name contains NUL-terminator"); + let param_name_cstr_with_module = + CString::new(format!("{}.{}", self.module, param.name)) + .expect("name contains NUL-terminator"); + + let param_name = ¶m.name; + let param_type = ¶m.ptype; + let param_default = ¶m.default; + + self.param_ts.extend(quote! { + #[allow(non_upper_case_globals)] + pub(crate) static #param_name: + ::kernel::module_param::ModuleParamAccess<#param_type> = + ::kernel::module_param::ModuleParamAccess::new(#param_default); + + const _: () = { + #[allow(non_upper_case_globals)] + #[link_section = "__param"] + #[used(compiler)] + static #static_name: ::kernel::module_param::KernelParam = ::kernel::module_param::KernelParam::new( - ::kernel::bindings::kernel_param {{ - name: if ::core::cfg!(MODULE) {{ - ::kernel::c_str!(\"{param_name}\").to_bytes_with_nul() - }} else {{ - ::kernel::c_str!(\"{module_name}.{param_name}\") - .to_bytes_with_nul() - }}.as_ptr(), + ::kernel::bindings::kernel_param { + name: kernel::str::as_char_ptr_in_const_context( + if ::core::cfg!(MODULE) { + #param_name_cstr + } else { + #param_name_cstr_with_module + } + ), // SAFETY: `__this_module` is constructed by the kernel at load // time and will not be freed until the module is unloaded. #[cfg(MODULE)] - mod_: unsafe {{ + mod_: unsafe { core::ptr::from_ref(&::kernel::bindings::__this_module) .cast_mut() - }}, + }, #[cfg(not(MODULE))] mod_: ::core::ptr::null_mut(), - ops: core::ptr::from_ref(&{ops}), + ops: core::ptr::from_ref(&#ops), perm: 0, // Will not appear in sysfs level: -1, flags: 0, - __bindgen_anon_1: ::kernel::bindings::kernel_param__bindgen_ty_1 {{ - arg: {param_name}.as_void_ptr() - }}, - }} + __bindgen_anon_1: ::kernel::bindings::kernel_param__bindgen_ty_1 { + arg: #param_name.as_void_ptr() + }, + } ); - }}; - ", - module_name = info.name, - param_type = param.ptype, - param_default = param.default, - param_name = param.name, - ops = ops, - ) - .unwrap(); + }; + }); } } } -fn param_ops_path(param_type: &str) -> &'static str { +fn param_ops_path(param_type: &str) -> Path { match param_type { - "i8" => "::kernel::module_param::PARAM_OPS_I8", - "u8" => "::kernel::module_param::PARAM_OPS_U8", - "i16" => "::kernel::module_param::PARAM_OPS_I16", - "u16" => "::kernel::module_param::PARAM_OPS_U16", - "i32" => "::kernel::module_param::PARAM_OPS_I32", - "u32" => "::kernel::module_param::PARAM_OPS_U32", - "i64" => "::kernel::module_param::PARAM_OPS_I64", - "u64" => "::kernel::module_param::PARAM_OPS_U64", - "isize" => "::kernel::module_param::PARAM_OPS_ISIZE", - "usize" => "::kernel::module_param::PARAM_OPS_USIZE", + "i8" => parse_quote!(::kernel::module_param::PARAM_OPS_I8), + "u8" => parse_quote!(::kernel::module_param::PARAM_OPS_U8), + "i16" => parse_quote!(::kernel::module_param::PARAM_OPS_I16), + "u16" => parse_quote!(::kernel::module_param::PARAM_OPS_U16), + "i32" => parse_quote!(::kernel::module_param::PARAM_OPS_I32), + "u32" => parse_quote!(::kernel::module_param::PARAM_OPS_U32), + "i64" => parse_quote!(::kernel::module_param::PARAM_OPS_I64), + "u64" => parse_quote!(::kernel::module_param::PARAM_OPS_U64), + "isize" => parse_quote!(::kernel::module_param::PARAM_OPS_ISIZE), + "usize" => parse_quote!(::kernel::module_param::PARAM_OPS_USIZE), t => panic!("Unsupported parameter type {}", t), } } -fn expect_param_default(param_it: &mut token_stream::IntoIter) -> String { - assert_eq!(expect_ident(param_it), "default"); - assert_eq!(expect_punct(param_it), ':'); - let sign = try_sign(param_it); - let default = try_literal(param_it).expect("Expected default param value"); - assert_eq!(expect_punct(param_it), ','); - let mut value = sign.map(String::from).unwrap_or_default(); - value.push_str(&default); - value -} - -#[derive(Debug, Default)] -struct ModuleInfo { - type_: String, - license: String, - name: String, - authors: Option<Vec<String>>, - description: Option<String>, - alias: Option<Vec<String>>, - firmware: Option<Vec<String>>, - imports_ns: Option<Vec<String>>, - params: Option<Vec<Parameter>>, -} - -#[derive(Debug)] -struct Parameter { - name: String, - ptype: String, - default: String, - description: String, -} - -fn expect_params(it: &mut token_stream::IntoIter) -> Vec<Parameter> { - let params = expect_group(it); - assert_eq!(params.delimiter(), Delimiter::Brace); - let mut it = params.stream().into_iter(); - let mut parsed = Vec::new(); - - loop { - let param_name = match it.next() { - Some(TokenTree::Ident(ident)) => ident.to_string(), - Some(_) => panic!("Expected Ident or end"), - None => break, - }; - - assert_eq!(expect_punct(&mut it), ':'); - let param_type = expect_ident(&mut it); - let group = expect_group(&mut it); - assert_eq!(group.delimiter(), Delimiter::Brace); - assert_eq!(expect_punct(&mut it), ','); - - let mut param_it = group.stream().into_iter(); - let param_default = expect_param_default(&mut param_it); - let param_description = expect_string_field(&mut param_it, "description"); - expect_end(&mut param_it); - - parsed.push(Parameter { - name: param_name, - ptype: param_type, - default: param_default, - description: param_description, - }) - } - - parsed -} - -impl ModuleInfo { - fn parse(it: &mut token_stream::IntoIter) -> Self { - let mut info = ModuleInfo::default(); - - const EXPECTED_KEYS: &[&str] = &[ - "type", - "name", - "authors", - "description", - "license", - "alias", - "firmware", - "imports_ns", - "params", - ]; - const REQUIRED_KEYS: &[&str] = &["type", "name", "license"]; +/// Parse fields that are required to use a specific order. +/// +/// As fields must follow a specific order, we *could* just parse fields one by one by peeking. +/// However the error message generated when implementing that way is not very friendly. +/// +/// So instead we parse fields in an arbitrary order, but only enforce the ordering after parsing, +/// and if the wrong order is used, the proper order is communicated to the user with error message. +/// +/// Usage looks like this: +/// ```ignore +/// parse_ordered_fields! { +/// from input; +/// +/// // This will extract "foo: <field>" into a variable named "foo". +/// // The variable will have type `Option<_>`. +/// foo => <expression that parses the field>, +/// +/// // If you need the variable name to be different than the key name. +/// // This extracts "baz: <field>" into a variable named "bar". +/// // You might want this if "baz" is a keyword. +/// baz as bar => <expression that parse the field>, +/// +/// // You can mark a key as required, and the variable will no longer be `Option`. +/// // foobar will be of type `Expr` instead of `Option<Expr>`. +/// foobar [required] => input.parse::<Expr>()?, +/// } +/// ``` +macro_rules! parse_ordered_fields { + (@gen + [$input:expr] + [$([$name:ident; $key:ident; $parser:expr])*] + [$([$req_name:ident; $req_key:ident])*] + ) => { + $(let mut $name = None;)* + + const EXPECTED_KEYS: &[&str] = &[$(stringify!($key),)*]; + const REQUIRED_KEYS: &[&str] = &[$(stringify!($req_key),)*]; + + let span = $input.span(); let mut seen_keys = Vec::new(); - loop { - let key = match it.next() { - Some(TokenTree::Ident(ident)) => ident.to_string(), - Some(_) => panic!("Expected Ident or end"), - None => break, - }; + while !$input.is_empty() { + let key = $input.call(Ident::parse_any)?; if seen_keys.contains(&key) { - panic!("Duplicated key \"{key}\". Keys can only be specified once."); + Err(Error::new_spanned( + &key, + format!(r#"duplicated key "{key}". Keys can only be specified once."#), + ))? } - assert_eq!(expect_punct(it), ':'); - - match key.as_str() { - "type" => info.type_ = expect_ident(it), - "name" => info.name = expect_string_ascii(it), - "authors" => info.authors = Some(expect_string_array(it)), - "description" => info.description = Some(expect_string(it)), - "license" => info.license = expect_string_ascii(it), - "alias" => info.alias = Some(expect_string_array(it)), - "firmware" => info.firmware = Some(expect_string_array(it)), - "imports_ns" => info.imports_ns = Some(expect_string_array(it)), - "params" => info.params = Some(expect_params(it)), - _ => panic!("Unknown key \"{key}\". Valid keys are: {EXPECTED_KEYS:?}."), + $input.parse::<Token![:]>()?; + + match &*key.to_string() { + $( + stringify!($key) => $name = Some($parser), + )* + _ => { + Err(Error::new_spanned( + &key, + format!(r#"unknown key "{key}". Valid keys are: {EXPECTED_KEYS:?}."#), + ))? + } } - assert_eq!(expect_punct(it), ','); - + $input.parse::<Token![,]>()?; seen_keys.push(key); } - expect_end(it); - for key in REQUIRED_KEYS { if !seen_keys.iter().any(|e| e == key) { - panic!("Missing required key \"{key}\"."); + Err(Error::new(span, format!(r#"missing required key "{key}""#)))? } } @@ -317,43 +284,190 @@ impl ModuleInfo { } if seen_keys != ordered_keys { - panic!("Keys are not ordered as expected. Order them like: {ordered_keys:?}."); + Err(Error::new( + span, + format!(r#"keys are not ordered as expected. Order them like: {ordered_keys:?}."#), + ))? + } + + $(let $req_name = $req_name.expect("required field");)* + }; + + // Handle required fields. + (@gen + [$input:expr] [$($tok:tt)*] [$($req:tt)*] + $key:ident as $name:ident [required] => $parser:expr, + $($rest:tt)* + ) => { + parse_ordered_fields!( + @gen [$input] [$($tok)* [$name; $key; $parser]] [$($req)* [$name; $key]] $($rest)* + ) + }; + (@gen + [$input:expr] [$($tok:tt)*] [$($req:tt)*] + $name:ident [required] => $parser:expr, + $($rest:tt)* + ) => { + parse_ordered_fields!( + @gen [$input] [$($tok)* [$name; $name; $parser]] [$($req)* [$name; $name]] $($rest)* + ) + }; + + // Handle optional fields. + (@gen + [$input:expr] [$($tok:tt)*] [$($req:tt)*] + $key:ident as $name:ident => $parser:expr, + $($rest:tt)* + ) => { + parse_ordered_fields!( + @gen [$input] [$($tok)* [$name; $key; $parser]] [$($req)*] $($rest)* + ) + }; + (@gen + [$input:expr] [$($tok:tt)*] [$($req:tt)*] + $name:ident => $parser:expr, + $($rest:tt)* + ) => { + parse_ordered_fields!( + @gen [$input] [$($tok)* [$name; $name; $parser]] [$($req)*] $($rest)* + ) + }; + + (from $input:expr; $($tok:tt)*) => { + parse_ordered_fields!(@gen [$input] [] [] $($tok)*) + } +} + +struct Parameter { + name: Ident, + ptype: Ident, + default: Expr, + description: LitStr, +} + +impl Parse for Parameter { + fn parse(input: ParseStream<'_>) -> Result<Self> { + let name = input.parse()?; + input.parse::<Token![:]>()?; + let ptype = input.parse()?; + + let fields; + braced!(fields in input); + + parse_ordered_fields! { + from fields; + default [required] => fields.parse()?, + description [required] => fields.parse()?, } - info + Ok(Self { + name, + ptype, + default, + description, + }) } } -pub(crate) fn module(ts: TokenStream) -> TokenStream { - let mut it = ts.into_iter(); +pub(crate) struct ModuleInfo { + type_: Type, + license: AsciiLitStr, + name: AsciiLitStr, + authors: Option<Punctuated<AsciiLitStr, Token![,]>>, + description: Option<LitStr>, + alias: Option<Punctuated<AsciiLitStr, Token![,]>>, + firmware: Option<Punctuated<AsciiLitStr, Token![,]>>, + imports_ns: Option<Punctuated<AsciiLitStr, Token![,]>>, + params: Option<Punctuated<Parameter, Token![,]>>, +} + +impl Parse for ModuleInfo { + fn parse(input: ParseStream<'_>) -> Result<Self> { + parse_ordered_fields!( + from input; + type as type_ [required] => input.parse()?, + name [required] => input.parse()?, + authors => { + let list; + bracketed!(list in input); + Punctuated::parse_terminated(&list)? + }, + description => input.parse()?, + license [required] => input.parse()?, + alias => { + let list; + bracketed!(list in input); + Punctuated::parse_terminated(&list)? + }, + firmware => { + let list; + bracketed!(list in input); + Punctuated::parse_terminated(&list)? + }, + imports_ns => { + let list; + bracketed!(list in input); + Punctuated::parse_terminated(&list)? + }, + params => { + let list; + braced!(list in input); + Punctuated::parse_terminated(&list)? + }, + ); + + Ok(ModuleInfo { + type_, + license, + name, + authors, + description, + alias, + firmware, + imports_ns, + params, + }) + } +} - let info = ModuleInfo::parse(&mut it); +pub(crate) fn module(info: ModuleInfo) -> Result<TokenStream> { + let ModuleInfo { + type_, + license, + name, + authors, + description, + alias, + firmware, + imports_ns, + params: _, + } = &info; // Rust does not allow hyphens in identifiers, use underscore instead. - let ident = info.name.replace('-', "_"); + let ident = name.value().replace('-', "_"); let mut modinfo = ModInfoBuilder::new(ident.as_ref()); - if let Some(authors) = &info.authors { + if let Some(authors) = authors { for author in authors { - modinfo.emit("author", author); + modinfo.emit("author", &author.value()); } } - if let Some(description) = &info.description { - modinfo.emit("description", description); + if let Some(description) = description { + modinfo.emit("description", &description.value()); } - modinfo.emit("license", &info.license); - if let Some(aliases) = &info.alias { + modinfo.emit("license", &license.value()); + if let Some(aliases) = alias { for alias in aliases { - modinfo.emit("alias", alias); + modinfo.emit("alias", &alias.value()); } } - if let Some(firmware) = &info.firmware { + if let Some(firmware) = firmware { for fw in firmware { - modinfo.emit("firmware", fw); + modinfo.emit("firmware", &fw.value()); } } - if let Some(imports) = &info.imports_ns { + if let Some(imports) = imports_ns { for ns in imports { - modinfo.emit("import_ns", ns); + modinfo.emit("import_ns", &ns.value()); } } @@ -364,182 +478,181 @@ pub(crate) fn module(ts: TokenStream) -> TokenStream { modinfo.emit_params(&info); - format!( - " - /// The module name. - /// - /// Used by the printing macros, e.g. [`info!`]. - const __LOG_PREFIX: &[u8] = b\"{name}\\0\"; - - // SAFETY: `__this_module` is constructed by the kernel at load time and will not be - // freed until the module is unloaded. - #[cfg(MODULE)] - static THIS_MODULE: ::kernel::ThisModule = unsafe {{ - extern \"C\" {{ - static __this_module: ::kernel::types::Opaque<::kernel::bindings::module>; - }} - - ::kernel::ThisModule::from_ptr(__this_module.get()) - }}; - #[cfg(not(MODULE))] - static THIS_MODULE: ::kernel::ThisModule = unsafe {{ - ::kernel::ThisModule::from_ptr(::core::ptr::null_mut()) - }}; - - /// The `LocalModule` type is the type of the module created by `module!`, - /// `module_pci_driver!`, `module_platform_driver!`, etc. - type LocalModule = {type_}; - - impl ::kernel::ModuleMetadata for {type_} {{ - const NAME: &'static ::kernel::str::CStr = c\"{name}\"; - }} - - // Double nested modules, since then nobody can access the public items inside. - mod __module_init {{ - mod __module_init {{ - use super::super::{type_}; - use pin_init::PinInit; - - /// The \"Rust loadable module\" mark. - // - // This may be best done another way later on, e.g. as a new modinfo - // key or a new section. For the moment, keep it simple. - #[cfg(MODULE)] - #[doc(hidden)] - #[used(compiler)] - static __IS_RUST_MODULE: () = (); - - static mut __MOD: ::core::mem::MaybeUninit<{type_}> = - ::core::mem::MaybeUninit::uninit(); - - // Loadable modules need to export the `{{init,cleanup}}_module` identifiers. - /// # Safety - /// - /// This function must not be called after module initialization, because it may be - /// freed after that completes. - #[cfg(MODULE)] - #[doc(hidden)] - #[no_mangle] - #[link_section = \".init.text\"] - pub unsafe extern \"C\" fn init_module() -> ::kernel::ffi::c_int {{ - // SAFETY: This function is inaccessible to the outside due to the double - // module wrapping it. It is called exactly once by the C side via its - // unique name. - unsafe {{ __init() }} - }} - - #[cfg(MODULE)] - #[doc(hidden)] - #[used(compiler)] - #[link_section = \".init.data\"] - static __UNIQUE_ID___addressable_init_module: unsafe extern \"C\" fn() -> i32 = init_module; - - #[cfg(MODULE)] - #[doc(hidden)] - #[no_mangle] - #[link_section = \".exit.text\"] - pub extern \"C\" fn cleanup_module() {{ - // SAFETY: - // - This function is inaccessible to the outside due to the double - // module wrapping it. It is called exactly once by the C side via its - // unique name, - // - furthermore it is only called after `init_module` has returned `0` - // (which delegates to `__init`). - unsafe {{ __exit() }} - }} - - #[cfg(MODULE)] - #[doc(hidden)] - #[used(compiler)] - #[link_section = \".exit.data\"] - static __UNIQUE_ID___addressable_cleanup_module: extern \"C\" fn() = cleanup_module; - - // Built-in modules are initialized through an initcall pointer - // and the identifiers need to be unique. - #[cfg(not(MODULE))] - #[cfg(not(CONFIG_HAVE_ARCH_PREL32_RELOCATIONS))] - #[doc(hidden)] - #[link_section = \"{initcall_section}\"] - #[used(compiler)] - pub static __{ident}_initcall: extern \"C\" fn() -> - ::kernel::ffi::c_int = __{ident}_init; - - #[cfg(not(MODULE))] - #[cfg(CONFIG_HAVE_ARCH_PREL32_RELOCATIONS)] - ::core::arch::global_asm!( - r#\".section \"{initcall_section}\", \"a\" - __{ident}_initcall: - .long __{ident}_init - . - .previous - \"# + let modinfo_ts = modinfo.ts; + let params_ts = modinfo.param_ts; + + let ident_init = format_ident!("__{ident}_init"); + let ident_exit = format_ident!("__{ident}_exit"); + let ident_initcall = format_ident!("__{ident}_initcall"); + let initcall_section = ".initcall6.init"; + + let global_asm = format!( + r#".section "{initcall_section}", "a" + __{ident}_initcall: + .long __{ident}_init - . + .previous + "# + ); + + let name_cstr = CString::new(name.value()).expect("name contains NUL-terminator"); + + Ok(quote! { + /// The module name. + /// + /// Used by the printing macros, e.g. [`info!`]. + const __LOG_PREFIX: &[u8] = #name_cstr.to_bytes_with_nul(); + + // SAFETY: `__this_module` is constructed by the kernel at load time and will not be + // freed until the module is unloaded. + #[cfg(MODULE)] + static THIS_MODULE: ::kernel::ThisModule = unsafe { + extern "C" { + static __this_module: ::kernel::types::Opaque<::kernel::bindings::module>; + }; + + ::kernel::ThisModule::from_ptr(__this_module.get()) + }; + + #[cfg(not(MODULE))] + static THIS_MODULE: ::kernel::ThisModule = unsafe { + ::kernel::ThisModule::from_ptr(::core::ptr::null_mut()) + }; + + /// The `LocalModule` type is the type of the module created by `module!`, + /// `module_pci_driver!`, `module_platform_driver!`, etc. + type LocalModule = #type_; + + impl ::kernel::ModuleMetadata for #type_ { + const NAME: &'static ::kernel::str::CStr = #name_cstr; + } + + // Double nested modules, since then nobody can access the public items inside. + #[doc(hidden)] + mod __module_init { + mod __module_init { + use pin_init::PinInit; + + /// The "Rust loadable module" mark. + // + // This may be best done another way later on, e.g. as a new modinfo + // key or a new section. For the moment, keep it simple. + #[cfg(MODULE)] + #[used(compiler)] + static __IS_RUST_MODULE: () = (); + + static mut __MOD: ::core::mem::MaybeUninit<super::super::LocalModule> = + ::core::mem::MaybeUninit::uninit(); + + // Loadable modules need to export the `{init,cleanup}_module` identifiers. + /// # Safety + /// + /// This function must not be called after module initialization, because it may be + /// freed after that completes. + #[cfg(MODULE)] + #[no_mangle] + #[link_section = ".init.text"] + pub unsafe extern "C" fn init_module() -> ::kernel::ffi::c_int { + // SAFETY: This function is inaccessible to the outside due to the double + // module wrapping it. It is called exactly once by the C side via its + // unique name. + unsafe { __init() } + } + + #[cfg(MODULE)] + #[used(compiler)] + #[link_section = ".init.data"] + static __UNIQUE_ID___addressable_init_module: unsafe extern "C" fn() -> i32 = + init_module; + + #[cfg(MODULE)] + #[no_mangle] + #[link_section = ".exit.text"] + pub extern "C" fn cleanup_module() { + // SAFETY: + // - This function is inaccessible to the outside due to the double + // module wrapping it. It is called exactly once by the C side via its + // unique name, + // - furthermore it is only called after `init_module` has returned `0` + // (which delegates to `__init`). + unsafe { __exit() } + } + + #[cfg(MODULE)] + #[used(compiler)] + #[link_section = ".exit.data"] + static __UNIQUE_ID___addressable_cleanup_module: extern "C" fn() = cleanup_module; + + // Built-in modules are initialized through an initcall pointer + // and the identifiers need to be unique. + #[cfg(not(MODULE))] + #[cfg(not(CONFIG_HAVE_ARCH_PREL32_RELOCATIONS))] + #[link_section = #initcall_section] + #[used(compiler)] + pub static #ident_initcall: extern "C" fn() -> + ::kernel::ffi::c_int = #ident_init; + + #[cfg(not(MODULE))] + #[cfg(CONFIG_HAVE_ARCH_PREL32_RELOCATIONS)] + ::core::arch::global_asm!(#global_asm); + + #[cfg(not(MODULE))] + #[no_mangle] + pub extern "C" fn #ident_init() -> ::kernel::ffi::c_int { + // SAFETY: This function is inaccessible to the outside due to the double + // module wrapping it. It is called exactly once by the C side via its + // placement above in the initcall section. + unsafe { __init() } + } + + #[cfg(not(MODULE))] + #[no_mangle] + pub extern "C" fn #ident_exit() { + // SAFETY: + // - This function is inaccessible to the outside due to the double + // module wrapping it. It is called exactly once by the C side via its + // unique name, + // - furthermore it is only called after `#ident_init` has + // returned `0` (which delegates to `__init`). + unsafe { __exit() } + } + + /// # Safety + /// + /// This function must only be called once. + unsafe fn __init() -> ::kernel::ffi::c_int { + let initer = <super::super::LocalModule as ::kernel::InPlaceModule>::init( + &super::super::THIS_MODULE ); + // SAFETY: No data race, since `__MOD` can only be accessed by this module + // and there only `__init` and `__exit` access it. These functions are only + // called once and `__exit` cannot be called before or during `__init`. + match unsafe { initer.__pinned_init(__MOD.as_mut_ptr()) } { + Ok(m) => 0, + Err(e) => e.to_errno(), + } + } + + /// # Safety + /// + /// This function must + /// - only be called once, + /// - be called after `__init` has been called and returned `0`. + unsafe fn __exit() { + // SAFETY: No data race, since `__MOD` can only be accessed by this module + // and there only `__init` and `__exit` access it. These functions are only + // called once and `__init` was already called. + unsafe { + // Invokes `drop()` on `__MOD`, which should be used for cleanup. + __MOD.assume_init_drop(); + } + } + + #modinfo_ts + } + } - #[cfg(not(MODULE))] - #[doc(hidden)] - #[no_mangle] - pub extern \"C\" fn __{ident}_init() -> ::kernel::ffi::c_int {{ - // SAFETY: This function is inaccessible to the outside due to the double - // module wrapping it. It is called exactly once by the C side via its - // placement above in the initcall section. - unsafe {{ __init() }} - }} - - #[cfg(not(MODULE))] - #[doc(hidden)] - #[no_mangle] - pub extern \"C\" fn __{ident}_exit() {{ - // SAFETY: - // - This function is inaccessible to the outside due to the double - // module wrapping it. It is called exactly once by the C side via its - // unique name, - // - furthermore it is only called after `__{ident}_init` has - // returned `0` (which delegates to `__init`). - unsafe {{ __exit() }} - }} - - /// # Safety - /// - /// This function must only be called once. - unsafe fn __init() -> ::kernel::ffi::c_int {{ - let initer = - <{type_} as ::kernel::InPlaceModule>::init(&super::super::THIS_MODULE); - // SAFETY: No data race, since `__MOD` can only be accessed by this module - // and there only `__init` and `__exit` access it. These functions are only - // called once and `__exit` cannot be called before or during `__init`. - match unsafe {{ initer.__pinned_init(__MOD.as_mut_ptr()) }} {{ - Ok(m) => 0, - Err(e) => e.to_errno(), - }} - }} - - /// # Safety - /// - /// This function must - /// - only be called once, - /// - be called after `__init` has been called and returned `0`. - unsafe fn __exit() {{ - // SAFETY: No data race, since `__MOD` can only be accessed by this module - // and there only `__init` and `__exit` access it. These functions are only - // called once and `__init` was already called. - unsafe {{ - // Invokes `drop()` on `__MOD`, which should be used for cleanup. - __MOD.assume_init_drop(); - }} - }} - {modinfo} - }} - }} - mod module_parameters {{ - {params} - }} - ", - type_ = info.type_, - name = info.name, - ident = ident, - modinfo = modinfo.buffer, - params = modinfo.param_buffer, - initcall_section = ".initcall6.init" - ) - .parse() - .expect("Error parsing formatted string into token stream.") + mod module_parameters { + #params_ts + } + }) } diff --git a/rust/macros/paste.rs b/rust/macros/paste.rs index cce712d19855..2181e312a7d3 100644 --- a/rust/macros/paste.rs +++ b/rust/macros/paste.rs @@ -1,6 +1,6 @@ // SPDX-License-Identifier: GPL-2.0 -use proc_macro::{Delimiter, Group, Ident, Spacing, Span, TokenTree}; +use proc_macro2::{Delimiter, Group, Ident, Spacing, Span, TokenTree}; fn concat_helper(tokens: &[TokenTree]) -> Vec<(String, Span)> { let mut tokens = tokens.iter(); diff --git a/rust/macros/quote.rs b/rust/macros/quote.rs deleted file mode 100644 index ddfc21577539..000000000000 --- a/rust/macros/quote.rs +++ /dev/null @@ -1,182 +0,0 @@ -// SPDX-License-Identifier: Apache-2.0 OR MIT - -use proc_macro::{TokenStream, TokenTree}; - -pub(crate) trait ToTokens { - fn to_tokens(&self, tokens: &mut TokenStream); -} - -impl<T: ToTokens> ToTokens for Option<T> { - fn to_tokens(&self, tokens: &mut TokenStream) { - if let Some(v) = self { - v.to_tokens(tokens); - } - } -} - -impl ToTokens for proc_macro::Group { - fn to_tokens(&self, tokens: &mut TokenStream) { - tokens.extend([TokenTree::from(self.clone())]); - } -} - -impl ToTokens for proc_macro::Ident { - fn to_tokens(&self, tokens: &mut TokenStream) { - tokens.extend([TokenTree::from(self.clone())]); - } -} - -impl ToTokens for TokenTree { - fn to_tokens(&self, tokens: &mut TokenStream) { - tokens.extend([self.clone()]); - } -} - -impl ToTokens for TokenStream { - fn to_tokens(&self, tokens: &mut TokenStream) { - tokens.extend(self.clone()); - } -} - -/// Converts tokens into [`proc_macro::TokenStream`] and performs variable interpolations with -/// the given span. -/// -/// This is a similar to the -/// [`quote_spanned!`](https://docs.rs/quote/latest/quote/macro.quote_spanned.html) macro from the -/// `quote` crate but provides only just enough functionality needed by the current `macros` crate. -macro_rules! quote_spanned { - ($span:expr => $($tt:tt)*) => {{ - let mut tokens = ::proc_macro::TokenStream::new(); - { - #[allow(unused_variables)] - let span = $span; - quote_spanned!(@proc tokens span $($tt)*); - } - tokens - }}; - (@proc $v:ident $span:ident) => {}; - (@proc $v:ident $span:ident #$id:ident $($tt:tt)*) => { - $crate::quote::ToTokens::to_tokens(&$id, &mut $v); - quote_spanned!(@proc $v $span $($tt)*); - }; - (@proc $v:ident $span:ident #(#$id:ident)* $($tt:tt)*) => { - for token in $id { - $crate::quote::ToTokens::to_tokens(&token, &mut $v); - } - quote_spanned!(@proc $v $span $($tt)*); - }; - (@proc $v:ident $span:ident ( $($inner:tt)* ) $($tt:tt)*) => { - #[allow(unused_mut)] - let mut tokens = ::proc_macro::TokenStream::new(); - quote_spanned!(@proc tokens $span $($inner)*); - $v.extend([::proc_macro::TokenTree::Group(::proc_macro::Group::new( - ::proc_macro::Delimiter::Parenthesis, - tokens, - ))]); - quote_spanned!(@proc $v $span $($tt)*); - }; - (@proc $v:ident $span:ident [ $($inner:tt)* ] $($tt:tt)*) => { - let mut tokens = ::proc_macro::TokenStream::new(); - quote_spanned!(@proc tokens $span $($inner)*); - $v.extend([::proc_macro::TokenTree::Group(::proc_macro::Group::new( - ::proc_macro::Delimiter::Bracket, - tokens, - ))]); - quote_spanned!(@proc $v $span $($tt)*); - }; - (@proc $v:ident $span:ident { $($inner:tt)* } $($tt:tt)*) => { - let mut tokens = ::proc_macro::TokenStream::new(); - quote_spanned!(@proc tokens $span $($inner)*); - $v.extend([::proc_macro::TokenTree::Group(::proc_macro::Group::new( - ::proc_macro::Delimiter::Brace, - tokens, - ))]); - quote_spanned!(@proc $v $span $($tt)*); - }; - (@proc $v:ident $span:ident :: $($tt:tt)*) => { - $v.extend([::proc_macro::Spacing::Joint, ::proc_macro::Spacing::Alone].map(|spacing| { - ::proc_macro::TokenTree::Punct(::proc_macro::Punct::new(':', spacing)) - })); - quote_spanned!(@proc $v $span $($tt)*); - }; - (@proc $v:ident $span:ident : $($tt:tt)*) => { - $v.extend([::proc_macro::TokenTree::Punct( - ::proc_macro::Punct::new(':', ::proc_macro::Spacing::Alone), - )]); - quote_spanned!(@proc $v $span $($tt)*); - }; - (@proc $v:ident $span:ident , $($tt:tt)*) => { - $v.extend([::proc_macro::TokenTree::Punct( - ::proc_macro::Punct::new(',', ::proc_macro::Spacing::Alone), - )]); - quote_spanned!(@proc $v $span $($tt)*); - }; - (@proc $v:ident $span:ident @ $($tt:tt)*) => { - $v.extend([::proc_macro::TokenTree::Punct( - ::proc_macro::Punct::new('@', ::proc_macro::Spacing::Alone), - )]); - quote_spanned!(@proc $v $span $($tt)*); - }; - (@proc $v:ident $span:ident ! $($tt:tt)*) => { - $v.extend([::proc_macro::TokenTree::Punct( - ::proc_macro::Punct::new('!', ::proc_macro::Spacing::Alone), - )]); - quote_spanned!(@proc $v $span $($tt)*); - }; - (@proc $v:ident $span:ident ; $($tt:tt)*) => { - $v.extend([::proc_macro::TokenTree::Punct( - ::proc_macro::Punct::new(';', ::proc_macro::Spacing::Alone), - )]); - quote_spanned!(@proc $v $span $($tt)*); - }; - (@proc $v:ident $span:ident + $($tt:tt)*) => { - $v.extend([::proc_macro::TokenTree::Punct( - ::proc_macro::Punct::new('+', ::proc_macro::Spacing::Alone), - )]); - quote_spanned!(@proc $v $span $($tt)*); - }; - (@proc $v:ident $span:ident = $($tt:tt)*) => { - $v.extend([::proc_macro::TokenTree::Punct( - ::proc_macro::Punct::new('=', ::proc_macro::Spacing::Alone), - )]); - quote_spanned!(@proc $v $span $($tt)*); - }; - (@proc $v:ident $span:ident # $($tt:tt)*) => { - $v.extend([::proc_macro::TokenTree::Punct( - ::proc_macro::Punct::new('#', ::proc_macro::Spacing::Alone), - )]); - quote_spanned!(@proc $v $span $($tt)*); - }; - (@proc $v:ident $span:ident & $($tt:tt)*) => { - $v.extend([::proc_macro::TokenTree::Punct( - ::proc_macro::Punct::new('&', ::proc_macro::Spacing::Alone), - )]); - quote_spanned!(@proc $v $span $($tt)*); - }; - (@proc $v:ident $span:ident _ $($tt:tt)*) => { - $v.extend([::proc_macro::TokenTree::Ident( - ::proc_macro::Ident::new("_", $span), - )]); - quote_spanned!(@proc $v $span $($tt)*); - }; - (@proc $v:ident $span:ident $id:ident $($tt:tt)*) => { - $v.extend([::proc_macro::TokenTree::Ident( - ::proc_macro::Ident::new(stringify!($id), $span), - )]); - quote_spanned!(@proc $v $span $($tt)*); - }; -} - -/// Converts tokens into [`proc_macro::TokenStream`] and performs variable interpolations with -/// mixed site span ([`Span::mixed_site()`]). -/// -/// This is a similar to the [`quote!`](https://docs.rs/quote/latest/quote/macro.quote.html) macro -/// from the `quote` crate but provides only just enough functionality needed by the current -/// `macros` crate. -/// -/// [`Span::mixed_site()`]: https://doc.rust-lang.org/proc_macro/struct.Span.html#method.mixed_site -macro_rules! quote { - ($($tt:tt)*) => { - quote_spanned!(::proc_macro::Span::mixed_site() => $($tt)*) - } -} diff --git a/rust/macros/vtable.rs b/rust/macros/vtable.rs index ee06044fcd4f..c6510b0c4ea1 100644 --- a/rust/macros/vtable.rs +++ b/rust/macros/vtable.rs @@ -1,96 +1,105 @@ // SPDX-License-Identifier: GPL-2.0 -use proc_macro::{Delimiter, Group, TokenStream, TokenTree}; -use std::collections::HashSet; -use std::fmt::Write; +use std::{ + collections::HashSet, + iter::Extend, // +}; -pub(crate) fn vtable(_attr: TokenStream, ts: TokenStream) -> TokenStream { - let mut tokens: Vec<_> = ts.into_iter().collect(); +use proc_macro2::{ + Ident, + TokenStream, // +}; +use quote::ToTokens; +use syn::{ + parse_quote, + Error, + ImplItem, + Item, + ItemImpl, + ItemTrait, + Result, + TraitItem, // +}; - // Scan for the `trait` or `impl` keyword. - let is_trait = tokens - .iter() - .find_map(|token| match token { - TokenTree::Ident(ident) => match ident.to_string().as_str() { - "trait" => Some(true), - "impl" => Some(false), - _ => None, - }, - _ => None, - }) - .expect("#[vtable] attribute should only be applied to trait or impl block"); +fn handle_trait(mut item: ItemTrait) -> Result<ItemTrait> { + let mut gen_items = Vec::new(); - // Retrieve the main body. The main body should be the last token tree. - let body = match tokens.pop() { - Some(TokenTree::Group(group)) if group.delimiter() == Delimiter::Brace => group, - _ => panic!("cannot locate main body of trait or impl block"), - }; + gen_items.push(parse_quote! { + /// A marker to prevent implementors from forgetting to use [`#[vtable]`](vtable) + /// attribute when implementing this trait. + const USE_VTABLE_ATTR: (); + }); - let mut body_it = body.stream().into_iter(); - let mut functions = Vec::new(); - let mut consts = HashSet::new(); - while let Some(token) = body_it.next() { - match token { - TokenTree::Ident(ident) if ident.to_string() == "fn" => { - let fn_name = match body_it.next() { - Some(TokenTree::Ident(ident)) => ident.to_string(), - // Possibly we've encountered a fn pointer type instead. - _ => continue, - }; - functions.push(fn_name); - } - TokenTree::Ident(ident) if ident.to_string() == "const" => { - let const_name = match body_it.next() { - Some(TokenTree::Ident(ident)) => ident.to_string(), - // Possibly we've encountered an inline const block instead. - _ => continue, - }; - consts.insert(const_name); - } - _ => (), + for item in &item.items { + if let TraitItem::Fn(fn_item) = item { + let name = &fn_item.sig.ident; + let gen_const_name = Ident::new( + &format!("HAS_{}", name.to_string().to_uppercase()), + name.span(), + ); + + // We don't know on the implementation-site whether a method is required or provided + // so we have to generate a const for all methods. + let cfg_attrs = crate::helpers::gather_cfg_attrs(&fn_item.attrs); + let comment = + format!("Indicates if the `{name}` method is overridden by the implementor."); + gen_items.push(parse_quote! { + #(#cfg_attrs)* + #[doc = #comment] + const #gen_const_name: bool = false; + }); } } - let mut const_items; - if is_trait { - const_items = " - /// A marker to prevent implementors from forgetting to use [`#[vtable]`](vtable) - /// attribute when implementing this trait. - const USE_VTABLE_ATTR: (); - " - .to_owned(); + item.items.extend(gen_items); + Ok(item) +} - for f in functions { - let gen_const_name = format!("HAS_{}", f.to_uppercase()); - // Skip if it's declared already -- this allows user override. - if consts.contains(&gen_const_name) { - continue; - } - // We don't know on the implementation-site whether a method is required or provided - // so we have to generate a const for all methods. - write!( - const_items, - "/// Indicates if the `{f}` method is overridden by the implementor. - const {gen_const_name}: bool = false;", - ) - .unwrap(); - consts.insert(gen_const_name); +fn handle_impl(mut item: ItemImpl) -> Result<ItemImpl> { + let mut gen_items = Vec::new(); + let mut defined_consts = HashSet::new(); + + // Iterate over all user-defined constants to gather any possible explicit overrides. + for item in &item.items { + if let ImplItem::Const(const_item) = item { + defined_consts.insert(const_item.ident.clone()); } - } else { - const_items = "const USE_VTABLE_ATTR: () = ();".to_owned(); + } + + gen_items.push(parse_quote! { + const USE_VTABLE_ATTR: () = (); + }); - for f in functions { - let gen_const_name = format!("HAS_{}", f.to_uppercase()); - if consts.contains(&gen_const_name) { + for item in &item.items { + if let ImplItem::Fn(fn_item) = item { + let name = &fn_item.sig.ident; + let gen_const_name = Ident::new( + &format!("HAS_{}", name.to_string().to_uppercase()), + name.span(), + ); + // Skip if it's declared already -- this allows user override. + if defined_consts.contains(&gen_const_name) { continue; } - write!(const_items, "const {gen_const_name}: bool = true;").unwrap(); + let cfg_attrs = crate::helpers::gather_cfg_attrs(&fn_item.attrs); + gen_items.push(parse_quote! { + #(#cfg_attrs)* + const #gen_const_name: bool = true; + }); } } - let new_body = vec![const_items.parse().unwrap(), body.stream()] - .into_iter() - .collect(); - tokens.push(TokenTree::Group(Group::new(Delimiter::Brace, new_body))); - tokens.into_iter().collect() + item.items.extend(gen_items); + Ok(item) +} + +pub(crate) fn vtable(input: Item) -> Result<TokenStream> { + match input { + Item::Trait(item) => Ok(handle_trait(item)?.into_token_stream()), + Item::Impl(item) => Ok(handle_impl(item)?.into_token_stream()), + _ => Err(Error::new_spanned( + input, + "`#[vtable]` attribute should only be applied to trait or impl block", + ))?, + } } diff --git a/rust/pin-init/README.md b/rust/pin-init/README.md index 74bbb4e0a2f7..6cee6ab1eb57 100644 --- a/rust/pin-init/README.md +++ b/rust/pin-init/README.md @@ -135,7 +135,7 @@ struct DriverData { impl DriverData { fn new() -> impl PinInit<Self, Error> { - try_pin_init!(Self { + pin_init!(Self { status <- CMutex::new(0), buffer: Box::init(pin_init::init_zeroed())?, }? Error) diff --git a/rust/pin-init/examples/linked_list.rs b/rust/pin-init/examples/linked_list.rs index f9e117c7dfe0..8445a5890cb7 100644 --- a/rust/pin-init/examples/linked_list.rs +++ b/rust/pin-init/examples/linked_list.rs @@ -6,7 +6,6 @@ use core::{ cell::Cell, - convert::Infallible, marker::PhantomPinned, pin::Pin, ptr::{self, NonNull}, @@ -31,31 +30,31 @@ pub struct ListHead { impl ListHead { #[inline] - pub fn new() -> impl PinInit<Self, Infallible> { - try_pin_init!(&this in Self { + pub fn new() -> impl PinInit<Self> { + pin_init!(&this in Self { next: unsafe { Link::new_unchecked(this) }, prev: unsafe { Link::new_unchecked(this) }, pin: PhantomPinned, - }? Infallible) + }) } #[inline] #[allow(dead_code)] - pub fn insert_next(list: &ListHead) -> impl PinInit<Self, Infallible> + '_ { - try_pin_init!(&this in Self { + pub fn insert_next(list: &ListHead) -> impl PinInit<Self> + '_ { + pin_init!(&this in Self { prev: list.next.prev().replace(unsafe { Link::new_unchecked(this)}), next: list.next.replace(unsafe { Link::new_unchecked(this)}), pin: PhantomPinned, - }? Infallible) + }) } #[inline] - pub fn insert_prev(list: &ListHead) -> impl PinInit<Self, Infallible> + '_ { - try_pin_init!(&this in Self { + pub fn insert_prev(list: &ListHead) -> impl PinInit<Self> + '_ { + pin_init!(&this in Self { next: list.prev.next().replace(unsafe { Link::new_unchecked(this)}), prev: list.prev.replace(unsafe { Link::new_unchecked(this)}), pin: PhantomPinned, - }? Infallible) + }) } #[inline] diff --git a/rust/pin-init/examples/pthread_mutex.rs b/rust/pin-init/examples/pthread_mutex.rs index 49b004c8c137..4e082ec7d5de 100644 --- a/rust/pin-init/examples/pthread_mutex.rs +++ b/rust/pin-init/examples/pthread_mutex.rs @@ -98,11 +98,11 @@ mod pthread_mtx { // SAFETY: mutex has been initialized unsafe { pin_init_from_closure(init) } } - try_pin_init!(Self { - data: UnsafeCell::new(data), - raw <- init_raw(), - pin: PhantomPinned, - }? Error) + pin_init!(Self { + data: UnsafeCell::new(data), + raw <- init_raw(), + pin: PhantomPinned, + }? Error) } #[allow(dead_code)] diff --git a/rust/pin-init/internal/src/diagnostics.rs b/rust/pin-init/internal/src/diagnostics.rs new file mode 100644 index 000000000000..3bdb477c2f2b --- /dev/null +++ b/rust/pin-init/internal/src/diagnostics.rs @@ -0,0 +1,30 @@ +// SPDX-License-Identifier: Apache-2.0 OR MIT + +use std::fmt::Display; + +use proc_macro2::TokenStream; +use syn::{spanned::Spanned, Error}; + +pub(crate) struct DiagCtxt(TokenStream); +pub(crate) struct ErrorGuaranteed(()); + +impl DiagCtxt { + pub(crate) fn error(&mut self, span: impl Spanned, msg: impl Display) -> ErrorGuaranteed { + let error = Error::new(span.span(), msg); + self.0.extend(error.into_compile_error()); + ErrorGuaranteed(()) + } + + pub(crate) fn with( + fun: impl FnOnce(&mut DiagCtxt) -> Result<TokenStream, ErrorGuaranteed>, + ) -> TokenStream { + let mut dcx = Self(TokenStream::new()); + match fun(&mut dcx) { + Ok(mut stream) => { + stream.extend(dcx.0); + stream + } + Err(ErrorGuaranteed(())) => dcx.0, + } + } +} diff --git a/rust/pin-init/internal/src/helpers.rs b/rust/pin-init/internal/src/helpers.rs deleted file mode 100644 index 236f989a50f2..000000000000 --- a/rust/pin-init/internal/src/helpers.rs +++ /dev/null @@ -1,152 +0,0 @@ -// SPDX-License-Identifier: Apache-2.0 OR MIT - -#[cfg(not(kernel))] -use proc_macro2 as proc_macro; - -use proc_macro::{TokenStream, TokenTree}; - -/// Parsed generics. -/// -/// See the field documentation for an explanation what each of the fields represents. -/// -/// # Examples -/// -/// ```rust,ignore -/// # let input = todo!(); -/// let (Generics { decl_generics, impl_generics, ty_generics }, rest) = parse_generics(input); -/// quote! { -/// struct Foo<$($decl_generics)*> { -/// // ... -/// } -/// -/// impl<$impl_generics> Foo<$ty_generics> { -/// fn foo() { -/// // ... -/// } -/// } -/// } -/// ``` -pub(crate) struct Generics { - /// The generics with bounds and default values (e.g. `T: Clone, const N: usize = 0`). - /// - /// Use this on type definitions e.g. `struct Foo<$decl_generics> ...` (or `union`/`enum`). - pub(crate) decl_generics: Vec<TokenTree>, - /// The generics with bounds (e.g. `T: Clone, const N: usize`). - /// - /// Use this on `impl` blocks e.g. `impl<$impl_generics> Trait for ...`. - pub(crate) impl_generics: Vec<TokenTree>, - /// The generics without bounds and without default values (e.g. `T, N`). - /// - /// Use this when you use the type that is declared with these generics e.g. - /// `Foo<$ty_generics>`. - pub(crate) ty_generics: Vec<TokenTree>, -} - -/// Parses the given `TokenStream` into `Generics` and the rest. -/// -/// The generics are not present in the rest, but a where clause might remain. -pub(crate) fn parse_generics(input: TokenStream) -> (Generics, Vec<TokenTree>) { - // The generics with bounds and default values. - let mut decl_generics = vec![]; - // `impl_generics`, the declared generics with their bounds. - let mut impl_generics = vec![]; - // Only the names of the generics, without any bounds. - let mut ty_generics = vec![]; - // Tokens not related to the generics e.g. the `where` token and definition. - let mut rest = vec![]; - // The current level of `<`. - let mut nesting = 0; - let mut toks = input.into_iter(); - // If we are at the beginning of a generic parameter. - let mut at_start = true; - let mut skip_until_comma = false; - while let Some(tt) = toks.next() { - if nesting == 1 && matches!(&tt, TokenTree::Punct(p) if p.as_char() == '>') { - // Found the end of the generics. - break; - } else if nesting >= 1 { - decl_generics.push(tt.clone()); - } - match tt.clone() { - TokenTree::Punct(p) if p.as_char() == '<' => { - if nesting >= 1 && !skip_until_comma { - // This is inside of the generics and part of some bound. - impl_generics.push(tt); - } - nesting += 1; - } - TokenTree::Punct(p) if p.as_char() == '>' => { - // This is a parsing error, so we just end it here. - if nesting == 0 { - break; - } else { - nesting -= 1; - if nesting >= 1 && !skip_until_comma { - // We are still inside of the generics and part of some bound. - impl_generics.push(tt); - } - } - } - TokenTree::Punct(p) if skip_until_comma && p.as_char() == ',' => { - if nesting == 1 { - impl_generics.push(tt.clone()); - impl_generics.push(tt); - skip_until_comma = false; - } - } - _ if !skip_until_comma => { - match nesting { - // If we haven't entered the generics yet, we still want to keep these tokens. - 0 => rest.push(tt), - 1 => { - // Here depending on the token, it might be a generic variable name. - match tt.clone() { - TokenTree::Ident(i) if at_start && i.to_string() == "const" => { - let Some(name) = toks.next() else { - // Parsing error. - break; - }; - impl_generics.push(tt); - impl_generics.push(name.clone()); - ty_generics.push(name.clone()); - decl_generics.push(name); - at_start = false; - } - TokenTree::Ident(_) if at_start => { - impl_generics.push(tt.clone()); - ty_generics.push(tt); - at_start = false; - } - TokenTree::Punct(p) if p.as_char() == ',' => { - impl_generics.push(tt.clone()); - ty_generics.push(tt); - at_start = true; - } - // Lifetimes begin with `'`. - TokenTree::Punct(p) if p.as_char() == '\'' && at_start => { - impl_generics.push(tt.clone()); - ty_generics.push(tt); - } - // Generics can have default values, we skip these. - TokenTree::Punct(p) if p.as_char() == '=' => { - skip_until_comma = true; - } - _ => impl_generics.push(tt), - } - } - _ => impl_generics.push(tt), - } - } - _ => {} - } - } - rest.extend(toks); - ( - Generics { - impl_generics, - decl_generics, - ty_generics, - }, - rest, - ) -} diff --git a/rust/pin-init/internal/src/init.rs b/rust/pin-init/internal/src/init.rs new file mode 100644 index 000000000000..42936f915a07 --- /dev/null +++ b/rust/pin-init/internal/src/init.rs @@ -0,0 +1,548 @@ +// SPDX-License-Identifier: Apache-2.0 OR MIT + +use proc_macro2::{Span, TokenStream}; +use quote::{format_ident, quote, quote_spanned}; +use syn::{ + braced, + parse::{End, Parse}, + parse_quote, + punctuated::Punctuated, + spanned::Spanned, + token, Attribute, Block, Expr, ExprCall, ExprPath, Ident, Path, Token, Type, +}; + +use crate::diagnostics::{DiagCtxt, ErrorGuaranteed}; + +pub(crate) struct Initializer { + attrs: Vec<InitializerAttribute>, + this: Option<This>, + path: Path, + brace_token: token::Brace, + fields: Punctuated<InitializerField, Token![,]>, + rest: Option<(Token![..], Expr)>, + error: Option<(Token![?], Type)>, +} + +struct This { + _and_token: Token![&], + ident: Ident, + _in_token: Token![in], +} + +struct InitializerField { + attrs: Vec<Attribute>, + kind: InitializerKind, +} + +enum InitializerKind { + Value { + ident: Ident, + value: Option<(Token![:], Expr)>, + }, + Init { + ident: Ident, + _left_arrow_token: Token![<-], + value: Expr, + }, + Code { + _underscore_token: Token![_], + _colon_token: Token![:], + block: Block, + }, +} + +impl InitializerKind { + fn ident(&self) -> Option<&Ident> { + match self { + Self::Value { ident, .. } | Self::Init { ident, .. } => Some(ident), + Self::Code { .. } => None, + } + } +} + +enum InitializerAttribute { + DefaultError(DefaultErrorAttribute), + DisableInitializedFieldAccess, +} + +struct DefaultErrorAttribute { + ty: Box<Type>, +} + +pub(crate) fn expand( + Initializer { + attrs, + this, + path, + brace_token, + fields, + rest, + error, + }: Initializer, + default_error: Option<&'static str>, + pinned: bool, + dcx: &mut DiagCtxt, +) -> Result<TokenStream, ErrorGuaranteed> { + let error = error.map_or_else( + || { + if let Some(default_error) = attrs.iter().fold(None, |acc, attr| { + if let InitializerAttribute::DefaultError(DefaultErrorAttribute { ty }) = attr { + Some(ty.clone()) + } else { + acc + } + }) { + default_error + } else if let Some(default_error) = default_error { + syn::parse_str(default_error).unwrap() + } else { + dcx.error(brace_token.span.close(), "expected `? <type>` after `}`"); + parse_quote!(::core::convert::Infallible) + } + }, + |(_, err)| Box::new(err), + ); + let slot = format_ident!("slot"); + let (has_data_trait, data_trait, get_data, init_from_closure) = if pinned { + ( + format_ident!("HasPinData"), + format_ident!("PinData"), + format_ident!("__pin_data"), + format_ident!("pin_init_from_closure"), + ) + } else { + ( + format_ident!("HasInitData"), + format_ident!("InitData"), + format_ident!("__init_data"), + format_ident!("init_from_closure"), + ) + }; + let init_kind = get_init_kind(rest, dcx); + let zeroable_check = match init_kind { + InitKind::Normal => quote!(), + InitKind::Zeroing => quote! { + // The user specified `..Zeroable::zeroed()` at the end of the list of fields. + // Therefore we check if the struct implements `Zeroable` and then zero the memory. + // This allows us to also remove the check that all fields are present (since we + // already set the memory to zero and that is a valid bit pattern). + fn assert_zeroable<T: ?::core::marker::Sized>(_: *mut T) + where T: ::pin_init::Zeroable + {} + // Ensure that the struct is indeed `Zeroable`. + assert_zeroable(#slot); + // SAFETY: The type implements `Zeroable` by the check above. + unsafe { ::core::ptr::write_bytes(#slot, 0, 1) }; + }, + }; + let this = match this { + None => quote!(), + Some(This { ident, .. }) => quote! { + // Create the `this` so it can be referenced by the user inside of the + // expressions creating the individual fields. + let #ident = unsafe { ::core::ptr::NonNull::new_unchecked(slot) }; + }, + }; + // `mixed_site` ensures that the data is not accessible to the user-controlled code. + let data = Ident::new("__data", Span::mixed_site()); + let init_fields = init_fields( + &fields, + pinned, + !attrs + .iter() + .any(|attr| matches!(attr, InitializerAttribute::DisableInitializedFieldAccess)), + &data, + &slot, + ); + let field_check = make_field_check(&fields, init_kind, &path); + Ok(quote! {{ + // We do not want to allow arbitrary returns, so we declare this type as the `Ok` return + // type and shadow it later when we insert the arbitrary user code. That way there will be + // no possibility of returning without `unsafe`. + struct __InitOk; + + // Get the data about fields from the supplied type. + // SAFETY: TODO + let #data = unsafe { + use ::pin_init::__internal::#has_data_trait; + // Can't use `<#path as #has_data_trait>::#get_data`, since the user is able to omit + // generics (which need to be present with that syntax). + #path::#get_data() + }; + // Ensure that `#data` really is of type `#data` and help with type inference: + let init = ::pin_init::__internal::#data_trait::make_closure::<_, __InitOk, #error>( + #data, + move |slot| { + { + // Shadow the structure so it cannot be used to return early. + struct __InitOk; + #zeroable_check + #this + #init_fields + #field_check + } + Ok(__InitOk) + } + ); + let init = move |slot| -> ::core::result::Result<(), #error> { + init(slot).map(|__InitOk| ()) + }; + // SAFETY: TODO + let init = unsafe { ::pin_init::#init_from_closure::<_, #error>(init) }; + init + }}) +} + +enum InitKind { + Normal, + Zeroing, +} + +fn get_init_kind(rest: Option<(Token![..], Expr)>, dcx: &mut DiagCtxt) -> InitKind { + let Some((dotdot, expr)) = rest else { + return InitKind::Normal; + }; + match &expr { + Expr::Call(ExprCall { func, args, .. }) if args.is_empty() => match &**func { + Expr::Path(ExprPath { + attrs, + qself: None, + path: + Path { + leading_colon: None, + segments, + }, + }) if attrs.is_empty() + && segments.len() == 2 + && segments[0].ident == "Zeroable" + && segments[0].arguments.is_none() + && segments[1].ident == "init_zeroed" + && segments[1].arguments.is_none() => + { + return InitKind::Zeroing; + } + _ => {} + }, + _ => {} + } + dcx.error( + dotdot.span().join(expr.span()).unwrap_or(expr.span()), + "expected nothing or `..Zeroable::init_zeroed()`.", + ); + InitKind::Normal +} + +/// Generate the code that initializes the fields of the struct using the initializers in `field`. +fn init_fields( + fields: &Punctuated<InitializerField, Token![,]>, + pinned: bool, + generate_initialized_accessors: bool, + data: &Ident, + slot: &Ident, +) -> TokenStream { + let mut guards = vec![]; + let mut guard_attrs = vec![]; + let mut res = TokenStream::new(); + for InitializerField { attrs, kind } in fields { + let cfgs = { + let mut cfgs = attrs.clone(); + cfgs.retain(|attr| attr.path().is_ident("cfg")); + cfgs + }; + let init = match kind { + InitializerKind::Value { ident, value } => { + let mut value_ident = ident.clone(); + let value_prep = value.as_ref().map(|value| &value.1).map(|value| { + // Setting the span of `value_ident` to `value`'s span improves error messages + // when the type of `value` is wrong. + value_ident.set_span(value.span()); + quote!(let #value_ident = #value;) + }); + // Again span for better diagnostics + let write = quote_spanned!(ident.span()=> ::core::ptr::write); + let accessor = if pinned { + let project_ident = format_ident!("__project_{ident}"); + quote! { + // SAFETY: TODO + unsafe { #data.#project_ident(&mut (*#slot).#ident) } + } + } else { + quote! { + // SAFETY: TODO + unsafe { &mut (*#slot).#ident } + } + }; + let accessor = generate_initialized_accessors.then(|| { + quote! { + #(#cfgs)* + #[allow(unused_variables)] + let #ident = #accessor; + } + }); + quote! { + #(#attrs)* + { + #value_prep + // SAFETY: TODO + unsafe { #write(::core::ptr::addr_of_mut!((*#slot).#ident), #value_ident) }; + } + #accessor + } + } + InitializerKind::Init { ident, value, .. } => { + // Again span for better diagnostics + let init = format_ident!("init", span = value.span()); + let (value_init, accessor) = if pinned { + let project_ident = format_ident!("__project_{ident}"); + ( + quote! { + // SAFETY: + // - `slot` is valid, because we are inside of an initializer closure, we + // return when an error/panic occurs. + // - We also use `#data` to require the correct trait (`Init` or `PinInit`) + // for `#ident`. + unsafe { #data.#ident(::core::ptr::addr_of_mut!((*#slot).#ident), #init)? }; + }, + quote! { + // SAFETY: TODO + unsafe { #data.#project_ident(&mut (*#slot).#ident) } + }, + ) + } else { + ( + quote! { + // SAFETY: `slot` is valid, because we are inside of an initializer + // closure, we return when an error/panic occurs. + unsafe { + ::pin_init::Init::__init( + #init, + ::core::ptr::addr_of_mut!((*#slot).#ident), + )? + }; + }, + quote! { + // SAFETY: TODO + unsafe { &mut (*#slot).#ident } + }, + ) + }; + let accessor = generate_initialized_accessors.then(|| { + quote! { + #(#cfgs)* + #[allow(unused_variables)] + let #ident = #accessor; + } + }); + quote! { + #(#attrs)* + { + let #init = #value; + #value_init + } + #accessor + } + } + InitializerKind::Code { block: value, .. } => quote! { + #(#attrs)* + #[allow(unused_braces)] + #value + }, + }; + res.extend(init); + if let Some(ident) = kind.ident() { + // `mixed_site` ensures that the guard is not accessible to the user-controlled code. + let guard = format_ident!("__{ident}_guard", span = Span::mixed_site()); + res.extend(quote! { + #(#cfgs)* + // Create the drop guard: + // + // We rely on macro hygiene to make it impossible for users to access this local + // variable. + // SAFETY: We forget the guard later when initialization has succeeded. + let #guard = unsafe { + ::pin_init::__internal::DropGuard::new( + ::core::ptr::addr_of_mut!((*slot).#ident) + ) + }; + }); + guards.push(guard); + guard_attrs.push(cfgs); + } + } + quote! { + #res + // If execution reaches this point, all fields have been initialized. Therefore we can now + // dismiss the guards by forgetting them. + #( + #(#guard_attrs)* + ::core::mem::forget(#guards); + )* + } +} + +/// Generate the check for ensuring that every field has been initialized. +fn make_field_check( + fields: &Punctuated<InitializerField, Token![,]>, + init_kind: InitKind, + path: &Path, +) -> TokenStream { + let field_attrs = fields + .iter() + .filter_map(|f| f.kind.ident().map(|_| &f.attrs)); + let field_name = fields.iter().filter_map(|f| f.kind.ident()); + match init_kind { + InitKind::Normal => quote! { + // We use unreachable code to ensure that all fields have been mentioned exactly once, + // this struct initializer will still be type-checked and complain with a very natural + // error message if a field is forgotten/mentioned more than once. + #[allow(unreachable_code, clippy::diverging_sub_expression)] + // SAFETY: this code is never executed. + let _ = || unsafe { + ::core::ptr::write(slot, #path { + #( + #(#field_attrs)* + #field_name: ::core::panic!(), + )* + }) + }; + }, + InitKind::Zeroing => quote! { + // We use unreachable code to ensure that all fields have been mentioned at most once. + // Since the user specified `..Zeroable::zeroed()` at the end, all missing fields will + // be zeroed. This struct initializer will still be type-checked and complain with a + // very natural error message if a field is mentioned more than once, or doesn't exist. + #[allow(unreachable_code, clippy::diverging_sub_expression, unused_assignments)] + // SAFETY: this code is never executed. + let _ = || unsafe { + ::core::ptr::write(slot, #path { + #( + #(#field_attrs)* + #field_name: ::core::panic!(), + )* + ..::core::mem::zeroed() + }) + }; + }, + } +} + +impl Parse for Initializer { + fn parse(input: syn::parse::ParseStream<'_>) -> syn::Result<Self> { + let attrs = input.call(Attribute::parse_outer)?; + let this = input.peek(Token![&]).then(|| input.parse()).transpose()?; + let path = input.parse()?; + let content; + let brace_token = braced!(content in input); + let mut fields = Punctuated::new(); + loop { + let lh = content.lookahead1(); + if lh.peek(End) || lh.peek(Token![..]) { + break; + } else if lh.peek(Ident) || lh.peek(Token![_]) || lh.peek(Token![#]) { + fields.push_value(content.parse()?); + let lh = content.lookahead1(); + if lh.peek(End) { + break; + } else if lh.peek(Token![,]) { + fields.push_punct(content.parse()?); + } else { + return Err(lh.error()); + } + } else { + return Err(lh.error()); + } + } + let rest = content + .peek(Token![..]) + .then(|| Ok::<_, syn::Error>((content.parse()?, content.parse()?))) + .transpose()?; + let error = input + .peek(Token![?]) + .then(|| Ok::<_, syn::Error>((input.parse()?, input.parse()?))) + .transpose()?; + let attrs = attrs + .into_iter() + .map(|a| { + if a.path().is_ident("default_error") { + a.parse_args::<DefaultErrorAttribute>() + .map(InitializerAttribute::DefaultError) + } else if a.path().is_ident("disable_initialized_field_access") { + a.meta + .require_path_only() + .map(|_| InitializerAttribute::DisableInitializedFieldAccess) + } else { + Err(syn::Error::new_spanned(a, "unknown initializer attribute")) + } + }) + .collect::<Result<Vec<_>, _>>()?; + Ok(Self { + attrs, + this, + path, + brace_token, + fields, + rest, + error, + }) + } +} + +impl Parse for DefaultErrorAttribute { + fn parse(input: syn::parse::ParseStream<'_>) -> syn::Result<Self> { + Ok(Self { ty: input.parse()? }) + } +} + +impl Parse for This { + fn parse(input: syn::parse::ParseStream<'_>) -> syn::Result<Self> { + Ok(Self { + _and_token: input.parse()?, + ident: input.parse()?, + _in_token: input.parse()?, + }) + } +} + +impl Parse for InitializerField { + fn parse(input: syn::parse::ParseStream<'_>) -> syn::Result<Self> { + let attrs = input.call(Attribute::parse_outer)?; + Ok(Self { + attrs, + kind: input.parse()?, + }) + } +} + +impl Parse for InitializerKind { + fn parse(input: syn::parse::ParseStream<'_>) -> syn::Result<Self> { + let lh = input.lookahead1(); + if lh.peek(Token![_]) { + Ok(Self::Code { + _underscore_token: input.parse()?, + _colon_token: input.parse()?, + block: input.parse()?, + }) + } else if lh.peek(Ident) { + let ident = input.parse()?; + let lh = input.lookahead1(); + if lh.peek(Token![<-]) { + Ok(Self::Init { + ident, + _left_arrow_token: input.parse()?, + value: input.parse()?, + }) + } else if lh.peek(Token![:]) { + Ok(Self::Value { + ident, + value: Some((input.parse()?, input.parse()?)), + }) + } else if lh.peek(Token![,]) || lh.peek(End) { + Ok(Self::Value { ident, value: None }) + } else { + Err(lh.error()) + } + } else { + Err(lh.error()) + } + } +} diff --git a/rust/pin-init/internal/src/lib.rs b/rust/pin-init/internal/src/lib.rs index 297b0129a5bf..08372c8f65f0 100644 --- a/rust/pin-init/internal/src/lib.rs +++ b/rust/pin-init/internal/src/lib.rs @@ -7,48 +7,54 @@ //! `pin-init` proc macros. #![cfg_attr(not(RUSTC_LINT_REASONS_IS_STABLE), feature(lint_reasons))] -// Allow `.into()` to convert -// - `proc_macro2::TokenStream` into `proc_macro::TokenStream` in the user-space version. -// - `proc_macro::TokenStream` into `proc_macro::TokenStream` in the kernel version. -// Clippy warns on this conversion, but it's required by the user-space version. -// -// Remove once we have `proc_macro2` in the kernel. -#![allow(clippy::useless_conversion)] // Documentation is done in the pin-init crate instead. #![allow(missing_docs)] use proc_macro::TokenStream; +use syn::parse_macro_input; -#[cfg(kernel)] -#[path = "../../../macros/quote.rs"] -#[macro_use] -#[cfg_attr(not(kernel), rustfmt::skip)] -mod quote; -#[cfg(not(kernel))] -#[macro_use] -extern crate quote; +use crate::diagnostics::DiagCtxt; -mod helpers; +mod diagnostics; +mod init; mod pin_data; mod pinned_drop; mod zeroable; #[proc_macro_attribute] -pub fn pin_data(inner: TokenStream, item: TokenStream) -> TokenStream { - pin_data::pin_data(inner.into(), item.into()).into() +pub fn pin_data(args: TokenStream, input: TokenStream) -> TokenStream { + let args = parse_macro_input!(args); + let input = parse_macro_input!(input); + DiagCtxt::with(|dcx| pin_data::pin_data(args, input, dcx)).into() } #[proc_macro_attribute] pub fn pinned_drop(args: TokenStream, input: TokenStream) -> TokenStream { - pinned_drop::pinned_drop(args.into(), input.into()).into() + let args = parse_macro_input!(args); + let input = parse_macro_input!(input); + DiagCtxt::with(|dcx| pinned_drop::pinned_drop(args, input, dcx)).into() } #[proc_macro_derive(Zeroable)] pub fn derive_zeroable(input: TokenStream) -> TokenStream { - zeroable::derive(input.into()).into() + let input = parse_macro_input!(input); + DiagCtxt::with(|dcx| zeroable::derive(input, dcx)).into() } #[proc_macro_derive(MaybeZeroable)] pub fn maybe_derive_zeroable(input: TokenStream) -> TokenStream { - zeroable::maybe_derive(input.into()).into() + let input = parse_macro_input!(input); + DiagCtxt::with(|dcx| zeroable::maybe_derive(input, dcx)).into() +} +#[proc_macro] +pub fn init(input: TokenStream) -> TokenStream { + let input = parse_macro_input!(input); + DiagCtxt::with(|dcx| init::expand(input, Some("::core::convert::Infallible"), false, dcx)) + .into() +} + +#[proc_macro] +pub fn pin_init(input: TokenStream) -> TokenStream { + let input = parse_macro_input!(input); + DiagCtxt::with(|dcx| init::expand(input, Some("::core::convert::Infallible"), true, dcx)).into() } diff --git a/rust/pin-init/internal/src/pin_data.rs b/rust/pin-init/internal/src/pin_data.rs index 87d4a7eb1d35..7d871236b49c 100644 --- a/rust/pin-init/internal/src/pin_data.rs +++ b/rust/pin-init/internal/src/pin_data.rs @@ -1,132 +1,513 @@ // SPDX-License-Identifier: Apache-2.0 OR MIT -#[cfg(not(kernel))] -use proc_macro2 as proc_macro; +use proc_macro2::TokenStream; +use quote::{format_ident, quote}; +use syn::{ + parse::{End, Nothing, Parse}, + parse_quote, parse_quote_spanned, + spanned::Spanned, + visit_mut::VisitMut, + Field, Generics, Ident, Item, PathSegment, Type, TypePath, Visibility, WhereClause, +}; -use crate::helpers::{parse_generics, Generics}; -use proc_macro::{Group, Punct, Spacing, TokenStream, TokenTree}; +use crate::diagnostics::{DiagCtxt, ErrorGuaranteed}; -pub(crate) fn pin_data(args: TokenStream, input: TokenStream) -> TokenStream { - // This proc-macro only does some pre-parsing and then delegates the actual parsing to - // `pin_init::__pin_data!`. +pub(crate) mod kw { + syn::custom_keyword!(PinnedDrop); +} + +pub(crate) enum Args { + Nothing(Nothing), + #[allow(dead_code)] + PinnedDrop(kw::PinnedDrop), +} + +impl Parse for Args { + fn parse(input: syn::parse::ParseStream<'_>) -> syn::Result<Self> { + let lh = input.lookahead1(); + if lh.peek(End) { + input.parse().map(Self::Nothing) + } else if lh.peek(kw::PinnedDrop) { + input.parse().map(Self::PinnedDrop) + } else { + Err(lh.error()) + } + } +} + +pub(crate) fn pin_data( + args: Args, + input: Item, + dcx: &mut DiagCtxt, +) -> Result<TokenStream, ErrorGuaranteed> { + let mut struct_ = match input { + Item::Struct(struct_) => struct_, + Item::Enum(enum_) => { + return Err(dcx.error( + enum_.enum_token, + "`#[pin_data]` only supports structs for now", + )); + } + Item::Union(union) => { + return Err(dcx.error( + union.union_token, + "`#[pin_data]` only supports structs for now", + )); + } + rest => { + return Err(dcx.error( + rest, + "`#[pin_data]` can only be applied to struct, enum and union definitions", + )); + } + }; + + // The generics might contain the `Self` type. Since this macro will define a new type with the + // same generics and bounds, this poses a problem: `Self` will refer to the new type as opposed + // to this struct definition. Therefore we have to replace `Self` with the concrete name. + let mut replacer = { + let name = &struct_.ident; + let (_, ty_generics, _) = struct_.generics.split_for_impl(); + SelfReplacer(parse_quote!(#name #ty_generics)) + }; + replacer.visit_generics_mut(&mut struct_.generics); + replacer.visit_fields_mut(&mut struct_.fields); + + let fields: Vec<(bool, &Field)> = struct_ + .fields + .iter_mut() + .map(|field| { + let len = field.attrs.len(); + field.attrs.retain(|a| !a.path().is_ident("pin")); + (len != field.attrs.len(), &*field) + }) + .collect(); + + for (pinned, field) in &fields { + if !pinned && is_phantom_pinned(&field.ty) { + dcx.error( + field, + format!( + "The field `{}` of type `PhantomPinned` only has an effect \ + if it has the `#[pin]` attribute", + field.ident.as_ref().unwrap(), + ), + ); + } + } + + let unpin_impl = generate_unpin_impl(&struct_.ident, &struct_.generics, &fields); + let drop_impl = generate_drop_impl(&struct_.ident, &struct_.generics, args); + let projections = + generate_projections(&struct_.vis, &struct_.ident, &struct_.generics, &fields); + let the_pin_data = + generate_the_pin_data(&struct_.vis, &struct_.ident, &struct_.generics, &fields); + + Ok(quote! { + #struct_ + #projections + // We put the rest into this const item, because it then will not be accessible to anything + // outside. + const _: () = { + #the_pin_data + #unpin_impl + #drop_impl + }; + }) +} + +fn is_phantom_pinned(ty: &Type) -> bool { + match ty { + Type::Path(TypePath { qself: None, path }) => { + // Cannot possibly refer to `PhantomPinned` (except alias, but that's on the user). + if path.segments.len() > 3 { + return false; + } + // If there is a `::`, then the path needs to be `::core::marker::PhantomPinned` or + // `::std::marker::PhantomPinned`. + if path.leading_colon.is_some() && path.segments.len() != 3 { + return false; + } + let expected: Vec<&[&str]> = vec![&["PhantomPinned"], &["marker"], &["core", "std"]]; + for (actual, expected) in path.segments.iter().rev().zip(expected) { + if !actual.arguments.is_empty() || expected.iter().all(|e| actual.ident != e) { + return false; + } + } + true + } + _ => false, + } +} +fn generate_unpin_impl( + ident: &Ident, + generics: &Generics, + fields: &[(bool, &Field)], +) -> TokenStream { + let (_, ty_generics, _) = generics.split_for_impl(); + let mut generics_with_pin_lt = generics.clone(); + generics_with_pin_lt.params.insert(0, parse_quote!('__pin)); + generics_with_pin_lt.make_where_clause(); let ( - Generics { - impl_generics, - decl_generics, - ty_generics, - }, - rest, - ) = parse_generics(input); - // The struct definition might contain the `Self` type. Since `__pin_data!` will define a new - // type with the same generics and bounds, this poses a problem, since `Self` will refer to the - // new type as opposed to this struct definition. Therefore we have to replace `Self` with the - // concrete name. - - // Errors that occur when replacing `Self` with `struct_name`. - let mut errs = TokenStream::new(); - // The name of the struct with ty_generics. - let struct_name = rest - .iter() - .skip_while(|tt| !matches!(tt, TokenTree::Ident(i) if i.to_string() == "struct")) - .nth(1) - .and_then(|tt| match tt { - TokenTree::Ident(_) => { - let tt = tt.clone(); - let mut res = vec![tt]; - if !ty_generics.is_empty() { - // We add this, so it is maximally compatible with e.g. `Self::CONST` which - // will be replaced by `StructName::<$generics>::CONST`. - res.push(TokenTree::Punct(Punct::new(':', Spacing::Joint))); - res.push(TokenTree::Punct(Punct::new(':', Spacing::Alone))); - res.push(TokenTree::Punct(Punct::new('<', Spacing::Alone))); - res.extend(ty_generics.iter().cloned()); - res.push(TokenTree::Punct(Punct::new('>', Spacing::Alone))); + impl_generics_with_pin_lt, + ty_generics_with_pin_lt, + Some(WhereClause { + where_token, + predicates, + }), + ) = generics_with_pin_lt.split_for_impl() + else { + unreachable!() + }; + let pinned_fields = fields.iter().filter_map(|(b, f)| b.then_some(f)); + quote! { + // This struct will be used for the unpin analysis. It is needed, because only structurally + // pinned fields are relevant whether the struct should implement `Unpin`. + #[allow(dead_code)] // The fields below are never used. + struct __Unpin #generics_with_pin_lt + #where_token + #predicates + { + __phantom_pin: ::core::marker::PhantomData<fn(&'__pin ()) -> &'__pin ()>, + __phantom: ::core::marker::PhantomData< + fn(#ident #ty_generics) -> #ident #ty_generics + >, + #(#pinned_fields),* + } + + #[doc(hidden)] + impl #impl_generics_with_pin_lt ::core::marker::Unpin for #ident #ty_generics + #where_token + __Unpin #ty_generics_with_pin_lt: ::core::marker::Unpin, + #predicates + {} + } +} + +fn generate_drop_impl(ident: &Ident, generics: &Generics, args: Args) -> TokenStream { + let (impl_generics, ty_generics, whr) = generics.split_for_impl(); + let has_pinned_drop = matches!(args, Args::PinnedDrop(_)); + // We need to disallow normal `Drop` implementation, the exact behavior depends on whether + // `PinnedDrop` was specified in `args`. + if has_pinned_drop { + // When `PinnedDrop` was specified we just implement `Drop` and delegate. + quote! { + impl #impl_generics ::core::ops::Drop for #ident #ty_generics + #whr + { + fn drop(&mut self) { + // SAFETY: Since this is a destructor, `self` will not move after this function + // terminates, since it is inaccessible. + let pinned = unsafe { ::core::pin::Pin::new_unchecked(self) }; + // SAFETY: Since this is a drop function, we can create this token to call the + // pinned destructor of this type. + let token = unsafe { ::pin_init::__internal::OnlyCallFromDrop::new() }; + ::pin_init::PinnedDrop::drop(pinned, token); } - Some(res) } - _ => None, - }) - .unwrap_or_else(|| { - // If we did not find the name of the struct then we will use `Self` as the replacement - // and add a compile error to ensure it does not compile. - errs.extend( - "::core::compile_error!(\"Could not locate type name.\");" - .parse::<TokenStream>() - .unwrap(), - ); - "Self".parse::<TokenStream>().unwrap().into_iter().collect() - }); - let impl_generics = impl_generics - .into_iter() - .flat_map(|tt| replace_self_and_deny_type_defs(&struct_name, tt, &mut errs)) - .collect::<Vec<_>>(); - let mut rest = rest - .into_iter() - .flat_map(|tt| { - // We ignore top level `struct` tokens, since they would emit a compile error. - if matches!(&tt, TokenTree::Ident(i) if i.to_string() == "struct") { - vec![tt] + } + } else { + // When no `PinnedDrop` was specified, then we have to prevent implementing drop. + quote! { + // We prevent this by creating a trait that will be implemented for all types implementing + // `Drop`. Additionally we will implement this trait for the struct leading to a conflict, + // if it also implements `Drop` + trait MustNotImplDrop {} + #[expect(drop_bounds)] + impl<T: ::core::ops::Drop + ?::core::marker::Sized> MustNotImplDrop for T {} + impl #impl_generics MustNotImplDrop for #ident #ty_generics + #whr + {} + // We also take care to prevent users from writing a useless `PinnedDrop` implementation. + // They might implement `PinnedDrop` correctly for the struct, but forget to give + // `PinnedDrop` as the parameter to `#[pin_data]`. + #[expect(non_camel_case_types)] + trait UselessPinnedDropImpl_you_need_to_specify_PinnedDrop {} + impl<T: ::pin_init::PinnedDrop + ?::core::marker::Sized> + UselessPinnedDropImpl_you_need_to_specify_PinnedDrop for T {} + impl #impl_generics + UselessPinnedDropImpl_you_need_to_specify_PinnedDrop for #ident #ty_generics + #whr + {} + } + } +} + +fn generate_projections( + vis: &Visibility, + ident: &Ident, + generics: &Generics, + fields: &[(bool, &Field)], +) -> TokenStream { + let (impl_generics, ty_generics, _) = generics.split_for_impl(); + let mut generics_with_pin_lt = generics.clone(); + generics_with_pin_lt.params.insert(0, parse_quote!('__pin)); + let (_, ty_generics_with_pin_lt, whr) = generics_with_pin_lt.split_for_impl(); + let projection = format_ident!("{ident}Projection"); + let this = format_ident!("this"); + + let (fields_decl, fields_proj) = collect_tuple(fields.iter().map( + |( + pinned, + Field { + vis, + ident, + ty, + attrs, + .. + }, + )| { + let mut attrs = attrs.clone(); + attrs.retain(|a| !a.path().is_ident("pin")); + let mut no_doc_attrs = attrs.clone(); + no_doc_attrs.retain(|a| !a.path().is_ident("doc")); + let ident = ident + .as_ref() + .expect("only structs with named fields are supported"); + if *pinned { + ( + quote!( + #(#attrs)* + #vis #ident: ::core::pin::Pin<&'__pin mut #ty>, + ), + quote!( + #(#no_doc_attrs)* + // SAFETY: this field is structurally pinned. + #ident: unsafe { ::core::pin::Pin::new_unchecked(&mut #this.#ident) }, + ), + ) } else { - replace_self_and_deny_type_defs(&struct_name, tt, &mut errs) + ( + quote!( + #(#attrs)* + #vis #ident: &'__pin mut #ty, + ), + quote!( + #(#no_doc_attrs)* + #ident: &mut #this.#ident, + ), + ) } - }) - .collect::<Vec<_>>(); - // This should be the body of the struct `{...}`. - let last = rest.pop(); - let mut quoted = quote!(::pin_init::__pin_data! { - parse_input: - @args(#args), - @sig(#(#rest)*), - @impl_generics(#(#impl_generics)*), - @ty_generics(#(#ty_generics)*), - @decl_generics(#(#decl_generics)*), - @body(#last), - }); - quoted.extend(errs); - quoted + }, + )); + let structurally_pinned_fields_docs = fields + .iter() + .filter_map(|(pinned, field)| pinned.then_some(field)) + .map(|Field { ident, .. }| format!(" - `{}`", ident.as_ref().unwrap())); + let not_structurally_pinned_fields_docs = fields + .iter() + .filter_map(|(pinned, field)| (!pinned).then_some(field)) + .map(|Field { ident, .. }| format!(" - `{}`", ident.as_ref().unwrap())); + let docs = format!(" Pin-projections of [`{ident}`]"); + quote! { + #[doc = #docs] + #[allow(dead_code)] + #[doc(hidden)] + #vis struct #projection #generics_with_pin_lt { + #(#fields_decl)* + ___pin_phantom_data: ::core::marker::PhantomData<&'__pin mut ()>, + } + + impl #impl_generics #ident #ty_generics + #whr + { + /// Pin-projects all fields of `Self`. + /// + /// These fields are structurally pinned: + #(#[doc = #structurally_pinned_fields_docs])* + /// + /// These fields are **not** structurally pinned: + #(#[doc = #not_structurally_pinned_fields_docs])* + #[inline] + #vis fn project<'__pin>( + self: ::core::pin::Pin<&'__pin mut Self>, + ) -> #projection #ty_generics_with_pin_lt { + // SAFETY: we only give access to `&mut` for fields not structurally pinned. + let #this = unsafe { ::core::pin::Pin::get_unchecked_mut(self) }; + #projection { + #(#fields_proj)* + ___pin_phantom_data: ::core::marker::PhantomData, + } + } + } + } } -/// Replaces `Self` with `struct_name` and errors on `enum`, `trait`, `struct` `union` and `impl` -/// keywords. -/// -/// The error is appended to `errs` to allow normal parsing to continue. -fn replace_self_and_deny_type_defs( - struct_name: &Vec<TokenTree>, - tt: TokenTree, - errs: &mut TokenStream, -) -> Vec<TokenTree> { - match tt { - TokenTree::Ident(ref i) - if i.to_string() == "enum" - || i.to_string() == "trait" - || i.to_string() == "struct" - || i.to_string() == "union" - || i.to_string() == "impl" => +fn generate_the_pin_data( + vis: &Visibility, + ident: &Ident, + generics: &Generics, + fields: &[(bool, &Field)], +) -> TokenStream { + let (impl_generics, ty_generics, whr) = generics.split_for_impl(); + + // For every field, we create an initializing projection function according to its projection + // type. If a field is structurally pinned, then it must be initialized via `PinInit`, if it is + // not structurally pinned, then it can be initialized via `Init`. + // + // The functions are `unsafe` to prevent accidentally calling them. + fn handle_field( + Field { + vis, + ident, + ty, + attrs, + .. + }: &Field, + struct_ident: &Ident, + pinned: bool, + ) -> TokenStream { + let mut attrs = attrs.clone(); + attrs.retain(|a| !a.path().is_ident("pin")); + let ident = ident + .as_ref() + .expect("only structs with named fields are supported"); + let project_ident = format_ident!("__project_{ident}"); + let (init_ty, init_fn, project_ty, project_body, pin_safety) = if pinned { + ( + quote!(PinInit), + quote!(__pinned_init), + quote!(::core::pin::Pin<&'__slot mut #ty>), + // SAFETY: this field is structurally pinned. + quote!(unsafe { ::core::pin::Pin::new_unchecked(slot) }), + quote!( + /// - `slot` will not move until it is dropped, i.e. it will be pinned. + ), + ) + } else { + ( + quote!(Init), + quote!(__init), + quote!(&'__slot mut #ty), + quote!(slot), + quote!(), + ) + }; + let slot_safety = format!( + " `slot` points at the field `{ident}` inside of `{struct_ident}`, which is pinned.", + ); + quote! { + /// # Safety + /// + /// - `slot` is a valid pointer to uninitialized memory. + /// - the caller does not touch `slot` when `Err` is returned, they are only permitted + /// to deallocate. + #pin_safety + #(#attrs)* + #vis unsafe fn #ident<E>( + self, + slot: *mut #ty, + init: impl ::pin_init::#init_ty<#ty, E>, + ) -> ::core::result::Result<(), E> { + // SAFETY: this function has the same safety requirements as the __init function + // called below. + unsafe { ::pin_init::#init_ty::#init_fn(init, slot) } + } + + /// # Safety + /// + #[doc = #slot_safety] + #(#attrs)* + #vis unsafe fn #project_ident<'__slot>( + self, + slot: &'__slot mut #ty, + ) -> #project_ty { + #project_body + } + } + } + + let field_accessors = fields + .iter() + .map(|(pinned, field)| handle_field(field, ident, *pinned)) + .collect::<TokenStream>(); + quote! { + // We declare this struct which will host all of the projection function for our type. It + // will be invariant over all generic parameters which are inherited from the struct. + #[doc(hidden)] + #vis struct __ThePinData #generics + #whr { - errs.extend( - format!( - "::core::compile_error!(\"Cannot use `{i}` inside of struct definition with \ - `#[pin_data]`.\");" - ) - .parse::<TokenStream>() - .unwrap() - .into_iter() - .map(|mut tok| { - tok.set_span(tt.span()); - tok - }), - ); - vec![tt] - } - TokenTree::Ident(i) if i.to_string() == "Self" => struct_name.clone(), - TokenTree::Literal(_) | TokenTree::Punct(_) | TokenTree::Ident(_) => vec![tt], - TokenTree::Group(g) => vec![TokenTree::Group(Group::new( - g.delimiter(), - g.stream() - .into_iter() - .flat_map(|tt| replace_self_and_deny_type_defs(struct_name, tt, errs)) - .collect(), - ))], + __phantom: ::core::marker::PhantomData< + fn(#ident #ty_generics) -> #ident #ty_generics + >, + } + + impl #impl_generics ::core::clone::Clone for __ThePinData #ty_generics + #whr + { + fn clone(&self) -> Self { *self } + } + + impl #impl_generics ::core::marker::Copy for __ThePinData #ty_generics + #whr + {} + + #[allow(dead_code)] // Some functions might never be used and private. + #[expect(clippy::missing_safety_doc)] + impl #impl_generics __ThePinData #ty_generics + #whr + { + #field_accessors + } + + // SAFETY: We have added the correct projection functions above to `__ThePinData` and + // we also use the least restrictive generics possible. + unsafe impl #impl_generics ::pin_init::__internal::HasPinData for #ident #ty_generics + #whr + { + type PinData = __ThePinData #ty_generics; + + unsafe fn __pin_data() -> Self::PinData { + __ThePinData { __phantom: ::core::marker::PhantomData } + } + } + + // SAFETY: TODO + unsafe impl #impl_generics ::pin_init::__internal::PinData for __ThePinData #ty_generics + #whr + { + type Datee = #ident #ty_generics; + } + } +} + +struct SelfReplacer(PathSegment); + +impl VisitMut for SelfReplacer { + fn visit_path_mut(&mut self, i: &mut syn::Path) { + if i.is_ident("Self") { + let span = i.span(); + let seg = &self.0; + *i = parse_quote_spanned!(span=> #seg); + } else { + syn::visit_mut::visit_path_mut(self, i); + } + } + + fn visit_path_segment_mut(&mut self, seg: &mut PathSegment) { + if seg.ident == "Self" { + let span = seg.span(); + let this = &self.0; + *seg = parse_quote_spanned!(span=> #this); + } else { + syn::visit_mut::visit_path_segment_mut(self, seg); + } + } + + fn visit_item_mut(&mut self, _: &mut Item) { + // Do not descend into items, since items reset/change what `Self` refers to. + } +} + +// replace with `.collect()` once MSRV is above 1.79 +fn collect_tuple<A, B>(iter: impl Iterator<Item = (A, B)>) -> (Vec<A>, Vec<B>) { + let mut res_a = vec![]; + let mut res_b = vec![]; + for (a, b) in iter { + res_a.push(a); + res_b.push(b); } + (res_a, res_b) } diff --git a/rust/pin-init/internal/src/pinned_drop.rs b/rust/pin-init/internal/src/pinned_drop.rs index c4ca7a70b726..a20ac314ca82 100644 --- a/rust/pin-init/internal/src/pinned_drop.rs +++ b/rust/pin-init/internal/src/pinned_drop.rs @@ -1,51 +1,61 @@ // SPDX-License-Identifier: Apache-2.0 OR MIT -#[cfg(not(kernel))] -use proc_macro2 as proc_macro; +use proc_macro2::TokenStream; +use quote::quote; +use syn::{parse::Nothing, parse_quote, spanned::Spanned, ImplItem, ItemImpl, Token}; -use proc_macro::{TokenStream, TokenTree}; +use crate::diagnostics::{DiagCtxt, ErrorGuaranteed}; -pub(crate) fn pinned_drop(_args: TokenStream, input: TokenStream) -> TokenStream { - let mut toks = input.into_iter().collect::<Vec<_>>(); - assert!(!toks.is_empty()); - // Ensure that we have an `impl` item. - assert!(matches!(&toks[0], TokenTree::Ident(i) if i.to_string() == "impl")); - // Ensure that we are implementing `PinnedDrop`. - let mut nesting: usize = 0; - let mut pinned_drop_idx = None; - for (i, tt) in toks.iter().enumerate() { - match tt { - TokenTree::Punct(p) if p.as_char() == '<' => { - nesting += 1; +pub(crate) fn pinned_drop( + _args: Nothing, + mut input: ItemImpl, + dcx: &mut DiagCtxt, +) -> Result<TokenStream, ErrorGuaranteed> { + if let Some(unsafety) = input.unsafety { + dcx.error(unsafety, "implementing `PinnedDrop` is safe"); + } + input.unsafety = Some(Token); + match &mut input.trait_ { + Some((not, path, _for)) => { + if let Some(not) = not { + dcx.error(not, "cannot implement `!PinnedDrop`"); } - TokenTree::Punct(p) if p.as_char() == '>' => { - nesting = nesting.checked_sub(1).unwrap(); - continue; + for (seg, expected) in path + .segments + .iter() + .rev() + .zip(["PinnedDrop", "pin_init", ""]) + { + if expected.is_empty() || seg.ident != expected { + dcx.error(seg, "bad import path for `PinnedDrop`"); + } + if !seg.arguments.is_none() { + dcx.error(&seg.arguments, "unexpected arguments for `PinnedDrop` path"); + } } - _ => {} + *path = parse_quote!(::pin_init::PinnedDrop); } - if i >= 1 && nesting == 0 { - // Found the end of the generics, this should be `PinnedDrop`. - assert!( - matches!(tt, TokenTree::Ident(i) if i.to_string() == "PinnedDrop"), - "expected 'PinnedDrop', found: '{tt:?}'" + None => { + let span = input + .impl_token + .span + .join(input.self_ty.span()) + .unwrap_or(input.impl_token.span); + dcx.error( + span, + "expected `impl ... PinnedDrop for ...`, got inherent impl", ); - pinned_drop_idx = Some(i); - break; } } - let idx = pinned_drop_idx - .unwrap_or_else(|| panic!("Expected an `impl` block implementing `PinnedDrop`.")); - // Fully qualify the `PinnedDrop`, as to avoid any tampering. - toks.splice(idx..idx, quote!(::pin_init::)); - // Take the `{}` body and call the declarative macro. - if let Some(TokenTree::Group(last)) = toks.pop() { - let last = last.stream(); - quote!(::pin_init::__pinned_drop! { - @impl_sig(#(#toks)*), - @impl_body(#last), - }) - } else { - TokenStream::from_iter(toks) + for item in &mut input.items { + if let ImplItem::Fn(fn_item) = item { + if fn_item.sig.ident == "drop" { + fn_item + .sig + .inputs + .push(parse_quote!(_: ::pin_init::__internal::OnlyCallFromDrop)); + } + } } + Ok(quote!(#input)) } diff --git a/rust/pin-init/internal/src/zeroable.rs b/rust/pin-init/internal/src/zeroable.rs index e0ed3998445c..05683319b0f7 100644 --- a/rust/pin-init/internal/src/zeroable.rs +++ b/rust/pin-init/internal/src/zeroable.rs @@ -1,101 +1,78 @@ // SPDX-License-Identifier: GPL-2.0 -#[cfg(not(kernel))] -use proc_macro2 as proc_macro; +use proc_macro2::TokenStream; +use quote::quote; +use syn::{parse_quote, Data, DeriveInput, Field, Fields}; -use crate::helpers::{parse_generics, Generics}; -use proc_macro::{TokenStream, TokenTree}; +use crate::{diagnostics::ErrorGuaranteed, DiagCtxt}; -pub(crate) fn parse_zeroable_derive_input( - input: TokenStream, -) -> ( - Vec<TokenTree>, - Vec<TokenTree>, - Vec<TokenTree>, - Option<TokenTree>, -) { - let ( - Generics { - impl_generics, - decl_generics: _, - ty_generics, - }, - mut rest, - ) = parse_generics(input); - // This should be the body of the struct `{...}`. - let last = rest.pop(); - // Now we insert `Zeroable` as a bound for every generic parameter in `impl_generics`. - let mut new_impl_generics = Vec::with_capacity(impl_generics.len()); - // Are we inside of a generic where we want to add `Zeroable`? - let mut in_generic = !impl_generics.is_empty(); - // Have we already inserted `Zeroable`? - let mut inserted = false; - // Level of `<>` nestings. - let mut nested = 0; - for tt in impl_generics { - match &tt { - // If we find a `,`, then we have finished a generic/constant/lifetime parameter. - TokenTree::Punct(p) if nested == 0 && p.as_char() == ',' => { - if in_generic && !inserted { - new_impl_generics.extend(quote! { : ::pin_init::Zeroable }); - } - in_generic = true; - inserted = false; - new_impl_generics.push(tt); - } - // If we find `'`, then we are entering a lifetime. - TokenTree::Punct(p) if nested == 0 && p.as_char() == '\'' => { - in_generic = false; - new_impl_generics.push(tt); - } - TokenTree::Punct(p) if nested == 0 && p.as_char() == ':' => { - new_impl_generics.push(tt); - if in_generic { - new_impl_generics.extend(quote! { ::pin_init::Zeroable + }); - inserted = true; - } - } - TokenTree::Punct(p) if p.as_char() == '<' => { - nested += 1; - new_impl_generics.push(tt); - } - TokenTree::Punct(p) if p.as_char() == '>' => { - assert!(nested > 0); - nested -= 1; - new_impl_generics.push(tt); - } - _ => new_impl_generics.push(tt), +pub(crate) fn derive( + input: DeriveInput, + dcx: &mut DiagCtxt, +) -> Result<TokenStream, ErrorGuaranteed> { + let fields = match input.data { + Data::Struct(data_struct) => data_struct.fields, + Data::Union(data_union) => Fields::Named(data_union.fields), + Data::Enum(data_enum) => { + return Err(dcx.error(data_enum.enum_token, "cannot derive `Zeroable` for an enum")); } + }; + let name = input.ident; + let mut generics = input.generics; + for param in generics.type_params_mut() { + param.bounds.insert(0, parse_quote!(::pin_init::Zeroable)); } - assert_eq!(nested, 0); - if in_generic && !inserted { - new_impl_generics.extend(quote! { : ::pin_init::Zeroable }); - } - (rest, new_impl_generics, ty_generics, last) + let (impl_gen, ty_gen, whr) = generics.split_for_impl(); + let field_type = fields.iter().map(|field| &field.ty); + Ok(quote! { + // SAFETY: Every field type implements `Zeroable` and padding bytes may be zero. + #[automatically_derived] + unsafe impl #impl_gen ::pin_init::Zeroable for #name #ty_gen + #whr + {} + const _: () = { + fn assert_zeroable<T: ?::core::marker::Sized + ::pin_init::Zeroable>() {} + fn ensure_zeroable #impl_gen () + #whr + { + #( + assert_zeroable::<#field_type>(); + )* + } + }; + }) } -pub(crate) fn derive(input: TokenStream) -> TokenStream { - let (rest, new_impl_generics, ty_generics, last) = parse_zeroable_derive_input(input); - quote! { - ::pin_init::__derive_zeroable!( - parse_input: - @sig(#(#rest)*), - @impl_generics(#(#new_impl_generics)*), - @ty_generics(#(#ty_generics)*), - @body(#last), - ); +pub(crate) fn maybe_derive( + input: DeriveInput, + dcx: &mut DiagCtxt, +) -> Result<TokenStream, ErrorGuaranteed> { + let fields = match input.data { + Data::Struct(data_struct) => data_struct.fields, + Data::Union(data_union) => Fields::Named(data_union.fields), + Data::Enum(data_enum) => { + return Err(dcx.error(data_enum.enum_token, "cannot derive `Zeroable` for an enum")); + } + }; + let name = input.ident; + let mut generics = input.generics; + for param in generics.type_params_mut() { + param.bounds.insert(0, parse_quote!(::pin_init::Zeroable)); } -} - -pub(crate) fn maybe_derive(input: TokenStream) -> TokenStream { - let (rest, new_impl_generics, ty_generics, last) = parse_zeroable_derive_input(input); - quote! { - ::pin_init::__maybe_derive_zeroable!( - parse_input: - @sig(#(#rest)*), - @impl_generics(#(#new_impl_generics)*), - @ty_generics(#(#ty_generics)*), - @body(#last), - ); + for Field { ty, .. } in fields { + generics + .make_where_clause() + .predicates + // the `for<'__dummy>` HRTB makes this not error without the `trivial_bounds` + // feature <https://github.com/rust-lang/rust/issues/48214#issuecomment-2557829956>. + .push(parse_quote!(#ty: for<'__dummy> ::pin_init::Zeroable)); } + let (impl_gen, ty_gen, whr) = generics.split_for_impl(); + Ok(quote! { + // SAFETY: Every field type implements `Zeroable` and padding bytes may be zero. + #[automatically_derived] + unsafe impl #impl_gen ::pin_init::Zeroable for #name #ty_gen + #whr + {} + }) } diff --git a/rust/pin-init/src/lib.rs b/rust/pin-init/src/lib.rs index 8dc9dd5ac6fd..49945fc07f25 100644 --- a/rust/pin-init/src/lib.rs +++ b/rust/pin-init/src/lib.rs @@ -146,7 +146,7 @@ //! //! impl DriverData { //! fn new() -> impl PinInit<Self, Error> { -//! try_pin_init!(Self { +//! pin_init!(Self { //! status <- CMutex::new(0), //! buffer: Box::init(pin_init::init_zeroed())?, //! }? Error) @@ -290,10 +290,13 @@ use core::{ ptr::{self, NonNull}, }; +// This is used by doc-tests -- the proc-macros expand to `::pin_init::...` and without this the +// doc-tests wouldn't have an extern crate named `pin_init`. +#[allow(unused_extern_crates)] +extern crate self as pin_init; + #[doc(hidden)] pub mod __internal; -#[doc(hidden)] -pub mod macros; #[cfg(any(feature = "std", feature = "alloc"))] mod alloc; @@ -528,7 +531,7 @@ macro_rules! stack_pin_init { /// x: u32, /// } /// -/// stack_try_pin_init!(let foo: Foo = try_pin_init!(Foo { +/// stack_try_pin_init!(let foo: Foo = pin_init!(Foo { /// a <- CMutex::new(42), /// b: Box::try_new(Bar { /// x: 64, @@ -555,7 +558,7 @@ macro_rules! stack_pin_init { /// x: u32, /// } /// -/// stack_try_pin_init!(let foo: Foo =? try_pin_init!(Foo { +/// stack_try_pin_init!(let foo: Foo =? pin_init!(Foo { /// a <- CMutex::new(42), /// b: Box::try_new(Bar { /// x: 64, @@ -584,10 +587,10 @@ macro_rules! stack_try_pin_init { }; } -/// Construct an in-place, pinned initializer for `struct`s. +/// Construct an in-place, fallible pinned initializer for `struct`s. /// -/// This macro defaults the error to [`Infallible`]. If you need a different error, then use -/// [`try_pin_init!`]. +/// The error type defaults to [`Infallible`]; if you need a different one, write `? Error` at the +/// end, after the struct initializer. /// /// The syntax is almost identical to that of a normal `struct` initializer: /// @@ -776,81 +779,12 @@ macro_rules! stack_try_pin_init { /// ``` /// /// [`NonNull<Self>`]: core::ptr::NonNull -// For a detailed example of how this macro works, see the module documentation of the hidden -// module `macros` inside of `macros.rs`. -#[macro_export] -macro_rules! pin_init { - ($(&$this:ident in)? $t:ident $(::<$($generics:ty),* $(,)?>)? { - $($fields:tt)* - }) => { - $crate::try_pin_init!($(&$this in)? $t $(::<$($generics),*>)? { - $($fields)* - }? ::core::convert::Infallible) - }; -} - -/// Construct an in-place, fallible pinned initializer for `struct`s. -/// -/// If the initialization can complete without error (or [`Infallible`]), then use [`pin_init!`]. -/// -/// You can use the `?` operator or use `return Err(err)` inside the initializer to stop -/// initialization and return the error. -/// -/// IMPORTANT: if you have `unsafe` code inside of the initializer you have to ensure that when -/// initialization fails, the memory can be safely deallocated without any further modifications. -/// -/// The syntax is identical to [`pin_init!`] with the following exception: you must append `? $type` -/// after the `struct` initializer to specify the error type you want to use. -/// -/// # Examples -/// -/// ```rust -/// # #![feature(allocator_api)] -/// # #[path = "../examples/error.rs"] mod error; use error::Error; -/// use pin_init::{pin_data, try_pin_init, PinInit, InPlaceInit, init_zeroed}; -/// -/// #[pin_data] -/// struct BigBuf { -/// big: Box<[u8; 1024 * 1024 * 1024]>, -/// small: [u8; 1024 * 1024], -/// ptr: *mut u8, -/// } -/// -/// impl BigBuf { -/// fn new() -> impl PinInit<Self, Error> { -/// try_pin_init!(Self { -/// big: Box::init(init_zeroed())?, -/// small: [0; 1024 * 1024], -/// ptr: core::ptr::null_mut(), -/// }? Error) -/// } -/// } -/// # let _ = Box::pin_init(BigBuf::new()); -/// ``` -// For a detailed example of how this macro works, see the module documentation of the hidden -// module `macros` inside of `macros.rs`. -#[macro_export] -macro_rules! try_pin_init { - ($(&$this:ident in)? $t:ident $(::<$($generics:ty),* $(,)?>)? { - $($fields:tt)* - }? $err:ty) => { - $crate::__init_internal!( - @this($($this)?), - @typ($t $(::<$($generics),*>)? ), - @fields($($fields)*), - @error($err), - @data(PinData, use_data), - @has_data(HasPinData, __pin_data), - @construct_closure(pin_init_from_closure), - @munch_fields($($fields)*), - ) - } -} +pub use pin_init_internal::pin_init; -/// Construct an in-place initializer for `struct`s. +/// Construct an in-place, fallible initializer for `struct`s. /// -/// This macro defaults the error to [`Infallible`]. If you need a different error, then use -/// [`try_init!`]. +/// This macro defaults the error to [`Infallible`]; if you need a different one, write `? Error` +/// at the end, after the struct initializer. /// /// The syntax is identical to [`pin_init!`] and its safety caveats also apply: /// - `unsafe` code must guarantee either full initialization or return an error and allow @@ -883,74 +817,7 @@ macro_rules! try_pin_init { /// } /// # let _ = Box::init(BigBuf::new()); /// ``` -// For a detailed example of how this macro works, see the module documentation of the hidden -// module `macros` inside of `macros.rs`. -#[macro_export] -macro_rules! init { - ($(&$this:ident in)? $t:ident $(::<$($generics:ty),* $(,)?>)? { - $($fields:tt)* - }) => { - $crate::try_init!($(&$this in)? $t $(::<$($generics),*>)? { - $($fields)* - }? ::core::convert::Infallible) - } -} - -/// Construct an in-place fallible initializer for `struct`s. -/// -/// If the initialization can complete without error (or [`Infallible`]), then use -/// [`init!`]. -/// -/// The syntax is identical to [`try_pin_init!`]. You need to specify a custom error -/// via `? $type` after the `struct` initializer. -/// The safety caveats from [`try_pin_init!`] also apply: -/// - `unsafe` code must guarantee either full initialization or return an error and allow -/// deallocation of the memory. -/// - the fields are initialized in the order given in the initializer. -/// - no references to fields are allowed to be created inside of the initializer. -/// -/// # Examples -/// -/// ```rust -/// # #![feature(allocator_api)] -/// # use core::alloc::AllocError; -/// # use pin_init::InPlaceInit; -/// use pin_init::{try_init, Init, init_zeroed}; -/// -/// struct BigBuf { -/// big: Box<[u8; 1024 * 1024 * 1024]>, -/// small: [u8; 1024 * 1024], -/// } -/// -/// impl BigBuf { -/// fn new() -> impl Init<Self, AllocError> { -/// try_init!(Self { -/// big: Box::init(init_zeroed())?, -/// small: [0; 1024 * 1024], -/// }? AllocError) -/// } -/// } -/// # let _ = Box::init(BigBuf::new()); -/// ``` -// For a detailed example of how this macro works, see the module documentation of the hidden -// module `macros` inside of `macros.rs`. -#[macro_export] -macro_rules! try_init { - ($(&$this:ident in)? $t:ident $(::<$($generics:ty),* $(,)?>)? { - $($fields:tt)* - }? $err:ty) => { - $crate::__init_internal!( - @this($($this)?), - @typ($t $(::<$($generics),*>)?), - @fields($($fields)*), - @error($err), - @data(InitData, /*no use_data*/), - @has_data(HasInitData, __init_data), - @construct_closure(init_from_closure), - @munch_fields($($fields)*), - ) - } -} +pub use pin_init_internal::init; /// Asserts that a field on a struct using `#[pin_data]` is marked with `#[pin]` ie. that it is /// structurally pinned. @@ -1410,14 +1277,14 @@ where /// fn init_foo() -> impl PinInit<Foo, Error> { /// pin_init_scope(|| { /// let bar = lookup_bar()?; -/// Ok(try_pin_init!(Foo { a: bar.a.into(), b: bar.b }? Error)) +/// Ok(pin_init!(Foo { a: bar.a.into(), b: bar.b }? Error)) /// }) /// } /// ``` /// /// This initializer will first execute `lookup_bar()`, match on it, if it returned an error, the /// initializer itself will fail with that error. If it returned `Ok`, then it will run the -/// initializer returned by the [`try_pin_init!`] invocation. +/// initializer returned by the [`pin_init!`] invocation. pub fn pin_init_scope<T, E, F, I>(make_init: F) -> impl PinInit<T, E> where F: FnOnce() -> Result<I, E>, @@ -1453,14 +1320,14 @@ where /// fn init_foo() -> impl Init<Foo, Error> { /// init_scope(|| { /// let bar = lookup_bar()?; -/// Ok(try_init!(Foo { a: bar.a.into(), b: bar.b }? Error)) +/// Ok(init!(Foo { a: bar.a.into(), b: bar.b }? Error)) /// }) /// } /// ``` /// /// This initializer will first execute `lookup_bar()`, match on it, if it returned an error, the /// initializer itself will fail with that error. If it returned `Ok`, then it will run the -/// initializer returned by the [`try_init!`] invocation. +/// initializer returned by the [`init!`] invocation. pub fn init_scope<T, E, F, I>(make_init: F) -> impl Init<T, E> where F: FnOnce() -> Result<I, E>, @@ -1536,6 +1403,33 @@ pub trait InPlaceWrite<T> { fn write_pin_init<E>(self, init: impl PinInit<T, E>) -> Result<Pin<Self::Initialized>, E>; } +impl<T> InPlaceWrite<T> for &'static mut MaybeUninit<T> { + type Initialized = &'static mut T; + + fn write_init<E>(self, init: impl Init<T, E>) -> Result<Self::Initialized, E> { + let slot = self.as_mut_ptr(); + + // SAFETY: `slot` is a valid pointer to uninitialized memory. + unsafe { init.__init(slot)? }; + + // SAFETY: The above call initialized the memory. + unsafe { Ok(self.assume_init_mut()) } + } + + fn write_pin_init<E>(self, init: impl PinInit<T, E>) -> Result<Pin<Self::Initialized>, E> { + let slot = self.as_mut_ptr(); + + // SAFETY: `slot` is a valid pointer to uninitialized memory. + // + // The `'static` borrow guarantees the data will not be + // moved/invalidated until it gets dropped (which is never). + unsafe { init.__pinned_init(slot)? }; + + // SAFETY: The above call initialized the memory. + Ok(Pin::static_mut(unsafe { self.assume_init_mut() })) + } +} + /// Trait facilitating pinned destruction. /// /// Use [`pinned_drop`] to implement this trait safely: diff --git a/rust/pin-init/src/macros.rs b/rust/pin-init/src/macros.rs deleted file mode 100644 index 682c61a587a0..000000000000 --- a/rust/pin-init/src/macros.rs +++ /dev/null @@ -1,1677 +0,0 @@ -// SPDX-License-Identifier: Apache-2.0 OR MIT - -//! This module provides the macros that actually implement the proc-macros `pin_data` and -//! `pinned_drop`. It also contains `__init_internal`, the implementation of the -//! `{try_}{pin_}init!` macros. -//! -//! These macros should never be called directly, since they expect their input to be -//! in a certain format which is internal. If used incorrectly, these macros can lead to UB even in -//! safe code! Use the public facing macros instead. -//! -//! This architecture has been chosen because the kernel does not yet have access to `syn` which -//! would make matters a lot easier for implementing these as proc-macros. -//! -//! Since this library and the kernel implementation should diverge as little as possible, the same -//! approach has been taken here. -//! -//! # Macro expansion example -//! -//! This section is intended for readers trying to understand the macros in this module and the -//! `[try_][pin_]init!` macros from `lib.rs`. -//! -//! We will look at the following example: -//! -//! ```rust,ignore -//! #[pin_data] -//! #[repr(C)] -//! struct Bar<T> { -//! #[pin] -//! t: T, -//! pub x: usize, -//! } -//! -//! impl<T> Bar<T> { -//! fn new(t: T) -> impl PinInit<Self> { -//! pin_init!(Self { t, x: 0 }) -//! } -//! } -//! -//! #[pin_data(PinnedDrop)] -//! struct Foo { -//! a: usize, -//! #[pin] -//! b: Bar<u32>, -//! } -//! -//! #[pinned_drop] -//! impl PinnedDrop for Foo { -//! fn drop(self: Pin<&mut Self>) { -//! println!("{self:p} is getting dropped."); -//! } -//! } -//! -//! let a = 42; -//! let initializer = pin_init!(Foo { -//! a, -//! b <- Bar::new(36), -//! }); -//! ``` -//! -//! This example includes the most common and important features of the pin-init API. -//! -//! Below you can find individual section about the different macro invocations. Here are some -//! general things we need to take into account when designing macros: -//! - use global paths, similarly to file paths, these start with the separator: `::core::panic!()` -//! this ensures that the correct item is used, since users could define their own `mod core {}` -//! and then their own `panic!` inside to execute arbitrary code inside of our macro. -//! - macro `unsafe` hygiene: we need to ensure that we do not expand arbitrary, user-supplied -//! expressions inside of an `unsafe` block in the macro, because this would allow users to do -//! `unsafe` operations without an associated `unsafe` block. -//! -//! ## `#[pin_data]` on `Bar` -//! -//! This macro is used to specify which fields are structurally pinned and which fields are not. It -//! is placed on the struct definition and allows `#[pin]` to be placed on the fields. -//! -//! Here is the definition of `Bar` from our example: -//! -//! ```rust,ignore -//! #[pin_data] -//! #[repr(C)] -//! struct Bar<T> { -//! #[pin] -//! t: T, -//! pub x: usize, -//! } -//! ``` -//! -//! This expands to the following code: -//! -//! ```rust,ignore -//! // Firstly the normal definition of the struct, attributes are preserved: -//! #[repr(C)] -//! struct Bar<T> { -//! t: T, -//! pub x: usize, -//! } -//! // Then an anonymous constant is defined, this is because we do not want any code to access the -//! // types that we define inside: -//! const _: () = { -//! // We define the pin-data carrying struct, it is a ZST and needs to have the same generics, -//! // since we need to implement access functions for each field and thus need to know its -//! // type. -//! struct __ThePinData<T> { -//! __phantom: ::core::marker::PhantomData<fn(Bar<T>) -> Bar<T>>, -//! } -//! // We implement `Copy` for the pin-data struct, since all functions it defines will take -//! // `self` by value. -//! impl<T> ::core::clone::Clone for __ThePinData<T> { -//! fn clone(&self) -> Self { -//! *self -//! } -//! } -//! impl<T> ::core::marker::Copy for __ThePinData<T> {} -//! // For every field of `Bar`, the pin-data struct will define a function with the same name -//! // and accessor (`pub` or `pub(crate)` etc.). This function will take a pointer to the -//! // field (`slot`) and a `PinInit` or `Init` depending on the projection kind of the field -//! // (if pinning is structural for the field, then `PinInit` otherwise `Init`). -//! #[allow(dead_code)] -//! impl<T> __ThePinData<T> { -//! unsafe fn t<E>( -//! self, -//! slot: *mut T, -//! // Since `t` is `#[pin]`, this is `PinInit`. -//! init: impl ::pin_init::PinInit<T, E>, -//! ) -> ::core::result::Result<(), E> { -//! unsafe { ::pin_init::PinInit::__pinned_init(init, slot) } -//! } -//! pub unsafe fn x<E>( -//! self, -//! slot: *mut usize, -//! // Since `x` is not `#[pin]`, this is `Init`. -//! init: impl ::pin_init::Init<usize, E>, -//! ) -> ::core::result::Result<(), E> { -//! unsafe { ::pin_init::Init::__init(init, slot) } -//! } -//! } -//! // Implement the internal `HasPinData` trait that associates `Bar` with the pin-data struct -//! // that we constructed above. -//! unsafe impl<T> ::pin_init::__internal::HasPinData for Bar<T> { -//! type PinData = __ThePinData<T>; -//! unsafe fn __pin_data() -> Self::PinData { -//! __ThePinData { -//! __phantom: ::core::marker::PhantomData, -//! } -//! } -//! } -//! // Implement the internal `PinData` trait that marks the pin-data struct as a pin-data -//! // struct. This is important to ensure that no user can implement a rogue `__pin_data` -//! // function without using `unsafe`. -//! unsafe impl<T> ::pin_init::__internal::PinData for __ThePinData<T> { -//! type Datee = Bar<T>; -//! } -//! // Now we only want to implement `Unpin` for `Bar` when every structurally pinned field is -//! // `Unpin`. In other words, whether `Bar` is `Unpin` only depends on structurally pinned -//! // fields (those marked with `#[pin]`). These fields will be listed in this struct, in our -//! // case no such fields exist, hence this is almost empty. The two phantomdata fields exist -//! // for two reasons: -//! // - `__phantom`: every generic must be used, since we cannot really know which generics -//! // are used, we declare all and then use everything here once. -//! // - `__phantom_pin`: uses the `'__pin` lifetime and ensures that this struct is invariant -//! // over it. The lifetime is needed to work around the limitation that trait bounds must -//! // not be trivial, e.g. the user has a `#[pin] PhantomPinned` field -- this is -//! // unconditionally `!Unpin` and results in an error. The lifetime tricks the compiler -//! // into accepting these bounds regardless. -//! #[allow(dead_code)] -//! struct __Unpin<'__pin, T> { -//! __phantom_pin: ::core::marker::PhantomData<fn(&'__pin ()) -> &'__pin ()>, -//! __phantom: ::core::marker::PhantomData<fn(Bar<T>) -> Bar<T>>, -//! // Our only `#[pin]` field is `t`. -//! t: T, -//! } -//! #[doc(hidden)] -//! impl<'__pin, T> ::core::marker::Unpin for Bar<T> -//! where -//! __Unpin<'__pin, T>: ::core::marker::Unpin, -//! {} -//! // Now we need to ensure that `Bar` does not implement `Drop`, since that would give users -//! // access to `&mut self` inside of `drop` even if the struct was pinned. This could lead to -//! // UB with only safe code, so we disallow this by giving a trait implementation error using -//! // a direct impl and a blanket implementation. -//! trait MustNotImplDrop {} -//! // Normally `Drop` bounds do not have the correct semantics, but for this purpose they do -//! // (normally people want to know if a type has any kind of drop glue at all, here we want -//! // to know if it has any kind of custom drop glue, which is exactly what this bound does). -//! #[expect(drop_bounds)] -//! impl<T: ::core::ops::Drop> MustNotImplDrop for T {} -//! impl<T> MustNotImplDrop for Bar<T> {} -//! // Here comes a convenience check, if one implemented `PinnedDrop`, but forgot to add it to -//! // `#[pin_data]`, then this will error with the same mechanic as above, this is not needed -//! // for safety, but a good sanity check, since no normal code calls `PinnedDrop::drop`. -//! #[expect(non_camel_case_types)] -//! trait UselessPinnedDropImpl_you_need_to_specify_PinnedDrop {} -//! impl< -//! T: ::pin_init::PinnedDrop, -//! > UselessPinnedDropImpl_you_need_to_specify_PinnedDrop for T {} -//! impl<T> UselessPinnedDropImpl_you_need_to_specify_PinnedDrop for Bar<T> {} -//! }; -//! ``` -//! -//! ## `pin_init!` in `impl Bar` -//! -//! This macro creates an pin-initializer for the given struct. It requires that the struct is -//! annotated by `#[pin_data]`. -//! -//! Here is the impl on `Bar` defining the new function: -//! -//! ```rust,ignore -//! impl<T> Bar<T> { -//! fn new(t: T) -> impl PinInit<Self> { -//! pin_init!(Self { t, x: 0 }) -//! } -//! } -//! ``` -//! -//! This expands to the following code: -//! -//! ```rust,ignore -//! impl<T> Bar<T> { -//! fn new(t: T) -> impl PinInit<Self> { -//! { -//! // We do not want to allow arbitrary returns, so we declare this type as the `Ok` -//! // return type and shadow it later when we insert the arbitrary user code. That way -//! // there will be no possibility of returning without `unsafe`. -//! struct __InitOk; -//! // Get the data about fields from the supplied type. -//! // - the function is unsafe, hence the unsafe block -//! // - we `use` the `HasPinData` trait in the block, it is only available in that -//! // scope. -//! let data = unsafe { -//! use ::pin_init::__internal::HasPinData; -//! Self::__pin_data() -//! }; -//! // Ensure that `data` really is of type `PinData` and help with type inference: -//! let init = ::pin_init::__internal::PinData::make_closure::< -//! _, -//! __InitOk, -//! ::core::convert::Infallible, -//! >(data, move |slot| { -//! { -//! // Shadow the structure so it cannot be used to return early. If a user -//! // tries to write `return Ok(__InitOk)`, then they get a type error, -//! // since that will refer to this struct instead of the one defined -//! // above. -//! struct __InitOk; -//! // This is the expansion of `t,`, which is syntactic sugar for `t: t,`. -//! { -//! unsafe { ::core::ptr::write(::core::addr_of_mut!((*slot).t), t) }; -//! } -//! // Since initialization could fail later (not in this case, since the -//! // error type is `Infallible`) we will need to drop this field if there -//! // is an error later. This `DropGuard` will drop the field when it gets -//! // dropped and has not yet been forgotten. -//! let __t_guard = unsafe { -//! ::pin_init::__internal::DropGuard::new(::core::addr_of_mut!((*slot).t)) -//! }; -//! // Expansion of `x: 0,`: -//! // Since this can be an arbitrary expression we cannot place it inside -//! // of the `unsafe` block, so we bind it here. -//! { -//! let x = 0; -//! unsafe { ::core::ptr::write(::core::addr_of_mut!((*slot).x), x) }; -//! } -//! // We again create a `DropGuard`. -//! let __x_guard = unsafe { -//! ::pin_init::__internal::DropGuard::new(::core::addr_of_mut!((*slot).x)) -//! }; -//! // Since initialization has successfully completed, we can now forget -//! // the guards. This is not `mem::forget`, since we only have -//! // `&DropGuard`. -//! ::core::mem::forget(__x_guard); -//! ::core::mem::forget(__t_guard); -//! // Here we use the type checker to ensure that every field has been -//! // initialized exactly once, since this is `if false` it will never get -//! // executed, but still type-checked. -//! // Additionally we abuse `slot` to automatically infer the correct type -//! // for the struct. This is also another check that every field is -//! // accessible from this scope. -//! #[allow(unreachable_code, clippy::diverging_sub_expression)] -//! let _ = || { -//! unsafe { -//! ::core::ptr::write( -//! slot, -//! Self { -//! // We only care about typecheck finding every field -//! // here, the expression does not matter, just conjure -//! // one using `panic!()`: -//! t: ::core::panic!(), -//! x: ::core::panic!(), -//! }, -//! ); -//! }; -//! }; -//! } -//! // We leave the scope above and gain access to the previously shadowed -//! // `__InitOk` that we need to return. -//! Ok(__InitOk) -//! }); -//! // Change the return type from `__InitOk` to `()`. -//! let init = move | -//! slot, -//! | -> ::core::result::Result<(), ::core::convert::Infallible> { -//! init(slot).map(|__InitOk| ()) -//! }; -//! // Construct the initializer. -//! let init = unsafe { -//! ::pin_init::pin_init_from_closure::< -//! _, -//! ::core::convert::Infallible, -//! >(init) -//! }; -//! init -//! } -//! } -//! } -//! ``` -//! -//! ## `#[pin_data]` on `Foo` -//! -//! Since we already took a look at `#[pin_data]` on `Bar`, this section will only explain the -//! differences/new things in the expansion of the `Foo` definition: -//! -//! ```rust,ignore -//! #[pin_data(PinnedDrop)] -//! struct Foo { -//! a: usize, -//! #[pin] -//! b: Bar<u32>, -//! } -//! ``` -//! -//! This expands to the following code: -//! -//! ```rust,ignore -//! struct Foo { -//! a: usize, -//! b: Bar<u32>, -//! } -//! const _: () = { -//! struct __ThePinData { -//! __phantom: ::core::marker::PhantomData<fn(Foo) -> Foo>, -//! } -//! impl ::core::clone::Clone for __ThePinData { -//! fn clone(&self) -> Self { -//! *self -//! } -//! } -//! impl ::core::marker::Copy for __ThePinData {} -//! #[allow(dead_code)] -//! impl __ThePinData { -//! unsafe fn b<E>( -//! self, -//! slot: *mut Bar<u32>, -//! init: impl ::pin_init::PinInit<Bar<u32>, E>, -//! ) -> ::core::result::Result<(), E> { -//! unsafe { ::pin_init::PinInit::__pinned_init(init, slot) } -//! } -//! unsafe fn a<E>( -//! self, -//! slot: *mut usize, -//! init: impl ::pin_init::Init<usize, E>, -//! ) -> ::core::result::Result<(), E> { -//! unsafe { ::pin_init::Init::__init(init, slot) } -//! } -//! } -//! unsafe impl ::pin_init::__internal::HasPinData for Foo { -//! type PinData = __ThePinData; -//! unsafe fn __pin_data() -> Self::PinData { -//! __ThePinData { -//! __phantom: ::core::marker::PhantomData, -//! } -//! } -//! } -//! unsafe impl ::pin_init::__internal::PinData for __ThePinData { -//! type Datee = Foo; -//! } -//! #[allow(dead_code)] -//! struct __Unpin<'__pin> { -//! __phantom_pin: ::core::marker::PhantomData<fn(&'__pin ()) -> &'__pin ()>, -//! __phantom: ::core::marker::PhantomData<fn(Foo) -> Foo>, -//! b: Bar<u32>, -//! } -//! #[doc(hidden)] -//! impl<'__pin> ::core::marker::Unpin for Foo -//! where -//! __Unpin<'__pin>: ::core::marker::Unpin, -//! {} -//! // Since we specified `PinnedDrop` as the argument to `#[pin_data]`, we expect `Foo` to -//! // implement `PinnedDrop`. Thus we do not need to prevent `Drop` implementations like -//! // before, instead we implement `Drop` here and delegate to `PinnedDrop`. -//! impl ::core::ops::Drop for Foo { -//! fn drop(&mut self) { -//! // Since we are getting dropped, no one else has a reference to `self` and thus we -//! // can assume that we never move. -//! let pinned = unsafe { ::core::pin::Pin::new_unchecked(self) }; -//! // Create the unsafe token that proves that we are inside of a destructor, this -//! // type is only allowed to be created in a destructor. -//! let token = unsafe { ::pin_init::__internal::OnlyCallFromDrop::new() }; -//! ::pin_init::PinnedDrop::drop(pinned, token); -//! } -//! } -//! }; -//! ``` -//! -//! ## `#[pinned_drop]` on `impl PinnedDrop for Foo` -//! -//! This macro is used to implement the `PinnedDrop` trait, since that trait is `unsafe` and has an -//! extra parameter that should not be used at all. The macro hides that parameter. -//! -//! Here is the `PinnedDrop` impl for `Foo`: -//! -//! ```rust,ignore -//! #[pinned_drop] -//! impl PinnedDrop for Foo { -//! fn drop(self: Pin<&mut Self>) { -//! println!("{self:p} is getting dropped."); -//! } -//! } -//! ``` -//! -//! This expands to the following code: -//! -//! ```rust,ignore -//! // `unsafe`, full path and the token parameter are added, everything else stays the same. -//! unsafe impl ::pin_init::PinnedDrop for Foo { -//! fn drop(self: Pin<&mut Self>, _: ::pin_init::__internal::OnlyCallFromDrop) { -//! println!("{self:p} is getting dropped."); -//! } -//! } -//! ``` -//! -//! ## `pin_init!` on `Foo` -//! -//! Since we already took a look at `pin_init!` on `Bar`, this section will only show the expansion -//! of `pin_init!` on `Foo`: -//! -//! ```rust,ignore -//! let a = 42; -//! let initializer = pin_init!(Foo { -//! a, -//! b <- Bar::new(36), -//! }); -//! ``` -//! -//! This expands to the following code: -//! -//! ```rust,ignore -//! let a = 42; -//! let initializer = { -//! struct __InitOk; -//! let data = unsafe { -//! use ::pin_init::__internal::HasPinData; -//! Foo::__pin_data() -//! }; -//! let init = ::pin_init::__internal::PinData::make_closure::< -//! _, -//! __InitOk, -//! ::core::convert::Infallible, -//! >(data, move |slot| { -//! { -//! struct __InitOk; -//! { -//! unsafe { ::core::ptr::write(::core::addr_of_mut!((*slot).a), a) }; -//! } -//! let __a_guard = unsafe { -//! ::pin_init::__internal::DropGuard::new(::core::addr_of_mut!((*slot).a)) -//! }; -//! let init = Bar::new(36); -//! unsafe { data.b(::core::addr_of_mut!((*slot).b), b)? }; -//! let __b_guard = unsafe { -//! ::pin_init::__internal::DropGuard::new(::core::addr_of_mut!((*slot).b)) -//! }; -//! ::core::mem::forget(__b_guard); -//! ::core::mem::forget(__a_guard); -//! #[allow(unreachable_code, clippy::diverging_sub_expression)] -//! let _ = || { -//! unsafe { -//! ::core::ptr::write( -//! slot, -//! Foo { -//! a: ::core::panic!(), -//! b: ::core::panic!(), -//! }, -//! ); -//! }; -//! }; -//! } -//! Ok(__InitOk) -//! }); -//! let init = move | -//! slot, -//! | -> ::core::result::Result<(), ::core::convert::Infallible> { -//! init(slot).map(|__InitOk| ()) -//! }; -//! let init = unsafe { -//! ::pin_init::pin_init_from_closure::<_, ::core::convert::Infallible>(init) -//! }; -//! init -//! }; -//! ``` - -#[cfg(kernel)] -pub use ::macros::paste; -#[cfg(not(kernel))] -pub use ::paste::paste; - -/// Creates a `unsafe impl<...> PinnedDrop for $type` block. -/// -/// See [`PinnedDrop`] for more information. -/// -/// [`PinnedDrop`]: crate::PinnedDrop -#[doc(hidden)] -#[macro_export] -macro_rules! __pinned_drop { - ( - @impl_sig($($impl_sig:tt)*), - @impl_body( - $(#[$($attr:tt)*])* - fn drop($($sig:tt)*) { - $($inner:tt)* - } - ), - ) => { - // SAFETY: TODO. - unsafe $($impl_sig)* { - // Inherit all attributes and the type/ident tokens for the signature. - $(#[$($attr)*])* - fn drop($($sig)*, _: $crate::__internal::OnlyCallFromDrop) { - $($inner)* - } - } - } -} - -/// This macro first parses the struct definition such that it separates pinned and not pinned -/// fields. Afterwards it declares the struct and implement the `PinData` trait safely. -#[doc(hidden)] -#[macro_export] -macro_rules! __pin_data { - // Proc-macro entry point, this is supplied by the proc-macro pre-parsing. - (parse_input: - @args($($pinned_drop:ident)?), - @sig( - $(#[$($struct_attr:tt)*])* - $vis:vis struct $name:ident - $(where $($whr:tt)*)? - ), - @impl_generics($($impl_generics:tt)*), - @ty_generics($($ty_generics:tt)*), - @decl_generics($($decl_generics:tt)*), - @body({ $($fields:tt)* }), - ) => { - // We now use token munching to iterate through all of the fields. While doing this we - // identify fields marked with `#[pin]`, these fields are the 'pinned fields'. The user - // wants these to be structurally pinned. The rest of the fields are the - // 'not pinned fields'. Additionally we collect all fields, since we need them in the right - // order to declare the struct. - // - // In this call we also put some explaining comments for the parameters. - $crate::__pin_data!(find_pinned_fields: - // Attributes on the struct itself, these will just be propagated to be put onto the - // struct definition. - @struct_attrs($(#[$($struct_attr)*])*), - // The visibility of the struct. - @vis($vis), - // The name of the struct. - @name($name), - // The 'impl generics', the generics that will need to be specified on the struct inside - // of an `impl<$ty_generics>` block. - @impl_generics($($impl_generics)*), - // The 'ty generics', the generics that will need to be specified on the impl blocks. - @ty_generics($($ty_generics)*), - // The 'decl generics', the generics that need to be specified on the struct - // definition. - @decl_generics($($decl_generics)*), - // The where clause of any impl block and the declaration. - @where($($($whr)*)?), - // The remaining fields tokens that need to be processed. - // We add a `,` at the end to ensure correct parsing. - @fields_munch($($fields)* ,), - // The pinned fields. - @pinned(), - // The not pinned fields. - @not_pinned(), - // All fields. - @fields(), - // The accumulator containing all attributes already parsed. - @accum(), - // Contains `yes` or `` to indicate if `#[pin]` was found on the current field. - @is_pinned(), - // The proc-macro argument, this should be `PinnedDrop` or ``. - @pinned_drop($($pinned_drop)?), - ); - }; - (find_pinned_fields: - @struct_attrs($($struct_attrs:tt)*), - @vis($vis:vis), - @name($name:ident), - @impl_generics($($impl_generics:tt)*), - @ty_generics($($ty_generics:tt)*), - @decl_generics($($decl_generics:tt)*), - @where($($whr:tt)*), - // We found a PhantomPinned field, this should generally be pinned! - @fields_munch($field:ident : $($($(::)?core::)?marker::)?PhantomPinned, $($rest:tt)*), - @pinned($($pinned:tt)*), - @not_pinned($($not_pinned:tt)*), - @fields($($fields:tt)*), - @accum($($accum:tt)*), - // This field is not pinned. - @is_pinned(), - @pinned_drop($($pinned_drop:ident)?), - ) => { - ::core::compile_error!(concat!( - "The field `", - stringify!($field), - "` of type `PhantomPinned` only has an effect, if it has the `#[pin]` attribute.", - )); - $crate::__pin_data!(find_pinned_fields: - @struct_attrs($($struct_attrs)*), - @vis($vis), - @name($name), - @impl_generics($($impl_generics)*), - @ty_generics($($ty_generics)*), - @decl_generics($($decl_generics)*), - @where($($whr)*), - @fields_munch($($rest)*), - @pinned($($pinned)* $($accum)* $field: ::core::marker::PhantomPinned,), - @not_pinned($($not_pinned)*), - @fields($($fields)* $($accum)* $field: ::core::marker::PhantomPinned,), - @accum(), - @is_pinned(), - @pinned_drop($($pinned_drop)?), - ); - }; - (find_pinned_fields: - @struct_attrs($($struct_attrs:tt)*), - @vis($vis:vis), - @name($name:ident), - @impl_generics($($impl_generics:tt)*), - @ty_generics($($ty_generics:tt)*), - @decl_generics($($decl_generics:tt)*), - @where($($whr:tt)*), - // We reached the field declaration. - @fields_munch($field:ident : $type:ty, $($rest:tt)*), - @pinned($($pinned:tt)*), - @not_pinned($($not_pinned:tt)*), - @fields($($fields:tt)*), - @accum($($accum:tt)*), - // This field is pinned. - @is_pinned(yes), - @pinned_drop($($pinned_drop:ident)?), - ) => { - $crate::__pin_data!(find_pinned_fields: - @struct_attrs($($struct_attrs)*), - @vis($vis), - @name($name), - @impl_generics($($impl_generics)*), - @ty_generics($($ty_generics)*), - @decl_generics($($decl_generics)*), - @where($($whr)*), - @fields_munch($($rest)*), - @pinned($($pinned)* $($accum)* $field: $type,), - @not_pinned($($not_pinned)*), - @fields($($fields)* $($accum)* $field: $type,), - @accum(), - @is_pinned(), - @pinned_drop($($pinned_drop)?), - ); - }; - (find_pinned_fields: - @struct_attrs($($struct_attrs:tt)*), - @vis($vis:vis), - @name($name:ident), - @impl_generics($($impl_generics:tt)*), - @ty_generics($($ty_generics:tt)*), - @decl_generics($($decl_generics:tt)*), - @where($($whr:tt)*), - // We reached the field declaration. - @fields_munch($field:ident : $type:ty, $($rest:tt)*), - @pinned($($pinned:tt)*), - @not_pinned($($not_pinned:tt)*), - @fields($($fields:tt)*), - @accum($($accum:tt)*), - // This field is not pinned. - @is_pinned(), - @pinned_drop($($pinned_drop:ident)?), - ) => { - $crate::__pin_data!(find_pinned_fields: - @struct_attrs($($struct_attrs)*), - @vis($vis), - @name($name), - @impl_generics($($impl_generics)*), - @ty_generics($($ty_generics)*), - @decl_generics($($decl_generics)*), - @where($($whr)*), - @fields_munch($($rest)*), - @pinned($($pinned)*), - @not_pinned($($not_pinned)* $($accum)* $field: $type,), - @fields($($fields)* $($accum)* $field: $type,), - @accum(), - @is_pinned(), - @pinned_drop($($pinned_drop)?), - ); - }; - (find_pinned_fields: - @struct_attrs($($struct_attrs:tt)*), - @vis($vis:vis), - @name($name:ident), - @impl_generics($($impl_generics:tt)*), - @ty_generics($($ty_generics:tt)*), - @decl_generics($($decl_generics:tt)*), - @where($($whr:tt)*), - // We found the `#[pin]` attr. - @fields_munch(#[pin] $($rest:tt)*), - @pinned($($pinned:tt)*), - @not_pinned($($not_pinned:tt)*), - @fields($($fields:tt)*), - @accum($($accum:tt)*), - @is_pinned($($is_pinned:ident)?), - @pinned_drop($($pinned_drop:ident)?), - ) => { - $crate::__pin_data!(find_pinned_fields: - @struct_attrs($($struct_attrs)*), - @vis($vis), - @name($name), - @impl_generics($($impl_generics)*), - @ty_generics($($ty_generics)*), - @decl_generics($($decl_generics)*), - @where($($whr)*), - @fields_munch($($rest)*), - // We do not include `#[pin]` in the list of attributes, since it is not actually an - // attribute that is defined somewhere. - @pinned($($pinned)*), - @not_pinned($($not_pinned)*), - @fields($($fields)*), - @accum($($accum)*), - // Set this to `yes`. - @is_pinned(yes), - @pinned_drop($($pinned_drop)?), - ); - }; - (find_pinned_fields: - @struct_attrs($($struct_attrs:tt)*), - @vis($vis:vis), - @name($name:ident), - @impl_generics($($impl_generics:tt)*), - @ty_generics($($ty_generics:tt)*), - @decl_generics($($decl_generics:tt)*), - @where($($whr:tt)*), - // We reached the field declaration with visibility, for simplicity we only munch the - // visibility and put it into `$accum`. - @fields_munch($fvis:vis $field:ident $($rest:tt)*), - @pinned($($pinned:tt)*), - @not_pinned($($not_pinned:tt)*), - @fields($($fields:tt)*), - @accum($($accum:tt)*), - @is_pinned($($is_pinned:ident)?), - @pinned_drop($($pinned_drop:ident)?), - ) => { - $crate::__pin_data!(find_pinned_fields: - @struct_attrs($($struct_attrs)*), - @vis($vis), - @name($name), - @impl_generics($($impl_generics)*), - @ty_generics($($ty_generics)*), - @decl_generics($($decl_generics)*), - @where($($whr)*), - @fields_munch($field $($rest)*), - @pinned($($pinned)*), - @not_pinned($($not_pinned)*), - @fields($($fields)*), - @accum($($accum)* $fvis), - @is_pinned($($is_pinned)?), - @pinned_drop($($pinned_drop)?), - ); - }; - (find_pinned_fields: - @struct_attrs($($struct_attrs:tt)*), - @vis($vis:vis), - @name($name:ident), - @impl_generics($($impl_generics:tt)*), - @ty_generics($($ty_generics:tt)*), - @decl_generics($($decl_generics:tt)*), - @where($($whr:tt)*), - // Some other attribute, just put it into `$accum`. - @fields_munch(#[$($attr:tt)*] $($rest:tt)*), - @pinned($($pinned:tt)*), - @not_pinned($($not_pinned:tt)*), - @fields($($fields:tt)*), - @accum($($accum:tt)*), - @is_pinned($($is_pinned:ident)?), - @pinned_drop($($pinned_drop:ident)?), - ) => { - $crate::__pin_data!(find_pinned_fields: - @struct_attrs($($struct_attrs)*), - @vis($vis), - @name($name), - @impl_generics($($impl_generics)*), - @ty_generics($($ty_generics)*), - @decl_generics($($decl_generics)*), - @where($($whr)*), - @fields_munch($($rest)*), - @pinned($($pinned)*), - @not_pinned($($not_pinned)*), - @fields($($fields)*), - @accum($($accum)* #[$($attr)*]), - @is_pinned($($is_pinned)?), - @pinned_drop($($pinned_drop)?), - ); - }; - (find_pinned_fields: - @struct_attrs($($struct_attrs:tt)*), - @vis($vis:vis), - @name($name:ident), - @impl_generics($($impl_generics:tt)*), - @ty_generics($($ty_generics:tt)*), - @decl_generics($($decl_generics:tt)*), - @where($($whr:tt)*), - // We reached the end of the fields, plus an optional additional comma, since we added one - // before and the user is also allowed to put a trailing comma. - @fields_munch($(,)?), - @pinned($($pinned:tt)*), - @not_pinned($($not_pinned:tt)*), - @fields($($fields:tt)*), - @accum(), - @is_pinned(), - @pinned_drop($($pinned_drop:ident)?), - ) => { - // Declare the struct with all fields in the correct order. - $($struct_attrs)* - $vis struct $name <$($decl_generics)*> - where $($whr)* - { - $($fields)* - } - - $crate::__pin_data!(make_pin_projections: - @vis($vis), - @name($name), - @impl_generics($($impl_generics)*), - @ty_generics($($ty_generics)*), - @decl_generics($($decl_generics)*), - @where($($whr)*), - @pinned($($pinned)*), - @not_pinned($($not_pinned)*), - ); - - // We put the rest into this const item, because it then will not be accessible to anything - // outside. - const _: () = { - // We declare this struct which will host all of the projection function for our type. - // it will be invariant over all generic parameters which are inherited from the - // struct. - $vis struct __ThePinData<$($impl_generics)*> - where $($whr)* - { - __phantom: ::core::marker::PhantomData< - fn($name<$($ty_generics)*>) -> $name<$($ty_generics)*> - >, - } - - impl<$($impl_generics)*> ::core::clone::Clone for __ThePinData<$($ty_generics)*> - where $($whr)* - { - fn clone(&self) -> Self { *self } - } - - impl<$($impl_generics)*> ::core::marker::Copy for __ThePinData<$($ty_generics)*> - where $($whr)* - {} - - // Make all projection functions. - $crate::__pin_data!(make_pin_data: - @pin_data(__ThePinData), - @impl_generics($($impl_generics)*), - @ty_generics($($ty_generics)*), - @where($($whr)*), - @pinned($($pinned)*), - @not_pinned($($not_pinned)*), - ); - - // SAFETY: We have added the correct projection functions above to `__ThePinData` and - // we also use the least restrictive generics possible. - unsafe impl<$($impl_generics)*> - $crate::__internal::HasPinData for $name<$($ty_generics)*> - where $($whr)* - { - type PinData = __ThePinData<$($ty_generics)*>; - - unsafe fn __pin_data() -> Self::PinData { - __ThePinData { __phantom: ::core::marker::PhantomData } - } - } - - // SAFETY: TODO. - unsafe impl<$($impl_generics)*> - $crate::__internal::PinData for __ThePinData<$($ty_generics)*> - where $($whr)* - { - type Datee = $name<$($ty_generics)*>; - } - - // This struct will be used for the unpin analysis. Since only structurally pinned - // fields are relevant whether the struct should implement `Unpin`. - #[allow(dead_code)] - struct __Unpin <'__pin, $($impl_generics)*> - where $($whr)* - { - __phantom_pin: ::core::marker::PhantomData<fn(&'__pin ()) -> &'__pin ()>, - __phantom: ::core::marker::PhantomData< - fn($name<$($ty_generics)*>) -> $name<$($ty_generics)*> - >, - // Only the pinned fields. - $($pinned)* - } - - #[doc(hidden)] - impl<'__pin, $($impl_generics)*> ::core::marker::Unpin for $name<$($ty_generics)*> - where - __Unpin<'__pin, $($ty_generics)*>: ::core::marker::Unpin, - $($whr)* - {} - - // We need to disallow normal `Drop` implementation, the exact behavior depends on - // whether `PinnedDrop` was specified as the parameter. - $crate::__pin_data!(drop_prevention: - @name($name), - @impl_generics($($impl_generics)*), - @ty_generics($($ty_generics)*), - @where($($whr)*), - @pinned_drop($($pinned_drop)?), - ); - }; - }; - // When no `PinnedDrop` was specified, then we have to prevent implementing drop. - (drop_prevention: - @name($name:ident), - @impl_generics($($impl_generics:tt)*), - @ty_generics($($ty_generics:tt)*), - @where($($whr:tt)*), - @pinned_drop(), - ) => { - // We prevent this by creating a trait that will be implemented for all types implementing - // `Drop`. Additionally we will implement this trait for the struct leading to a conflict, - // if it also implements `Drop` - trait MustNotImplDrop {} - #[expect(drop_bounds)] - impl<T: ::core::ops::Drop> MustNotImplDrop for T {} - impl<$($impl_generics)*> MustNotImplDrop for $name<$($ty_generics)*> - where $($whr)* {} - // We also take care to prevent users from writing a useless `PinnedDrop` implementation. - // They might implement `PinnedDrop` correctly for the struct, but forget to give - // `PinnedDrop` as the parameter to `#[pin_data]`. - #[expect(non_camel_case_types)] - trait UselessPinnedDropImpl_you_need_to_specify_PinnedDrop {} - impl<T: $crate::PinnedDrop> - UselessPinnedDropImpl_you_need_to_specify_PinnedDrop for T {} - impl<$($impl_generics)*> - UselessPinnedDropImpl_you_need_to_specify_PinnedDrop for $name<$($ty_generics)*> - where $($whr)* {} - }; - // When `PinnedDrop` was specified we just implement `Drop` and delegate. - (drop_prevention: - @name($name:ident), - @impl_generics($($impl_generics:tt)*), - @ty_generics($($ty_generics:tt)*), - @where($($whr:tt)*), - @pinned_drop(PinnedDrop), - ) => { - impl<$($impl_generics)*> ::core::ops::Drop for $name<$($ty_generics)*> - where $($whr)* - { - fn drop(&mut self) { - // SAFETY: Since this is a destructor, `self` will not move after this function - // terminates, since it is inaccessible. - let pinned = unsafe { ::core::pin::Pin::new_unchecked(self) }; - // SAFETY: Since this is a drop function, we can create this token to call the - // pinned destructor of this type. - let token = unsafe { $crate::__internal::OnlyCallFromDrop::new() }; - $crate::PinnedDrop::drop(pinned, token); - } - } - }; - // If some other parameter was specified, we emit a readable error. - (drop_prevention: - @name($name:ident), - @impl_generics($($impl_generics:tt)*), - @ty_generics($($ty_generics:tt)*), - @where($($whr:tt)*), - @pinned_drop($($rest:tt)*), - ) => { - compile_error!( - "Wrong parameters to `#[pin_data]`, expected nothing or `PinnedDrop`, got '{}'.", - stringify!($($rest)*), - ); - }; - (make_pin_projections: - @vis($vis:vis), - @name($name:ident), - @impl_generics($($impl_generics:tt)*), - @ty_generics($($ty_generics:tt)*), - @decl_generics($($decl_generics:tt)*), - @where($($whr:tt)*), - @pinned($($(#[$($p_attr:tt)*])* $pvis:vis $p_field:ident : $p_type:ty),* $(,)?), - @not_pinned($($(#[$($attr:tt)*])* $fvis:vis $field:ident : $type:ty),* $(,)?), - ) => { - $crate::macros::paste! { - #[doc(hidden)] - $vis struct [< $name Projection >] <'__pin, $($decl_generics)*> { - $($(#[$($p_attr)*])* $pvis $p_field : ::core::pin::Pin<&'__pin mut $p_type>,)* - $($(#[$($attr)*])* $fvis $field : &'__pin mut $type,)* - ___pin_phantom_data: ::core::marker::PhantomData<&'__pin mut ()>, - } - - impl<$($impl_generics)*> $name<$($ty_generics)*> - where $($whr)* - { - /// Pin-projects all fields of `Self`. - /// - /// These fields are structurally pinned: - $(#[doc = ::core::concat!(" - `", ::core::stringify!($p_field), "`")])* - /// - /// These fields are **not** structurally pinned: - $(#[doc = ::core::concat!(" - `", ::core::stringify!($field), "`")])* - #[inline] - $vis fn project<'__pin>( - self: ::core::pin::Pin<&'__pin mut Self>, - ) -> [< $name Projection >] <'__pin, $($ty_generics)*> { - // SAFETY: we only give access to `&mut` for fields not structurally pinned. - let this = unsafe { ::core::pin::Pin::get_unchecked_mut(self) }; - [< $name Projection >] { - $( - // SAFETY: `$p_field` is structurally pinned. - $(#[$($p_attr)*])* - $p_field : unsafe { ::core::pin::Pin::new_unchecked(&mut this.$p_field) }, - )* - $( - $(#[$($attr)*])* - $field : &mut this.$field, - )* - ___pin_phantom_data: ::core::marker::PhantomData, - } - } - } - } - }; - (make_pin_data: - @pin_data($pin_data:ident), - @impl_generics($($impl_generics:tt)*), - @ty_generics($($ty_generics:tt)*), - @where($($whr:tt)*), - @pinned($($(#[$($p_attr:tt)*])* $pvis:vis $p_field:ident : $p_type:ty),* $(,)?), - @not_pinned($($(#[$($attr:tt)*])* $fvis:vis $field:ident : $type:ty),* $(,)?), - ) => { - $crate::macros::paste! { - // For every field, we create a projection function according to its projection type. If a - // field is structurally pinned, then it must be initialized via `PinInit`, if it is not - // structurally pinned, then it can be initialized via `Init`. - // - // The functions are `unsafe` to prevent accidentally calling them. - #[allow(dead_code)] - #[expect(clippy::missing_safety_doc)] - impl<$($impl_generics)*> $pin_data<$($ty_generics)*> - where $($whr)* - { - $( - $(#[$($p_attr)*])* - $pvis unsafe fn $p_field<E>( - self, - slot: *mut $p_type, - init: impl $crate::PinInit<$p_type, E>, - ) -> ::core::result::Result<(), E> { - // SAFETY: TODO. - unsafe { $crate::PinInit::__pinned_init(init, slot) } - } - - $(#[$($p_attr)*])* - $pvis unsafe fn [<__project_ $p_field>]<'__slot>( - self, - slot: &'__slot mut $p_type, - ) -> ::core::pin::Pin<&'__slot mut $p_type> { - ::core::pin::Pin::new_unchecked(slot) - } - )* - $( - $(#[$($attr)*])* - $fvis unsafe fn $field<E>( - self, - slot: *mut $type, - init: impl $crate::Init<$type, E>, - ) -> ::core::result::Result<(), E> { - // SAFETY: TODO. - unsafe { $crate::Init::__init(init, slot) } - } - - $(#[$($attr)*])* - $fvis unsafe fn [<__project_ $field>]<'__slot>( - self, - slot: &'__slot mut $type, - ) -> &'__slot mut $type { - slot - } - )* - } - } - }; -} - -/// The internal init macro. Do not call manually! -/// -/// This is called by the `{try_}{pin_}init!` macros with various inputs. -/// -/// This macro has multiple internal call configurations, these are always the very first ident: -/// - nothing: this is the base case and called by the `{try_}{pin_}init!` macros. -/// - `with_update_parsed`: when the `..Zeroable::init_zeroed()` syntax has been handled. -/// - `init_slot`: recursively creates the code that initializes all fields in `slot`. -/// - `make_initializer`: recursively create the struct initializer that guarantees that every -/// field has been initialized exactly once. -#[doc(hidden)] -#[macro_export] -macro_rules! __init_internal { - ( - @this($($this:ident)?), - @typ($t:path), - @fields($($fields:tt)*), - @error($err:ty), - // Either `PinData` or `InitData`, `$use_data` should only be present in the `PinData` - // case. - @data($data:ident, $($use_data:ident)?), - // `HasPinData` or `HasInitData`. - @has_data($has_data:ident, $get_data:ident), - // `pin_init_from_closure` or `init_from_closure`. - @construct_closure($construct_closure:ident), - @munch_fields(), - ) => { - $crate::__init_internal!(with_update_parsed: - @this($($this)?), - @typ($t), - @fields($($fields)*), - @error($err), - @data($data, $($use_data)?), - @has_data($has_data, $get_data), - @construct_closure($construct_closure), - @init_zeroed(), // Nothing means default behavior. - ) - }; - ( - @this($($this:ident)?), - @typ($t:path), - @fields($($fields:tt)*), - @error($err:ty), - // Either `PinData` or `InitData`, `$use_data` should only be present in the `PinData` - // case. - @data($data:ident, $($use_data:ident)?), - // `HasPinData` or `HasInitData`. - @has_data($has_data:ident, $get_data:ident), - // `pin_init_from_closure` or `init_from_closure`. - @construct_closure($construct_closure:ident), - @munch_fields(..Zeroable::init_zeroed()), - ) => { - $crate::__init_internal!(with_update_parsed: - @this($($this)?), - @typ($t), - @fields($($fields)*), - @error($err), - @data($data, $($use_data)?), - @has_data($has_data, $get_data), - @construct_closure($construct_closure), - @init_zeroed(()), // `()` means zero all fields not mentioned. - ) - }; - ( - @this($($this:ident)?), - @typ($t:path), - @fields($($fields:tt)*), - @error($err:ty), - // Either `PinData` or `InitData`, `$use_data` should only be present in the `PinData` - // case. - @data($data:ident, $($use_data:ident)?), - // `HasPinData` or `HasInitData`. - @has_data($has_data:ident, $get_data:ident), - // `pin_init_from_closure` or `init_from_closure`. - @construct_closure($construct_closure:ident), - @munch_fields($ignore:tt $($rest:tt)*), - ) => { - $crate::__init_internal!( - @this($($this)?), - @typ($t), - @fields($($fields)*), - @error($err), - @data($data, $($use_data)?), - @has_data($has_data, $get_data), - @construct_closure($construct_closure), - @munch_fields($($rest)*), - ) - }; - (with_update_parsed: - @this($($this:ident)?), - @typ($t:path), - @fields($($fields:tt)*), - @error($err:ty), - // Either `PinData` or `InitData`, `$use_data` should only be present in the `PinData` - // case. - @data($data:ident, $($use_data:ident)?), - // `HasPinData` or `HasInitData`. - @has_data($has_data:ident, $get_data:ident), - // `pin_init_from_closure` or `init_from_closure`. - @construct_closure($construct_closure:ident), - @init_zeroed($($init_zeroed:expr)?), - ) => {{ - // We do not want to allow arbitrary returns, so we declare this type as the `Ok` return - // type and shadow it later when we insert the arbitrary user code. That way there will be - // no possibility of returning without `unsafe`. - struct __InitOk; - // Get the data about fields from the supplied type. - // - // SAFETY: TODO. - let data = unsafe { - use $crate::__internal::$has_data; - // Here we abuse `paste!` to retokenize `$t`. Declarative macros have some internal - // information that is associated to already parsed fragments, so a path fragment - // cannot be used in this position. Doing the retokenization results in valid rust - // code. - $crate::macros::paste!($t::$get_data()) - }; - // Ensure that `data` really is of type `$data` and help with type inference: - let init = $crate::__internal::$data::make_closure::<_, __InitOk, $err>( - data, - move |slot| { - { - // Shadow the structure so it cannot be used to return early. - struct __InitOk; - // If `$init_zeroed` is present we should zero the slot now and not emit an - // error when fields are missing (since they will be zeroed). We also have to - // check that the type actually implements `Zeroable`. - $({ - fn assert_zeroable<T: $crate::Zeroable>(_: *mut T) {} - // Ensure that the struct is indeed `Zeroable`. - assert_zeroable(slot); - // SAFETY: The type implements `Zeroable` by the check above. - unsafe { ::core::ptr::write_bytes(slot, 0, 1) }; - $init_zeroed // This will be `()` if set. - })? - // Create the `this` so it can be referenced by the user inside of the - // expressions creating the individual fields. - $(let $this = unsafe { ::core::ptr::NonNull::new_unchecked(slot) };)? - // Initialize every field. - $crate::__init_internal!(init_slot($($use_data)?): - @data(data), - @slot(slot), - @guards(), - @munch_fields($($fields)*,), - ); - // We use unreachable code to ensure that all fields have been mentioned exactly - // once, this struct initializer will still be type-checked and complain with a - // very natural error message if a field is forgotten/mentioned more than once. - #[allow(unreachable_code, clippy::diverging_sub_expression)] - let _ = || { - $crate::__init_internal!(make_initializer: - @slot(slot), - @type_name($t), - @munch_fields($($fields)*,), - @acc(), - ); - }; - } - Ok(__InitOk) - } - ); - let init = move |slot| -> ::core::result::Result<(), $err> { - init(slot).map(|__InitOk| ()) - }; - // SAFETY: TODO. - let init = unsafe { $crate::$construct_closure::<_, $err>(init) }; - init - }}; - (init_slot($($use_data:ident)?): - @data($data:ident), - @slot($slot:ident), - @guards($($guards:ident,)*), - @munch_fields($(..Zeroable::init_zeroed())? $(,)?), - ) => { - // Endpoint of munching, no fields are left. If execution reaches this point, all fields - // have been initialized. Therefore we can now dismiss the guards by forgetting them. - $(::core::mem::forget($guards);)* - }; - (init_slot($($use_data:ident)?): - @data($data:ident), - @slot($slot:ident), - @guards($($guards:ident,)*), - // arbitrary code block - @munch_fields(_: { $($code:tt)* }, $($rest:tt)*), - ) => { - { $($code)* } - $crate::__init_internal!(init_slot($($use_data)?): - @data($data), - @slot($slot), - @guards($($guards,)*), - @munch_fields($($rest)*), - ); - }; - (init_slot($use_data:ident): // `use_data` is present, so we use the `data` to init fields. - @data($data:ident), - @slot($slot:ident), - @guards($($guards:ident,)*), - // In-place initialization syntax. - @munch_fields($field:ident <- $val:expr, $($rest:tt)*), - ) => { - let init = $val; - // Call the initializer. - // - // SAFETY: `slot` is valid, because we are inside of an initializer closure, we - // return when an error/panic occurs. - // We also use the `data` to require the correct trait (`Init` or `PinInit`) for `$field`. - unsafe { $data.$field(::core::ptr::addr_of_mut!((*$slot).$field), init)? }; - // SAFETY: - // - the project function does the correct field projection, - // - the field has been initialized, - // - the reference is only valid until the end of the initializer. - #[allow(unused_variables)] - let $field = $crate::macros::paste!(unsafe { $data.[< __project_ $field >](&mut (*$slot).$field) }); - - // Create the drop guard: - // - // We rely on macro hygiene to make it impossible for users to access this local variable. - // We use `paste!` to create new hygiene for `$field`. - $crate::macros::paste! { - // SAFETY: We forget the guard later when initialization has succeeded. - let [< __ $field _guard >] = unsafe { - $crate::__internal::DropGuard::new(::core::ptr::addr_of_mut!((*$slot).$field)) - }; - - $crate::__init_internal!(init_slot($use_data): - @data($data), - @slot($slot), - @guards([< __ $field _guard >], $($guards,)*), - @munch_fields($($rest)*), - ); - } - }; - (init_slot(): // No `use_data`, so we use `Init::__init` directly. - @data($data:ident), - @slot($slot:ident), - @guards($($guards:ident,)*), - // In-place initialization syntax. - @munch_fields($field:ident <- $val:expr, $($rest:tt)*), - ) => { - let init = $val; - // Call the initializer. - // - // SAFETY: `slot` is valid, because we are inside of an initializer closure, we - // return when an error/panic occurs. - unsafe { $crate::Init::__init(init, ::core::ptr::addr_of_mut!((*$slot).$field))? }; - - // SAFETY: - // - the field is not structurally pinned, since the line above must compile, - // - the field has been initialized, - // - the reference is only valid until the end of the initializer. - #[allow(unused_variables)] - let $field = unsafe { &mut (*$slot).$field }; - - // Create the drop guard: - // - // We rely on macro hygiene to make it impossible for users to access this local variable. - // We use `paste!` to create new hygiene for `$field`. - $crate::macros::paste! { - // SAFETY: We forget the guard later when initialization has succeeded. - let [< __ $field _guard >] = unsafe { - $crate::__internal::DropGuard::new(::core::ptr::addr_of_mut!((*$slot).$field)) - }; - - $crate::__init_internal!(init_slot(): - @data($data), - @slot($slot), - @guards([< __ $field _guard >], $($guards,)*), - @munch_fields($($rest)*), - ); - } - }; - (init_slot(): // No `use_data`, so all fields are not structurally pinned - @data($data:ident), - @slot($slot:ident), - @guards($($guards:ident,)*), - // Init by-value. - @munch_fields($field:ident $(: $val:expr)?, $($rest:tt)*), - ) => { - { - $(let $field = $val;)? - // Initialize the field. - // - // SAFETY: The memory at `slot` is uninitialized. - unsafe { ::core::ptr::write(::core::ptr::addr_of_mut!((*$slot).$field), $field) }; - } - - #[allow(unused_variables)] - // SAFETY: - // - the field is not structurally pinned, since no `use_data` was required to create this - // initializer, - // - the field has been initialized, - // - the reference is only valid until the end of the initializer. - let $field = unsafe { &mut (*$slot).$field }; - - // Create the drop guard: - // - // We rely on macro hygiene to make it impossible for users to access this local variable. - // We use `paste!` to create new hygiene for `$field`. - $crate::macros::paste! { - // SAFETY: We forget the guard later when initialization has succeeded. - let [< __ $field _guard >] = unsafe { - $crate::__internal::DropGuard::new(::core::ptr::addr_of_mut!((*$slot).$field)) - }; - - $crate::__init_internal!(init_slot(): - @data($data), - @slot($slot), - @guards([< __ $field _guard >], $($guards,)*), - @munch_fields($($rest)*), - ); - } - }; - (init_slot($use_data:ident): - @data($data:ident), - @slot($slot:ident), - @guards($($guards:ident,)*), - // Init by-value. - @munch_fields($field:ident $(: $val:expr)?, $($rest:tt)*), - ) => { - { - $(let $field = $val;)? - // Initialize the field. - // - // SAFETY: The memory at `slot` is uninitialized. - unsafe { ::core::ptr::write(::core::ptr::addr_of_mut!((*$slot).$field), $field) }; - } - // SAFETY: - // - the project function does the correct field projection, - // - the field has been initialized, - // - the reference is only valid until the end of the initializer. - #[allow(unused_variables)] - let $field = $crate::macros::paste!(unsafe { $data.[< __project_ $field >](&mut (*$slot).$field) }); - - // Create the drop guard: - // - // We rely on macro hygiene to make it impossible for users to access this local variable. - // We use `paste!` to create new hygiene for `$field`. - $crate::macros::paste! { - // SAFETY: We forget the guard later when initialization has succeeded. - let [< __ $field _guard >] = unsafe { - $crate::__internal::DropGuard::new(::core::ptr::addr_of_mut!((*$slot).$field)) - }; - - $crate::__init_internal!(init_slot($use_data): - @data($data), - @slot($slot), - @guards([< __ $field _guard >], $($guards,)*), - @munch_fields($($rest)*), - ); - } - }; - (make_initializer: - @slot($slot:ident), - @type_name($t:path), - @munch_fields(_: { $($code:tt)* }, $($rest:tt)*), - @acc($($acc:tt)*), - ) => { - // code blocks are ignored for the initializer check - $crate::__init_internal!(make_initializer: - @slot($slot), - @type_name($t), - @munch_fields($($rest)*), - @acc($($acc)*), - ); - }; - (make_initializer: - @slot($slot:ident), - @type_name($t:path), - @munch_fields(..Zeroable::init_zeroed() $(,)?), - @acc($($acc:tt)*), - ) => { - // Endpoint, nothing more to munch, create the initializer. Since the users specified - // `..Zeroable::init_zeroed()`, the slot will already have been zeroed and all field that have - // not been overwritten are thus zero and initialized. We still check that all fields are - // actually accessible by using the struct update syntax ourselves. - // We are inside of a closure that is never executed and thus we can abuse `slot` to - // get the correct type inference here: - #[allow(unused_assignments)] - unsafe { - let mut zeroed = ::core::mem::zeroed(); - // We have to use type inference here to make zeroed have the correct type. This does - // not get executed, so it has no effect. - ::core::ptr::write($slot, zeroed); - zeroed = ::core::mem::zeroed(); - // Here we abuse `paste!` to retokenize `$t`. Declarative macros have some internal - // information that is associated to already parsed fragments, so a path fragment - // cannot be used in this position. Doing the retokenization results in valid rust - // code. - $crate::macros::paste!( - ::core::ptr::write($slot, $t { - $($acc)* - ..zeroed - }); - ); - } - }; - (make_initializer: - @slot($slot:ident), - @type_name($t:path), - @munch_fields($(,)?), - @acc($($acc:tt)*), - ) => { - // Endpoint, nothing more to munch, create the initializer. - // Since we are in the closure that is never called, this will never get executed. - // We abuse `slot` to get the correct type inference here: - // - // SAFETY: TODO. - unsafe { - // Here we abuse `paste!` to retokenize `$t`. Declarative macros have some internal - // information that is associated to already parsed fragments, so a path fragment - // cannot be used in this position. Doing the retokenization results in valid rust - // code. - $crate::macros::paste!( - ::core::ptr::write($slot, $t { - $($acc)* - }); - ); - } - }; - (make_initializer: - @slot($slot:ident), - @type_name($t:path), - @munch_fields($field:ident <- $val:expr, $($rest:tt)*), - @acc($($acc:tt)*), - ) => { - $crate::__init_internal!(make_initializer: - @slot($slot), - @type_name($t), - @munch_fields($($rest)*), - @acc($($acc)* $field: ::core::panic!(),), - ); - }; - (make_initializer: - @slot($slot:ident), - @type_name($t:path), - @munch_fields($field:ident $(: $val:expr)?, $($rest:tt)*), - @acc($($acc:tt)*), - ) => { - $crate::__init_internal!(make_initializer: - @slot($slot), - @type_name($t), - @munch_fields($($rest)*), - @acc($($acc)* $field: ::core::panic!(),), - ); - }; -} - -#[doc(hidden)] -#[macro_export] -macro_rules! __derive_zeroable { - (parse_input: - @sig( - $(#[$($struct_attr:tt)*])* - $vis:vis struct $name:ident - $(where $($whr:tt)*)? - ), - @impl_generics($($impl_generics:tt)*), - @ty_generics($($ty_generics:tt)*), - @body({ - $( - $(#[$($field_attr:tt)*])* - $field_vis:vis $field:ident : $field_ty:ty - ),* $(,)? - }), - ) => { - // SAFETY: Every field type implements `Zeroable` and padding bytes may be zero. - #[automatically_derived] - unsafe impl<$($impl_generics)*> $crate::Zeroable for $name<$($ty_generics)*> - where - $($($whr)*)? - {} - const _: () = { - fn assert_zeroable<T: ?::core::marker::Sized + $crate::Zeroable>() {} - fn ensure_zeroable<$($impl_generics)*>() - where $($($whr)*)? - { - $(assert_zeroable::<$field_ty>();)* - } - }; - }; - (parse_input: - @sig( - $(#[$($struct_attr:tt)*])* - $vis:vis union $name:ident - $(where $($whr:tt)*)? - ), - @impl_generics($($impl_generics:tt)*), - @ty_generics($($ty_generics:tt)*), - @body({ - $( - $(#[$($field_attr:tt)*])* - $field_vis:vis $field:ident : $field_ty:ty - ),* $(,)? - }), - ) => { - // SAFETY: Every field type implements `Zeroable` and padding bytes may be zero. - #[automatically_derived] - unsafe impl<$($impl_generics)*> $crate::Zeroable for $name<$($ty_generics)*> - where - $($($whr)*)? - {} - const _: () = { - fn assert_zeroable<T: ?::core::marker::Sized + $crate::Zeroable>() {} - fn ensure_zeroable<$($impl_generics)*>() - where $($($whr)*)? - { - $(assert_zeroable::<$field_ty>();)* - } - }; - }; -} - -#[doc(hidden)] -#[macro_export] -macro_rules! __maybe_derive_zeroable { - (parse_input: - @sig( - $(#[$($struct_attr:tt)*])* - $vis:vis struct $name:ident - $(where $($whr:tt)*)? - ), - @impl_generics($($impl_generics:tt)*), - @ty_generics($($ty_generics:tt)*), - @body({ - $( - $(#[$($field_attr:tt)*])* - $field_vis:vis $field:ident : $field_ty:ty - ),* $(,)? - }), - ) => { - // SAFETY: Every field type implements `Zeroable` and padding bytes may be zero. - #[automatically_derived] - unsafe impl<$($impl_generics)*> $crate::Zeroable for $name<$($ty_generics)*> - where - $( - // the `for<'__dummy>` HRTB makes this not error without the `trivial_bounds` - // feature <https://github.com/rust-lang/rust/issues/48214#issuecomment-2557829956>. - $field_ty: for<'__dummy> $crate::Zeroable, - )* - $($($whr)*)? - {} - }; - (parse_input: - @sig( - $(#[$($struct_attr:tt)*])* - $vis:vis union $name:ident - $(where $($whr:tt)*)? - ), - @impl_generics($($impl_generics:tt)*), - @ty_generics($($ty_generics:tt)*), - @body({ - $( - $(#[$($field_attr:tt)*])* - $field_vis:vis $field:ident : $field_ty:ty - ),* $(,)? - }), - ) => { - // SAFETY: Every field type implements `Zeroable` and padding bytes may be zero. - #[automatically_derived] - unsafe impl<$($impl_generics)*> $crate::Zeroable for $name<$($ty_generics)*> - where - $( - // the `for<'__dummy>` HRTB makes this not error without the `trivial_bounds` - // feature <https://github.com/rust-lang/rust/issues/48214#issuecomment-2557829956>. - $field_ty: for<'__dummy> $crate::Zeroable, - )* - $($($whr)*)? - {} - }; -} diff --git a/rust/proc-macro2/lib.rs b/rust/proc-macro2/lib.rs index 7b78d065d51c..5d408943fa0d 100644 --- a/rust/proc-macro2/lib.rs +++ b/rust/proc-macro2/lib.rs @@ -1,5 +1,9 @@ // SPDX-License-Identifier: Apache-2.0 OR MIT +// When fixdep scans this, it will find this string `CONFIG_RUSTC_VERSION_TEXT` +// and thus add a dependency on `include/config/RUSTC_VERSION_TEXT`, which is +// touched by Kconfig when the version string from the compiler changes. + //! [![github]](https://github.com/dtolnay/proc-macro2) [![crates-io]](https://crates.io/crates/proc-macro2) [![docs-rs]](crate) //! //! [github]: https://img.shields.io/badge/github-8da0cb?style=for-the-badge&labelColor=555555&logo=github |
