diff options
Diffstat (limited to 'rust')
91 files changed, 3216 insertions, 3517 deletions
diff --git a/rust/Makefile b/rust/Makefile index 4dcc2eff51cb..757974551359 100644 --- a/rust/Makefile +++ b/rust/Makefile @@ -117,6 +117,23 @@ syn-flags := \ --extern quote \ $(call cfgs-to-flags,$(syn-cfgs)) +pin_init_internal-cfgs := \ + kernel + +pin_init_internal-flags := \ + --extern proc_macro2 \ + --extern quote \ + --extern syn \ + $(call cfgs-to-flags,$(pin_init_internal-cfgs)) + +pin_init-cfgs := \ + kernel + +pin_init-flags := \ + --extern pin_init_internal \ + --extern macros \ + $(call cfgs-to-flags,$(pin_init-cfgs)) + # `rustdoc` did not save the target modifiers, thus workaround for # the time being (https://github.com/rust-lang/rust/issues/144521). rustdoc_modifiers_workaround := $(if $(call rustc-min-version,108800),-Cunsafe-allow-abi-mismatch=fixed-x18) @@ -211,15 +228,15 @@ rustdoc-ffi: $(src)/ffi.rs rustdoc-core FORCE +$(call if_changed,rustdoc) rustdoc-pin_init_internal: private rustdoc_host = yes -rustdoc-pin_init_internal: private rustc_target_flags = --cfg kernel \ +rustdoc-pin_init_internal: private rustc_target_flags = $(pin_init_internal-flags) \ --extern proc_macro --crate-type proc-macro rustdoc-pin_init_internal: $(src)/pin-init/internal/src/lib.rs \ - rustdoc-clean FORCE + rustdoc-clean rustdoc-proc_macro2 rustdoc-quote rustdoc-syn FORCE +$(call if_changed,rustdoc) rustdoc-pin_init: private rustdoc_host = yes -rustdoc-pin_init: private rustc_target_flags = --extern pin_init_internal \ - --extern macros --extern alloc --cfg kernel --cfg feature=\"alloc\" +rustdoc-pin_init: private rustc_target_flags = $(pin_init-flags) \ + --extern alloc --cfg feature=\"alloc\" rustdoc-pin_init: $(src)/pin-init/src/lib.rs rustdoc-pin_init_internal \ rustdoc-macros FORCE +$(call if_changed,rustdoc) @@ -272,14 +289,14 @@ rusttestlib-macros: $(src)/macros/lib.rs \ rusttestlib-proc_macro2 rusttestlib-quote rusttestlib-syn FORCE +$(call if_changed,rustc_test_library) -rusttestlib-pin_init_internal: private rustc_target_flags = --cfg kernel \ +rusttestlib-pin_init_internal: private rustc_target_flags = $(pin_init_internal-flags) \ --extern proc_macro rusttestlib-pin_init_internal: private rustc_test_library_proc = yes -rusttestlib-pin_init_internal: $(src)/pin-init/internal/src/lib.rs FORCE +rusttestlib-pin_init_internal: $(src)/pin-init/internal/src/lib.rs \ + rusttestlib-proc_macro2 rusttestlib-quote rusttestlib-syn FORCE +$(call if_changed,rustc_test_library) -rusttestlib-pin_init: private rustc_target_flags = --extern pin_init_internal \ - --extern macros --cfg kernel +rusttestlib-pin_init: private rustc_target_flags = $(pin_init-flags) rusttestlib-pin_init: $(src)/pin-init/src/lib.rs rusttestlib-macros \ rusttestlib-pin_init_internal $(obj)/$(libpin_init_internal_name) FORCE +$(call if_changed,rustc_test_library) @@ -548,8 +565,9 @@ $(obj)/$(libmacros_name): $(src)/macros/lib.rs $(obj)/libproc_macro2.rlib \ $(obj)/libquote.rlib $(obj)/libsyn.rlib FORCE +$(call if_changed_dep,rustc_procmacro) -$(obj)/$(libpin_init_internal_name): private rustc_target_flags = --cfg kernel -$(obj)/$(libpin_init_internal_name): $(src)/pin-init/internal/src/lib.rs FORCE +$(obj)/$(libpin_init_internal_name): private rustc_target_flags = $(pin_init_internal-flags) +$(obj)/$(libpin_init_internal_name): $(src)/pin-init/internal/src/lib.rs \ + $(obj)/libproc_macro2.rlib $(obj)/libquote.rlib $(obj)/libsyn.rlib FORCE +$(call if_changed_dep,rustc_procmacro) quiet_cmd_rustc_library = $(if $(skip_clippy),RUSTC,$(RUSTC_OR_CLIPPY_QUIET)) L $@ @@ -643,8 +661,7 @@ $(obj)/compiler_builtins.o: $(src)/compiler_builtins.rs $(obj)/core.o FORCE +$(call if_changed_rule,rustc_library) $(obj)/pin_init.o: private skip_gendwarfksyms = 1 -$(obj)/pin_init.o: private rustc_target_flags = --extern pin_init_internal \ - --extern macros --cfg kernel +$(obj)/pin_init.o: private rustc_target_flags = $(pin_init-flags) $(obj)/pin_init.o: $(src)/pin-init/src/lib.rs $(obj)/compiler_builtins.o \ $(obj)/$(libpin_init_internal_name) $(obj)/$(libmacros_name) FORCE +$(call if_changed_rule,rustc_library) diff --git a/rust/helpers/atomic.c b/rust/helpers/atomic.c index cf06b7ef9a1c..4b24eceef5fc 100644 --- a/rust/helpers/atomic.c +++ b/rust/helpers/atomic.c @@ -11,11 +11,6 @@ #include <linux/atomic.h> -// TODO: Remove this after INLINE_HELPERS support is added. -#ifndef __rust_helper -#define __rust_helper -#endif - __rust_helper int rust_helper_atomic_read(const atomic_t *v) { @@ -1037,4 +1032,4 @@ rust_helper_atomic64_dec_if_positive(atomic64_t *v) } #endif /* _RUST_ATOMIC_API_H */ -// 615a0e0c98b5973a47fe4fa65e92935051ca00ed +// e4edb6174dd42a265284958f00a7cea7ddb464b1 diff --git a/rust/helpers/atomic_ext.c b/rust/helpers/atomic_ext.c new file mode 100644 index 000000000000..7d0c2bd340da --- /dev/null +++ b/rust/helpers/atomic_ext.c @@ -0,0 +1,139 @@ +// SPDX-License-Identifier: GPL-2.0 + +#include <asm/barrier.h> +#include <asm/rwonce.h> +#include <linux/atomic.h> + +__rust_helper s8 rust_helper_atomic_i8_read(s8 *ptr) +{ + return READ_ONCE(*ptr); +} + +__rust_helper s8 rust_helper_atomic_i8_read_acquire(s8 *ptr) +{ + return smp_load_acquire(ptr); +} + +__rust_helper s16 rust_helper_atomic_i16_read(s16 *ptr) +{ + return READ_ONCE(*ptr); +} + +__rust_helper s16 rust_helper_atomic_i16_read_acquire(s16 *ptr) +{ + return smp_load_acquire(ptr); +} + +__rust_helper void rust_helper_atomic_i8_set(s8 *ptr, s8 val) +{ + WRITE_ONCE(*ptr, val); +} + +__rust_helper void rust_helper_atomic_i8_set_release(s8 *ptr, s8 val) +{ + smp_store_release(ptr, val); +} + +__rust_helper void rust_helper_atomic_i16_set(s16 *ptr, s16 val) +{ + WRITE_ONCE(*ptr, val); +} + +__rust_helper void rust_helper_atomic_i16_set_release(s16 *ptr, s16 val) +{ + smp_store_release(ptr, val); +} + +/* + * xchg helpers depend on ARCH_SUPPORTS_ATOMIC_RMW and on the + * architecture provding xchg() support for i8 and i16. + * + * The architectures that currently support Rust (x86_64, armv7, + * arm64, riscv, and loongarch) satisfy these requirements. + */ +__rust_helper s8 rust_helper_atomic_i8_xchg(s8 *ptr, s8 new) +{ + return xchg(ptr, new); +} + +__rust_helper s16 rust_helper_atomic_i16_xchg(s16 *ptr, s16 new) +{ + return xchg(ptr, new); +} + +__rust_helper s8 rust_helper_atomic_i8_xchg_acquire(s8 *ptr, s8 new) +{ + return xchg_acquire(ptr, new); +} + +__rust_helper s16 rust_helper_atomic_i16_xchg_acquire(s16 *ptr, s16 new) +{ + return xchg_acquire(ptr, new); +} + +__rust_helper s8 rust_helper_atomic_i8_xchg_release(s8 *ptr, s8 new) +{ + return xchg_release(ptr, new); +} + +__rust_helper s16 rust_helper_atomic_i16_xchg_release(s16 *ptr, s16 new) +{ + return xchg_release(ptr, new); +} + +__rust_helper s8 rust_helper_atomic_i8_xchg_relaxed(s8 *ptr, s8 new) +{ + return xchg_relaxed(ptr, new); +} + +__rust_helper s16 rust_helper_atomic_i16_xchg_relaxed(s16 *ptr, s16 new) +{ + return xchg_relaxed(ptr, new); +} + +/* + * try_cmpxchg helpers depend on ARCH_SUPPORTS_ATOMIC_RMW and on the + * architecture provding try_cmpxchg() support for i8 and i16. + * + * The architectures that currently support Rust (x86_64, armv7, + * arm64, riscv, and loongarch) satisfy these requirements. + */ +__rust_helper bool rust_helper_atomic_i8_try_cmpxchg(s8 *ptr, s8 *old, s8 new) +{ + return try_cmpxchg(ptr, old, new); +} + +__rust_helper bool rust_helper_atomic_i16_try_cmpxchg(s16 *ptr, s16 *old, s16 new) +{ + return try_cmpxchg(ptr, old, new); +} + +__rust_helper bool rust_helper_atomic_i8_try_cmpxchg_acquire(s8 *ptr, s8 *old, s8 new) +{ + return try_cmpxchg_acquire(ptr, old, new); +} + +__rust_helper bool rust_helper_atomic_i16_try_cmpxchg_acquire(s16 *ptr, s16 *old, s16 new) +{ + return try_cmpxchg_acquire(ptr, old, new); +} + +__rust_helper bool rust_helper_atomic_i8_try_cmpxchg_release(s8 *ptr, s8 *old, s8 new) +{ + return try_cmpxchg_release(ptr, old, new); +} + +__rust_helper bool rust_helper_atomic_i16_try_cmpxchg_release(s16 *ptr, s16 *old, s16 new) +{ + return try_cmpxchg_release(ptr, old, new); +} + +__rust_helper bool rust_helper_atomic_i8_try_cmpxchg_relaxed(s8 *ptr, s8 *old, s8 new) +{ + return try_cmpxchg_relaxed(ptr, old, new); +} + +__rust_helper bool rust_helper_atomic_i16_try_cmpxchg_relaxed(s16 *ptr, s16 *old, s16 new) +{ + return try_cmpxchg_relaxed(ptr, old, new); +} diff --git a/rust/helpers/barrier.c b/rust/helpers/barrier.c index cdf28ce8e511..fed8853745c8 100644 --- a/rust/helpers/barrier.c +++ b/rust/helpers/barrier.c @@ -2,17 +2,17 @@ #include <asm/barrier.h> -void rust_helper_smp_mb(void) +__rust_helper void rust_helper_smp_mb(void) { smp_mb(); } -void rust_helper_smp_wmb(void) +__rust_helper void rust_helper_smp_wmb(void) { smp_wmb(); } -void rust_helper_smp_rmb(void) +__rust_helper void rust_helper_smp_rmb(void) { smp_rmb(); } diff --git a/rust/helpers/bitmap.c b/rust/helpers/bitmap.c index a50e2f082e47..e4e9f4361270 100644 --- a/rust/helpers/bitmap.c +++ b/rust/helpers/bitmap.c @@ -2,6 +2,7 @@ #include <linux/bitmap.h> +__rust_helper void rust_helper_bitmap_copy_and_extend(unsigned long *to, const unsigned long *from, unsigned int count, unsigned int size) { diff --git a/rust/helpers/bitops.c b/rust/helpers/bitops.c index e79ef9e6d98f..271b8a712dee 100644 --- a/rust/helpers/bitops.c +++ b/rust/helpers/bitops.c @@ -3,21 +3,25 @@ #include <linux/bitops.h> #include <linux/find.h> +__rust_helper void rust_helper___set_bit(unsigned long nr, unsigned long *addr) { __set_bit(nr, addr); } +__rust_helper void rust_helper___clear_bit(unsigned long nr, unsigned long *addr) { __clear_bit(nr, addr); } +__rust_helper void rust_helper_set_bit(unsigned long nr, volatile unsigned long *addr) { set_bit(nr, addr); } +__rust_helper void rust_helper_clear_bit(unsigned long nr, volatile unsigned long *addr) { clear_bit(nr, addr); diff --git a/rust/helpers/blk.c b/rust/helpers/blk.c index cc9f4e6a2d23..20c512e46a7a 100644 --- a/rust/helpers/blk.c +++ b/rust/helpers/blk.c @@ -3,12 +3,12 @@ #include <linux/blk-mq.h> #include <linux/blkdev.h> -void *rust_helper_blk_mq_rq_to_pdu(struct request *rq) +__rust_helper void *rust_helper_blk_mq_rq_to_pdu(struct request *rq) { return blk_mq_rq_to_pdu(rq); } -struct request *rust_helper_blk_mq_rq_from_pdu(void *pdu) +__rust_helper struct request *rust_helper_blk_mq_rq_from_pdu(void *pdu) { return blk_mq_rq_from_pdu(pdu); } diff --git a/rust/helpers/bug.c b/rust/helpers/bug.c index a62c96f507d1..b51e60772578 100644 --- a/rust/helpers/bug.c +++ b/rust/helpers/bug.c @@ -2,12 +2,12 @@ #include <linux/bug.h> -__noreturn void rust_helper_BUG(void) +__rust_helper __noreturn void rust_helper_BUG(void) { BUG(); } -bool rust_helper_WARN_ON(bool cond) +__rust_helper bool rust_helper_WARN_ON(bool cond) { return WARN_ON(cond); } diff --git a/rust/helpers/build_bug.c b/rust/helpers/build_bug.c index 44e579488037..14dbc55bb539 100644 --- a/rust/helpers/build_bug.c +++ b/rust/helpers/build_bug.c @@ -2,7 +2,7 @@ #include <linux/errname.h> -const char *rust_helper_errname(int err) +__rust_helper const char *rust_helper_errname(int err) { return errname(err); } diff --git a/rust/helpers/completion.c b/rust/helpers/completion.c index b2443262a2ae..0126767cc3be 100644 --- a/rust/helpers/completion.c +++ b/rust/helpers/completion.c @@ -2,7 +2,7 @@ #include <linux/completion.h> -void rust_helper_init_completion(struct completion *x) +__rust_helper void rust_helper_init_completion(struct completion *x) { init_completion(x); } diff --git a/rust/helpers/cpu.c b/rust/helpers/cpu.c index 824e0adb19d4..5759349b2c88 100644 --- a/rust/helpers/cpu.c +++ b/rust/helpers/cpu.c @@ -2,7 +2,7 @@ #include <linux/smp.h> -unsigned int rust_helper_raw_smp_processor_id(void) +__rust_helper unsigned int rust_helper_raw_smp_processor_id(void) { return raw_smp_processor_id(); } diff --git a/rust/helpers/cpufreq.c b/rust/helpers/cpufreq.c index 7c1343c4d65e..0e16aeef2b5a 100644 --- a/rust/helpers/cpufreq.c +++ b/rust/helpers/cpufreq.c @@ -3,7 +3,8 @@ #include <linux/cpufreq.h> #ifdef CONFIG_CPU_FREQ -void rust_helper_cpufreq_register_em_with_opp(struct cpufreq_policy *policy) +__rust_helper void +rust_helper_cpufreq_register_em_with_opp(struct cpufreq_policy *policy) { cpufreq_register_em_with_opp(policy); } diff --git a/rust/helpers/cpumask.c b/rust/helpers/cpumask.c index eb10598a0242..5deced5b975e 100644 --- a/rust/helpers/cpumask.c +++ b/rust/helpers/cpumask.c @@ -2,67 +2,80 @@ #include <linux/cpumask.h> +__rust_helper void rust_helper_cpumask_set_cpu(unsigned int cpu, struct cpumask *dstp) { cpumask_set_cpu(cpu, dstp); } +__rust_helper void rust_helper___cpumask_set_cpu(unsigned int cpu, struct cpumask *dstp) { __cpumask_set_cpu(cpu, dstp); } +__rust_helper void rust_helper_cpumask_clear_cpu(int cpu, struct cpumask *dstp) { cpumask_clear_cpu(cpu, dstp); } +__rust_helper void rust_helper___cpumask_clear_cpu(int cpu, struct cpumask *dstp) { __cpumask_clear_cpu(cpu, dstp); } +__rust_helper bool rust_helper_cpumask_test_cpu(int cpu, struct cpumask *srcp) { return cpumask_test_cpu(cpu, srcp); } +__rust_helper void rust_helper_cpumask_setall(struct cpumask *dstp) { cpumask_setall(dstp); } +__rust_helper bool rust_helper_cpumask_empty(struct cpumask *srcp) { return cpumask_empty(srcp); } +__rust_helper bool rust_helper_cpumask_full(struct cpumask *srcp) { return cpumask_full(srcp); } +__rust_helper unsigned int rust_helper_cpumask_weight(struct cpumask *srcp) { return cpumask_weight(srcp); } +__rust_helper void rust_helper_cpumask_copy(struct cpumask *dstp, const struct cpumask *srcp) { cpumask_copy(dstp, srcp); } +__rust_helper bool rust_helper_alloc_cpumask_var(cpumask_var_t *mask, gfp_t flags) { return alloc_cpumask_var(mask, flags); } +__rust_helper bool rust_helper_zalloc_cpumask_var(cpumask_var_t *mask, gfp_t flags) { return zalloc_cpumask_var(mask, flags); } #ifndef CONFIG_CPUMASK_OFFSTACK +__rust_helper void rust_helper_free_cpumask_var(cpumask_var_t mask) { free_cpumask_var(mask); diff --git a/rust/helpers/cred.c b/rust/helpers/cred.c index fde7ae20cdd1..a56a7b753623 100644 --- a/rust/helpers/cred.c +++ b/rust/helpers/cred.c @@ -2,12 +2,12 @@ #include <linux/cred.h> -const struct cred *rust_helper_get_cred(const struct cred *cred) +__rust_helper const struct cred *rust_helper_get_cred(const struct cred *cred) { return get_cred(cred); } -void rust_helper_put_cred(const struct cred *cred) +__rust_helper void rust_helper_put_cred(const struct cred *cred) { put_cred(cred); } diff --git a/rust/helpers/err.c b/rust/helpers/err.c index 544c7cb86632..2872158e3793 100644 --- a/rust/helpers/err.c +++ b/rust/helpers/err.c @@ -2,17 +2,17 @@ #include <linux/err.h> -__force void *rust_helper_ERR_PTR(long err) +__rust_helper __force void *rust_helper_ERR_PTR(long err) { return ERR_PTR(err); } -bool rust_helper_IS_ERR(__force const void *ptr) +__rust_helper bool rust_helper_IS_ERR(__force const void *ptr) { return IS_ERR(ptr); } -long rust_helper_PTR_ERR(__force const void *ptr) +__rust_helper long rust_helper_PTR_ERR(__force const void *ptr) { return PTR_ERR(ptr); } diff --git a/rust/helpers/fs.c b/rust/helpers/fs.c index a75c96763372..789d60fb8908 100644 --- a/rust/helpers/fs.c +++ b/rust/helpers/fs.c @@ -6,7 +6,7 @@ #include <linux/fs.h> -struct file *rust_helper_get_file(struct file *f) +__rust_helper struct file *rust_helper_get_file(struct file *f) { return get_file(f); } diff --git a/rust/helpers/helpers.c b/rust/helpers/helpers.c index 79c72762ad9c..a3c42e51f00a 100644 --- a/rust/helpers/helpers.c +++ b/rust/helpers/helpers.c @@ -7,7 +7,10 @@ * Sorted alphabetically. */ +#define __rust_helper + #include "atomic.c" +#include "atomic_ext.c" #include "auxiliary.c" #include "barrier.c" #include "binder.c" diff --git a/rust/helpers/kunit.c b/rust/helpers/kunit.c index b85a4d394c11..cafb94b6776c 100644 --- a/rust/helpers/kunit.c +++ b/rust/helpers/kunit.c @@ -2,7 +2,7 @@ #include <kunit/test-bug.h> -struct kunit *rust_helper_kunit_get_current_test(void) +__rust_helper struct kunit *rust_helper_kunit_get_current_test(void) { return kunit_get_current_test(); } diff --git a/rust/helpers/maple_tree.c b/rust/helpers/maple_tree.c index 1dd9ac84a13f..5586486a76e0 100644 --- a/rust/helpers/maple_tree.c +++ b/rust/helpers/maple_tree.c @@ -2,7 +2,8 @@ #include <linux/maple_tree.h> -void rust_helper_mt_init_flags(struct maple_tree *mt, unsigned int flags) +__rust_helper void rust_helper_mt_init_flags(struct maple_tree *mt, + unsigned int flags) { mt_init_flags(mt, flags); } diff --git a/rust/helpers/mm.c b/rust/helpers/mm.c index 81b510c96fd2..b5540997bd20 100644 --- a/rust/helpers/mm.c +++ b/rust/helpers/mm.c @@ -3,48 +3,48 @@ #include <linux/mm.h> #include <linux/sched/mm.h> -void rust_helper_mmgrab(struct mm_struct *mm) +__rust_helper void rust_helper_mmgrab(struct mm_struct *mm) { mmgrab(mm); } -void rust_helper_mmdrop(struct mm_struct *mm) +__rust_helper void rust_helper_mmdrop(struct mm_struct *mm) { mmdrop(mm); } -void rust_helper_mmget(struct mm_struct *mm) +__rust_helper void rust_helper_mmget(struct mm_struct *mm) { mmget(mm); } -bool rust_helper_mmget_not_zero(struct mm_struct *mm) +__rust_helper bool rust_helper_mmget_not_zero(struct mm_struct *mm) { return mmget_not_zero(mm); } -void rust_helper_mmap_read_lock(struct mm_struct *mm) +__rust_helper void rust_helper_mmap_read_lock(struct mm_struct *mm) { mmap_read_lock(mm); } -bool rust_helper_mmap_read_trylock(struct mm_struct *mm) +__rust_helper bool rust_helper_mmap_read_trylock(struct mm_struct *mm) { return mmap_read_trylock(mm); } -void rust_helper_mmap_read_unlock(struct mm_struct *mm) +__rust_helper void rust_helper_mmap_read_unlock(struct mm_struct *mm) { mmap_read_unlock(mm); } -struct vm_area_struct *rust_helper_vma_lookup(struct mm_struct *mm, - unsigned long addr) +__rust_helper struct vm_area_struct * +rust_helper_vma_lookup(struct mm_struct *mm, unsigned long addr) { return vma_lookup(mm, addr); } -void rust_helper_vma_end_read(struct vm_area_struct *vma) +__rust_helper void rust_helper_vma_end_read(struct vm_area_struct *vma) { vma_end_read(vma); } diff --git a/rust/helpers/mutex.c b/rust/helpers/mutex.c index e487819125f0..1b07d6e64299 100644 --- a/rust/helpers/mutex.c +++ b/rust/helpers/mutex.c @@ -2,28 +2,29 @@ #include <linux/mutex.h> -void rust_helper_mutex_lock(struct mutex *lock) +__rust_helper void rust_helper_mutex_lock(struct mutex *lock) { mutex_lock(lock); } -int rust_helper_mutex_trylock(struct mutex *lock) +__rust_helper int rust_helper_mutex_trylock(struct mutex *lock) { return mutex_trylock(lock); } -void rust_helper___mutex_init(struct mutex *mutex, const char *name, - struct lock_class_key *key) +__rust_helper void rust_helper___mutex_init(struct mutex *mutex, + const char *name, + struct lock_class_key *key) { __mutex_init(mutex, name, key); } -void rust_helper_mutex_assert_is_held(struct mutex *mutex) +__rust_helper void rust_helper_mutex_assert_is_held(struct mutex *mutex) { lockdep_assert_held(mutex); } -void rust_helper_mutex_destroy(struct mutex *lock) +__rust_helper void rust_helper_mutex_destroy(struct mutex *lock) { mutex_destroy(lock); } diff --git a/rust/helpers/of.c b/rust/helpers/of.c index 86b51167c913..8f62ca69e8ba 100644 --- a/rust/helpers/of.c +++ b/rust/helpers/of.c @@ -2,7 +2,7 @@ #include <linux/of.h> -bool rust_helper_is_of_node(const struct fwnode_handle *fwnode) +__rust_helper bool rust_helper_is_of_node(const struct fwnode_handle *fwnode) { return is_of_node(fwnode); } diff --git a/rust/helpers/page.c b/rust/helpers/page.c index 7144de5a61db..f8463fbed2a2 100644 --- a/rust/helpers/page.c +++ b/rust/helpers/page.c @@ -4,23 +4,24 @@ #include <linux/highmem.h> #include <linux/mm.h> -struct page *rust_helper_alloc_pages(gfp_t gfp_mask, unsigned int order) +__rust_helper struct page *rust_helper_alloc_pages(gfp_t gfp_mask, + unsigned int order) { return alloc_pages(gfp_mask, order); } -void *rust_helper_kmap_local_page(struct page *page) +__rust_helper void *rust_helper_kmap_local_page(struct page *page) { return kmap_local_page(page); } -void rust_helper_kunmap_local(const void *addr) +__rust_helper void rust_helper_kunmap_local(const void *addr) { kunmap_local(addr); } #ifndef NODE_NOT_IN_PAGE_FLAGS -int rust_helper_page_to_nid(const struct page *page) +__rust_helper int rust_helper_page_to_nid(const struct page *page) { return page_to_nid(page); } diff --git a/rust/helpers/pid_namespace.c b/rust/helpers/pid_namespace.c index f41482bdec9a..f46ab779b527 100644 --- a/rust/helpers/pid_namespace.c +++ b/rust/helpers/pid_namespace.c @@ -3,18 +3,20 @@ #include <linux/pid_namespace.h> #include <linux/cleanup.h> -struct pid_namespace *rust_helper_get_pid_ns(struct pid_namespace *ns) +__rust_helper struct pid_namespace * +rust_helper_get_pid_ns(struct pid_namespace *ns) { return get_pid_ns(ns); } -void rust_helper_put_pid_ns(struct pid_namespace *ns) +__rust_helper void rust_helper_put_pid_ns(struct pid_namespace *ns) { put_pid_ns(ns); } /* Get a reference on a task's pid namespace. */ -struct pid_namespace *rust_helper_task_get_pid_ns(struct task_struct *task) +__rust_helper struct pid_namespace * +rust_helper_task_get_pid_ns(struct task_struct *task) { struct pid_namespace *pid_ns; diff --git a/rust/helpers/poll.c b/rust/helpers/poll.c index 7e5b1751c2d5..78b3839b50f0 100644 --- a/rust/helpers/poll.c +++ b/rust/helpers/poll.c @@ -3,8 +3,9 @@ #include <linux/export.h> #include <linux/poll.h> -void rust_helper_poll_wait(struct file *filp, wait_queue_head_t *wait_address, - poll_table *p) +__rust_helper void rust_helper_poll_wait(struct file *filp, + wait_queue_head_t *wait_address, + poll_table *p) { poll_wait(filp, wait_address, p); } diff --git a/rust/helpers/processor.c b/rust/helpers/processor.c index d41355e14d6e..76fadbb647c5 100644 --- a/rust/helpers/processor.c +++ b/rust/helpers/processor.c @@ -2,7 +2,7 @@ #include <linux/processor.h> -void rust_helper_cpu_relax(void) +__rust_helper void rust_helper_cpu_relax(void) { cpu_relax(); } diff --git a/rust/helpers/rbtree.c b/rust/helpers/rbtree.c index 2a0eabbb4160..a85defb22ff7 100644 --- a/rust/helpers/rbtree.c +++ b/rust/helpers/rbtree.c @@ -2,18 +2,19 @@ #include <linux/rbtree.h> -void rust_helper_rb_link_node(struct rb_node *node, struct rb_node *parent, - struct rb_node **rb_link) +__rust_helper void rust_helper_rb_link_node(struct rb_node *node, + struct rb_node *parent, + struct rb_node **rb_link) { rb_link_node(node, parent, rb_link); } -struct rb_node *rust_helper_rb_first(const struct rb_root *root) +__rust_helper struct rb_node *rust_helper_rb_first(const struct rb_root *root) { return rb_first(root); } -struct rb_node *rust_helper_rb_last(const struct rb_root *root) +__rust_helper struct rb_node *rust_helper_rb_last(const struct rb_root *root) { return rb_last(root); } diff --git a/rust/helpers/rcu.c b/rust/helpers/rcu.c index f1cec6583513..481274c05857 100644 --- a/rust/helpers/rcu.c +++ b/rust/helpers/rcu.c @@ -2,12 +2,12 @@ #include <linux/rcupdate.h> -void rust_helper_rcu_read_lock(void) +__rust_helper void rust_helper_rcu_read_lock(void) { rcu_read_lock(); } -void rust_helper_rcu_read_unlock(void) +__rust_helper void rust_helper_rcu_read_unlock(void) { rcu_read_unlock(); } diff --git a/rust/helpers/refcount.c b/rust/helpers/refcount.c index d175898ad7b8..36334a674ee4 100644 --- a/rust/helpers/refcount.c +++ b/rust/helpers/refcount.c @@ -2,27 +2,27 @@ #include <linux/refcount.h> -refcount_t rust_helper_REFCOUNT_INIT(int n) +__rust_helper refcount_t rust_helper_REFCOUNT_INIT(int n) { return (refcount_t)REFCOUNT_INIT(n); } -void rust_helper_refcount_set(refcount_t *r, int n) +__rust_helper void rust_helper_refcount_set(refcount_t *r, int n) { refcount_set(r, n); } -void rust_helper_refcount_inc(refcount_t *r) +__rust_helper void rust_helper_refcount_inc(refcount_t *r) { refcount_inc(r); } -void rust_helper_refcount_dec(refcount_t *r) +__rust_helper void rust_helper_refcount_dec(refcount_t *r) { refcount_dec(r); } -bool rust_helper_refcount_dec_and_test(refcount_t *r) +__rust_helper bool rust_helper_refcount_dec_and_test(refcount_t *r) { return refcount_dec_and_test(r); } diff --git a/rust/helpers/security.c b/rust/helpers/security.c index ca22da09548d..8d0a25fcf931 100644 --- a/rust/helpers/security.c +++ b/rust/helpers/security.c @@ -3,41 +3,45 @@ #include <linux/security.h> #ifndef CONFIG_SECURITY -void rust_helper_security_cred_getsecid(const struct cred *c, u32 *secid) +__rust_helper void rust_helper_security_cred_getsecid(const struct cred *c, + u32 *secid) { security_cred_getsecid(c, secid); } -int rust_helper_security_secid_to_secctx(u32 secid, struct lsm_context *cp) +__rust_helper int rust_helper_security_secid_to_secctx(u32 secid, + struct lsm_context *cp) { return security_secid_to_secctx(secid, cp); } -void rust_helper_security_release_secctx(struct lsm_context *cp) +__rust_helper void rust_helper_security_release_secctx(struct lsm_context *cp) { security_release_secctx(cp); } -int rust_helper_security_binder_set_context_mgr(const struct cred *mgr) +__rust_helper int +rust_helper_security_binder_set_context_mgr(const struct cred *mgr) { return security_binder_set_context_mgr(mgr); } -int rust_helper_security_binder_transaction(const struct cred *from, - const struct cred *to) +__rust_helper int +rust_helper_security_binder_transaction(const struct cred *from, + const struct cred *to) { return security_binder_transaction(from, to); } -int rust_helper_security_binder_transfer_binder(const struct cred *from, - const struct cred *to) +__rust_helper int +rust_helper_security_binder_transfer_binder(const struct cred *from, + const struct cred *to) { return security_binder_transfer_binder(from, to); } -int rust_helper_security_binder_transfer_file(const struct cred *from, - const struct cred *to, - const struct file *file) +__rust_helper int rust_helper_security_binder_transfer_file( + const struct cred *from, const struct cred *to, const struct file *file) { return security_binder_transfer_file(from, to, file); } diff --git a/rust/helpers/signal.c b/rust/helpers/signal.c index 1a6bbe9438e2..85111186cf3d 100644 --- a/rust/helpers/signal.c +++ b/rust/helpers/signal.c @@ -2,7 +2,7 @@ #include <linux/sched/signal.h> -int rust_helper_signal_pending(struct task_struct *t) +__rust_helper int rust_helper_signal_pending(struct task_struct *t) { return signal_pending(t); } diff --git a/rust/helpers/slab.c b/rust/helpers/slab.c index 7fac958907b0..9279f082467d 100644 --- a/rust/helpers/slab.c +++ b/rust/helpers/slab.c @@ -2,14 +2,14 @@ #include <linux/slab.h> -void * __must_check __realloc_size(2) +__rust_helper void *__must_check __realloc_size(2) rust_helper_krealloc_node_align(const void *objp, size_t new_size, unsigned long align, gfp_t flags, int node) { return krealloc_node_align(objp, new_size, align, flags, node); } -void * __must_check __realloc_size(2) +__rust_helper void *__must_check __realloc_size(2) rust_helper_kvrealloc_node_align(const void *p, size_t size, unsigned long align, gfp_t flags, int node) { diff --git a/rust/helpers/spinlock.c b/rust/helpers/spinlock.c index 42c4bf01a23e..4d13062cf253 100644 --- a/rust/helpers/spinlock.c +++ b/rust/helpers/spinlock.c @@ -2,8 +2,9 @@ #include <linux/spinlock.h> -void rust_helper___spin_lock_init(spinlock_t *lock, const char *name, - struct lock_class_key *key) +__rust_helper void rust_helper___spin_lock_init(spinlock_t *lock, + const char *name, + struct lock_class_key *key) { #ifdef CONFIG_DEBUG_SPINLOCK # if defined(CONFIG_PREEMPT_RT) @@ -16,22 +17,22 @@ void rust_helper___spin_lock_init(spinlock_t *lock, const char *name, #endif /* CONFIG_DEBUG_SPINLOCK */ } -void rust_helper_spin_lock(spinlock_t *lock) +__rust_helper void rust_helper_spin_lock(spinlock_t *lock) { spin_lock(lock); } -void rust_helper_spin_unlock(spinlock_t *lock) +__rust_helper void rust_helper_spin_unlock(spinlock_t *lock) { spin_unlock(lock); } -int rust_helper_spin_trylock(spinlock_t *lock) +__rust_helper int rust_helper_spin_trylock(spinlock_t *lock) { return spin_trylock(lock); } -void rust_helper_spin_assert_is_held(spinlock_t *lock) +__rust_helper void rust_helper_spin_assert_is_held(spinlock_t *lock) { lockdep_assert_held(lock); } diff --git a/rust/helpers/sync.c b/rust/helpers/sync.c index ff7e68b48810..82d6aff73b04 100644 --- a/rust/helpers/sync.c +++ b/rust/helpers/sync.c @@ -2,12 +2,12 @@ #include <linux/lockdep.h> -void rust_helper_lockdep_register_key(struct lock_class_key *k) +__rust_helper void rust_helper_lockdep_register_key(struct lock_class_key *k) { lockdep_register_key(k); } -void rust_helper_lockdep_unregister_key(struct lock_class_key *k) +__rust_helper void rust_helper_lockdep_unregister_key(struct lock_class_key *k) { lockdep_unregister_key(k); } diff --git a/rust/helpers/task.c b/rust/helpers/task.c index 2c85bbc2727e..c0e1a06ede78 100644 --- a/rust/helpers/task.c +++ b/rust/helpers/task.c @@ -3,60 +3,60 @@ #include <linux/kernel.h> #include <linux/sched/task.h> -void rust_helper_might_resched(void) +__rust_helper void rust_helper_might_resched(void) { might_resched(); } -struct task_struct *rust_helper_get_current(void) +__rust_helper struct task_struct *rust_helper_get_current(void) { return current; } -void rust_helper_get_task_struct(struct task_struct *t) +__rust_helper void rust_helper_get_task_struct(struct task_struct *t) { get_task_struct(t); } -void rust_helper_put_task_struct(struct task_struct *t) +__rust_helper void rust_helper_put_task_struct(struct task_struct *t) { put_task_struct(t); } -kuid_t rust_helper_task_uid(struct task_struct *task) +__rust_helper kuid_t rust_helper_task_uid(struct task_struct *task) { return task_uid(task); } -kuid_t rust_helper_task_euid(struct task_struct *task) +__rust_helper kuid_t rust_helper_task_euid(struct task_struct *task) { return task_euid(task); } #ifndef CONFIG_USER_NS -uid_t rust_helper_from_kuid(struct user_namespace *to, kuid_t uid) +__rust_helper uid_t rust_helper_from_kuid(struct user_namespace *to, kuid_t uid) { return from_kuid(to, uid); } #endif /* CONFIG_USER_NS */ -bool rust_helper_uid_eq(kuid_t left, kuid_t right) +__rust_helper bool rust_helper_uid_eq(kuid_t left, kuid_t right) { return uid_eq(left, right); } -kuid_t rust_helper_current_euid(void) +__rust_helper kuid_t rust_helper_current_euid(void) { return current_euid(); } -struct user_namespace *rust_helper_current_user_ns(void) +__rust_helper struct user_namespace *rust_helper_current_user_ns(void) { return current_user_ns(); } -pid_t rust_helper_task_tgid_nr_ns(struct task_struct *tsk, - struct pid_namespace *ns) +__rust_helper pid_t rust_helper_task_tgid_nr_ns(struct task_struct *tsk, + struct pid_namespace *ns) { return task_tgid_nr_ns(tsk, ns); } diff --git a/rust/helpers/time.c b/rust/helpers/time.c index 67a36ccc3ec4..32f495970493 100644 --- a/rust/helpers/time.c +++ b/rust/helpers/time.c @@ -4,37 +4,37 @@ #include <linux/ktime.h> #include <linux/timekeeping.h> -void rust_helper_fsleep(unsigned long usecs) +__rust_helper void rust_helper_fsleep(unsigned long usecs) { fsleep(usecs); } -ktime_t rust_helper_ktime_get_real(void) +__rust_helper ktime_t rust_helper_ktime_get_real(void) { return ktime_get_real(); } -ktime_t rust_helper_ktime_get_boottime(void) +__rust_helper ktime_t rust_helper_ktime_get_boottime(void) { return ktime_get_boottime(); } -ktime_t rust_helper_ktime_get_clocktai(void) +__rust_helper ktime_t rust_helper_ktime_get_clocktai(void) { return ktime_get_clocktai(); } -s64 rust_helper_ktime_to_us(const ktime_t kt) +__rust_helper s64 rust_helper_ktime_to_us(const ktime_t kt) { return ktime_to_us(kt); } -s64 rust_helper_ktime_to_ms(const ktime_t kt) +__rust_helper s64 rust_helper_ktime_to_ms(const ktime_t kt) { return ktime_to_ms(kt); } -void rust_helper_udelay(unsigned long usec) +__rust_helper void rust_helper_udelay(unsigned long usec) { udelay(usec); } diff --git a/rust/helpers/uaccess.c b/rust/helpers/uaccess.c index 4629b2d15529..d9625b9ee046 100644 --- a/rust/helpers/uaccess.c +++ b/rust/helpers/uaccess.c @@ -2,24 +2,26 @@ #include <linux/uaccess.h> -unsigned long rust_helper_copy_from_user(void *to, const void __user *from, - unsigned long n) +__rust_helper unsigned long +rust_helper_copy_from_user(void *to, const void __user *from, unsigned long n) { return copy_from_user(to, from, n); } -unsigned long rust_helper_copy_to_user(void __user *to, const void *from, - unsigned long n) +__rust_helper unsigned long +rust_helper_copy_to_user(void __user *to, const void *from, unsigned long n) { return copy_to_user(to, from, n); } #ifdef INLINE_COPY_FROM_USER +__rust_helper unsigned long rust_helper__copy_from_user(void *to, const void __user *from, unsigned long n) { return _inline_copy_from_user(to, from, n); } +__rust_helper unsigned long rust_helper__copy_to_user(void __user *to, const void *from, unsigned long n) { return _inline_copy_to_user(to, from, n); diff --git a/rust/helpers/vmalloc.c b/rust/helpers/vmalloc.c index 7d7f7336b3d2..326b030487a2 100644 --- a/rust/helpers/vmalloc.c +++ b/rust/helpers/vmalloc.c @@ -2,7 +2,7 @@ #include <linux/vmalloc.h> -void * __must_check __realloc_size(2) +__rust_helper void *__must_check __realloc_size(2) rust_helper_vrealloc_node_align(const void *p, size_t size, unsigned long align, gfp_t flags, int node) { diff --git a/rust/helpers/wait.c b/rust/helpers/wait.c index ae48e33d9da3..2dde1e451780 100644 --- a/rust/helpers/wait.c +++ b/rust/helpers/wait.c @@ -2,7 +2,7 @@ #include <linux/wait.h> -void rust_helper_init_wait(struct wait_queue_entry *wq_entry) +__rust_helper void rust_helper_init_wait(struct wait_queue_entry *wq_entry) { init_wait(wq_entry); } diff --git a/rust/helpers/workqueue.c b/rust/helpers/workqueue.c index b2b82753509b..ce1c3a5b2150 100644 --- a/rust/helpers/workqueue.c +++ b/rust/helpers/workqueue.c @@ -2,9 +2,11 @@ #include <linux/workqueue.h> -void rust_helper_init_work_with_key(struct work_struct *work, work_func_t func, - bool onstack, const char *name, - struct lock_class_key *key) +__rust_helper void rust_helper_init_work_with_key(struct work_struct *work, + work_func_t func, + bool onstack, + const char *name, + struct lock_class_key *key) { __init_work(work, onstack); work->data = (atomic_long_t)WORK_DATA_INIT(); diff --git a/rust/helpers/xarray.c b/rust/helpers/xarray.c index 60b299f11451..08979b304341 100644 --- a/rust/helpers/xarray.c +++ b/rust/helpers/xarray.c @@ -2,27 +2,27 @@ #include <linux/xarray.h> -int rust_helper_xa_err(void *entry) +__rust_helper int rust_helper_xa_err(void *entry) { return xa_err(entry); } -void rust_helper_xa_init_flags(struct xarray *xa, gfp_t flags) +__rust_helper void rust_helper_xa_init_flags(struct xarray *xa, gfp_t flags) { return xa_init_flags(xa, flags); } -int rust_helper_xa_trylock(struct xarray *xa) +__rust_helper int rust_helper_xa_trylock(struct xarray *xa) { return xa_trylock(xa); } -void rust_helper_xa_lock(struct xarray *xa) +__rust_helper void rust_helper_xa_lock(struct xarray *xa) { return xa_lock(xa); } -void rust_helper_xa_unlock(struct xarray *xa) +__rust_helper void rust_helper_xa_unlock(struct xarray *xa) { return xa_unlock(xa); } diff --git a/rust/kernel/block/mq/gen_disk.rs b/rust/kernel/block/mq/gen_disk.rs index 1ce815c8cdab..c8b0ecb17082 100644 --- a/rust/kernel/block/mq/gen_disk.rs +++ b/rust/kernel/block/mq/gen_disk.rs @@ -107,8 +107,7 @@ impl GenDiskBuilder { drop(unsafe { T::QueueData::from_foreign(data) }); }); - // SAFETY: `bindings::queue_limits` contain only fields that are valid when zeroed. - let mut lim: bindings::queue_limits = unsafe { core::mem::zeroed() }; + let mut lim: bindings::queue_limits = pin_init::zeroed(); lim.logical_block_size = self.logical_block_size; lim.physical_block_size = self.physical_block_size; diff --git a/rust/kernel/block/mq/tag_set.rs b/rust/kernel/block/mq/tag_set.rs index c3cf56d52bee..dae9df408a86 100644 --- a/rust/kernel/block/mq/tag_set.rs +++ b/rust/kernel/block/mq/tag_set.rs @@ -38,9 +38,7 @@ impl<T: Operations> TagSet<T> { num_tags: u32, num_maps: u32, ) -> impl PinInit<Self, error::Error> { - // SAFETY: `blk_mq_tag_set` only contains integers and pointers, which - // all are allowed to be 0. - let tag_set: bindings::blk_mq_tag_set = unsafe { core::mem::zeroed() }; + let tag_set: bindings::blk_mq_tag_set = pin_init::zeroed(); let tag_set: Result<_> = core::mem::size_of::<RequestDataWrapper>() .try_into() .map(|cmd_size| { diff --git a/rust/kernel/bug.rs b/rust/kernel/bug.rs index 36aef43e5ebe..ed943960f851 100644 --- a/rust/kernel/bug.rs +++ b/rust/kernel/bug.rs @@ -11,9 +11,9 @@ #[cfg(all(CONFIG_BUG, not(CONFIG_UML), not(CONFIG_LOONGARCH), not(CONFIG_ARM)))] #[cfg(CONFIG_DEBUG_BUGVERBOSE)] macro_rules! warn_flags { - ($flags:expr) => { + ($file:expr, $flags:expr) => { const FLAGS: u32 = $crate::bindings::BUGFLAG_WARNING | $flags; - const _FILE: &[u8] = file!().as_bytes(); + const _FILE: &[u8] = $file.as_bytes(); // Plus one for null-terminator. static FILE: [u8; _FILE.len() + 1] = { let mut bytes = [0; _FILE.len() + 1]; @@ -50,7 +50,7 @@ macro_rules! warn_flags { #[cfg(all(CONFIG_BUG, not(CONFIG_UML), not(CONFIG_LOONGARCH), not(CONFIG_ARM)))] #[cfg(not(CONFIG_DEBUG_BUGVERBOSE))] macro_rules! warn_flags { - ($flags:expr) => { + ($file:expr, $flags:expr) => { const FLAGS: u32 = $crate::bindings::BUGFLAG_WARNING | $flags; // SAFETY: @@ -75,7 +75,7 @@ macro_rules! warn_flags { #[doc(hidden)] #[cfg(all(CONFIG_BUG, CONFIG_UML))] macro_rules! warn_flags { - ($flags:expr) => { + ($file:expr, $flags:expr) => { // SAFETY: It is always safe to call `warn_slowpath_fmt()` // with a valid null-terminated string. unsafe { @@ -93,7 +93,7 @@ macro_rules! warn_flags { #[doc(hidden)] #[cfg(all(CONFIG_BUG, any(CONFIG_LOONGARCH, CONFIG_ARM)))] macro_rules! warn_flags { - ($flags:expr) => { + ($file:expr, $flags:expr) => { // SAFETY: It is always safe to call `WARN_ON()`. unsafe { $crate::bindings::WARN_ON(true) } }; @@ -103,7 +103,7 @@ macro_rules! warn_flags { #[doc(hidden)] #[cfg(not(CONFIG_BUG))] macro_rules! warn_flags { - ($flags:expr) => {}; + ($file:expr, $flags:expr) => {}; } #[doc(hidden)] @@ -116,10 +116,16 @@ pub const fn bugflag_taint(value: u32) -> u32 { macro_rules! warn_on { ($cond:expr) => {{ let cond = $cond; + + #[cfg(CONFIG_DEBUG_BUGVERBOSE_DETAILED)] + const _COND_STR: &str = concat!("[", stringify!($cond), "] ", file!()); + #[cfg(not(CONFIG_DEBUG_BUGVERBOSE_DETAILED))] + const _COND_STR: &str = file!(); + if cond { const WARN_ON_FLAGS: u32 = $crate::bug::bugflag_taint($crate::bindings::TAINT_WARN); - $crate::warn_flags!(WARN_ON_FLAGS); + $crate::warn_flags!(_COND_STR, WARN_ON_FLAGS); } cond }}; diff --git a/rust/kernel/build_assert.rs b/rust/kernel/build_assert.rs index 6331b15d7c4d..f8124dbc663f 100644 --- a/rust/kernel/build_assert.rs +++ b/rust/kernel/build_assert.rs @@ -61,8 +61,13 @@ macro_rules! build_error { /// build_assert!(N > 1); // Build-time check /// assert!(N > 1); // Run-time check /// } +/// ``` /// -/// #[inline] +/// When a condition depends on a function argument, the function must be annotated with +/// `#[inline(always)]`. Without this attribute, the compiler may choose to not inline the +/// function, preventing it from optimizing out the error path. +/// ``` +/// #[inline(always)] /// fn bar(n: usize) { /// // `static_assert!(n > 1);` is not allowed /// build_assert!(n > 1); // Build-time check diff --git a/rust/kernel/clk.rs b/rust/kernel/clk.rs index c1cfaeaa36a2..4059aff34d09 100644 --- a/rust/kernel/clk.rs +++ b/rust/kernel/clk.rs @@ -94,7 +94,7 @@ mod common_clk { /// # Invariants /// /// A [`Clk`] instance holds either a pointer to a valid [`struct clk`] created by the C - /// portion of the kernel or a NULL pointer. + /// portion of the kernel or a `NULL` pointer. /// /// Instances of this type are reference-counted. Calling [`Clk::get`] ensures that the /// allocation remains valid for the lifetime of the [`Clk`]. @@ -104,13 +104,12 @@ mod common_clk { /// The following example demonstrates how to obtain and configure a clock for a device. /// /// ``` - /// use kernel::c_str; /// use kernel::clk::{Clk, Hertz}; /// use kernel::device::Device; /// use kernel::error::Result; /// /// fn configure_clk(dev: &Device) -> Result { - /// let clk = Clk::get(dev, Some(c_str!("apb_clk")))?; + /// let clk = Clk::get(dev, Some(c"apb_clk"))?; /// /// clk.prepare_enable()?; /// @@ -272,13 +271,12 @@ mod common_clk { /// device. The code functions correctly whether or not the clock is available. /// /// ``` - /// use kernel::c_str; /// use kernel::clk::{OptionalClk, Hertz}; /// use kernel::device::Device; /// use kernel::error::Result; /// /// fn configure_clk(dev: &Device) -> Result { - /// let clk = OptionalClk::get(dev, Some(c_str!("apb_clk")))?; + /// let clk = OptionalClk::get(dev, Some(c"apb_clk"))?; /// /// clk.prepare_enable()?; /// diff --git a/rust/kernel/cpufreq.rs b/rust/kernel/cpufreq.rs index f968fbd22890..76faa1ac8501 100644 --- a/rust/kernel/cpufreq.rs +++ b/rust/kernel/cpufreq.rs @@ -840,7 +840,6 @@ pub trait Driver { /// ``` /// use kernel::{ /// cpufreq, -/// c_str, /// device::{Core, Device}, /// macros::vtable, /// of, platform, @@ -853,7 +852,7 @@ pub trait Driver { /// /// #[vtable] /// impl cpufreq::Driver for SampleDriver { -/// const NAME: &'static CStr = c_str!("cpufreq-sample"); +/// const NAME: &'static CStr = c"cpufreq-sample"; /// const FLAGS: u16 = cpufreq::flags::NEED_INITIAL_FREQ_CHECK | cpufreq::flags::IS_COOLING_DEV; /// const BOOST_ENABLED: bool = true; /// @@ -1015,6 +1014,8 @@ impl<T: Driver> Registration<T> { ..pin_init::zeroed() }; + // Always inline to optimize out error path of `build_assert`. + #[inline(always)] const fn copy_name(name: &'static CStr) -> [c_char; CPUFREQ_NAME_LEN] { let src = name.to_bytes_with_nul(); let mut dst = [0; CPUFREQ_NAME_LEN]; diff --git a/rust/kernel/cpumask.rs b/rust/kernel/cpumask.rs index c1d17826ae7b..44bb36636ee3 100644 --- a/rust/kernel/cpumask.rs +++ b/rust/kernel/cpumask.rs @@ -39,7 +39,7 @@ use core::ops::{Deref, DerefMut}; /// fn set_clear_cpu(ptr: *mut bindings::cpumask, set_cpu: CpuId, clear_cpu: CpuId) { /// // SAFETY: The `ptr` is valid for writing and remains valid for the lifetime of the /// // returned reference. -/// let mask = unsafe { Cpumask::as_mut_ref(ptr) }; +/// let mask = unsafe { Cpumask::from_raw_mut(ptr) }; /// /// mask.set(set_cpu); /// mask.clear(clear_cpu); @@ -49,13 +49,13 @@ use core::ops::{Deref, DerefMut}; pub struct Cpumask(Opaque<bindings::cpumask>); impl Cpumask { - /// Creates a mutable reference to an existing `struct cpumask` pointer. + /// Creates a mutable reference from an existing `struct cpumask` pointer. /// /// # Safety /// /// The caller must ensure that `ptr` is valid for writing and remains valid for the lifetime /// of the returned reference. - pub unsafe fn as_mut_ref<'a>(ptr: *mut bindings::cpumask) -> &'a mut Self { + pub unsafe fn from_raw_mut<'a>(ptr: *mut bindings::cpumask) -> &'a mut Self { // SAFETY: Guaranteed by the safety requirements of the function. // // INVARIANT: The caller ensures that `ptr` is valid for writing and remains valid for the @@ -63,13 +63,13 @@ impl Cpumask { unsafe { &mut *ptr.cast() } } - /// Creates a reference to an existing `struct cpumask` pointer. + /// Creates a reference from an existing `struct cpumask` pointer. /// /// # Safety /// /// The caller must ensure that `ptr` is valid for reading and remains valid for the lifetime /// of the returned reference. - pub unsafe fn as_ref<'a>(ptr: *const bindings::cpumask) -> &'a Self { + pub unsafe fn from_raw<'a>(ptr: *const bindings::cpumask) -> &'a Self { // SAFETY: Guaranteed by the safety requirements of the function. // // INVARIANT: The caller ensures that `ptr` is valid for reading and remains valid for the diff --git a/rust/kernel/debugfs/entry.rs b/rust/kernel/debugfs/entry.rs index 706cb7f73d6c..a30bf8f29679 100644 --- a/rust/kernel/debugfs/entry.rs +++ b/rust/kernel/debugfs/entry.rs @@ -148,7 +148,7 @@ impl Entry<'_> { /// # Guarantees /// /// Due to the type invariant, the value returned from this function will always be an error - /// code, NULL, or a live DebugFS directory. If it is live, it will remain live at least as + /// code, `NULL`, or a live DebugFS directory. If it is live, it will remain live at least as /// long as this entry lives. pub(crate) fn as_ptr(&self) -> *mut bindings::dentry { self.entry diff --git a/rust/kernel/i2c.rs b/rust/kernel/i2c.rs index 39b0a9a207fd..bb5b830f48c3 100644 --- a/rust/kernel/i2c.rs +++ b/rust/kernel/i2c.rs @@ -262,7 +262,7 @@ macro_rules! module_i2c_driver { /// # Example /// ///``` -/// # use kernel::{acpi, bindings, c_str, device::Core, i2c, of}; +/// # use kernel::{acpi, bindings, device::Core, i2c, of}; /// /// struct MyDriver; /// @@ -271,7 +271,7 @@ macro_rules! module_i2c_driver { /// MODULE_ACPI_TABLE, /// <MyDriver as i2c::Driver>::IdInfo, /// [ -/// (acpi::DeviceId::new(c_str!("LNUXBEEF")), ()) +/// (acpi::DeviceId::new(c"LNUXBEEF"), ()) /// ] /// ); /// @@ -280,7 +280,7 @@ macro_rules! module_i2c_driver { /// MODULE_I2C_TABLE, /// <MyDriver as i2c::Driver>::IdInfo, /// [ -/// (i2c::DeviceId::new(c_str!("rust_driver_i2c")), ()) +/// (i2c::DeviceId::new(c"rust_driver_i2c"), ()) /// ] /// ); /// @@ -289,7 +289,7 @@ macro_rules! module_i2c_driver { /// MODULE_OF_TABLE, /// <MyDriver as i2c::Driver>::IdInfo, /// [ -/// (of::DeviceId::new(c_str!("test,device")), ()) +/// (of::DeviceId::new(c"test,device"), ()) /// ] /// ); /// diff --git a/rust/kernel/impl_flags.rs b/rust/kernel/impl_flags.rs new file mode 100644 index 000000000000..e2bd7639da12 --- /dev/null +++ b/rust/kernel/impl_flags.rs @@ -0,0 +1,272 @@ +// SPDX-License-Identifier: GPL-2.0 + +//! Bitflag type generator. + +/// Common helper for declaring bitflag and bitmask types. +/// +/// This macro takes as input: +/// - A struct declaration representing a bitmask type +/// (e.g., `pub struct Permissions(u32)`). +/// - An enumeration declaration representing individual bit flags +/// (e.g., `pub enum Permission { ... }`). +/// +/// And generates: +/// - The struct and enum types with appropriate `#[repr]` attributes. +/// - Implementations of common bitflag operators +/// ([`::core::ops::BitOr`], [`::core::ops::BitAnd`], etc.). +/// - Utility methods such as `.contains()` to check flags. +/// +/// # Examples +/// +/// ``` +/// use kernel::impl_flags; +/// +/// impl_flags!( +/// /// Represents multiple permissions. +/// #[derive(Debug, Clone, Default, Copy, PartialEq, Eq)] +/// pub struct Permissions(u32); +/// +/// /// Represents a single permission. +/// #[derive(Debug, Clone, Copy, PartialEq, Eq)] +/// pub enum Permission { +/// /// Read permission. +/// Read = 1 << 0, +/// +/// /// Write permission. +/// Write = 1 << 1, +/// +/// /// Execute permission. +/// Execute = 1 << 2, +/// } +/// ); +/// +/// // Combine multiple permissions using the bitwise OR (`|`) operator. +/// let mut read_write: Permissions = Permission::Read | Permission::Write; +/// assert!(read_write.contains(Permission::Read)); +/// assert!(read_write.contains(Permission::Write)); +/// assert!(!read_write.contains(Permission::Execute)); +/// assert!(read_write.contains_any(Permission::Read | Permission::Execute)); +/// assert!(read_write.contains_all(Permission::Read | Permission::Write)); +/// +/// // Using the bitwise OR assignment (`|=`) operator. +/// read_write |= Permission::Execute; +/// assert!(read_write.contains(Permission::Execute)); +/// +/// // Masking a permission with the bitwise AND (`&`) operator. +/// let read_only: Permissions = read_write & Permission::Read; +/// assert!(read_only.contains(Permission::Read)); +/// assert!(!read_only.contains(Permission::Write)); +/// +/// // Toggling permissions with the bitwise XOR (`^`) operator. +/// let toggled: Permissions = read_only ^ Permission::Read; +/// assert!(!toggled.contains(Permission::Read)); +/// +/// // Inverting permissions with the bitwise NOT (`!`) operator. +/// let negated = !read_only; +/// assert!(negated.contains(Permission::Write)); +/// assert!(!negated.contains(Permission::Read)); +/// ``` +#[macro_export] +macro_rules! impl_flags { + ( + $(#[$outer_flags:meta])* + $vis_flags:vis struct $flags:ident($ty:ty); + + $(#[$outer_flag:meta])* + $vis_flag:vis enum $flag:ident { + $( + $(#[$inner_flag:meta])* + $name:ident = $value:expr + ),+ $( , )? + } + ) => { + $(#[$outer_flags])* + #[repr(transparent)] + $vis_flags struct $flags($ty); + + $(#[$outer_flag])* + #[repr($ty)] + $vis_flag enum $flag { + $( + $(#[$inner_flag])* + $name = $value + ),+ + } + + impl ::core::convert::From<$flag> for $flags { + #[inline] + fn from(value: $flag) -> Self { + Self(value as $ty) + } + } + + impl ::core::convert::From<$flags> for $ty { + #[inline] + fn from(value: $flags) -> Self { + value.0 + } + } + + impl ::core::ops::BitOr for $flags { + type Output = Self; + #[inline] + fn bitor(self, rhs: Self) -> Self::Output { + Self(self.0 | rhs.0) + } + } + + impl ::core::ops::BitOrAssign for $flags { + #[inline] + fn bitor_assign(&mut self, rhs: Self) { + *self = *self | rhs; + } + } + + impl ::core::ops::BitOr<$flag> for $flags { + type Output = Self; + #[inline] + fn bitor(self, rhs: $flag) -> Self::Output { + self | Self::from(rhs) + } + } + + impl ::core::ops::BitOrAssign<$flag> for $flags { + #[inline] + fn bitor_assign(&mut self, rhs: $flag) { + *self = *self | rhs; + } + } + + impl ::core::ops::BitAnd for $flags { + type Output = Self; + #[inline] + fn bitand(self, rhs: Self) -> Self::Output { + Self(self.0 & rhs.0) + } + } + + impl ::core::ops::BitAndAssign for $flags { + #[inline] + fn bitand_assign(&mut self, rhs: Self) { + *self = *self & rhs; + } + } + + impl ::core::ops::BitAnd<$flag> for $flags { + type Output = Self; + #[inline] + fn bitand(self, rhs: $flag) -> Self::Output { + self & Self::from(rhs) + } + } + + impl ::core::ops::BitAndAssign<$flag> for $flags { + #[inline] + fn bitand_assign(&mut self, rhs: $flag) { + *self = *self & rhs; + } + } + + impl ::core::ops::BitXor for $flags { + type Output = Self; + #[inline] + fn bitxor(self, rhs: Self) -> Self::Output { + Self((self.0 ^ rhs.0) & Self::all_bits()) + } + } + + impl ::core::ops::BitXorAssign for $flags { + #[inline] + fn bitxor_assign(&mut self, rhs: Self) { + *self = *self ^ rhs; + } + } + + impl ::core::ops::BitXor<$flag> for $flags { + type Output = Self; + #[inline] + fn bitxor(self, rhs: $flag) -> Self::Output { + self ^ Self::from(rhs) + } + } + + impl ::core::ops::BitXorAssign<$flag> for $flags { + #[inline] + fn bitxor_assign(&mut self, rhs: $flag) { + *self = *self ^ rhs; + } + } + + impl ::core::ops::Not for $flags { + type Output = Self; + #[inline] + fn not(self) -> Self::Output { + Self((!self.0) & Self::all_bits()) + } + } + + impl ::core::ops::BitOr for $flag { + type Output = $flags; + #[inline] + fn bitor(self, rhs: Self) -> Self::Output { + $flags(self as $ty | rhs as $ty) + } + } + + impl ::core::ops::BitAnd for $flag { + type Output = $flags; + #[inline] + fn bitand(self, rhs: Self) -> Self::Output { + $flags(self as $ty & rhs as $ty) + } + } + + impl ::core::ops::BitXor for $flag { + type Output = $flags; + #[inline] + fn bitxor(self, rhs: Self) -> Self::Output { + $flags((self as $ty ^ rhs as $ty) & $flags::all_bits()) + } + } + + impl ::core::ops::Not for $flag { + type Output = $flags; + #[inline] + fn not(self) -> Self::Output { + $flags((!(self as $ty)) & $flags::all_bits()) + } + } + + impl $flags { + /// Returns an empty instance where no flags are set. + #[inline] + pub const fn empty() -> Self { + Self(0) + } + + /// Returns a mask containing all valid flag bits. + #[inline] + pub const fn all_bits() -> $ty { + 0 $( | $value )+ + } + + /// Checks if a specific flag is set. + #[inline] + pub fn contains(self, flag: $flag) -> bool { + (self.0 & flag as $ty) == flag as $ty + } + + /// Checks if at least one of the provided flags is set. + #[inline] + pub fn contains_any(self, flags: $flags) -> bool { + (self.0 & flags.0) != 0 + } + + /// Checks if all of the provided flags are set. + #[inline] + pub fn contains_all(self, flags: $flags) -> bool { + (self.0 & flags.0) == flags.0 + } + } + }; +} diff --git a/rust/kernel/init.rs b/rust/kernel/init.rs index 899b9a962762..7a0d4559d7b5 100644 --- a/rust/kernel/init.rs +++ b/rust/kernel/init.rs @@ -219,20 +219,12 @@ pub trait InPlaceInit<T>: Sized { /// [`Error`]: crate::error::Error #[macro_export] macro_rules! try_init { - ($(&$this:ident in)? $t:ident $(::<$($generics:ty),* $(,)?>)? { - $($fields:tt)* - }) => { - ::pin_init::try_init!($(&$this in)? $t $(::<$($generics),*>)? { - $($fields)* - }? $crate::error::Error) - }; - ($(&$this:ident in)? $t:ident $(::<$($generics:ty),* $(,)?>)? { - $($fields:tt)* - }? $err:ty) => { - ::pin_init::try_init!($(&$this in)? $t $(::<$($generics),*>)? { - $($fields)* - }? $err) - }; + ($($args:tt)*) => { + ::pin_init::init!( + #[default_error($crate::error::Error)] + $($args)* + ) + } } /// Construct an in-place, fallible pinned initializer for `struct`s. @@ -279,18 +271,10 @@ macro_rules! try_init { /// [`Error`]: crate::error::Error #[macro_export] macro_rules! try_pin_init { - ($(&$this:ident in)? $t:ident $(::<$($generics:ty),* $(,)?>)? { - $($fields:tt)* - }) => { - ::pin_init::try_pin_init!($(&$this in)? $t $(::<$($generics),*>)? { - $($fields)* - }? $crate::error::Error) - }; - ($(&$this:ident in)? $t:ident $(::<$($generics:ty),* $(,)?>)? { - $($fields:tt)* - }? $err:ty) => { - ::pin_init::try_pin_init!($(&$this in)? $t $(::<$($generics),*>)? { - $($fields)* - }? $err) - }; + ($($args:tt)*) => { + ::pin_init::pin_init!( + #[default_error($crate::error::Error)] + $($args)* + ) + } } diff --git a/rust/kernel/kunit.rs b/rust/kernel/kunit.rs index 79436509dd73..f93f24a60bdd 100644 --- a/rust/kernel/kunit.rs +++ b/rust/kernel/kunit.rs @@ -9,9 +9,6 @@ use crate::fmt; use crate::prelude::*; -#[cfg(CONFIG_PRINTK)] -use crate::c_str; - /// Prints a KUnit error-level message. /// /// Public but hidden since it should only be used from KUnit generated code. @@ -22,7 +19,7 @@ pub fn err(args: fmt::Arguments<'_>) { #[cfg(CONFIG_PRINTK)] unsafe { bindings::_printk( - c_str!("\x013%pA").as_char_ptr(), + c"\x013%pA".as_char_ptr(), core::ptr::from_ref(&args).cast::<c_void>(), ); } @@ -38,7 +35,7 @@ pub fn info(args: fmt::Arguments<'_>) { #[cfg(CONFIG_PRINTK)] unsafe { bindings::_printk( - c_str!("\x016%pA").as_char_ptr(), + c"\x016%pA".as_char_ptr(), core::ptr::from_ref(&args).cast::<c_void>(), ); } @@ -60,7 +57,7 @@ macro_rules! kunit_assert { break 'out; } - static FILE: &'static $crate::str::CStr = $crate::c_str!($file); + static FILE: &'static $crate::str::CStr = $file; static LINE: i32 = ::core::line!() as i32 - $diff; static CONDITION: &'static $crate::str::CStr = $crate::c_str!(stringify!($condition)); @@ -192,9 +189,6 @@ pub fn is_test_result_ok(t: impl TestResult) -> bool { } /// Represents an individual test case. -/// -/// The [`kunit_unsafe_test_suite!`] macro expects a NULL-terminated list of valid test cases. -/// Use [`kunit_case_null`] to generate such a delimiter. #[doc(hidden)] pub const fn kunit_case( name: &'static kernel::str::CStr, @@ -215,32 +209,11 @@ pub const fn kunit_case( } } -/// Represents the NULL test case delimiter. -/// -/// The [`kunit_unsafe_test_suite!`] macro expects a NULL-terminated list of test cases. This -/// function returns such a delimiter. -#[doc(hidden)] -pub const fn kunit_case_null() -> kernel::bindings::kunit_case { - kernel::bindings::kunit_case { - run_case: None, - name: core::ptr::null_mut(), - generate_params: None, - attr: kernel::bindings::kunit_attributes { - speed: kernel::bindings::kunit_speed_KUNIT_SPEED_NORMAL, - }, - status: kernel::bindings::kunit_status_KUNIT_SUCCESS, - module_name: core::ptr::null_mut(), - log: core::ptr::null_mut(), - param_init: None, - param_exit: None, - } -} - /// Registers a KUnit test suite. /// /// # Safety /// -/// `test_cases` must be a NULL terminated array of valid test cases, +/// `test_cases` must be a `NULL` terminated array of valid test cases, /// whose lifetime is at least that of the test suite (i.e., static). /// /// # Examples @@ -253,8 +226,8 @@ pub const fn kunit_case_null() -> kernel::bindings::kunit_case { /// } /// /// static mut KUNIT_TEST_CASES: [kernel::bindings::kunit_case; 2] = [ -/// kernel::kunit::kunit_case(kernel::c_str!("name"), test_fn), -/// kernel::kunit::kunit_case_null(), +/// kernel::kunit::kunit_case(c"name", test_fn), +/// pin_init::zeroed(), /// ]; /// kernel::kunit_unsafe_test_suite!(suite_name, KUNIT_TEST_CASES); /// ``` diff --git a/rust/kernel/lib.rs b/rust/kernel/lib.rs index f812cf120042..696f62f85eb5 100644 --- a/rust/kernel/lib.rs +++ b/rust/kernel/lib.rs @@ -100,6 +100,8 @@ pub mod fs; #[cfg(CONFIG_I2C = "y")] pub mod i2c; pub mod id_pool; +#[doc(hidden)] +pub mod impl_flags; pub mod init; pub mod io; pub mod ioctl; @@ -133,6 +135,7 @@ pub mod pwm; pub mod rbtree; pub mod regulator; pub mod revocable; +pub mod safety; pub mod scatterlist; pub mod security; pub mod seq_file; diff --git a/rust/kernel/list/arc.rs b/rust/kernel/list/arc.rs index d92bcf665c89..2282f33913ee 100644 --- a/rust/kernel/list/arc.rs +++ b/rust/kernel/list/arc.rs @@ -6,11 +6,11 @@ use crate::alloc::{AllocError, Flags}; use crate::prelude::*; +use crate::sync::atomic::{ordering, Atomic}; use crate::sync::{Arc, ArcBorrow, UniqueArc}; use core::marker::PhantomPinned; use core::ops::Deref; use core::pin::Pin; -use core::sync::atomic::{AtomicBool, Ordering}; /// Declares that this type has some way to ensure that there is exactly one `ListArc` instance for /// this id. @@ -469,7 +469,7 @@ where /// If the boolean is `false`, then there is no [`ListArc`] for this value. #[repr(transparent)] pub struct AtomicTracker<const ID: u64 = 0> { - inner: AtomicBool, + inner: Atomic<bool>, // This value needs to be pinned to justify the INVARIANT: comment in `AtomicTracker::new`. _pin: PhantomPinned, } @@ -480,12 +480,12 @@ impl<const ID: u64> AtomicTracker<ID> { // INVARIANT: Pin-init initializers can't be used on an existing `Arc`, so this value will // not be constructed in an `Arc` that already has a `ListArc`. Self { - inner: AtomicBool::new(false), + inner: Atomic::new(false), _pin: PhantomPinned, } } - fn project_inner(self: Pin<&mut Self>) -> &mut AtomicBool { + fn project_inner(self: Pin<&mut Self>) -> &mut Atomic<bool> { // SAFETY: The `inner` field is not structurally pinned, so we may obtain a mutable // reference to it even if we only have a pinned reference to `self`. unsafe { &mut Pin::into_inner_unchecked(self).inner } @@ -500,7 +500,7 @@ impl<const ID: u64> ListArcSafe<ID> for AtomicTracker<ID> { unsafe fn on_drop_list_arc(&self) { // INVARIANT: We just dropped a ListArc, so the boolean should be false. - self.inner.store(false, Ordering::Release); + self.inner.store(false, ordering::Release); } } @@ -514,8 +514,6 @@ unsafe impl<const ID: u64> TryNewListArc<ID> for AtomicTracker<ID> { fn try_new_list_arc(&self) -> bool { // INVARIANT: If this method returns true, then the boolean used to be false, and is no // longer false, so it is okay for the caller to create a new [`ListArc`]. - self.inner - .compare_exchange(false, true, Ordering::Acquire, Ordering::Relaxed) - .is_ok() + self.inner.cmpxchg(false, true, ordering::Acquire).is_ok() } } diff --git a/rust/kernel/print.rs b/rust/kernel/print.rs index 2d743d78d220..6fd84389a858 100644 --- a/rust/kernel/print.rs +++ b/rust/kernel/print.rs @@ -11,6 +11,11 @@ use crate::{ fmt, prelude::*, str::RawFormatter, + sync::atomic::{ + Atomic, + AtomicType, + Relaxed, // + }, }; // Called from `vsprintf` with format specifier `%pA`. @@ -423,3 +428,151 @@ macro_rules! pr_cont ( $crate::print_macro!($crate::print::format_strings::CONT, true, $($arg)*) ) ); + +/// A lightweight `call_once` primitive. +/// +/// This structure provides the Rust equivalent of the kernel's `DO_ONCE_LITE` macro. +/// While it would be possible to implement the feature entirely as a Rust macro, +/// the functionality that can be implemented as regular functions has been +/// extracted and implemented as the `OnceLite` struct for better code maintainability. +pub struct OnceLite(Atomic<State>); + +#[derive(Clone, Copy, PartialEq, Eq)] +#[repr(i32)] +enum State { + Incomplete = 0, + Complete = 1, +} + +// SAFETY: `State` and `i32` has the same size and alignment, and it's round-trip +// transmutable to `i32`. +unsafe impl AtomicType for State { + type Repr = i32; +} + +impl OnceLite { + /// Creates a new [`OnceLite`] in the incomplete state. + #[inline(always)] + #[allow(clippy::new_without_default)] + pub const fn new() -> Self { + OnceLite(Atomic::new(State::Incomplete)) + } + + /// Calls the provided function exactly once. + /// + /// There is no other synchronization between two `call_once()`s + /// except that only one will execute `f`, in other words, callers + /// should not use a failed `call_once()` as a proof that another + /// `call_once()` has already finished and the effect is observable + /// to this thread. + pub fn call_once<F>(&self, f: F) -> bool + where + F: FnOnce(), + { + // Avoid expensive cmpxchg if already completed. + // ORDERING: `Relaxed` is used here since no synchronization is required. + let old = self.0.load(Relaxed); + if old == State::Complete { + return false; + } + + // ORDERING: `Relaxed` is used here since no synchronization is required. + let old = self.0.xchg(State::Complete, Relaxed); + if old == State::Complete { + return false; + } + + f(); + true + } +} + +/// Run the given function exactly once. +/// +/// This is equivalent to the kernel's `DO_ONCE_LITE` macro. +/// +/// # Examples +/// +/// ``` +/// kernel::do_once_lite! { +/// kernel::pr_info!("This will be printed only once\n"); +/// }; +/// ``` +#[macro_export] +macro_rules! do_once_lite { + { $($e:tt)* } => {{ + #[link_section = ".data..once"] + static ONCE: $crate::print::OnceLite = $crate::print::OnceLite::new(); + ONCE.call_once(|| { $($e)* }); + }}; +} + +/// Prints an emergency-level message (level 0) only once. +/// +/// Equivalent to the kernel's `pr_emerg_once` macro. +#[macro_export] +macro_rules! pr_emerg_once ( + ($($arg:tt)*) => ( + $crate::do_once_lite! { $crate::pr_emerg!($($arg)*) } + ) +); + +/// Prints an alert-level message (level 1) only once. +/// +/// Equivalent to the kernel's `pr_alert_once` macro. +#[macro_export] +macro_rules! pr_alert_once ( + ($($arg:tt)*) => ( + $crate::do_once_lite! { $crate::pr_alert!($($arg)*) } + ) +); + +/// Prints a critical-level message (level 2) only once. +/// +/// Equivalent to the kernel's `pr_crit_once` macro. +#[macro_export] +macro_rules! pr_crit_once ( + ($($arg:tt)*) => ( + $crate::do_once_lite! { $crate::pr_crit!($($arg)*) } + ) +); + +/// Prints an error-level message (level 3) only once. +/// +/// Equivalent to the kernel's `pr_err_once` macro. +#[macro_export] +macro_rules! pr_err_once ( + ($($arg:tt)*) => ( + $crate::do_once_lite! { $crate::pr_err!($($arg)*) } + ) +); + +/// Prints a warning-level message (level 4) only once. +/// +/// Equivalent to the kernel's `pr_warn_once` macro. +#[macro_export] +macro_rules! pr_warn_once ( + ($($arg:tt)*) => ( + $crate::do_once_lite! { $crate::pr_warn!($($arg)*) } + ) +); + +/// Prints a notice-level message (level 5) only once. +/// +/// Equivalent to the kernel's `pr_notice_once` macro. +#[macro_export] +macro_rules! pr_notice_once ( + ($($arg:tt)*) => ( + $crate::do_once_lite! { $crate::pr_notice!($($arg)*) } + ) +); + +/// Prints an info-level message (level 6) only once. +/// +/// Equivalent to the kernel's `pr_info_once` macro. +#[macro_export] +macro_rules! pr_info_once ( + ($($arg:tt)*) => ( + $crate::do_once_lite! { $crate::pr_info!($($arg)*) } + ) +); diff --git a/rust/kernel/ptr.rs b/rust/kernel/ptr.rs index e3893ed04049..5b6a382637fe 100644 --- a/rust/kernel/ptr.rs +++ b/rust/kernel/ptr.rs @@ -5,8 +5,6 @@ use core::mem::align_of; use core::num::NonZero; -use crate::build_assert; - /// Type representing an alignment, which is always a power of two. /// /// It is used to validate that a given value is a valid alignment, and to perform masking and @@ -40,10 +38,12 @@ impl Alignment { /// ``` #[inline(always)] pub const fn new<const ALIGN: usize>() -> Self { - build_assert!( - ALIGN.is_power_of_two(), - "Provided alignment is not a power of two." - ); + const { + assert!( + ALIGN.is_power_of_two(), + "Provided alignment is not a power of two." + ); + } // INVARIANT: `align` is a power of two. // SAFETY: `align` is a power of two, and thus non-zero. diff --git a/rust/kernel/rbtree.rs b/rust/kernel/rbtree.rs index 312cecab72e7..6fbd579d4a43 100644 --- a/rust/kernel/rbtree.rs +++ b/rust/kernel/rbtree.rs @@ -414,14 +414,17 @@ where // SAFETY: By the type invariant of `Self`, all non-null `rb_node` pointers stored in `self` // point to the links field of `Node<K, V>` objects. let this = unsafe { container_of!(node, Node<K, V>, links) }; + // SAFETY: `this` is a non-null node so it is valid by the type invariants. - node = match key.cmp(unsafe { &(*this).key }) { - // SAFETY: `node` is a non-null node so it is valid by the type invariants. - Ordering::Less => unsafe { (*node).rb_left }, - // SAFETY: `node` is a non-null node so it is valid by the type invariants. - Ordering::Greater => unsafe { (*node).rb_right }, - // SAFETY: `node` is a non-null node so it is valid by the type invariants. - Ordering::Equal => return Some(unsafe { &(*this).value }), + let this_ref = unsafe { &*this }; + + // SAFETY: `node` is a non-null node so it is valid by the type invariants. + let node_ref = unsafe { &*node }; + + node = match key.cmp(&this_ref.key) { + Ordering::Less => node_ref.rb_left, + Ordering::Greater => node_ref.rb_right, + Ordering::Equal => return Some(&this_ref.value), } } None @@ -498,10 +501,10 @@ where let this = unsafe { container_of!(node, Node<K, V>, links) }; // SAFETY: `this` is a non-null node so it is valid by the type invariants. let this_key = unsafe { &(*this).key }; + // SAFETY: `node` is a non-null node so it is valid by the type invariants. - let left_child = unsafe { (*node).rb_left }; - // SAFETY: `node` is a non-null node so it is valid by the type invariants. - let right_child = unsafe { (*node).rb_right }; + let node_ref = unsafe { &*node }; + match key.cmp(this_key) { Ordering::Equal => { // SAFETY: `this` is a non-null node so it is valid by the type invariants. @@ -509,7 +512,7 @@ where break; } Ordering::Greater => { - node = right_child; + node = node_ref.rb_right; } Ordering::Less => { let is_better_match = match best_key { @@ -521,7 +524,7 @@ where // SAFETY: `this` is a non-null node so it is valid by the type invariants. best_links = Some(unsafe { NonNull::new_unchecked(&mut (*this).links) }); } - node = left_child; + node = node_ref.rb_left; } }; } diff --git a/rust/kernel/safety.rs b/rust/kernel/safety.rs new file mode 100644 index 000000000000..c1c6bd0fa2cc --- /dev/null +++ b/rust/kernel/safety.rs @@ -0,0 +1,53 @@ +// SPDX-License-Identifier: GPL-2.0 + +//! Safety related APIs. + +/// Checks that a precondition of an unsafe function is followed. +/// +/// The check is enabled at runtime if debug assertions (`CONFIG_RUST_DEBUG_ASSERTIONS`) +/// are enabled. Otherwise, this macro is a no-op. +/// +/// # Examples +/// +/// ```no_run +/// use kernel::unsafe_precondition_assert; +/// +/// struct RawBuffer<T: Copy, const N: usize> { +/// data: [T; N], +/// } +/// +/// impl<T: Copy, const N: usize> RawBuffer<T, N> { +/// /// # Safety +/// /// +/// /// The caller must ensure that `index` is less than `N`. +/// unsafe fn set_unchecked(&mut self, index: usize, value: T) { +/// unsafe_precondition_assert!( +/// index < N, +/// "RawBuffer::set_unchecked() requires index ({index}) < N ({N})" +/// ); +/// +/// // SAFETY: By the safety requirements of this function, `index` is valid. +/// unsafe { +/// *self.data.get_unchecked_mut(index) = value; +/// } +/// } +/// } +/// ``` +/// +/// # Panics +/// +/// Panics if the expression is evaluated to [`false`] at runtime. +#[macro_export] +macro_rules! unsafe_precondition_assert { + ($cond:expr $(,)?) => { + $crate::unsafe_precondition_assert!(@inner $cond, ::core::stringify!($cond)) + }; + + ($cond:expr, $($arg:tt)+) => { + $crate::unsafe_precondition_assert!(@inner $cond, $crate::prelude::fmt!($($arg)+)) + }; + + (@inner $cond:expr, $msg:expr) => { + ::core::debug_assert!($cond, "unsafe precondition violated: {}", $msg) + }; +} diff --git a/rust/kernel/sync.rs b/rust/kernel/sync.rs index 5df87e2bd212..993dbf2caa0e 100644 --- a/rust/kernel/sync.rs +++ b/rust/kernel/sync.rs @@ -32,7 +32,9 @@ pub use locked_by::LockedBy; pub use refcount::Refcount; pub use set_once::SetOnce; -/// Represents a lockdep class. It's a wrapper around C's `lock_class_key`. +/// Represents a lockdep class. +/// +/// Wraps the kernel's `struct lock_class_key`. #[repr(transparent)] #[pin_data(PinnedDrop)] pub struct LockClassKey { @@ -40,20 +42,42 @@ pub struct LockClassKey { inner: Opaque<bindings::lock_class_key>, } +// SAFETY: Unregistering a lock class key from a different thread than where it was registered is +// allowed. +unsafe impl Send for LockClassKey {} + // SAFETY: `bindings::lock_class_key` is designed to be used concurrently from multiple threads and // provides its own synchronization. unsafe impl Sync for LockClassKey {} impl LockClassKey { - /// Initializes a dynamically allocated lock class key. In the common case of using a - /// statically allocated lock class key, the static_lock_class! macro should be used instead. + /// Initializes a statically allocated lock class key. + /// + /// This is usually used indirectly through the [`static_lock_class!`] macro. See its + /// documentation for more information. + /// + /// # Safety + /// + /// * Before using the returned value, it must be pinned in a static memory location. + /// * The destructor must never run on the returned `LockClassKey`. + pub const unsafe fn new_static() -> Self { + LockClassKey { + inner: Opaque::uninit(), + } + } + + /// Initializes a dynamically allocated lock class key. + /// + /// In the common case of using a statically allocated lock class key, the + /// [`static_lock_class!`] macro should be used instead. /// /// # Examples + /// /// ``` - /// # use kernel::alloc::KBox; - /// # use kernel::types::ForeignOwnable; - /// # use kernel::sync::{LockClassKey, SpinLock}; - /// # use pin_init::stack_pin_init; + /// use kernel::alloc::KBox; + /// use kernel::types::ForeignOwnable; + /// use kernel::sync::{LockClassKey, SpinLock}; + /// use pin_init::stack_pin_init; /// /// let key = KBox::pin_init(LockClassKey::new_dynamic(), GFP_KERNEL)?; /// let key_ptr = key.into_foreign(); @@ -71,7 +95,6 @@ impl LockClassKey { /// // SAFETY: We dropped `num`, the only use of the key, so the result of the previous /// // `borrow` has also been dropped. Thus, it's safe to use from_foreign. /// unsafe { drop(<Pin<KBox<LockClassKey>> as ForeignOwnable>::from_foreign(key_ptr)) }; - /// /// # Ok::<(), Error>(()) /// ``` pub fn new_dynamic() -> impl PinInit<Self> { @@ -81,7 +104,10 @@ impl LockClassKey { }) } - pub(crate) fn as_ptr(&self) -> *mut bindings::lock_class_key { + /// Returns a raw pointer to the inner C struct. + /// + /// It is up to the caller to use the raw pointer correctly. + pub fn as_ptr(&self) -> *mut bindings::lock_class_key { self.inner.get() } } @@ -89,27 +115,38 @@ impl LockClassKey { #[pinned_drop] impl PinnedDrop for LockClassKey { fn drop(self: Pin<&mut Self>) { - // SAFETY: self.as_ptr was registered with lockdep and self is pinned, so the address - // hasn't changed. Thus, it's safe to pass to unregister. + // SAFETY: `self.as_ptr()` was registered with lockdep and `self` is pinned, so the address + // hasn't changed. Thus, it's safe to pass it to unregister. unsafe { bindings::lockdep_unregister_key(self.as_ptr()) } } } /// Defines a new static lock class and returns a pointer to it. -#[doc(hidden)] +/// +/// # Examples +/// +/// ``` +/// use kernel::sync::{static_lock_class, Arc, SpinLock}; +/// +/// fn new_locked_int() -> Result<Arc<SpinLock<u32>>> { +/// Arc::pin_init(SpinLock::new( +/// 42, +/// c"new_locked_int", +/// static_lock_class!(), +/// ), GFP_KERNEL) +/// } +/// ``` #[macro_export] macro_rules! static_lock_class { () => {{ static CLASS: $crate::sync::LockClassKey = - // Lockdep expects uninitialized memory when it's handed a statically allocated `struct - // lock_class_key`. - // - // SAFETY: `LockClassKey` transparently wraps `Opaque` which permits uninitialized - // memory. - unsafe { ::core::mem::MaybeUninit::uninit().assume_init() }; + // SAFETY: The returned `LockClassKey` is stored in static memory and we pin it. Drop + // never runs on a static global. + unsafe { $crate::sync::LockClassKey::new_static() }; $crate::prelude::Pin::static_ref(&CLASS) }}; } +pub use static_lock_class; /// Returns the given string, if one is provided, otherwise generates one based on the source code /// location. diff --git a/rust/kernel/sync/aref.rs b/rust/kernel/sync/aref.rs index 0d24a0432015..0616c0353c2b 100644 --- a/rust/kernel/sync/aref.rs +++ b/rust/kernel/sync/aref.rs @@ -83,6 +83,9 @@ unsafe impl<T: AlwaysRefCounted + Sync + Send> Send for ARef<T> {} // example, when the reference count reaches zero and `T` is dropped. unsafe impl<T: AlwaysRefCounted + Sync + Send> Sync for ARef<T> {} +// Even if `T` is pinned, pointers to `T` can still move. +impl<T: AlwaysRefCounted> Unpin for ARef<T> {} + impl<T: AlwaysRefCounted> ARef<T> { /// Creates a new instance of [`ARef`]. /// diff --git a/rust/kernel/sync/atomic/internal.rs b/rust/kernel/sync/atomic/internal.rs index 6fdd8e59f45b..0dac58bca2b3 100644 --- a/rust/kernel/sync/atomic/internal.rs +++ b/rust/kernel/sync/atomic/internal.rs @@ -13,17 +13,22 @@ mod private { pub trait Sealed {} } -// `i32` and `i64` are only supported atomic implementations. +// The C side supports atomic primitives only for `i32` and `i64` (`atomic_t` and `atomic64_t`), +// while the Rust side also layers provides atomic support for `i8` and `i16` +// on top of lower-level C primitives. +impl private::Sealed for i8 {} +impl private::Sealed for i16 {} impl private::Sealed for i32 {} impl private::Sealed for i64 {} /// A marker trait for types that implement atomic operations with C side primitives. /// -/// This trait is sealed, and only types that have directly mapping to the C side atomics should -/// impl this: +/// This trait is sealed, and only types that map directly to the C side atomics +/// or can be implemented with lower-level C primitives are allowed to implement this: /// -/// - `i32` maps to `atomic_t`. -/// - `i64` maps to `atomic64_t`. +/// - `i8` and `i16` are implemented with lower-level C primitives. +/// - `i32` map to `atomic_t` +/// - `i64` map to `atomic64_t` pub trait AtomicImpl: Sized + Send + Copy + private::Sealed { /// The type of the delta in arithmetic or logical operations. /// @@ -32,6 +37,20 @@ pub trait AtomicImpl: Sized + Send + Copy + private::Sealed { type Delta; } +// The current helpers of load/store uses `{WRITE,READ}_ONCE()` hence the atomicity is only +// guaranteed against read-modify-write operations if the architecture supports native atomic RmW. +#[cfg(CONFIG_ARCH_SUPPORTS_ATOMIC_RMW)] +impl AtomicImpl for i8 { + type Delta = Self; +} + +// The current helpers of load/store uses `{WRITE,READ}_ONCE()` hence the atomicity is only +// guaranteed against read-modify-write operations if the architecture supports native atomic RmW. +#[cfg(CONFIG_ARCH_SUPPORTS_ATOMIC_RMW)] +impl AtomicImpl for i16 { + type Delta = Self; +} + // `atomic_t` implements atomic operations on `i32`. impl AtomicImpl for i32 { type Delta = Self; @@ -156,16 +175,17 @@ macro_rules! impl_atomic_method { } } -// Delcares $ops trait with methods and implements the trait for `i32` and `i64`. -macro_rules! declare_and_impl_atomic_methods { - ($(#[$attr:meta])* $pub:vis trait $ops:ident { - $( - $(#[doc=$doc:expr])* - fn $func:ident [$($variant:ident),*]($($arg_sig:tt)*) $( -> $ret:ty)? { - $unsafe:tt { bindings::#call($($arg:tt)*) } - } - )* - }) => { +macro_rules! declare_atomic_ops_trait { + ( + $(#[$attr:meta])* $pub:vis trait $ops:ident { + $( + $(#[doc=$doc:expr])* + fn $func:ident [$($variant:ident),*]($($arg_sig:tt)*) $( -> $ret:ty)? { + $unsafe:tt { bindings::#call($($arg:tt)*) } + } + )* + } + ) => { $(#[$attr])* $pub trait $ops: AtomicImpl { $( @@ -175,21 +195,25 @@ macro_rules! declare_and_impl_atomic_methods { ); )* } + } +} - impl $ops for i32 { +macro_rules! impl_atomic_ops_for_one { + ( + $ty:ty => $ctype:ident, + $(#[$attr:meta])* $pub:vis trait $ops:ident { $( - impl_atomic_method!( - (atomic) $func[$($variant)*]($($arg_sig)*) $(-> $ret)? { - $unsafe { call($($arg)*) } - } - ); + $(#[doc=$doc:expr])* + fn $func:ident [$($variant:ident),*]($($arg_sig:tt)*) $( -> $ret:ty)? { + $unsafe:tt { bindings::#call($($arg:tt)*) } + } )* } - - impl $ops for i64 { + ) => { + impl $ops for $ty { $( impl_atomic_method!( - (atomic64) $func[$($variant)*]($($arg_sig)*) $(-> $ret)? { + ($ctype) $func[$($variant)*]($($arg_sig)*) $(-> $ret)? { $unsafe { call($($arg)*) } } ); @@ -198,7 +222,47 @@ macro_rules! declare_and_impl_atomic_methods { } } +// Declares $ops trait with methods and implements the trait. +macro_rules! declare_and_impl_atomic_methods { + ( + [ $($map:tt)* ] + $(#[$attr:meta])* $pub:vis trait $ops:ident { $($body:tt)* } + ) => { + declare_and_impl_atomic_methods!( + @with_ops_def + [ $($map)* ] + ( $(#[$attr])* $pub trait $ops { $($body)* } ) + ); + }; + + (@with_ops_def [ $($map:tt)* ] ( $($ops_def:tt)* )) => { + declare_atomic_ops_trait!( $($ops_def)* ); + + declare_and_impl_atomic_methods!( + @munch + [ $($map)* ] + ( $($ops_def)* ) + ); + }; + + (@munch [] ( $($ops_def:tt)* )) => {}; + + (@munch [ $ty:ty => $ctype:ident $(, $($rest:tt)*)? ] ( $($ops_def:tt)* )) => { + impl_atomic_ops_for_one!( + $ty => $ctype, + $($ops_def)* + ); + + declare_and_impl_atomic_methods!( + @munch + [ $($($rest)*)? ] + ( $($ops_def)* ) + ); + }; +} + declare_and_impl_atomic_methods!( + [ i8 => atomic_i8, i16 => atomic_i16, i32 => atomic, i64 => atomic64 ] /// Basic atomic operations pub trait AtomicBasicOps { /// Atomic read (load). @@ -216,6 +280,7 @@ declare_and_impl_atomic_methods!( ); declare_and_impl_atomic_methods!( + [ i8 => atomic_i8, i16 => atomic_i16, i32 => atomic, i64 => atomic64 ] /// Exchange and compare-and-exchange atomic operations pub trait AtomicExchangeOps { /// Atomic exchange. @@ -243,6 +308,7 @@ declare_and_impl_atomic_methods!( ); declare_and_impl_atomic_methods!( + [ i32 => atomic, i64 => atomic64 ] /// Atomic arithmetic operations pub trait AtomicArithmeticOps { /// Atomic add (wrapping). diff --git a/rust/kernel/sync/atomic/predefine.rs b/rust/kernel/sync/atomic/predefine.rs index 0fca1ba3c2db..67a0406d3ea4 100644 --- a/rust/kernel/sync/atomic/predefine.rs +++ b/rust/kernel/sync/atomic/predefine.rs @@ -5,6 +5,29 @@ use crate::static_assert; use core::mem::{align_of, size_of}; +// Ensure size and alignment requirements are checked. +static_assert!(size_of::<bool>() == size_of::<i8>()); +static_assert!(align_of::<bool>() == align_of::<i8>()); + +// SAFETY: `bool` has the same size and alignment as `i8`, and Rust guarantees that `bool` has +// only two valid bit patterns: 0 (false) and 1 (true). Those are valid `i8` values, so `bool` is +// round-trip transmutable to `i8`. +unsafe impl super::AtomicType for bool { + type Repr = i8; +} + +// SAFETY: `i8` has the same size and alignment with itself, and is round-trip transmutable to +// itself. +unsafe impl super::AtomicType for i8 { + type Repr = i8; +} + +// SAFETY: `i16` has the same size and alignment with itself, and is round-trip transmutable to +// itself. +unsafe impl super::AtomicType for i16 { + type Repr = i16; +} + // SAFETY: `i32` has the same size and alignment with itself, and is round-trip transmutable to // itself. unsafe impl super::AtomicType for i32 { @@ -129,7 +152,7 @@ mod tests { #[test] fn atomic_basic_tests() { - for_each_type!(42 in [i32, i64, u32, u64, isize, usize] |v| { + for_each_type!(42 in [i8, i16, i32, i64, u32, u64, isize, usize] |v| { let x = Atomic::new(v); assert_eq!(v, x.load(Relaxed)); @@ -137,8 +160,18 @@ mod tests { } #[test] + fn atomic_acquire_release_tests() { + for_each_type!(42 in [i8, i16, i32, i64, u32, u64, isize, usize] |v| { + let x = Atomic::new(0); + + x.store(v, Release); + assert_eq!(v, x.load(Acquire)); + }); + } + + #[test] fn atomic_xchg_tests() { - for_each_type!(42 in [i32, i64, u32, u64, isize, usize] |v| { + for_each_type!(42 in [i8, i16, i32, i64, u32, u64, isize, usize] |v| { let x = Atomic::new(v); let old = v; @@ -151,7 +184,7 @@ mod tests { #[test] fn atomic_cmpxchg_tests() { - for_each_type!(42 in [i32, i64, u32, u64, isize, usize] |v| { + for_each_type!(42 in [i8, i16, i32, i64, u32, u64, isize, usize] |v| { let x = Atomic::new(v); let old = v; @@ -177,4 +210,20 @@ mod tests { assert_eq!(v + 25, x.load(Relaxed)); }); } + + #[test] + fn atomic_bool_tests() { + let x = Atomic::new(false); + + assert_eq!(false, x.load(Relaxed)); + x.store(true, Relaxed); + assert_eq!(true, x.load(Relaxed)); + + assert_eq!(true, x.xchg(false, Relaxed)); + assert_eq!(false, x.load(Relaxed)); + + assert_eq!(Err(false), x.cmpxchg(true, true, Relaxed)); + assert_eq!(false, x.load(Relaxed)); + assert_eq!(Ok(false), x.cmpxchg(false, true, Full)); + } } diff --git a/rust/kernel/sync/lock.rs b/rust/kernel/sync/lock.rs index 46a57d1fc309..10b6b5e9b024 100644 --- a/rust/kernel/sync/lock.rs +++ b/rust/kernel/sync/lock.rs @@ -156,6 +156,7 @@ impl<B: Backend> Lock<(), B> { /// the whole lifetime of `'a`. /// /// [`State`]: Backend::State + #[inline] pub unsafe fn from_raw<'a>(ptr: *mut B::State) -> &'a Self { // SAFETY: // - By the safety contract `ptr` must point to a valid initialised instance of `B::State` @@ -169,6 +170,7 @@ impl<B: Backend> Lock<(), B> { impl<T: ?Sized, B: Backend> Lock<T, B> { /// Acquires the lock and gives the caller access to the data protected by it. + #[inline] pub fn lock(&self) -> Guard<'_, T, B> { // SAFETY: The constructor of the type calls `init`, so the existence of the object proves // that `init` was called. @@ -182,6 +184,7 @@ impl<T: ?Sized, B: Backend> Lock<T, B> { /// Returns a guard that can be used to access the data protected by the lock if successful. // `Option<T>` is not `#[must_use]` even if `T` is, thus the attribute is needed here. #[must_use = "if unused, the lock will be immediately unlocked"] + #[inline] pub fn try_lock(&self) -> Option<Guard<'_, T, B>> { // SAFETY: The constructor of the type calls `init`, so the existence of the object proves // that `init` was called. @@ -275,6 +278,7 @@ impl<'a, T: ?Sized, B: Backend> Guard<'a, T, B> { impl<T: ?Sized, B: Backend> core::ops::Deref for Guard<'_, T, B> { type Target = T; + #[inline] fn deref(&self) -> &Self::Target { // SAFETY: The caller owns the lock, so it is safe to deref the protected data. unsafe { &*self.lock.data.get() } @@ -285,6 +289,7 @@ impl<T: ?Sized, B: Backend> core::ops::DerefMut for Guard<'_, T, B> where T: Unpin, { + #[inline] fn deref_mut(&mut self) -> &mut Self::Target { // SAFETY: The caller owns the lock, so it is safe to deref the protected data. unsafe { &mut *self.lock.data.get() } @@ -292,6 +297,7 @@ where } impl<T: ?Sized, B: Backend> Drop for Guard<'_, T, B> { + #[inline] fn drop(&mut self) { // SAFETY: The caller owns the lock, so it is safe to unlock it. unsafe { B::unlock(self.lock.state.get(), &self.state) }; @@ -304,6 +310,7 @@ impl<'a, T: ?Sized, B: Backend> Guard<'a, T, B> { /// # Safety /// /// The caller must ensure that it owns the lock. + #[inline] pub unsafe fn new(lock: &'a Lock<T, B>, state: B::GuardState) -> Self { // SAFETY: The caller can only hold the lock if `Backend::init` has already been called. unsafe { B::assert_is_held(lock.state.get()) }; diff --git a/rust/kernel/sync/lock/global.rs b/rust/kernel/sync/lock/global.rs index eab48108a4ae..aecbdc34738f 100644 --- a/rust/kernel/sync/lock/global.rs +++ b/rust/kernel/sync/lock/global.rs @@ -77,6 +77,7 @@ impl<B: GlobalLockBackend> GlobalLock<B> { } /// Lock this global lock. + #[inline] pub fn lock(&'static self) -> GlobalGuard<B> { GlobalGuard { inner: self.inner.lock(), @@ -84,6 +85,7 @@ impl<B: GlobalLockBackend> GlobalLock<B> { } /// Try to lock this global lock. + #[inline] pub fn try_lock(&'static self) -> Option<GlobalGuard<B>> { Some(GlobalGuard { inner: self.inner.try_lock()?, diff --git a/rust/kernel/sync/lock/mutex.rs b/rust/kernel/sync/lock/mutex.rs index 581cee7ab842..cda0203efefb 100644 --- a/rust/kernel/sync/lock/mutex.rs +++ b/rust/kernel/sync/lock/mutex.rs @@ -102,6 +102,7 @@ unsafe impl super::Backend for MutexBackend { type State = bindings::mutex; type GuardState = (); + #[inline] unsafe fn init( ptr: *mut Self::State, name: *const crate::ffi::c_char, @@ -112,18 +113,21 @@ unsafe impl super::Backend for MutexBackend { unsafe { bindings::__mutex_init(ptr, name, key) } } + #[inline] unsafe fn lock(ptr: *mut Self::State) -> Self::GuardState { // SAFETY: The safety requirements of this function ensure that `ptr` points to valid // memory, and that it has been initialised before. unsafe { bindings::mutex_lock(ptr) }; } + #[inline] unsafe fn unlock(ptr: *mut Self::State, _guard_state: &Self::GuardState) { // SAFETY: The safety requirements of this function ensure that `ptr` is valid and that the // caller is the owner of the mutex. unsafe { bindings::mutex_unlock(ptr) }; } + #[inline] unsafe fn try_lock(ptr: *mut Self::State) -> Option<Self::GuardState> { // SAFETY: The `ptr` pointer is guaranteed to be valid and initialized before use. let result = unsafe { bindings::mutex_trylock(ptr) }; @@ -135,6 +139,7 @@ unsafe impl super::Backend for MutexBackend { } } + #[inline] unsafe fn assert_is_held(ptr: *mut Self::State) { // SAFETY: The `ptr` pointer is guaranteed to be valid and initialized before use. unsafe { bindings::mutex_assert_is_held(ptr) } diff --git a/rust/kernel/sync/lock/spinlock.rs b/rust/kernel/sync/lock/spinlock.rs index d7be38ccbdc7..ef76fa07ca3a 100644 --- a/rust/kernel/sync/lock/spinlock.rs +++ b/rust/kernel/sync/lock/spinlock.rs @@ -101,6 +101,7 @@ unsafe impl super::Backend for SpinLockBackend { type State = bindings::spinlock_t; type GuardState = (); + #[inline] unsafe fn init( ptr: *mut Self::State, name: *const crate::ffi::c_char, @@ -111,18 +112,21 @@ unsafe impl super::Backend for SpinLockBackend { unsafe { bindings::__spin_lock_init(ptr, name, key) } } + #[inline] unsafe fn lock(ptr: *mut Self::State) -> Self::GuardState { // SAFETY: The safety requirements of this function ensure that `ptr` points to valid // memory, and that it has been initialised before. unsafe { bindings::spin_lock(ptr) } } + #[inline] unsafe fn unlock(ptr: *mut Self::State, _guard_state: &Self::GuardState) { // SAFETY: The safety requirements of this function ensure that `ptr` is valid and that the // caller is the owner of the spinlock. unsafe { bindings::spin_unlock(ptr) } } + #[inline] unsafe fn try_lock(ptr: *mut Self::State) -> Option<Self::GuardState> { // SAFETY: The `ptr` pointer is guaranteed to be valid and initialized before use. let result = unsafe { bindings::spin_trylock(ptr) }; @@ -134,6 +138,7 @@ unsafe impl super::Backend for SpinLockBackend { } } + #[inline] unsafe fn assert_is_held(ptr: *mut Self::State) { // SAFETY: The `ptr` pointer is guaranteed to be valid and initialized before use. unsafe { bindings::spin_assert_is_held(ptr) } diff --git a/rust/kernel/sync/set_once.rs b/rust/kernel/sync/set_once.rs index bdba601807d8..139cef05e935 100644 --- a/rust/kernel/sync/set_once.rs +++ b/rust/kernel/sync/set_once.rs @@ -123,3 +123,11 @@ impl<T> Drop for SetOnce<T> { } } } + +// SAFETY: `SetOnce` can be transferred across thread boundaries iff the data it contains can. +unsafe impl<T: Send> Send for SetOnce<T> {} + +// SAFETY: `SetOnce` synchronises access to the inner value via atomic operations, +// so shared references are safe when `T: Sync`. Since the inner `T` may be dropped +// on any thread, we also require `T: Send`. +unsafe impl<T: Send + Sync> Sync for SetOnce<T> {} diff --git a/rust/kernel/transmute.rs b/rust/kernel/transmute.rs index be5dbf3829e2..5711580c9f9b 100644 --- a/rust/kernel/transmute.rs +++ b/rust/kernel/transmute.rs @@ -170,6 +170,10 @@ macro_rules! impl_frombytes { } impl_frombytes! { + // SAFETY: Inhabited ZSTs only have one possible bit pattern, and these two have no invariant. + (), + {<T>} core::marker::PhantomData<T>, + // SAFETY: All bit patterns are acceptable values of the types below. u8, u16, u32, u64, usize, i8, i16, i32, i64, isize, @@ -230,6 +234,10 @@ macro_rules! impl_asbytes { } impl_asbytes! { + // SAFETY: Inhabited ZSTs only have one possible bit pattern, and these two have no invariant. + (), + {<T>} core::marker::PhantomData<T>, + // SAFETY: Instances of the following types have no uninitialized portions. u8, u16, u32, u64, usize, i8, i16, i32, i64, isize, diff --git a/rust/macros/concat_idents.rs b/rust/macros/concat_idents.rs index 7e4b450f3a50..47b6add378d2 100644 --- a/rust/macros/concat_idents.rs +++ b/rust/macros/concat_idents.rs @@ -1,23 +1,36 @@ // SPDX-License-Identifier: GPL-2.0 -use proc_macro::{token_stream, Ident, TokenStream, TokenTree}; +use proc_macro2::{ + Ident, + TokenStream, + TokenTree, // +}; +use syn::{ + parse::{ + Parse, + ParseStream, // + }, + Result, + Token, // +}; -use crate::helpers::expect_punct; +pub(crate) struct Input { + a: Ident, + _comma: Token![,], + b: Ident, +} -fn expect_ident(it: &mut token_stream::IntoIter) -> Ident { - if let Some(TokenTree::Ident(ident)) = it.next() { - ident - } else { - panic!("Expected Ident") +impl Parse for Input { + fn parse(input: ParseStream<'_>) -> Result<Self> { + Ok(Self { + a: input.parse()?, + _comma: input.parse()?, + b: input.parse()?, + }) } } -pub(crate) fn concat_idents(ts: TokenStream) -> TokenStream { - let mut it = ts.into_iter(); - let a = expect_ident(&mut it); - assert_eq!(expect_punct(&mut it), ','); - let b = expect_ident(&mut it); - assert!(it.next().is_none(), "only two idents can be concatenated"); +pub(crate) fn concat_idents(Input { a, b, .. }: Input) -> TokenStream { let res = Ident::new(&format!("{a}{b}"), b.span()); TokenStream::from_iter([TokenTree::Ident(res)]) } diff --git a/rust/macros/export.rs b/rust/macros/export.rs index a08f6337d5c8..6d53521f62fc 100644 --- a/rust/macros/export.rs +++ b/rust/macros/export.rs @@ -1,19 +1,16 @@ // SPDX-License-Identifier: GPL-2.0 -use crate::helpers::function_name; -use proc_macro::TokenStream; +use proc_macro2::TokenStream; +use quote::quote; /// Please see [`crate::export`] for documentation. -pub(crate) fn export(_attr: TokenStream, ts: TokenStream) -> TokenStream { - let Some(name) = function_name(ts.clone()) else { - return "::core::compile_error!(\"The #[export] attribute must be used on a function.\");" - .parse::<TokenStream>() - .unwrap(); - }; +pub(crate) fn export(f: syn::ItemFn) -> TokenStream { + let name = &f.sig.ident; - // This verifies that the function has the same signature as the declaration generated by - // bindgen. It makes use of the fact that all branches of an if/else must have the same type. - let signature_check = quote!( + quote! { + // This verifies that the function has the same signature as the declaration generated by + // bindgen. It makes use of the fact that all branches of an if/else must have the same + // type. const _: () = { if true { ::kernel::bindings::#name @@ -21,9 +18,8 @@ pub(crate) fn export(_attr: TokenStream, ts: TokenStream) -> TokenStream { #name }; }; - ); - let no_mangle = quote!(#[no_mangle]); - - TokenStream::from_iter([signature_check, no_mangle, ts]) + #[no_mangle] + #f + } } diff --git a/rust/macros/fmt.rs b/rust/macros/fmt.rs index 8354abd54502..ce6c7249305a 100644 --- a/rust/macros/fmt.rs +++ b/rust/macros/fmt.rs @@ -1,8 +1,10 @@ // SPDX-License-Identifier: GPL-2.0 -use proc_macro::{Ident, TokenStream, TokenTree}; use std::collections::BTreeSet; +use proc_macro2::{Ident, TokenStream, TokenTree}; +use quote::quote_spanned; + /// Please see [`crate::fmt`] for documentation. pub(crate) fn fmt(input: TokenStream) -> TokenStream { let mut input = input.into_iter(); diff --git a/rust/macros/helpers.rs b/rust/macros/helpers.rs index 365d7eb499c0..37ef6a6f2c85 100644 --- a/rust/macros/helpers.rs +++ b/rust/macros/helpers.rs @@ -1,101 +1,41 @@ // SPDX-License-Identifier: GPL-2.0 -use proc_macro::{token_stream, Group, Ident, TokenStream, TokenTree}; - -pub(crate) fn try_ident(it: &mut token_stream::IntoIter) -> Option<String> { - if let Some(TokenTree::Ident(ident)) = it.next() { - Some(ident.to_string()) - } else { - None - } -} - -pub(crate) fn try_sign(it: &mut token_stream::IntoIter) -> Option<char> { - let peek = it.clone().next(); - match peek { - Some(TokenTree::Punct(punct)) if punct.as_char() == '-' => { - let _ = it.next(); - Some(punct.as_char()) +use proc_macro2::TokenStream; +use quote::ToTokens; +use syn::{ + parse::{ + Parse, + ParseStream, // + }, + Attribute, + Error, + LitStr, + Result, // +}; + +/// A string literal that is required to have ASCII value only. +pub(crate) struct AsciiLitStr(LitStr); + +impl Parse for AsciiLitStr { + fn parse(input: ParseStream<'_>) -> Result<Self> { + let s: LitStr = input.parse()?; + if !s.value().is_ascii() { + return Err(Error::new_spanned(s, "expected ASCII-only string literal")); } - _ => None, - } -} - -pub(crate) fn try_literal(it: &mut token_stream::IntoIter) -> Option<String> { - if let Some(TokenTree::Literal(literal)) = it.next() { - Some(literal.to_string()) - } else { - None + Ok(Self(s)) } } -pub(crate) fn try_string(it: &mut token_stream::IntoIter) -> Option<String> { - try_literal(it).and_then(|string| { - if string.starts_with('\"') && string.ends_with('\"') { - let content = &string[1..string.len() - 1]; - if content.contains('\\') { - panic!("Escape sequences in string literals not yet handled"); - } - Some(content.to_string()) - } else if string.starts_with("r\"") { - panic!("Raw string literals are not yet handled"); - } else { - None - } - }) -} - -pub(crate) fn expect_ident(it: &mut token_stream::IntoIter) -> String { - try_ident(it).expect("Expected Ident") -} - -pub(crate) fn expect_punct(it: &mut token_stream::IntoIter) -> char { - if let TokenTree::Punct(punct) = it.next().expect("Reached end of token stream for Punct") { - punct.as_char() - } else { - panic!("Expected Punct"); +impl ToTokens for AsciiLitStr { + fn to_tokens(&self, ts: &mut TokenStream) { + self.0.to_tokens(ts); } } -pub(crate) fn expect_string(it: &mut token_stream::IntoIter) -> String { - try_string(it).expect("Expected string") -} - -pub(crate) fn expect_string_ascii(it: &mut token_stream::IntoIter) -> String { - let string = try_string(it).expect("Expected string"); - assert!(string.is_ascii(), "Expected ASCII string"); - string -} - -pub(crate) fn expect_group(it: &mut token_stream::IntoIter) -> Group { - if let TokenTree::Group(group) = it.next().expect("Reached end of token stream for Group") { - group - } else { - panic!("Expected Group"); - } -} - -pub(crate) fn expect_end(it: &mut token_stream::IntoIter) { - if it.next().is_some() { - panic!("Expected end"); - } -} - -/// Given a function declaration, finds the name of the function. -pub(crate) fn function_name(input: TokenStream) -> Option<Ident> { - let mut input = input.into_iter(); - while let Some(token) = input.next() { - match token { - TokenTree::Ident(i) if i.to_string() == "fn" => { - if let Some(TokenTree::Ident(i)) = input.next() { - return Some(i); - } - return None; - } - _ => continue, - } +impl AsciiLitStr { + pub(crate) fn value(&self) -> String { + self.0.value() } - None } pub(crate) fn file() -> String { @@ -115,16 +55,7 @@ pub(crate) fn file() -> String { } } -/// Parse a token stream of the form `expected_name: "value",` and return the -/// string in the position of "value". -/// -/// # Panics -/// -/// - On parse error. -pub(crate) fn expect_string_field(it: &mut token_stream::IntoIter, expected_name: &str) -> String { - assert_eq!(expect_ident(it), expected_name); - assert_eq!(expect_punct(it), ':'); - let string = expect_string(it); - assert_eq!(expect_punct(it), ','); - string +/// Obtain all `#[cfg]` attributes. +pub(crate) fn gather_cfg_attrs(attr: &[Attribute]) -> impl Iterator<Item = &Attribute> + '_ { + attr.iter().filter(|a| a.path().is_ident("cfg")) } diff --git a/rust/macros/kunit.rs b/rust/macros/kunit.rs index b395bb053695..6be880d634e2 100644 --- a/rust/macros/kunit.rs +++ b/rust/macros/kunit.rs @@ -4,80 +4,50 @@ //! //! Copyright (c) 2023 José Expósito <jose.exposito89@gmail.com> -use proc_macro::{Delimiter, Group, TokenStream, TokenTree}; -use std::collections::HashMap; -use std::fmt::Write; - -pub(crate) fn kunit_tests(attr: TokenStream, ts: TokenStream) -> TokenStream { - let attr = attr.to_string(); - - if attr.is_empty() { - panic!("Missing test name in `#[kunit_tests(test_name)]` macro") - } - - if attr.len() > 255 { - panic!("The test suite name `{attr}` exceeds the maximum length of 255 bytes") +use std::ffi::CString; + +use proc_macro2::TokenStream; +use quote::{ + format_ident, + quote, + ToTokens, // +}; +use syn::{ + parse_quote, + Error, + Ident, + Item, + ItemMod, + LitCStr, + Result, // +}; + +pub(crate) fn kunit_tests(test_suite: Ident, mut module: ItemMod) -> Result<TokenStream> { + if test_suite.to_string().len() > 255 { + return Err(Error::new_spanned( + test_suite, + "test suite names cannot exceed the maximum length of 255 bytes", + )); } - let mut tokens: Vec<_> = ts.into_iter().collect(); - - // Scan for the `mod` keyword. - tokens - .iter() - .find_map(|token| match token { - TokenTree::Ident(ident) => match ident.to_string().as_str() { - "mod" => Some(true), - _ => None, - }, - _ => None, - }) - .expect("`#[kunit_tests(test_name)]` attribute should only be applied to modules"); - - // Retrieve the main body. The main body should be the last token tree. - let body = match tokens.pop() { - Some(TokenTree::Group(group)) if group.delimiter() == Delimiter::Brace => group, - _ => panic!("Cannot locate main body of module"), + // We cannot handle modules that defer to another file (e.g. `mod foo;`). + let Some((module_brace, module_items)) = module.content.take() else { + Err(Error::new_spanned( + module, + "`#[kunit_tests(test_name)]` attribute should only be applied to inline modules", + ))? }; - // Get the functions set as tests. Search for `[test]` -> `fn`. - let mut body_it = body.stream().into_iter(); - let mut tests = Vec::new(); - let mut attributes: HashMap<String, TokenStream> = HashMap::new(); - while let Some(token) = body_it.next() { - match token { - TokenTree::Punct(ref p) if p.as_char() == '#' => match body_it.next() { - Some(TokenTree::Group(g)) if g.delimiter() == Delimiter::Bracket => { - if let Some(TokenTree::Ident(name)) = g.stream().into_iter().next() { - // Collect attributes because we need to find which are tests. We also - // need to copy `cfg` attributes so tests can be conditionally enabled. - attributes - .entry(name.to_string()) - .or_default() - .extend([token, TokenTree::Group(g)]); - } - continue; - } - _ => (), - }, - TokenTree::Ident(i) if i.to_string() == "fn" && attributes.contains_key("test") => { - if let Some(TokenTree::Ident(test_name)) = body_it.next() { - tests.push((test_name, attributes.remove("cfg").unwrap_or_default())) - } - } - - _ => (), - } - attributes.clear(); - } + // Make the entire module gated behind `CONFIG_KUNIT`. + module + .attrs + .insert(0, parse_quote!(#[cfg(CONFIG_KUNIT="y")])); - // Add `#[cfg(CONFIG_KUNIT="y")]` before the module declaration. - let config_kunit = "#[cfg(CONFIG_KUNIT=\"y\")]".to_owned().parse().unwrap(); - tokens.insert( - 0, - TokenTree::Group(Group::new(Delimiter::None, config_kunit)), - ); + let mut processed_items = Vec::new(); + let mut test_cases = Vec::new(); // Generate the test KUnit test suite and a test case for each `#[test]`. + // // The code generated for the following test module: // // ``` @@ -102,105 +72,100 @@ pub(crate) fn kunit_tests(attr: TokenStream, ts: TokenStream) -> TokenStream { // unsafe extern "C" fn kunit_rust_wrapper_bar(_test: *mut ::kernel::bindings::kunit) { bar(); } // // static mut TEST_CASES: [::kernel::bindings::kunit_case; 3] = [ - // ::kernel::kunit::kunit_case(::kernel::c_str!("foo"), kunit_rust_wrapper_foo), - // ::kernel::kunit::kunit_case(::kernel::c_str!("bar"), kunit_rust_wrapper_bar), - // ::kernel::kunit::kunit_case_null(), + // ::kernel::kunit::kunit_case(c"foo", kunit_rust_wrapper_foo), + // ::kernel::kunit::kunit_case(c"bar", kunit_rust_wrapper_bar), + // ::pin_init::zeroed(), // ]; // // ::kernel::kunit_unsafe_test_suite!(kunit_test_suit_name, TEST_CASES); // ``` - let mut kunit_macros = "".to_owned(); - let mut test_cases = "".to_owned(); - let mut assert_macros = "".to_owned(); - let path = crate::helpers::file(); - let num_tests = tests.len(); - for (test, cfg_attr) in tests { - let kunit_wrapper_fn_name = format!("kunit_rust_wrapper_{test}"); - // Append any `cfg` attributes the user might have written on their tests so we don't - // attempt to call them when they are `cfg`'d out. An extra `use` is used here to reduce - // the length of the assert message. - let kunit_wrapper = format!( - r#"unsafe extern "C" fn {kunit_wrapper_fn_name}(_test: *mut ::kernel::bindings::kunit) - {{ - (*_test).status = ::kernel::bindings::kunit_status_KUNIT_SKIPPED; - {cfg_attr} {{ - (*_test).status = ::kernel::bindings::kunit_status_KUNIT_SUCCESS; - use ::kernel::kunit::is_test_result_ok; - assert!(is_test_result_ok({test}())); + // + // Non-function items (e.g. imports) are preserved. + for item in module_items { + let Item::Fn(mut f) = item else { + processed_items.push(item); + continue; + }; + + // TODO: Replace below with `extract_if` when MSRV is bumped above 1.85. + let before_len = f.attrs.len(); + f.attrs.retain(|attr| !attr.path().is_ident("test")); + if f.attrs.len() == before_len { + processed_items.push(Item::Fn(f)); + continue; + } + + let test = f.sig.ident.clone(); + + // Retrieve `#[cfg]` applied on the function which needs to be present on derived items too. + let cfg_attrs: Vec<_> = f + .attrs + .iter() + .filter(|attr| attr.path().is_ident("cfg")) + .cloned() + .collect(); + + // Before the test, override usual `assert!` and `assert_eq!` macros with ones that call + // KUnit instead. + let test_str = test.to_string(); + let path = CString::new(crate::helpers::file()).expect("file path cannot contain NUL"); + processed_items.push(parse_quote! { + #[allow(unused)] + macro_rules! assert { + ($cond:expr $(,)?) => {{ + kernel::kunit_assert!(#test_str, #path, 0, $cond); + }} + } + }); + processed_items.push(parse_quote! { + #[allow(unused)] + macro_rules! assert_eq { + ($left:expr, $right:expr $(,)?) => {{ + kernel::kunit_assert_eq!(#test_str, #path, 0, $left, $right); }} - }}"#, + } + }); + + // Add back the test item. + processed_items.push(Item::Fn(f)); + + let kunit_wrapper_fn_name = format_ident!("kunit_rust_wrapper_{test}"); + let test_cstr = LitCStr::new( + &CString::new(test_str.as_str()).expect("identifier cannot contain NUL"), + test.span(), ); - writeln!(kunit_macros, "{kunit_wrapper}").unwrap(); - writeln!( - test_cases, - " ::kernel::kunit::kunit_case(::kernel::c_str!(\"{test}\"), {kunit_wrapper_fn_name})," - ) - .unwrap(); - writeln!( - assert_macros, - r#" -/// Overrides the usual [`assert!`] macro with one that calls KUnit instead. -#[allow(unused)] -macro_rules! assert {{ - ($cond:expr $(,)?) => {{{{ - kernel::kunit_assert!("{test}", "{path}", 0, $cond); - }}}} -}} - -/// Overrides the usual [`assert_eq!`] macro with one that calls KUnit instead. -#[allow(unused)] -macro_rules! assert_eq {{ - ($left:expr, $right:expr $(,)?) => {{{{ - kernel::kunit_assert_eq!("{test}", "{path}", 0, $left, $right); - }}}} -}} - "# - ) - .unwrap(); - } + processed_items.push(parse_quote! { + unsafe extern "C" fn #kunit_wrapper_fn_name(_test: *mut ::kernel::bindings::kunit) { + (*_test).status = ::kernel::bindings::kunit_status_KUNIT_SKIPPED; - writeln!(kunit_macros).unwrap(); - writeln!( - kunit_macros, - "static mut TEST_CASES: [::kernel::bindings::kunit_case; {}] = [\n{test_cases} ::kernel::kunit::kunit_case_null(),\n];", - num_tests + 1 - ) - .unwrap(); - - writeln!( - kunit_macros, - "::kernel::kunit_unsafe_test_suite!({attr}, TEST_CASES);" - ) - .unwrap(); - - // Remove the `#[test]` macros. - // We do this at a token level, in order to preserve span information. - let mut new_body = vec![]; - let mut body_it = body.stream().into_iter(); - - while let Some(token) = body_it.next() { - match token { - TokenTree::Punct(ref c) if c.as_char() == '#' => match body_it.next() { - Some(TokenTree::Group(group)) if group.to_string() == "[test]" => (), - Some(next) => { - new_body.extend([token, next]); - } - _ => { - new_body.push(token); + // Append any `cfg` attributes the user might have written on their tests so we + // don't attempt to call them when they are `cfg`'d out. An extra `use` is used + // here to reduce the length of the assert message. + #(#cfg_attrs)* + { + (*_test).status = ::kernel::bindings::kunit_status_KUNIT_SUCCESS; + use ::kernel::kunit::is_test_result_ok; + assert!(is_test_result_ok(#test())); } - }, - _ => { - new_body.push(token); } - } - } - - let mut final_body = TokenStream::new(); - final_body.extend::<TokenStream>(assert_macros.parse().unwrap()); - final_body.extend(new_body); - final_body.extend::<TokenStream>(kunit_macros.parse().unwrap()); + }); - tokens.push(TokenTree::Group(Group::new(Delimiter::Brace, final_body))); + test_cases.push(quote!( + ::kernel::kunit::kunit_case(#test_cstr, #kunit_wrapper_fn_name) + )); + } - tokens.into_iter().collect() + let num_tests_plus_1 = test_cases.len() + 1; + processed_items.push(parse_quote! { + static mut TEST_CASES: [::kernel::bindings::kunit_case; #num_tests_plus_1] = [ + #(#test_cases,)* + ::pin_init::zeroed(), + ]; + }); + processed_items.push(parse_quote! { + ::kernel::kunit_unsafe_test_suite!(#test_suite, TEST_CASES); + }); + + module.content = Some((module_brace, processed_items)); + Ok(module.to_token_stream()) } diff --git a/rust/macros/lib.rs b/rust/macros/lib.rs index 33f66e86418a..0c36194d9971 100644 --- a/rust/macros/lib.rs +++ b/rust/macros/lib.rs @@ -11,8 +11,6 @@ // to avoid depending on the full `proc_macro_span` on Rust >= 1.88.0. #![cfg_attr(not(CONFIG_RUSTC_HAS_SPAN_FILE), feature(proc_macro_span))] -#[macro_use] -mod quote; mod concat_idents; mod export; mod fmt; @@ -24,6 +22,8 @@ mod vtable; use proc_macro::TokenStream; +use syn::parse_macro_input; + /// Declares a kernel module. /// /// The `type` argument should be a type which implements the [`Module`] @@ -131,8 +131,10 @@ use proc_macro::TokenStream; /// - `firmware`: array of ASCII string literals of the firmware files of /// the kernel module. #[proc_macro] -pub fn module(ts: TokenStream) -> TokenStream { - module::module(ts) +pub fn module(input: TokenStream) -> TokenStream { + module::module(parse_macro_input!(input)) + .unwrap_or_else(|e| e.into_compile_error()) + .into() } /// Declares or implements a vtable trait. @@ -154,7 +156,7 @@ pub fn module(ts: TokenStream) -> TokenStream { /// case the default implementation will never be executed. The reason for this /// is that the functions will be called through function pointers installed in /// C side vtables. When an optional method is not implemented on a `#[vtable]` -/// trait, a NULL entry is installed in the vtable. Thus the default +/// trait, a `NULL` entry is installed in the vtable. Thus the default /// implementation is never called. Since these traits are not designed to be /// used on the Rust side, it should not be possible to call the default /// implementation. This is done to ensure that we call the vtable methods @@ -206,8 +208,11 @@ pub fn module(ts: TokenStream) -> TokenStream { /// /// [`kernel::error::VTABLE_DEFAULT_ERROR`]: ../kernel/error/constant.VTABLE_DEFAULT_ERROR.html #[proc_macro_attribute] -pub fn vtable(attr: TokenStream, ts: TokenStream) -> TokenStream { - vtable::vtable(attr, ts) +pub fn vtable(attr: TokenStream, input: TokenStream) -> TokenStream { + parse_macro_input!(attr as syn::parse::Nothing); + vtable::vtable(parse_macro_input!(input)) + .unwrap_or_else(|e| e.into_compile_error()) + .into() } /// Export a function so that C code can call it via a header file. @@ -229,8 +234,9 @@ pub fn vtable(attr: TokenStream, ts: TokenStream) -> TokenStream { /// This macro is *not* the same as the C macros `EXPORT_SYMBOL_*`. All Rust symbols are currently /// automatically exported with `EXPORT_SYMBOL_GPL`. #[proc_macro_attribute] -pub fn export(attr: TokenStream, ts: TokenStream) -> TokenStream { - export::export(attr, ts) +pub fn export(attr: TokenStream, input: TokenStream) -> TokenStream { + parse_macro_input!(attr as syn::parse::Nothing); + export::export(parse_macro_input!(input)).into() } /// Like [`core::format_args!`], but automatically wraps arguments in [`kernel::fmt::Adapter`]. @@ -248,7 +254,7 @@ pub fn export(attr: TokenStream, ts: TokenStream) -> TokenStream { /// [`pr_info!`]: ../kernel/macro.pr_info.html #[proc_macro] pub fn fmt(input: TokenStream) -> TokenStream { - fmt::fmt(input) + fmt::fmt(input.into()).into() } /// Concatenate two identifiers. @@ -305,8 +311,8 @@ pub fn fmt(input: TokenStream) -> TokenStream { /// assert_eq!(BR_OK, binder_driver_return_protocol_BR_OK); /// ``` #[proc_macro] -pub fn concat_idents(ts: TokenStream) -> TokenStream { - concat_idents::concat_idents(ts) +pub fn concat_idents(input: TokenStream) -> TokenStream { + concat_idents::concat_idents(parse_macro_input!(input)).into() } /// Paste identifiers together. @@ -444,9 +450,12 @@ pub fn concat_idents(ts: TokenStream) -> TokenStream { /// [`paste`]: https://docs.rs/paste/ #[proc_macro] pub fn paste(input: TokenStream) -> TokenStream { - let mut tokens = input.into_iter().collect(); + let mut tokens = proc_macro2::TokenStream::from(input).into_iter().collect(); paste::expand(&mut tokens); - tokens.into_iter().collect() + tokens + .into_iter() + .collect::<proc_macro2::TokenStream>() + .into() } /// Registers a KUnit test suite and its test cases using a user-space like syntax. @@ -472,6 +481,8 @@ pub fn paste(input: TokenStream) -> TokenStream { /// } /// ``` #[proc_macro_attribute] -pub fn kunit_tests(attr: TokenStream, ts: TokenStream) -> TokenStream { - kunit::kunit_tests(attr, ts) +pub fn kunit_tests(attr: TokenStream, input: TokenStream) -> TokenStream { + kunit::kunit_tests(parse_macro_input!(attr), parse_macro_input!(input)) + .unwrap_or_else(|e| e.into_compile_error()) + .into() } diff --git a/rust/macros/module.rs b/rust/macros/module.rs index 80cb9b16f5aa..e16298e520c7 100644 --- a/rust/macros/module.rs +++ b/rust/macros/module.rs @@ -1,32 +1,42 @@ // SPDX-License-Identifier: GPL-2.0 +use std::ffi::CString; + +use proc_macro2::{ + Literal, + TokenStream, // +}; +use quote::{ + format_ident, + quote, // +}; +use syn::{ + braced, + bracketed, + ext::IdentExt, + parse::{ + Parse, + ParseStream, // + }, + parse_quote, + punctuated::Punctuated, + Error, + Expr, + Ident, + LitStr, + Path, + Result, + Token, + Type, // +}; + use crate::helpers::*; -use proc_macro::{token_stream, Delimiter, Literal, TokenStream, TokenTree}; -use std::fmt::Write; - -fn expect_string_array(it: &mut token_stream::IntoIter) -> Vec<String> { - let group = expect_group(it); - assert_eq!(group.delimiter(), Delimiter::Bracket); - let mut values = Vec::new(); - let mut it = group.stream().into_iter(); - - while let Some(val) = try_string(&mut it) { - assert!(val.is_ascii(), "Expected ASCII string"); - values.push(val); - match it.next() { - Some(TokenTree::Punct(punct)) => assert_eq!(punct.as_char(), ','), - None => break, - _ => panic!("Expected ',' or end of array"), - } - } - values -} struct ModInfoBuilder<'a> { module: &'a str, counter: usize, - buffer: String, - param_buffer: String, + ts: TokenStream, + param_ts: TokenStream, } impl<'a> ModInfoBuilder<'a> { @@ -34,8 +44,8 @@ impl<'a> ModInfoBuilder<'a> { ModInfoBuilder { module, counter: 0, - buffer: String::new(), - param_buffer: String::new(), + ts: TokenStream::new(), + param_ts: TokenStream::new(), } } @@ -52,33 +62,31 @@ impl<'a> ModInfoBuilder<'a> { // Loadable modules' modinfo strings go as-is. format!("{field}={content}\0") }; - - let buffer = if param { - &mut self.param_buffer + let length = string.len(); + let string = Literal::byte_string(string.as_bytes()); + let cfg = if builtin { + quote!(#[cfg(not(MODULE))]) } else { - &mut self.buffer + quote!(#[cfg(MODULE)]) }; - write!( - buffer, - " - {cfg} - #[doc(hidden)] - #[cfg_attr(not(target_os = \"macos\"), link_section = \".modinfo\")] - #[used(compiler)] - pub static __{module}_{counter}: [u8; {length}] = *{string}; - ", - cfg = if builtin { - "#[cfg(not(MODULE))]" - } else { - "#[cfg(MODULE)]" - }, + let counter = format_ident!( + "__{module}_{counter}", module = self.module.to_uppercase(), - counter = self.counter, - length = string.len(), - string = Literal::byte_string(string.as_bytes()), - ) - .unwrap(); + counter = self.counter + ); + let item = quote! { + #cfg + #[cfg_attr(not(target_os = "macos"), link_section = ".modinfo")] + #[used(compiler)] + pub static #counter: [u8; #length] = *#string; + }; + + if param { + self.param_ts.extend(item); + } else { + self.ts.extend(item); + } self.counter += 1; } @@ -111,201 +119,160 @@ impl<'a> ModInfoBuilder<'a> { }; for param in params { - let ops = param_ops_path(¶m.ptype); + let param_name_str = param.name.to_string(); + let param_type_str = param.ptype.to_string(); + + let ops = param_ops_path(¶m_type_str); // Note: The spelling of these fields is dictated by the user space // tool `modinfo`. - self.emit_param("parmtype", ¶m.name, ¶m.ptype); - self.emit_param("parm", ¶m.name, ¶m.description); - - write!( - self.param_buffer, - " - pub(crate) static {param_name}: - ::kernel::module_param::ModuleParamAccess<{param_type}> = - ::kernel::module_param::ModuleParamAccess::new({param_default}); - - const _: () = {{ - #[link_section = \"__param\"] - #[used] - static __{module_name}_{param_name}_struct: + self.emit_param("parmtype", ¶m_name_str, ¶m_type_str); + self.emit_param("parm", ¶m_name_str, ¶m.description.value()); + + let static_name = format_ident!("__{}_{}_struct", self.module, param.name); + let param_name_cstr = + CString::new(param_name_str).expect("name contains NUL-terminator"); + let param_name_cstr_with_module = + CString::new(format!("{}.{}", self.module, param.name)) + .expect("name contains NUL-terminator"); + + let param_name = ¶m.name; + let param_type = ¶m.ptype; + let param_default = ¶m.default; + + self.param_ts.extend(quote! { + #[allow(non_upper_case_globals)] + pub(crate) static #param_name: + ::kernel::module_param::ModuleParamAccess<#param_type> = + ::kernel::module_param::ModuleParamAccess::new(#param_default); + + const _: () = { + #[allow(non_upper_case_globals)] + #[link_section = "__param"] + #[used(compiler)] + static #static_name: ::kernel::module_param::KernelParam = ::kernel::module_param::KernelParam::new( - ::kernel::bindings::kernel_param {{ - name: if ::core::cfg!(MODULE) {{ - ::kernel::c_str!(\"{param_name}\").to_bytes_with_nul() - }} else {{ - ::kernel::c_str!(\"{module_name}.{param_name}\") - .to_bytes_with_nul() - }}.as_ptr(), + ::kernel::bindings::kernel_param { + name: kernel::str::as_char_ptr_in_const_context( + if ::core::cfg!(MODULE) { + #param_name_cstr + } else { + #param_name_cstr_with_module + } + ), // SAFETY: `__this_module` is constructed by the kernel at load // time and will not be freed until the module is unloaded. #[cfg(MODULE)] - mod_: unsafe {{ + mod_: unsafe { core::ptr::from_ref(&::kernel::bindings::__this_module) .cast_mut() - }}, + }, #[cfg(not(MODULE))] mod_: ::core::ptr::null_mut(), - ops: core::ptr::from_ref(&{ops}), + ops: core::ptr::from_ref(&#ops), perm: 0, // Will not appear in sysfs level: -1, flags: 0, - __bindgen_anon_1: ::kernel::bindings::kernel_param__bindgen_ty_1 {{ - arg: {param_name}.as_void_ptr() - }}, - }} + __bindgen_anon_1: ::kernel::bindings::kernel_param__bindgen_ty_1 { + arg: #param_name.as_void_ptr() + }, + } ); - }}; - ", - module_name = info.name, - param_type = param.ptype, - param_default = param.default, - param_name = param.name, - ops = ops, - ) - .unwrap(); + }; + }); } } } -fn param_ops_path(param_type: &str) -> &'static str { +fn param_ops_path(param_type: &str) -> Path { match param_type { - "i8" => "::kernel::module_param::PARAM_OPS_I8", - "u8" => "::kernel::module_param::PARAM_OPS_U8", - "i16" => "::kernel::module_param::PARAM_OPS_I16", - "u16" => "::kernel::module_param::PARAM_OPS_U16", - "i32" => "::kernel::module_param::PARAM_OPS_I32", - "u32" => "::kernel::module_param::PARAM_OPS_U32", - "i64" => "::kernel::module_param::PARAM_OPS_I64", - "u64" => "::kernel::module_param::PARAM_OPS_U64", - "isize" => "::kernel::module_param::PARAM_OPS_ISIZE", - "usize" => "::kernel::module_param::PARAM_OPS_USIZE", + "i8" => parse_quote!(::kernel::module_param::PARAM_OPS_I8), + "u8" => parse_quote!(::kernel::module_param::PARAM_OPS_U8), + "i16" => parse_quote!(::kernel::module_param::PARAM_OPS_I16), + "u16" => parse_quote!(::kernel::module_param::PARAM_OPS_U16), + "i32" => parse_quote!(::kernel::module_param::PARAM_OPS_I32), + "u32" => parse_quote!(::kernel::module_param::PARAM_OPS_U32), + "i64" => parse_quote!(::kernel::module_param::PARAM_OPS_I64), + "u64" => parse_quote!(::kernel::module_param::PARAM_OPS_U64), + "isize" => parse_quote!(::kernel::module_param::PARAM_OPS_ISIZE), + "usize" => parse_quote!(::kernel::module_param::PARAM_OPS_USIZE), t => panic!("Unsupported parameter type {}", t), } } -fn expect_param_default(param_it: &mut token_stream::IntoIter) -> String { - assert_eq!(expect_ident(param_it), "default"); - assert_eq!(expect_punct(param_it), ':'); - let sign = try_sign(param_it); - let default = try_literal(param_it).expect("Expected default param value"); - assert_eq!(expect_punct(param_it), ','); - let mut value = sign.map(String::from).unwrap_or_default(); - value.push_str(&default); - value -} - -#[derive(Debug, Default)] -struct ModuleInfo { - type_: String, - license: String, - name: String, - authors: Option<Vec<String>>, - description: Option<String>, - alias: Option<Vec<String>>, - firmware: Option<Vec<String>>, - imports_ns: Option<Vec<String>>, - params: Option<Vec<Parameter>>, -} - -#[derive(Debug)] -struct Parameter { - name: String, - ptype: String, - default: String, - description: String, -} - -fn expect_params(it: &mut token_stream::IntoIter) -> Vec<Parameter> { - let params = expect_group(it); - assert_eq!(params.delimiter(), Delimiter::Brace); - let mut it = params.stream().into_iter(); - let mut parsed = Vec::new(); - - loop { - let param_name = match it.next() { - Some(TokenTree::Ident(ident)) => ident.to_string(), - Some(_) => panic!("Expected Ident or end"), - None => break, - }; - - assert_eq!(expect_punct(&mut it), ':'); - let param_type = expect_ident(&mut it); - let group = expect_group(&mut it); - assert_eq!(group.delimiter(), Delimiter::Brace); - assert_eq!(expect_punct(&mut it), ','); - - let mut param_it = group.stream().into_iter(); - let param_default = expect_param_default(&mut param_it); - let param_description = expect_string_field(&mut param_it, "description"); - expect_end(&mut param_it); - - parsed.push(Parameter { - name: param_name, - ptype: param_type, - default: param_default, - description: param_description, - }) - } - - parsed -} - -impl ModuleInfo { - fn parse(it: &mut token_stream::IntoIter) -> Self { - let mut info = ModuleInfo::default(); - - const EXPECTED_KEYS: &[&str] = &[ - "type", - "name", - "authors", - "description", - "license", - "alias", - "firmware", - "imports_ns", - "params", - ]; - const REQUIRED_KEYS: &[&str] = &["type", "name", "license"]; +/// Parse fields that are required to use a specific order. +/// +/// As fields must follow a specific order, we *could* just parse fields one by one by peeking. +/// However the error message generated when implementing that way is not very friendly. +/// +/// So instead we parse fields in an arbitrary order, but only enforce the ordering after parsing, +/// and if the wrong order is used, the proper order is communicated to the user with error message. +/// +/// Usage looks like this: +/// ```ignore +/// parse_ordered_fields! { +/// from input; +/// +/// // This will extract "foo: <field>" into a variable named "foo". +/// // The variable will have type `Option<_>`. +/// foo => <expression that parses the field>, +/// +/// // If you need the variable name to be different than the key name. +/// // This extracts "baz: <field>" into a variable named "bar". +/// // You might want this if "baz" is a keyword. +/// baz as bar => <expression that parse the field>, +/// +/// // You can mark a key as required, and the variable will no longer be `Option`. +/// // foobar will be of type `Expr` instead of `Option<Expr>`. +/// foobar [required] => input.parse::<Expr>()?, +/// } +/// ``` +macro_rules! parse_ordered_fields { + (@gen + [$input:expr] + [$([$name:ident; $key:ident; $parser:expr])*] + [$([$req_name:ident; $req_key:ident])*] + ) => { + $(let mut $name = None;)* + + const EXPECTED_KEYS: &[&str] = &[$(stringify!($key),)*]; + const REQUIRED_KEYS: &[&str] = &[$(stringify!($req_key),)*]; + + let span = $input.span(); let mut seen_keys = Vec::new(); - loop { - let key = match it.next() { - Some(TokenTree::Ident(ident)) => ident.to_string(), - Some(_) => panic!("Expected Ident or end"), - None => break, - }; + while !$input.is_empty() { + let key = $input.call(Ident::parse_any)?; if seen_keys.contains(&key) { - panic!("Duplicated key \"{key}\". Keys can only be specified once."); + Err(Error::new_spanned( + &key, + format!(r#"duplicated key "{key}". Keys can only be specified once."#), + ))? } - assert_eq!(expect_punct(it), ':'); - - match key.as_str() { - "type" => info.type_ = expect_ident(it), - "name" => info.name = expect_string_ascii(it), - "authors" => info.authors = Some(expect_string_array(it)), - "description" => info.description = Some(expect_string(it)), - "license" => info.license = expect_string_ascii(it), - "alias" => info.alias = Some(expect_string_array(it)), - "firmware" => info.firmware = Some(expect_string_array(it)), - "imports_ns" => info.imports_ns = Some(expect_string_array(it)), - "params" => info.params = Some(expect_params(it)), - _ => panic!("Unknown key \"{key}\". Valid keys are: {EXPECTED_KEYS:?}."), + $input.parse::<Token![:]>()?; + + match &*key.to_string() { + $( + stringify!($key) => $name = Some($parser), + )* + _ => { + Err(Error::new_spanned( + &key, + format!(r#"unknown key "{key}". Valid keys are: {EXPECTED_KEYS:?}."#), + ))? + } } - assert_eq!(expect_punct(it), ','); - + $input.parse::<Token![,]>()?; seen_keys.push(key); } - expect_end(it); - for key in REQUIRED_KEYS { if !seen_keys.iter().any(|e| e == key) { - panic!("Missing required key \"{key}\"."); + Err(Error::new(span, format!(r#"missing required key "{key}""#)))? } } @@ -317,43 +284,190 @@ impl ModuleInfo { } if seen_keys != ordered_keys { - panic!("Keys are not ordered as expected. Order them like: {ordered_keys:?}."); + Err(Error::new( + span, + format!(r#"keys are not ordered as expected. Order them like: {ordered_keys:?}."#), + ))? + } + + $(let $req_name = $req_name.expect("required field");)* + }; + + // Handle required fields. + (@gen + [$input:expr] [$($tok:tt)*] [$($req:tt)*] + $key:ident as $name:ident [required] => $parser:expr, + $($rest:tt)* + ) => { + parse_ordered_fields!( + @gen [$input] [$($tok)* [$name; $key; $parser]] [$($req)* [$name; $key]] $($rest)* + ) + }; + (@gen + [$input:expr] [$($tok:tt)*] [$($req:tt)*] + $name:ident [required] => $parser:expr, + $($rest:tt)* + ) => { + parse_ordered_fields!( + @gen [$input] [$($tok)* [$name; $name; $parser]] [$($req)* [$name; $name]] $($rest)* + ) + }; + + // Handle optional fields. + (@gen + [$input:expr] [$($tok:tt)*] [$($req:tt)*] + $key:ident as $name:ident => $parser:expr, + $($rest:tt)* + ) => { + parse_ordered_fields!( + @gen [$input] [$($tok)* [$name; $key; $parser]] [$($req)*] $($rest)* + ) + }; + (@gen + [$input:expr] [$($tok:tt)*] [$($req:tt)*] + $name:ident => $parser:expr, + $($rest:tt)* + ) => { + parse_ordered_fields!( + @gen [$input] [$($tok)* [$name; $name; $parser]] [$($req)*] $($rest)* + ) + }; + + (from $input:expr; $($tok:tt)*) => { + parse_ordered_fields!(@gen [$input] [] [] $($tok)*) + } +} + +struct Parameter { + name: Ident, + ptype: Ident, + default: Expr, + description: LitStr, +} + +impl Parse for Parameter { + fn parse(input: ParseStream<'_>) -> Result<Self> { + let name = input.parse()?; + input.parse::<Token![:]>()?; + let ptype = input.parse()?; + + let fields; + braced!(fields in input); + + parse_ordered_fields! { + from fields; + default [required] => fields.parse()?, + description [required] => fields.parse()?, } - info + Ok(Self { + name, + ptype, + default, + description, + }) } } -pub(crate) fn module(ts: TokenStream) -> TokenStream { - let mut it = ts.into_iter(); +pub(crate) struct ModuleInfo { + type_: Type, + license: AsciiLitStr, + name: AsciiLitStr, + authors: Option<Punctuated<AsciiLitStr, Token![,]>>, + description: Option<LitStr>, + alias: Option<Punctuated<AsciiLitStr, Token![,]>>, + firmware: Option<Punctuated<AsciiLitStr, Token![,]>>, + imports_ns: Option<Punctuated<AsciiLitStr, Token![,]>>, + params: Option<Punctuated<Parameter, Token![,]>>, +} + +impl Parse for ModuleInfo { + fn parse(input: ParseStream<'_>) -> Result<Self> { + parse_ordered_fields!( + from input; + type as type_ [required] => input.parse()?, + name [required] => input.parse()?, + authors => { + let list; + bracketed!(list in input); + Punctuated::parse_terminated(&list)? + }, + description => input.parse()?, + license [required] => input.parse()?, + alias => { + let list; + bracketed!(list in input); + Punctuated::parse_terminated(&list)? + }, + firmware => { + let list; + bracketed!(list in input); + Punctuated::parse_terminated(&list)? + }, + imports_ns => { + let list; + bracketed!(list in input); + Punctuated::parse_terminated(&list)? + }, + params => { + let list; + braced!(list in input); + Punctuated::parse_terminated(&list)? + }, + ); + + Ok(ModuleInfo { + type_, + license, + name, + authors, + description, + alias, + firmware, + imports_ns, + params, + }) + } +} - let info = ModuleInfo::parse(&mut it); +pub(crate) fn module(info: ModuleInfo) -> Result<TokenStream> { + let ModuleInfo { + type_, + license, + name, + authors, + description, + alias, + firmware, + imports_ns, + params: _, + } = &info; // Rust does not allow hyphens in identifiers, use underscore instead. - let ident = info.name.replace('-', "_"); + let ident = name.value().replace('-', "_"); let mut modinfo = ModInfoBuilder::new(ident.as_ref()); - if let Some(authors) = &info.authors { + if let Some(authors) = authors { for author in authors { - modinfo.emit("author", author); + modinfo.emit("author", &author.value()); } } - if let Some(description) = &info.description { - modinfo.emit("description", description); + if let Some(description) = description { + modinfo.emit("description", &description.value()); } - modinfo.emit("license", &info.license); - if let Some(aliases) = &info.alias { + modinfo.emit("license", &license.value()); + if let Some(aliases) = alias { for alias in aliases { - modinfo.emit("alias", alias); + modinfo.emit("alias", &alias.value()); } } - if let Some(firmware) = &info.firmware { + if let Some(firmware) = firmware { for fw in firmware { - modinfo.emit("firmware", fw); + modinfo.emit("firmware", &fw.value()); } } - if let Some(imports) = &info.imports_ns { + if let Some(imports) = imports_ns { for ns in imports { - modinfo.emit("import_ns", ns); + modinfo.emit("import_ns", &ns.value()); } } @@ -364,182 +478,181 @@ pub(crate) fn module(ts: TokenStream) -> TokenStream { modinfo.emit_params(&info); - format!( - " - /// The module name. - /// - /// Used by the printing macros, e.g. [`info!`]. - const __LOG_PREFIX: &[u8] = b\"{name}\\0\"; - - // SAFETY: `__this_module` is constructed by the kernel at load time and will not be - // freed until the module is unloaded. - #[cfg(MODULE)] - static THIS_MODULE: ::kernel::ThisModule = unsafe {{ - extern \"C\" {{ - static __this_module: ::kernel::types::Opaque<::kernel::bindings::module>; - }} - - ::kernel::ThisModule::from_ptr(__this_module.get()) - }}; - #[cfg(not(MODULE))] - static THIS_MODULE: ::kernel::ThisModule = unsafe {{ - ::kernel::ThisModule::from_ptr(::core::ptr::null_mut()) - }}; - - /// The `LocalModule` type is the type of the module created by `module!`, - /// `module_pci_driver!`, `module_platform_driver!`, etc. - type LocalModule = {type_}; - - impl ::kernel::ModuleMetadata for {type_} {{ - const NAME: &'static ::kernel::str::CStr = c\"{name}\"; - }} - - // Double nested modules, since then nobody can access the public items inside. - mod __module_init {{ - mod __module_init {{ - use super::super::{type_}; - use pin_init::PinInit; - - /// The \"Rust loadable module\" mark. - // - // This may be best done another way later on, e.g. as a new modinfo - // key or a new section. For the moment, keep it simple. - #[cfg(MODULE)] - #[doc(hidden)] - #[used(compiler)] - static __IS_RUST_MODULE: () = (); - - static mut __MOD: ::core::mem::MaybeUninit<{type_}> = - ::core::mem::MaybeUninit::uninit(); - - // Loadable modules need to export the `{{init,cleanup}}_module` identifiers. - /// # Safety - /// - /// This function must not be called after module initialization, because it may be - /// freed after that completes. - #[cfg(MODULE)] - #[doc(hidden)] - #[no_mangle] - #[link_section = \".init.text\"] - pub unsafe extern \"C\" fn init_module() -> ::kernel::ffi::c_int {{ - // SAFETY: This function is inaccessible to the outside due to the double - // module wrapping it. It is called exactly once by the C side via its - // unique name. - unsafe {{ __init() }} - }} - - #[cfg(MODULE)] - #[doc(hidden)] - #[used(compiler)] - #[link_section = \".init.data\"] - static __UNIQUE_ID___addressable_init_module: unsafe extern \"C\" fn() -> i32 = init_module; - - #[cfg(MODULE)] - #[doc(hidden)] - #[no_mangle] - #[link_section = \".exit.text\"] - pub extern \"C\" fn cleanup_module() {{ - // SAFETY: - // - This function is inaccessible to the outside due to the double - // module wrapping it. It is called exactly once by the C side via its - // unique name, - // - furthermore it is only called after `init_module` has returned `0` - // (which delegates to `__init`). - unsafe {{ __exit() }} - }} - - #[cfg(MODULE)] - #[doc(hidden)] - #[used(compiler)] - #[link_section = \".exit.data\"] - static __UNIQUE_ID___addressable_cleanup_module: extern \"C\" fn() = cleanup_module; - - // Built-in modules are initialized through an initcall pointer - // and the identifiers need to be unique. - #[cfg(not(MODULE))] - #[cfg(not(CONFIG_HAVE_ARCH_PREL32_RELOCATIONS))] - #[doc(hidden)] - #[link_section = \"{initcall_section}\"] - #[used(compiler)] - pub static __{ident}_initcall: extern \"C\" fn() -> - ::kernel::ffi::c_int = __{ident}_init; - - #[cfg(not(MODULE))] - #[cfg(CONFIG_HAVE_ARCH_PREL32_RELOCATIONS)] - ::core::arch::global_asm!( - r#\".section \"{initcall_section}\", \"a\" - __{ident}_initcall: - .long __{ident}_init - . - .previous - \"# + let modinfo_ts = modinfo.ts; + let params_ts = modinfo.param_ts; + + let ident_init = format_ident!("__{ident}_init"); + let ident_exit = format_ident!("__{ident}_exit"); + let ident_initcall = format_ident!("__{ident}_initcall"); + let initcall_section = ".initcall6.init"; + + let global_asm = format!( + r#".section "{initcall_section}", "a" + __{ident}_initcall: + .long __{ident}_init - . + .previous + "# + ); + + let name_cstr = CString::new(name.value()).expect("name contains NUL-terminator"); + + Ok(quote! { + /// The module name. + /// + /// Used by the printing macros, e.g. [`info!`]. + const __LOG_PREFIX: &[u8] = #name_cstr.to_bytes_with_nul(); + + // SAFETY: `__this_module` is constructed by the kernel at load time and will not be + // freed until the module is unloaded. + #[cfg(MODULE)] + static THIS_MODULE: ::kernel::ThisModule = unsafe { + extern "C" { + static __this_module: ::kernel::types::Opaque<::kernel::bindings::module>; + }; + + ::kernel::ThisModule::from_ptr(__this_module.get()) + }; + + #[cfg(not(MODULE))] + static THIS_MODULE: ::kernel::ThisModule = unsafe { + ::kernel::ThisModule::from_ptr(::core::ptr::null_mut()) + }; + + /// The `LocalModule` type is the type of the module created by `module!`, + /// `module_pci_driver!`, `module_platform_driver!`, etc. + type LocalModule = #type_; + + impl ::kernel::ModuleMetadata for #type_ { + const NAME: &'static ::kernel::str::CStr = #name_cstr; + } + + // Double nested modules, since then nobody can access the public items inside. + #[doc(hidden)] + mod __module_init { + mod __module_init { + use pin_init::PinInit; + + /// The "Rust loadable module" mark. + // + // This may be best done another way later on, e.g. as a new modinfo + // key or a new section. For the moment, keep it simple. + #[cfg(MODULE)] + #[used(compiler)] + static __IS_RUST_MODULE: () = (); + + static mut __MOD: ::core::mem::MaybeUninit<super::super::LocalModule> = + ::core::mem::MaybeUninit::uninit(); + + // Loadable modules need to export the `{init,cleanup}_module` identifiers. + /// # Safety + /// + /// This function must not be called after module initialization, because it may be + /// freed after that completes. + #[cfg(MODULE)] + #[no_mangle] + #[link_section = ".init.text"] + pub unsafe extern "C" fn init_module() -> ::kernel::ffi::c_int { + // SAFETY: This function is inaccessible to the outside due to the double + // module wrapping it. It is called exactly once by the C side via its + // unique name. + unsafe { __init() } + } + + #[cfg(MODULE)] + #[used(compiler)] + #[link_section = ".init.data"] + static __UNIQUE_ID___addressable_init_module: unsafe extern "C" fn() -> i32 = + init_module; + + #[cfg(MODULE)] + #[no_mangle] + #[link_section = ".exit.text"] + pub extern "C" fn cleanup_module() { + // SAFETY: + // - This function is inaccessible to the outside due to the double + // module wrapping it. It is called exactly once by the C side via its + // unique name, + // - furthermore it is only called after `init_module` has returned `0` + // (which delegates to `__init`). + unsafe { __exit() } + } + + #[cfg(MODULE)] + #[used(compiler)] + #[link_section = ".exit.data"] + static __UNIQUE_ID___addressable_cleanup_module: extern "C" fn() = cleanup_module; + + // Built-in modules are initialized through an initcall pointer + // and the identifiers need to be unique. + #[cfg(not(MODULE))] + #[cfg(not(CONFIG_HAVE_ARCH_PREL32_RELOCATIONS))] + #[link_section = #initcall_section] + #[used(compiler)] + pub static #ident_initcall: extern "C" fn() -> + ::kernel::ffi::c_int = #ident_init; + + #[cfg(not(MODULE))] + #[cfg(CONFIG_HAVE_ARCH_PREL32_RELOCATIONS)] + ::core::arch::global_asm!(#global_asm); + + #[cfg(not(MODULE))] + #[no_mangle] + pub extern "C" fn #ident_init() -> ::kernel::ffi::c_int { + // SAFETY: This function is inaccessible to the outside due to the double + // module wrapping it. It is called exactly once by the C side via its + // placement above in the initcall section. + unsafe { __init() } + } + + #[cfg(not(MODULE))] + #[no_mangle] + pub extern "C" fn #ident_exit() { + // SAFETY: + // - This function is inaccessible to the outside due to the double + // module wrapping it. It is called exactly once by the C side via its + // unique name, + // - furthermore it is only called after `#ident_init` has + // returned `0` (which delegates to `__init`). + unsafe { __exit() } + } + + /// # Safety + /// + /// This function must only be called once. + unsafe fn __init() -> ::kernel::ffi::c_int { + let initer = <super::super::LocalModule as ::kernel::InPlaceModule>::init( + &super::super::THIS_MODULE ); + // SAFETY: No data race, since `__MOD` can only be accessed by this module + // and there only `__init` and `__exit` access it. These functions are only + // called once and `__exit` cannot be called before or during `__init`. + match unsafe { initer.__pinned_init(__MOD.as_mut_ptr()) } { + Ok(m) => 0, + Err(e) => e.to_errno(), + } + } + + /// # Safety + /// + /// This function must + /// - only be called once, + /// - be called after `__init` has been called and returned `0`. + unsafe fn __exit() { + // SAFETY: No data race, since `__MOD` can only be accessed by this module + // and there only `__init` and `__exit` access it. These functions are only + // called once and `__init` was already called. + unsafe { + // Invokes `drop()` on `__MOD`, which should be used for cleanup. + __MOD.assume_init_drop(); + } + } + + #modinfo_ts + } + } - #[cfg(not(MODULE))] - #[doc(hidden)] - #[no_mangle] - pub extern \"C\" fn __{ident}_init() -> ::kernel::ffi::c_int {{ - // SAFETY: This function is inaccessible to the outside due to the double - // module wrapping it. It is called exactly once by the C side via its - // placement above in the initcall section. - unsafe {{ __init() }} - }} - - #[cfg(not(MODULE))] - #[doc(hidden)] - #[no_mangle] - pub extern \"C\" fn __{ident}_exit() {{ - // SAFETY: - // - This function is inaccessible to the outside due to the double - // module wrapping it. It is called exactly once by the C side via its - // unique name, - // - furthermore it is only called after `__{ident}_init` has - // returned `0` (which delegates to `__init`). - unsafe {{ __exit() }} - }} - - /// # Safety - /// - /// This function must only be called once. - unsafe fn __init() -> ::kernel::ffi::c_int {{ - let initer = - <{type_} as ::kernel::InPlaceModule>::init(&super::super::THIS_MODULE); - // SAFETY: No data race, since `__MOD` can only be accessed by this module - // and there only `__init` and `__exit` access it. These functions are only - // called once and `__exit` cannot be called before or during `__init`. - match unsafe {{ initer.__pinned_init(__MOD.as_mut_ptr()) }} {{ - Ok(m) => 0, - Err(e) => e.to_errno(), - }} - }} - - /// # Safety - /// - /// This function must - /// - only be called once, - /// - be called after `__init` has been called and returned `0`. - unsafe fn __exit() {{ - // SAFETY: No data race, since `__MOD` can only be accessed by this module - // and there only `__init` and `__exit` access it. These functions are only - // called once and `__init` was already called. - unsafe {{ - // Invokes `drop()` on `__MOD`, which should be used for cleanup. - __MOD.assume_init_drop(); - }} - }} - {modinfo} - }} - }} - mod module_parameters {{ - {params} - }} - ", - type_ = info.type_, - name = info.name, - ident = ident, - modinfo = modinfo.buffer, - params = modinfo.param_buffer, - initcall_section = ".initcall6.init" - ) - .parse() - .expect("Error parsing formatted string into token stream.") + mod module_parameters { + #params_ts + } + }) } diff --git a/rust/macros/paste.rs b/rust/macros/paste.rs index cce712d19855..2181e312a7d3 100644 --- a/rust/macros/paste.rs +++ b/rust/macros/paste.rs @@ -1,6 +1,6 @@ // SPDX-License-Identifier: GPL-2.0 -use proc_macro::{Delimiter, Group, Ident, Spacing, Span, TokenTree}; +use proc_macro2::{Delimiter, Group, Ident, Spacing, Span, TokenTree}; fn concat_helper(tokens: &[TokenTree]) -> Vec<(String, Span)> { let mut tokens = tokens.iter(); diff --git a/rust/macros/quote.rs b/rust/macros/quote.rs deleted file mode 100644 index ddfc21577539..000000000000 --- a/rust/macros/quote.rs +++ /dev/null @@ -1,182 +0,0 @@ -// SPDX-License-Identifier: Apache-2.0 OR MIT - -use proc_macro::{TokenStream, TokenTree}; - -pub(crate) trait ToTokens { - fn to_tokens(&self, tokens: &mut TokenStream); -} - -impl<T: ToTokens> ToTokens for Option<T> { - fn to_tokens(&self, tokens: &mut TokenStream) { - if let Some(v) = self { - v.to_tokens(tokens); - } - } -} - -impl ToTokens for proc_macro::Group { - fn to_tokens(&self, tokens: &mut TokenStream) { - tokens.extend([TokenTree::from(self.clone())]); - } -} - -impl ToTokens for proc_macro::Ident { - fn to_tokens(&self, tokens: &mut TokenStream) { - tokens.extend([TokenTree::from(self.clone())]); - } -} - -impl ToTokens for TokenTree { - fn to_tokens(&self, tokens: &mut TokenStream) { - tokens.extend([self.clone()]); - } -} - -impl ToTokens for TokenStream { - fn to_tokens(&self, tokens: &mut TokenStream) { - tokens.extend(self.clone()); - } -} - -/// Converts tokens into [`proc_macro::TokenStream`] and performs variable interpolations with -/// the given span. -/// -/// This is a similar to the -/// [`quote_spanned!`](https://docs.rs/quote/latest/quote/macro.quote_spanned.html) macro from the -/// `quote` crate but provides only just enough functionality needed by the current `macros` crate. -macro_rules! quote_spanned { - ($span:expr => $($tt:tt)*) => {{ - let mut tokens = ::proc_macro::TokenStream::new(); - { - #[allow(unused_variables)] - let span = $span; - quote_spanned!(@proc tokens span $($tt)*); - } - tokens - }}; - (@proc $v:ident $span:ident) => {}; - (@proc $v:ident $span:ident #$id:ident $($tt:tt)*) => { - $crate::quote::ToTokens::to_tokens(&$id, &mut $v); - quote_spanned!(@proc $v $span $($tt)*); - }; - (@proc $v:ident $span:ident #(#$id:ident)* $($tt:tt)*) => { - for token in $id { - $crate::quote::ToTokens::to_tokens(&token, &mut $v); - } - quote_spanned!(@proc $v $span $($tt)*); - }; - (@proc $v:ident $span:ident ( $($inner:tt)* ) $($tt:tt)*) => { - #[allow(unused_mut)] - let mut tokens = ::proc_macro::TokenStream::new(); - quote_spanned!(@proc tokens $span $($inner)*); - $v.extend([::proc_macro::TokenTree::Group(::proc_macro::Group::new( - ::proc_macro::Delimiter::Parenthesis, - tokens, - ))]); - quote_spanned!(@proc $v $span $($tt)*); - }; - (@proc $v:ident $span:ident [ $($inner:tt)* ] $($tt:tt)*) => { - let mut tokens = ::proc_macro::TokenStream::new(); - quote_spanned!(@proc tokens $span $($inner)*); - $v.extend([::proc_macro::TokenTree::Group(::proc_macro::Group::new( - ::proc_macro::Delimiter::Bracket, - tokens, - ))]); - quote_spanned!(@proc $v $span $($tt)*); - }; - (@proc $v:ident $span:ident { $($inner:tt)* } $($tt:tt)*) => { - let mut tokens = ::proc_macro::TokenStream::new(); - quote_spanned!(@proc tokens $span $($inner)*); - $v.extend([::proc_macro::TokenTree::Group(::proc_macro::Group::new( - ::proc_macro::Delimiter::Brace, - tokens, - ))]); - quote_spanned!(@proc $v $span $($tt)*); - }; - (@proc $v:ident $span:ident :: $($tt:tt)*) => { - $v.extend([::proc_macro::Spacing::Joint, ::proc_macro::Spacing::Alone].map(|spacing| { - ::proc_macro::TokenTree::Punct(::proc_macro::Punct::new(':', spacing)) - })); - quote_spanned!(@proc $v $span $($tt)*); - }; - (@proc $v:ident $span:ident : $($tt:tt)*) => { - $v.extend([::proc_macro::TokenTree::Punct( - ::proc_macro::Punct::new(':', ::proc_macro::Spacing::Alone), - )]); - quote_spanned!(@proc $v $span $($tt)*); - }; - (@proc $v:ident $span:ident , $($tt:tt)*) => { - $v.extend([::proc_macro::TokenTree::Punct( - ::proc_macro::Punct::new(',', ::proc_macro::Spacing::Alone), - )]); - quote_spanned!(@proc $v $span $($tt)*); - }; - (@proc $v:ident $span:ident @ $($tt:tt)*) => { - $v.extend([::proc_macro::TokenTree::Punct( - ::proc_macro::Punct::new('@', ::proc_macro::Spacing::Alone), - )]); - quote_spanned!(@proc $v $span $($tt)*); - }; - (@proc $v:ident $span:ident ! $($tt:tt)*) => { - $v.extend([::proc_macro::TokenTree::Punct( - ::proc_macro::Punct::new('!', ::proc_macro::Spacing::Alone), - )]); - quote_spanned!(@proc $v $span $($tt)*); - }; - (@proc $v:ident $span:ident ; $($tt:tt)*) => { - $v.extend([::proc_macro::TokenTree::Punct( - ::proc_macro::Punct::new(';', ::proc_macro::Spacing::Alone), - )]); - quote_spanned!(@proc $v $span $($tt)*); - }; - (@proc $v:ident $span:ident + $($tt:tt)*) => { - $v.extend([::proc_macro::TokenTree::Punct( - ::proc_macro::Punct::new('+', ::proc_macro::Spacing::Alone), - )]); - quote_spanned!(@proc $v $span $($tt)*); - }; - (@proc $v:ident $span:ident = $($tt:tt)*) => { - $v.extend([::proc_macro::TokenTree::Punct( - ::proc_macro::Punct::new('=', ::proc_macro::Spacing::Alone), - )]); - quote_spanned!(@proc $v $span $($tt)*); - }; - (@proc $v:ident $span:ident # $($tt:tt)*) => { - $v.extend([::proc_macro::TokenTree::Punct( - ::proc_macro::Punct::new('#', ::proc_macro::Spacing::Alone), - )]); - quote_spanned!(@proc $v $span $($tt)*); - }; - (@proc $v:ident $span:ident & $($tt:tt)*) => { - $v.extend([::proc_macro::TokenTree::Punct( - ::proc_macro::Punct::new('&', ::proc_macro::Spacing::Alone), - )]); - quote_spanned!(@proc $v $span $($tt)*); - }; - (@proc $v:ident $span:ident _ $($tt:tt)*) => { - $v.extend([::proc_macro::TokenTree::Ident( - ::proc_macro::Ident::new("_", $span), - )]); - quote_spanned!(@proc $v $span $($tt)*); - }; - (@proc $v:ident $span:ident $id:ident $($tt:tt)*) => { - $v.extend([::proc_macro::TokenTree::Ident( - ::proc_macro::Ident::new(stringify!($id), $span), - )]); - quote_spanned!(@proc $v $span $($tt)*); - }; -} - -/// Converts tokens into [`proc_macro::TokenStream`] and performs variable interpolations with -/// mixed site span ([`Span::mixed_site()`]). -/// -/// This is a similar to the [`quote!`](https://docs.rs/quote/latest/quote/macro.quote.html) macro -/// from the `quote` crate but provides only just enough functionality needed by the current -/// `macros` crate. -/// -/// [`Span::mixed_site()`]: https://doc.rust-lang.org/proc_macro/struct.Span.html#method.mixed_site -macro_rules! quote { - ($($tt:tt)*) => { - quote_spanned!(::proc_macro::Span::mixed_site() => $($tt)*) - } -} diff --git a/rust/macros/vtable.rs b/rust/macros/vtable.rs index ee06044fcd4f..c6510b0c4ea1 100644 --- a/rust/macros/vtable.rs +++ b/rust/macros/vtable.rs @@ -1,96 +1,105 @@ // SPDX-License-Identifier: GPL-2.0 -use proc_macro::{Delimiter, Group, TokenStream, TokenTree}; -use std::collections::HashSet; -use std::fmt::Write; +use std::{ + collections::HashSet, + iter::Extend, // +}; -pub(crate) fn vtable(_attr: TokenStream, ts: TokenStream) -> TokenStream { - let mut tokens: Vec<_> = ts.into_iter().collect(); +use proc_macro2::{ + Ident, + TokenStream, // +}; +use quote::ToTokens; +use syn::{ + parse_quote, + Error, + ImplItem, + Item, + ItemImpl, + ItemTrait, + Result, + TraitItem, // +}; - // Scan for the `trait` or `impl` keyword. - let is_trait = tokens - .iter() - .find_map(|token| match token { - TokenTree::Ident(ident) => match ident.to_string().as_str() { - "trait" => Some(true), - "impl" => Some(false), - _ => None, - }, - _ => None, - }) - .expect("#[vtable] attribute should only be applied to trait or impl block"); +fn handle_trait(mut item: ItemTrait) -> Result<ItemTrait> { + let mut gen_items = Vec::new(); - // Retrieve the main body. The main body should be the last token tree. - let body = match tokens.pop() { - Some(TokenTree::Group(group)) if group.delimiter() == Delimiter::Brace => group, - _ => panic!("cannot locate main body of trait or impl block"), - }; + gen_items.push(parse_quote! { + /// A marker to prevent implementors from forgetting to use [`#[vtable]`](vtable) + /// attribute when implementing this trait. + const USE_VTABLE_ATTR: (); + }); - let mut body_it = body.stream().into_iter(); - let mut functions = Vec::new(); - let mut consts = HashSet::new(); - while let Some(token) = body_it.next() { - match token { - TokenTree::Ident(ident) if ident.to_string() == "fn" => { - let fn_name = match body_it.next() { - Some(TokenTree::Ident(ident)) => ident.to_string(), - // Possibly we've encountered a fn pointer type instead. - _ => continue, - }; - functions.push(fn_name); - } - TokenTree::Ident(ident) if ident.to_string() == "const" => { - let const_name = match body_it.next() { - Some(TokenTree::Ident(ident)) => ident.to_string(), - // Possibly we've encountered an inline const block instead. - _ => continue, - }; - consts.insert(const_name); - } - _ => (), + for item in &item.items { + if let TraitItem::Fn(fn_item) = item { + let name = &fn_item.sig.ident; + let gen_const_name = Ident::new( + &format!("HAS_{}", name.to_string().to_uppercase()), + name.span(), + ); + + // We don't know on the implementation-site whether a method is required or provided + // so we have to generate a const for all methods. + let cfg_attrs = crate::helpers::gather_cfg_attrs(&fn_item.attrs); + let comment = + format!("Indicates if the `{name}` method is overridden by the implementor."); + gen_items.push(parse_quote! { + #(#cfg_attrs)* + #[doc = #comment] + const #gen_const_name: bool = false; + }); } } - let mut const_items; - if is_trait { - const_items = " - /// A marker to prevent implementors from forgetting to use [`#[vtable]`](vtable) - /// attribute when implementing this trait. - const USE_VTABLE_ATTR: (); - " - .to_owned(); + item.items.extend(gen_items); + Ok(item) +} - for f in functions { - let gen_const_name = format!("HAS_{}", f.to_uppercase()); - // Skip if it's declared already -- this allows user override. - if consts.contains(&gen_const_name) { - continue; - } - // We don't know on the implementation-site whether a method is required or provided - // so we have to generate a const for all methods. - write!( - const_items, - "/// Indicates if the `{f}` method is overridden by the implementor. - const {gen_const_name}: bool = false;", - ) - .unwrap(); - consts.insert(gen_const_name); +fn handle_impl(mut item: ItemImpl) -> Result<ItemImpl> { + let mut gen_items = Vec::new(); + let mut defined_consts = HashSet::new(); + + // Iterate over all user-defined constants to gather any possible explicit overrides. + for item in &item.items { + if let ImplItem::Const(const_item) = item { + defined_consts.insert(const_item.ident.clone()); } - } else { - const_items = "const USE_VTABLE_ATTR: () = ();".to_owned(); + } + + gen_items.push(parse_quote! { + const USE_VTABLE_ATTR: () = (); + }); - for f in functions { - let gen_const_name = format!("HAS_{}", f.to_uppercase()); - if consts.contains(&gen_const_name) { + for item in &item.items { + if let ImplItem::Fn(fn_item) = item { + let name = &fn_item.sig.ident; + let gen_const_name = Ident::new( + &format!("HAS_{}", name.to_string().to_uppercase()), + name.span(), + ); + // Skip if it's declared already -- this allows user override. + if defined_consts.contains(&gen_const_name) { continue; } - write!(const_items, "const {gen_const_name}: bool = true;").unwrap(); + let cfg_attrs = crate::helpers::gather_cfg_attrs(&fn_item.attrs); + gen_items.push(parse_quote! { + #(#cfg_attrs)* + const #gen_const_name: bool = true; + }); } } - let new_body = vec![const_items.parse().unwrap(), body.stream()] - .into_iter() - .collect(); - tokens.push(TokenTree::Group(Group::new(Delimiter::Brace, new_body))); - tokens.into_iter().collect() + item.items.extend(gen_items); + Ok(item) +} + +pub(crate) fn vtable(input: Item) -> Result<TokenStream> { + match input { + Item::Trait(item) => Ok(handle_trait(item)?.into_token_stream()), + Item::Impl(item) => Ok(handle_impl(item)?.into_token_stream()), + _ => Err(Error::new_spanned( + input, + "`#[vtable]` attribute should only be applied to trait or impl block", + ))?, + } } diff --git a/rust/pin-init/README.md b/rust/pin-init/README.md index 74bbb4e0a2f7..6cee6ab1eb57 100644 --- a/rust/pin-init/README.md +++ b/rust/pin-init/README.md @@ -135,7 +135,7 @@ struct DriverData { impl DriverData { fn new() -> impl PinInit<Self, Error> { - try_pin_init!(Self { + pin_init!(Self { status <- CMutex::new(0), buffer: Box::init(pin_init::init_zeroed())?, }? Error) diff --git a/rust/pin-init/examples/linked_list.rs b/rust/pin-init/examples/linked_list.rs index f9e117c7dfe0..8445a5890cb7 100644 --- a/rust/pin-init/examples/linked_list.rs +++ b/rust/pin-init/examples/linked_list.rs @@ -6,7 +6,6 @@ use core::{ cell::Cell, - convert::Infallible, marker::PhantomPinned, pin::Pin, ptr::{self, NonNull}, @@ -31,31 +30,31 @@ pub struct ListHead { impl ListHead { #[inline] - pub fn new() -> impl PinInit<Self, Infallible> { - try_pin_init!(&this in Self { + pub fn new() -> impl PinInit<Self> { + pin_init!(&this in Self { next: unsafe { Link::new_unchecked(this) }, prev: unsafe { Link::new_unchecked(this) }, pin: PhantomPinned, - }? Infallible) + }) } #[inline] #[allow(dead_code)] - pub fn insert_next(list: &ListHead) -> impl PinInit<Self, Infallible> + '_ { - try_pin_init!(&this in Self { + pub fn insert_next(list: &ListHead) -> impl PinInit<Self> + '_ { + pin_init!(&this in Self { prev: list.next.prev().replace(unsafe { Link::new_unchecked(this)}), next: list.next.replace(unsafe { Link::new_unchecked(this)}), pin: PhantomPinned, - }? Infallible) + }) } #[inline] - pub fn insert_prev(list: &ListHead) -> impl PinInit<Self, Infallible> + '_ { - try_pin_init!(&this in Self { + pub fn insert_prev(list: &ListHead) -> impl PinInit<Self> + '_ { + pin_init!(&this in Self { next: list.prev.next().replace(unsafe { Link::new_unchecked(this)}), prev: list.prev.replace(unsafe { Link::new_unchecked(this)}), pin: PhantomPinned, - }? Infallible) + }) } #[inline] diff --git a/rust/pin-init/examples/pthread_mutex.rs b/rust/pin-init/examples/pthread_mutex.rs index 49b004c8c137..4e082ec7d5de 100644 --- a/rust/pin-init/examples/pthread_mutex.rs +++ b/rust/pin-init/examples/pthread_mutex.rs @@ -98,11 +98,11 @@ mod pthread_mtx { // SAFETY: mutex has been initialized unsafe { pin_init_from_closure(init) } } - try_pin_init!(Self { - data: UnsafeCell::new(data), - raw <- init_raw(), - pin: PhantomPinned, - }? Error) + pin_init!(Self { + data: UnsafeCell::new(data), + raw <- init_raw(), + pin: PhantomPinned, + }? Error) } #[allow(dead_code)] diff --git a/rust/pin-init/internal/src/diagnostics.rs b/rust/pin-init/internal/src/diagnostics.rs new file mode 100644 index 000000000000..3bdb477c2f2b --- /dev/null +++ b/rust/pin-init/internal/src/diagnostics.rs @@ -0,0 +1,30 @@ +// SPDX-License-Identifier: Apache-2.0 OR MIT + +use std::fmt::Display; + +use proc_macro2::TokenStream; +use syn::{spanned::Spanned, Error}; + +pub(crate) struct DiagCtxt(TokenStream); +pub(crate) struct ErrorGuaranteed(()); + +impl DiagCtxt { + pub(crate) fn error(&mut self, span: impl Spanned, msg: impl Display) -> ErrorGuaranteed { + let error = Error::new(span.span(), msg); + self.0.extend(error.into_compile_error()); + ErrorGuaranteed(()) + } + + pub(crate) fn with( + fun: impl FnOnce(&mut DiagCtxt) -> Result<TokenStream, ErrorGuaranteed>, + ) -> TokenStream { + let mut dcx = Self(TokenStream::new()); + match fun(&mut dcx) { + Ok(mut stream) => { + stream.extend(dcx.0); + stream + } + Err(ErrorGuaranteed(())) => dcx.0, + } + } +} diff --git a/rust/pin-init/internal/src/helpers.rs b/rust/pin-init/internal/src/helpers.rs deleted file mode 100644 index 236f989a50f2..000000000000 --- a/rust/pin-init/internal/src/helpers.rs +++ /dev/null @@ -1,152 +0,0 @@ -// SPDX-License-Identifier: Apache-2.0 OR MIT - -#[cfg(not(kernel))] -use proc_macro2 as proc_macro; - -use proc_macro::{TokenStream, TokenTree}; - -/// Parsed generics. -/// -/// See the field documentation for an explanation what each of the fields represents. -/// -/// # Examples -/// -/// ```rust,ignore -/// # let input = todo!(); -/// let (Generics { decl_generics, impl_generics, ty_generics }, rest) = parse_generics(input); -/// quote! { -/// struct Foo<$($decl_generics)*> { -/// // ... -/// } -/// -/// impl<$impl_generics> Foo<$ty_generics> { -/// fn foo() { -/// // ... -/// } -/// } -/// } -/// ``` -pub(crate) struct Generics { - /// The generics with bounds and default values (e.g. `T: Clone, const N: usize = 0`). - /// - /// Use this on type definitions e.g. `struct Foo<$decl_generics> ...` (or `union`/`enum`). - pub(crate) decl_generics: Vec<TokenTree>, - /// The generics with bounds (e.g. `T: Clone, const N: usize`). - /// - /// Use this on `impl` blocks e.g. `impl<$impl_generics> Trait for ...`. - pub(crate) impl_generics: Vec<TokenTree>, - /// The generics without bounds and without default values (e.g. `T, N`). - /// - /// Use this when you use the type that is declared with these generics e.g. - /// `Foo<$ty_generics>`. - pub(crate) ty_generics: Vec<TokenTree>, -} - -/// Parses the given `TokenStream` into `Generics` and the rest. -/// -/// The generics are not present in the rest, but a where clause might remain. -pub(crate) fn parse_generics(input: TokenStream) -> (Generics, Vec<TokenTree>) { - // The generics with bounds and default values. - let mut decl_generics = vec![]; - // `impl_generics`, the declared generics with their bounds. - let mut impl_generics = vec![]; - // Only the names of the generics, without any bounds. - let mut ty_generics = vec![]; - // Tokens not related to the generics e.g. the `where` token and definition. - let mut rest = vec![]; - // The current level of `<`. - let mut nesting = 0; - let mut toks = input.into_iter(); - // If we are at the beginning of a generic parameter. - let mut at_start = true; - let mut skip_until_comma = false; - while let Some(tt) = toks.next() { - if nesting == 1 && matches!(&tt, TokenTree::Punct(p) if p.as_char() == '>') { - // Found the end of the generics. - break; - } else if nesting >= 1 { - decl_generics.push(tt.clone()); - } - match tt.clone() { - TokenTree::Punct(p) if p.as_char() == '<' => { - if nesting >= 1 && !skip_until_comma { - // This is inside of the generics and part of some bound. - impl_generics.push(tt); - } - nesting += 1; - } - TokenTree::Punct(p) if p.as_char() == '>' => { - // This is a parsing error, so we just end it here. - if nesting == 0 { - break; - } else { - nesting -= 1; - if nesting >= 1 && !skip_until_comma { - // We are still inside of the generics and part of some bound. - impl_generics.push(tt); - } - } - } - TokenTree::Punct(p) if skip_until_comma && p.as_char() == ',' => { - if nesting == 1 { - impl_generics.push(tt.clone()); - impl_generics.push(tt); - skip_until_comma = false; - } - } - _ if !skip_until_comma => { - match nesting { - // If we haven't entered the generics yet, we still want to keep these tokens. - 0 => rest.push(tt), - 1 => { - // Here depending on the token, it might be a generic variable name. - match tt.clone() { - TokenTree::Ident(i) if at_start && i.to_string() == "const" => { - let Some(name) = toks.next() else { - // Parsing error. - break; - }; - impl_generics.push(tt); - impl_generics.push(name.clone()); - ty_generics.push(name.clone()); - decl_generics.push(name); - at_start = false; - } - TokenTree::Ident(_) if at_start => { - impl_generics.push(tt.clone()); - ty_generics.push(tt); - at_start = false; - } - TokenTree::Punct(p) if p.as_char() == ',' => { - impl_generics.push(tt.clone()); - ty_generics.push(tt); - at_start = true; - } - // Lifetimes begin with `'`. - TokenTree::Punct(p) if p.as_char() == '\'' && at_start => { - impl_generics.push(tt.clone()); - ty_generics.push(tt); - } - // Generics can have default values, we skip these. - TokenTree::Punct(p) if p.as_char() == '=' => { - skip_until_comma = true; - } - _ => impl_generics.push(tt), - } - } - _ => impl_generics.push(tt), - } - } - _ => {} - } - } - rest.extend(toks); - ( - Generics { - impl_generics, - decl_generics, - ty_generics, - }, - rest, - ) -} diff --git a/rust/pin-init/internal/src/init.rs b/rust/pin-init/internal/src/init.rs new file mode 100644 index 000000000000..42936f915a07 --- /dev/null +++ b/rust/pin-init/internal/src/init.rs @@ -0,0 +1,548 @@ +// SPDX-License-Identifier: Apache-2.0 OR MIT + +use proc_macro2::{Span, TokenStream}; +use quote::{format_ident, quote, quote_spanned}; +use syn::{ + braced, + parse::{End, Parse}, + parse_quote, + punctuated::Punctuated, + spanned::Spanned, + token, Attribute, Block, Expr, ExprCall, ExprPath, Ident, Path, Token, Type, +}; + +use crate::diagnostics::{DiagCtxt, ErrorGuaranteed}; + +pub(crate) struct Initializer { + attrs: Vec<InitializerAttribute>, + this: Option<This>, + path: Path, + brace_token: token::Brace, + fields: Punctuated<InitializerField, Token![,]>, + rest: Option<(Token![..], Expr)>, + error: Option<(Token![?], Type)>, +} + +struct This { + _and_token: Token![&], + ident: Ident, + _in_token: Token![in], +} + +struct InitializerField { + attrs: Vec<Attribute>, + kind: InitializerKind, +} + +enum InitializerKind { + Value { + ident: Ident, + value: Option<(Token![:], Expr)>, + }, + Init { + ident: Ident, + _left_arrow_token: Token![<-], + value: Expr, + }, + Code { + _underscore_token: Token![_], + _colon_token: Token![:], + block: Block, + }, +} + +impl InitializerKind { + fn ident(&self) -> Option<&Ident> { + match self { + Self::Value { ident, .. } | Self::Init { ident, .. } => Some(ident), + Self::Code { .. } => None, + } + } +} + +enum InitializerAttribute { + DefaultError(DefaultErrorAttribute), + DisableInitializedFieldAccess, +} + +struct DefaultErrorAttribute { + ty: Box<Type>, +} + +pub(crate) fn expand( + Initializer { + attrs, + this, + path, + brace_token, + fields, + rest, + error, + }: Initializer, + default_error: Option<&'static str>, + pinned: bool, + dcx: &mut DiagCtxt, +) -> Result<TokenStream, ErrorGuaranteed> { + let error = error.map_or_else( + || { + if let Some(default_error) = attrs.iter().fold(None, |acc, attr| { + if let InitializerAttribute::DefaultError(DefaultErrorAttribute { ty }) = attr { + Some(ty.clone()) + } else { + acc + } + }) { + default_error + } else if let Some(default_error) = default_error { + syn::parse_str(default_error).unwrap() + } else { + dcx.error(brace_token.span.close(), "expected `? <type>` after `}`"); + parse_quote!(::core::convert::Infallible) + } + }, + |(_, err)| Box::new(err), + ); + let slot = format_ident!("slot"); + let (has_data_trait, data_trait, get_data, init_from_closure) = if pinned { + ( + format_ident!("HasPinData"), + format_ident!("PinData"), + format_ident!("__pin_data"), + format_ident!("pin_init_from_closure"), + ) + } else { + ( + format_ident!("HasInitData"), + format_ident!("InitData"), + format_ident!("__init_data"), + format_ident!("init_from_closure"), + ) + }; + let init_kind = get_init_kind(rest, dcx); + let zeroable_check = match init_kind { + InitKind::Normal => quote!(), + InitKind::Zeroing => quote! { + // The user specified `..Zeroable::zeroed()` at the end of the list of fields. + // Therefore we check if the struct implements `Zeroable` and then zero the memory. + // This allows us to also remove the check that all fields are present (since we + // already set the memory to zero and that is a valid bit pattern). + fn assert_zeroable<T: ?::core::marker::Sized>(_: *mut T) + where T: ::pin_init::Zeroable + {} + // Ensure that the struct is indeed `Zeroable`. + assert_zeroable(#slot); + // SAFETY: The type implements `Zeroable` by the check above. + unsafe { ::core::ptr::write_bytes(#slot, 0, 1) }; + }, + }; + let this = match this { + None => quote!(), + Some(This { ident, .. }) => quote! { + // Create the `this` so it can be referenced by the user inside of the + // expressions creating the individual fields. + let #ident = unsafe { ::core::ptr::NonNull::new_unchecked(slot) }; + }, + }; + // `mixed_site` ensures that the data is not accessible to the user-controlled code. + let data = Ident::new("__data", Span::mixed_site()); + let init_fields = init_fields( + &fields, + pinned, + !attrs + .iter() + .any(|attr| matches!(attr, InitializerAttribute::DisableInitializedFieldAccess)), + &data, + &slot, + ); + let field_check = make_field_check(&fields, init_kind, &path); + Ok(quote! {{ + // We do not want to allow arbitrary returns, so we declare this type as the `Ok` return + // type and shadow it later when we insert the arbitrary user code. That way there will be + // no possibility of returning without `unsafe`. + struct __InitOk; + + // Get the data about fields from the supplied type. + // SAFETY: TODO + let #data = unsafe { + use ::pin_init::__internal::#has_data_trait; + // Can't use `<#path as #has_data_trait>::#get_data`, since the user is able to omit + // generics (which need to be present with that syntax). + #path::#get_data() + }; + // Ensure that `#data` really is of type `#data` and help with type inference: + let init = ::pin_init::__internal::#data_trait::make_closure::<_, __InitOk, #error>( + #data, + move |slot| { + { + // Shadow the structure so it cannot be used to return early. + struct __InitOk; + #zeroable_check + #this + #init_fields + #field_check + } + Ok(__InitOk) + } + ); + let init = move |slot| -> ::core::result::Result<(), #error> { + init(slot).map(|__InitOk| ()) + }; + // SAFETY: TODO + let init = unsafe { ::pin_init::#init_from_closure::<_, #error>(init) }; + init + }}) +} + +enum InitKind { + Normal, + Zeroing, +} + +fn get_init_kind(rest: Option<(Token![..], Expr)>, dcx: &mut DiagCtxt) -> InitKind { + let Some((dotdot, expr)) = rest else { + return InitKind::Normal; + }; + match &expr { + Expr::Call(ExprCall { func, args, .. }) if args.is_empty() => match &**func { + Expr::Path(ExprPath { + attrs, + qself: None, + path: + Path { + leading_colon: None, + segments, + }, + }) if attrs.is_empty() + && segments.len() == 2 + && segments[0].ident == "Zeroable" + && segments[0].arguments.is_none() + && segments[1].ident == "init_zeroed" + && segments[1].arguments.is_none() => + { + return InitKind::Zeroing; + } + _ => {} + }, + _ => {} + } + dcx.error( + dotdot.span().join(expr.span()).unwrap_or(expr.span()), + "expected nothing or `..Zeroable::init_zeroed()`.", + ); + InitKind::Normal +} + +/// Generate the code that initializes the fields of the struct using the initializers in `field`. +fn init_fields( + fields: &Punctuated<InitializerField, Token![,]>, + pinned: bool, + generate_initialized_accessors: bool, + data: &Ident, + slot: &Ident, +) -> TokenStream { + let mut guards = vec![]; + let mut guard_attrs = vec![]; + let mut res = TokenStream::new(); + for InitializerField { attrs, kind } in fields { + let cfgs = { + let mut cfgs = attrs.clone(); + cfgs.retain(|attr| attr.path().is_ident("cfg")); + cfgs + }; + let init = match kind { + InitializerKind::Value { ident, value } => { + let mut value_ident = ident.clone(); + let value_prep = value.as_ref().map(|value| &value.1).map(|value| { + // Setting the span of `value_ident` to `value`'s span improves error messages + // when the type of `value` is wrong. + value_ident.set_span(value.span()); + quote!(let #value_ident = #value;) + }); + // Again span for better diagnostics + let write = quote_spanned!(ident.span()=> ::core::ptr::write); + let accessor = if pinned { + let project_ident = format_ident!("__project_{ident}"); + quote! { + // SAFETY: TODO + unsafe { #data.#project_ident(&mut (*#slot).#ident) } + } + } else { + quote! { + // SAFETY: TODO + unsafe { &mut (*#slot).#ident } + } + }; + let accessor = generate_initialized_accessors.then(|| { + quote! { + #(#cfgs)* + #[allow(unused_variables)] + let #ident = #accessor; + } + }); + quote! { + #(#attrs)* + { + #value_prep + // SAFETY: TODO + unsafe { #write(::core::ptr::addr_of_mut!((*#slot).#ident), #value_ident) }; + } + #accessor + } + } + InitializerKind::Init { ident, value, .. } => { + // Again span for better diagnostics + let init = format_ident!("init", span = value.span()); + let (value_init, accessor) = if pinned { + let project_ident = format_ident!("__project_{ident}"); + ( + quote! { + // SAFETY: + // - `slot` is valid, because we are inside of an initializer closure, we + // return when an error/panic occurs. + // - We also use `#data` to require the correct trait (`Init` or `PinInit`) + // for `#ident`. + unsafe { #data.#ident(::core::ptr::addr_of_mut!((*#slot).#ident), #init)? }; + }, + quote! { + // SAFETY: TODO + unsafe { #data.#project_ident(&mut (*#slot).#ident) } + }, + ) + } else { + ( + quote! { + // SAFETY: `slot` is valid, because we are inside of an initializer + // closure, we return when an error/panic occurs. + unsafe { + ::pin_init::Init::__init( + #init, + ::core::ptr::addr_of_mut!((*#slot).#ident), + )? + }; + }, + quote! { + // SAFETY: TODO + unsafe { &mut (*#slot).#ident } + }, + ) + }; + let accessor = generate_initialized_accessors.then(|| { + quote! { + #(#cfgs)* + #[allow(unused_variables)] + let #ident = #accessor; + } + }); + quote! { + #(#attrs)* + { + let #init = #value; + #value_init + } + #accessor + } + } + InitializerKind::Code { block: value, .. } => quote! { + #(#attrs)* + #[allow(unused_braces)] + #value + }, + }; + res.extend(init); + if let Some(ident) = kind.ident() { + // `mixed_site` ensures that the guard is not accessible to the user-controlled code. + let guard = format_ident!("__{ident}_guard", span = Span::mixed_site()); + res.extend(quote! { + #(#cfgs)* + // Create the drop guard: + // + // We rely on macro hygiene to make it impossible for users to access this local + // variable. + // SAFETY: We forget the guard later when initialization has succeeded. + let #guard = unsafe { + ::pin_init::__internal::DropGuard::new( + ::core::ptr::addr_of_mut!((*slot).#ident) + ) + }; + }); + guards.push(guard); + guard_attrs.push(cfgs); + } + } + quote! { + #res + // If execution reaches this point, all fields have been initialized. Therefore we can now + // dismiss the guards by forgetting them. + #( + #(#guard_attrs)* + ::core::mem::forget(#guards); + )* + } +} + +/// Generate the check for ensuring that every field has been initialized. +fn make_field_check( + fields: &Punctuated<InitializerField, Token![,]>, + init_kind: InitKind, + path: &Path, +) -> TokenStream { + let field_attrs = fields + .iter() + .filter_map(|f| f.kind.ident().map(|_| &f.attrs)); + let field_name = fields.iter().filter_map(|f| f.kind.ident()); + match init_kind { + InitKind::Normal => quote! { + // We use unreachable code to ensure that all fields have been mentioned exactly once, + // this struct initializer will still be type-checked and complain with a very natural + // error message if a field is forgotten/mentioned more than once. + #[allow(unreachable_code, clippy::diverging_sub_expression)] + // SAFETY: this code is never executed. + let _ = || unsafe { + ::core::ptr::write(slot, #path { + #( + #(#field_attrs)* + #field_name: ::core::panic!(), + )* + }) + }; + }, + InitKind::Zeroing => quote! { + // We use unreachable code to ensure that all fields have been mentioned at most once. + // Since the user specified `..Zeroable::zeroed()` at the end, all missing fields will + // be zeroed. This struct initializer will still be type-checked and complain with a + // very natural error message if a field is mentioned more than once, or doesn't exist. + #[allow(unreachable_code, clippy::diverging_sub_expression, unused_assignments)] + // SAFETY: this code is never executed. + let _ = || unsafe { + ::core::ptr::write(slot, #path { + #( + #(#field_attrs)* + #field_name: ::core::panic!(), + )* + ..::core::mem::zeroed() + }) + }; + }, + } +} + +impl Parse for Initializer { + fn parse(input: syn::parse::ParseStream<'_>) -> syn::Result<Self> { + let attrs = input.call(Attribute::parse_outer)?; + let this = input.peek(Token![&]).then(|| input.parse()).transpose()?; + let path = input.parse()?; + let content; + let brace_token = braced!(content in input); + let mut fields = Punctuated::new(); + loop { + let lh = content.lookahead1(); + if lh.peek(End) || lh.peek(Token![..]) { + break; + } else if lh.peek(Ident) || lh.peek(Token![_]) || lh.peek(Token![#]) { + fields.push_value(content.parse()?); + let lh = content.lookahead1(); + if lh.peek(End) { + break; + } else if lh.peek(Token![,]) { + fields.push_punct(content.parse()?); + } else { + return Err(lh.error()); + } + } else { + return Err(lh.error()); + } + } + let rest = content + .peek(Token![..]) + .then(|| Ok::<_, syn::Error>((content.parse()?, content.parse()?))) + .transpose()?; + let error = input + .peek(Token![?]) + .then(|| Ok::<_, syn::Error>((input.parse()?, input.parse()?))) + .transpose()?; + let attrs = attrs + .into_iter() + .map(|a| { + if a.path().is_ident("default_error") { + a.parse_args::<DefaultErrorAttribute>() + .map(InitializerAttribute::DefaultError) + } else if a.path().is_ident("disable_initialized_field_access") { + a.meta + .require_path_only() + .map(|_| InitializerAttribute::DisableInitializedFieldAccess) + } else { + Err(syn::Error::new_spanned(a, "unknown initializer attribute")) + } + }) + .collect::<Result<Vec<_>, _>>()?; + Ok(Self { + attrs, + this, + path, + brace_token, + fields, + rest, + error, + }) + } +} + +impl Parse for DefaultErrorAttribute { + fn parse(input: syn::parse::ParseStream<'_>) -> syn::Result<Self> { + Ok(Self { ty: input.parse()? }) + } +} + +impl Parse for This { + fn parse(input: syn::parse::ParseStream<'_>) -> syn::Result<Self> { + Ok(Self { + _and_token: input.parse()?, + ident: input.parse()?, + _in_token: input.parse()?, + }) + } +} + +impl Parse for InitializerField { + fn parse(input: syn::parse::ParseStream<'_>) -> syn::Result<Self> { + let attrs = input.call(Attribute::parse_outer)?; + Ok(Self { + attrs, + kind: input.parse()?, + }) + } +} + +impl Parse for InitializerKind { + fn parse(input: syn::parse::ParseStream<'_>) -> syn::Result<Self> { + let lh = input.lookahead1(); + if lh.peek(Token![_]) { + Ok(Self::Code { + _underscore_token: input.parse()?, + _colon_token: input.parse()?, + block: input.parse()?, + }) + } else if lh.peek(Ident) { + let ident = input.parse()?; + let lh = input.lookahead1(); + if lh.peek(Token![<-]) { + Ok(Self::Init { + ident, + _left_arrow_token: input.parse()?, + value: input.parse()?, + }) + } else if lh.peek(Token![:]) { + Ok(Self::Value { + ident, + value: Some((input.parse()?, input.parse()?)), + }) + } else if lh.peek(Token![,]) || lh.peek(End) { + Ok(Self::Value { ident, value: None }) + } else { + Err(lh.error()) + } + } else { + Err(lh.error()) + } + } +} diff --git a/rust/pin-init/internal/src/lib.rs b/rust/pin-init/internal/src/lib.rs index 297b0129a5bf..08372c8f65f0 100644 --- a/rust/pin-init/internal/src/lib.rs +++ b/rust/pin-init/internal/src/lib.rs @@ -7,48 +7,54 @@ //! `pin-init` proc macros. #![cfg_attr(not(RUSTC_LINT_REASONS_IS_STABLE), feature(lint_reasons))] -// Allow `.into()` to convert -// - `proc_macro2::TokenStream` into `proc_macro::TokenStream` in the user-space version. -// - `proc_macro::TokenStream` into `proc_macro::TokenStream` in the kernel version. -// Clippy warns on this conversion, but it's required by the user-space version. -// -// Remove once we have `proc_macro2` in the kernel. -#![allow(clippy::useless_conversion)] // Documentation is done in the pin-init crate instead. #![allow(missing_docs)] use proc_macro::TokenStream; +use syn::parse_macro_input; -#[cfg(kernel)] -#[path = "../../../macros/quote.rs"] -#[macro_use] -#[cfg_attr(not(kernel), rustfmt::skip)] -mod quote; -#[cfg(not(kernel))] -#[macro_use] -extern crate quote; +use crate::diagnostics::DiagCtxt; -mod helpers; +mod diagnostics; +mod init; mod pin_data; mod pinned_drop; mod zeroable; #[proc_macro_attribute] -pub fn pin_data(inner: TokenStream, item: TokenStream) -> TokenStream { - pin_data::pin_data(inner.into(), item.into()).into() +pub fn pin_data(args: TokenStream, input: TokenStream) -> TokenStream { + let args = parse_macro_input!(args); + let input = parse_macro_input!(input); + DiagCtxt::with(|dcx| pin_data::pin_data(args, input, dcx)).into() } #[proc_macro_attribute] pub fn pinned_drop(args: TokenStream, input: TokenStream) -> TokenStream { - pinned_drop::pinned_drop(args.into(), input.into()).into() + let args = parse_macro_input!(args); + let input = parse_macro_input!(input); + DiagCtxt::with(|dcx| pinned_drop::pinned_drop(args, input, dcx)).into() } #[proc_macro_derive(Zeroable)] pub fn derive_zeroable(input: TokenStream) -> TokenStream { - zeroable::derive(input.into()).into() + let input = parse_macro_input!(input); + DiagCtxt::with(|dcx| zeroable::derive(input, dcx)).into() } #[proc_macro_derive(MaybeZeroable)] pub fn maybe_derive_zeroable(input: TokenStream) -> TokenStream { - zeroable::maybe_derive(input.into()).into() + let input = parse_macro_input!(input); + DiagCtxt::with(|dcx| zeroable::maybe_derive(input, dcx)).into() +} +#[proc_macro] +pub fn init(input: TokenStream) -> TokenStream { + let input = parse_macro_input!(input); + DiagCtxt::with(|dcx| init::expand(input, Some("::core::convert::Infallible"), false, dcx)) + .into() +} + +#[proc_macro] +pub fn pin_init(input: TokenStream) -> TokenStream { + let input = parse_macro_input!(input); + DiagCtxt::with(|dcx| init::expand(input, Some("::core::convert::Infallible"), true, dcx)).into() } diff --git a/rust/pin-init/internal/src/pin_data.rs b/rust/pin-init/internal/src/pin_data.rs index 87d4a7eb1d35..7d871236b49c 100644 --- a/rust/pin-init/internal/src/pin_data.rs +++ b/rust/pin-init/internal/src/pin_data.rs @@ -1,132 +1,513 @@ // SPDX-License-Identifier: Apache-2.0 OR MIT -#[cfg(not(kernel))] -use proc_macro2 as proc_macro; +use proc_macro2::TokenStream; +use quote::{format_ident, quote}; +use syn::{ + parse::{End, Nothing, Parse}, + parse_quote, parse_quote_spanned, + spanned::Spanned, + visit_mut::VisitMut, + Field, Generics, Ident, Item, PathSegment, Type, TypePath, Visibility, WhereClause, +}; -use crate::helpers::{parse_generics, Generics}; -use proc_macro::{Group, Punct, Spacing, TokenStream, TokenTree}; +use crate::diagnostics::{DiagCtxt, ErrorGuaranteed}; -pub(crate) fn pin_data(args: TokenStream, input: TokenStream) -> TokenStream { - // This proc-macro only does some pre-parsing and then delegates the actual parsing to - // `pin_init::__pin_data!`. +pub(crate) mod kw { + syn::custom_keyword!(PinnedDrop); +} + +pub(crate) enum Args { + Nothing(Nothing), + #[allow(dead_code)] + PinnedDrop(kw::PinnedDrop), +} + +impl Parse for Args { + fn parse(input: syn::parse::ParseStream<'_>) -> syn::Result<Self> { + let lh = input.lookahead1(); + if lh.peek(End) { + input.parse().map(Self::Nothing) + } else if lh.peek(kw::PinnedDrop) { + input.parse().map(Self::PinnedDrop) + } else { + Err(lh.error()) + } + } +} + +pub(crate) fn pin_data( + args: Args, + input: Item, + dcx: &mut DiagCtxt, +) -> Result<TokenStream, ErrorGuaranteed> { + let mut struct_ = match input { + Item::Struct(struct_) => struct_, + Item::Enum(enum_) => { + return Err(dcx.error( + enum_.enum_token, + "`#[pin_data]` only supports structs for now", + )); + } + Item::Union(union) => { + return Err(dcx.error( + union.union_token, + "`#[pin_data]` only supports structs for now", + )); + } + rest => { + return Err(dcx.error( + rest, + "`#[pin_data]` can only be applied to struct, enum and union definitions", + )); + } + }; + + // The generics might contain the `Self` type. Since this macro will define a new type with the + // same generics and bounds, this poses a problem: `Self` will refer to the new type as opposed + // to this struct definition. Therefore we have to replace `Self` with the concrete name. + let mut replacer = { + let name = &struct_.ident; + let (_, ty_generics, _) = struct_.generics.split_for_impl(); + SelfReplacer(parse_quote!(#name #ty_generics)) + }; + replacer.visit_generics_mut(&mut struct_.generics); + replacer.visit_fields_mut(&mut struct_.fields); + + let fields: Vec<(bool, &Field)> = struct_ + .fields + .iter_mut() + .map(|field| { + let len = field.attrs.len(); + field.attrs.retain(|a| !a.path().is_ident("pin")); + (len != field.attrs.len(), &*field) + }) + .collect(); + + for (pinned, field) in &fields { + if !pinned && is_phantom_pinned(&field.ty) { + dcx.error( + field, + format!( + "The field `{}` of type `PhantomPinned` only has an effect \ + if it has the `#[pin]` attribute", + field.ident.as_ref().unwrap(), + ), + ); + } + } + + let unpin_impl = generate_unpin_impl(&struct_.ident, &struct_.generics, &fields); + let drop_impl = generate_drop_impl(&struct_.ident, &struct_.generics, args); + let projections = + generate_projections(&struct_.vis, &struct_.ident, &struct_.generics, &fields); + let the_pin_data = + generate_the_pin_data(&struct_.vis, &struct_.ident, &struct_.generics, &fields); + + Ok(quote! { + #struct_ + #projections + // We put the rest into this const item, because it then will not be accessible to anything + // outside. + const _: () = { + #the_pin_data + #unpin_impl + #drop_impl + }; + }) +} + +fn is_phantom_pinned(ty: &Type) -> bool { + match ty { + Type::Path(TypePath { qself: None, path }) => { + // Cannot possibly refer to `PhantomPinned` (except alias, but that's on the user). + if path.segments.len() > 3 { + return false; + } + // If there is a `::`, then the path needs to be `::core::marker::PhantomPinned` or + // `::std::marker::PhantomPinned`. + if path.leading_colon.is_some() && path.segments.len() != 3 { + return false; + } + let expected: Vec<&[&str]> = vec![&["PhantomPinned"], &["marker"], &["core", "std"]]; + for (actual, expected) in path.segments.iter().rev().zip(expected) { + if !actual.arguments.is_empty() || expected.iter().all(|e| actual.ident != e) { + return false; + } + } + true + } + _ => false, + } +} +fn generate_unpin_impl( + ident: &Ident, + generics: &Generics, + fields: &[(bool, &Field)], +) -> TokenStream { + let (_, ty_generics, _) = generics.split_for_impl(); + let mut generics_with_pin_lt = generics.clone(); + generics_with_pin_lt.params.insert(0, parse_quote!('__pin)); + generics_with_pin_lt.make_where_clause(); let ( - Generics { - impl_generics, - decl_generics, - ty_generics, - }, - rest, - ) = parse_generics(input); - // The struct definition might contain the `Self` type. Since `__pin_data!` will define a new - // type with the same generics and bounds, this poses a problem, since `Self` will refer to the - // new type as opposed to this struct definition. Therefore we have to replace `Self` with the - // concrete name. - - // Errors that occur when replacing `Self` with `struct_name`. - let mut errs = TokenStream::new(); - // The name of the struct with ty_generics. - let struct_name = rest - .iter() - .skip_while(|tt| !matches!(tt, TokenTree::Ident(i) if i.to_string() == "struct")) - .nth(1) - .and_then(|tt| match tt { - TokenTree::Ident(_) => { - let tt = tt.clone(); - let mut res = vec![tt]; - if !ty_generics.is_empty() { - // We add this, so it is maximally compatible with e.g. `Self::CONST` which - // will be replaced by `StructName::<$generics>::CONST`. - res.push(TokenTree::Punct(Punct::new(':', Spacing::Joint))); - res.push(TokenTree::Punct(Punct::new(':', Spacing::Alone))); - res.push(TokenTree::Punct(Punct::new('<', Spacing::Alone))); - res.extend(ty_generics.iter().cloned()); - res.push(TokenTree::Punct(Punct::new('>', Spacing::Alone))); + impl_generics_with_pin_lt, + ty_generics_with_pin_lt, + Some(WhereClause { + where_token, + predicates, + }), + ) = generics_with_pin_lt.split_for_impl() + else { + unreachable!() + }; + let pinned_fields = fields.iter().filter_map(|(b, f)| b.then_some(f)); + quote! { + // This struct will be used for the unpin analysis. It is needed, because only structurally + // pinned fields are relevant whether the struct should implement `Unpin`. + #[allow(dead_code)] // The fields below are never used. + struct __Unpin #generics_with_pin_lt + #where_token + #predicates + { + __phantom_pin: ::core::marker::PhantomData<fn(&'__pin ()) -> &'__pin ()>, + __phantom: ::core::marker::PhantomData< + fn(#ident #ty_generics) -> #ident #ty_generics + >, + #(#pinned_fields),* + } + + #[doc(hidden)] + impl #impl_generics_with_pin_lt ::core::marker::Unpin for #ident #ty_generics + #where_token + __Unpin #ty_generics_with_pin_lt: ::core::marker::Unpin, + #predicates + {} + } +} + +fn generate_drop_impl(ident: &Ident, generics: &Generics, args: Args) -> TokenStream { + let (impl_generics, ty_generics, whr) = generics.split_for_impl(); + let has_pinned_drop = matches!(args, Args::PinnedDrop(_)); + // We need to disallow normal `Drop` implementation, the exact behavior depends on whether + // `PinnedDrop` was specified in `args`. + if has_pinned_drop { + // When `PinnedDrop` was specified we just implement `Drop` and delegate. + quote! { + impl #impl_generics ::core::ops::Drop for #ident #ty_generics + #whr + { + fn drop(&mut self) { + // SAFETY: Since this is a destructor, `self` will not move after this function + // terminates, since it is inaccessible. + let pinned = unsafe { ::core::pin::Pin::new_unchecked(self) }; + // SAFETY: Since this is a drop function, we can create this token to call the + // pinned destructor of this type. + let token = unsafe { ::pin_init::__internal::OnlyCallFromDrop::new() }; + ::pin_init::PinnedDrop::drop(pinned, token); } - Some(res) } - _ => None, - }) - .unwrap_or_else(|| { - // If we did not find the name of the struct then we will use `Self` as the replacement - // and add a compile error to ensure it does not compile. - errs.extend( - "::core::compile_error!(\"Could not locate type name.\");" - .parse::<TokenStream>() - .unwrap(), - ); - "Self".parse::<TokenStream>().unwrap().into_iter().collect() - }); - let impl_generics = impl_generics - .into_iter() - .flat_map(|tt| replace_self_and_deny_type_defs(&struct_name, tt, &mut errs)) - .collect::<Vec<_>>(); - let mut rest = rest - .into_iter() - .flat_map(|tt| { - // We ignore top level `struct` tokens, since they would emit a compile error. - if matches!(&tt, TokenTree::Ident(i) if i.to_string() == "struct") { - vec![tt] + } + } else { + // When no `PinnedDrop` was specified, then we have to prevent implementing drop. + quote! { + // We prevent this by creating a trait that will be implemented for all types implementing + // `Drop`. Additionally we will implement this trait for the struct leading to a conflict, + // if it also implements `Drop` + trait MustNotImplDrop {} + #[expect(drop_bounds)] + impl<T: ::core::ops::Drop + ?::core::marker::Sized> MustNotImplDrop for T {} + impl #impl_generics MustNotImplDrop for #ident #ty_generics + #whr + {} + // We also take care to prevent users from writing a useless `PinnedDrop` implementation. + // They might implement `PinnedDrop` correctly for the struct, but forget to give + // `PinnedDrop` as the parameter to `#[pin_data]`. + #[expect(non_camel_case_types)] + trait UselessPinnedDropImpl_you_need_to_specify_PinnedDrop {} + impl<T: ::pin_init::PinnedDrop + ?::core::marker::Sized> + UselessPinnedDropImpl_you_need_to_specify_PinnedDrop for T {} + impl #impl_generics + UselessPinnedDropImpl_you_need_to_specify_PinnedDrop for #ident #ty_generics + #whr + {} + } + } +} + +fn generate_projections( + vis: &Visibility, + ident: &Ident, + generics: &Generics, + fields: &[(bool, &Field)], +) -> TokenStream { + let (impl_generics, ty_generics, _) = generics.split_for_impl(); + let mut generics_with_pin_lt = generics.clone(); + generics_with_pin_lt.params.insert(0, parse_quote!('__pin)); + let (_, ty_generics_with_pin_lt, whr) = generics_with_pin_lt.split_for_impl(); + let projection = format_ident!("{ident}Projection"); + let this = format_ident!("this"); + + let (fields_decl, fields_proj) = collect_tuple(fields.iter().map( + |( + pinned, + Field { + vis, + ident, + ty, + attrs, + .. + }, + )| { + let mut attrs = attrs.clone(); + attrs.retain(|a| !a.path().is_ident("pin")); + let mut no_doc_attrs = attrs.clone(); + no_doc_attrs.retain(|a| !a.path().is_ident("doc")); + let ident = ident + .as_ref() + .expect("only structs with named fields are supported"); + if *pinned { + ( + quote!( + #(#attrs)* + #vis #ident: ::core::pin::Pin<&'__pin mut #ty>, + ), + quote!( + #(#no_doc_attrs)* + // SAFETY: this field is structurally pinned. + #ident: unsafe { ::core::pin::Pin::new_unchecked(&mut #this.#ident) }, + ), + ) } else { - replace_self_and_deny_type_defs(&struct_name, tt, &mut errs) + ( + quote!( + #(#attrs)* + #vis #ident: &'__pin mut #ty, + ), + quote!( + #(#no_doc_attrs)* + #ident: &mut #this.#ident, + ), + ) } - }) - .collect::<Vec<_>>(); - // This should be the body of the struct `{...}`. - let last = rest.pop(); - let mut quoted = quote!(::pin_init::__pin_data! { - parse_input: - @args(#args), - @sig(#(#rest)*), - @impl_generics(#(#impl_generics)*), - @ty_generics(#(#ty_generics)*), - @decl_generics(#(#decl_generics)*), - @body(#last), - }); - quoted.extend(errs); - quoted + }, + )); + let structurally_pinned_fields_docs = fields + .iter() + .filter_map(|(pinned, field)| pinned.then_some(field)) + .map(|Field { ident, .. }| format!(" - `{}`", ident.as_ref().unwrap())); + let not_structurally_pinned_fields_docs = fields + .iter() + .filter_map(|(pinned, field)| (!pinned).then_some(field)) + .map(|Field { ident, .. }| format!(" - `{}`", ident.as_ref().unwrap())); + let docs = format!(" Pin-projections of [`{ident}`]"); + quote! { + #[doc = #docs] + #[allow(dead_code)] + #[doc(hidden)] + #vis struct #projection #generics_with_pin_lt { + #(#fields_decl)* + ___pin_phantom_data: ::core::marker::PhantomData<&'__pin mut ()>, + } + + impl #impl_generics #ident #ty_generics + #whr + { + /// Pin-projects all fields of `Self`. + /// + /// These fields are structurally pinned: + #(#[doc = #structurally_pinned_fields_docs])* + /// + /// These fields are **not** structurally pinned: + #(#[doc = #not_structurally_pinned_fields_docs])* + #[inline] + #vis fn project<'__pin>( + self: ::core::pin::Pin<&'__pin mut Self>, + ) -> #projection #ty_generics_with_pin_lt { + // SAFETY: we only give access to `&mut` for fields not structurally pinned. + let #this = unsafe { ::core::pin::Pin::get_unchecked_mut(self) }; + #projection { + #(#fields_proj)* + ___pin_phantom_data: ::core::marker::PhantomData, + } + } + } + } } -/// Replaces `Self` with `struct_name` and errors on `enum`, `trait`, `struct` `union` and `impl` -/// keywords. -/// -/// The error is appended to `errs` to allow normal parsing to continue. -fn replace_self_and_deny_type_defs( - struct_name: &Vec<TokenTree>, - tt: TokenTree, - errs: &mut TokenStream, -) -> Vec<TokenTree> { - match tt { - TokenTree::Ident(ref i) - if i.to_string() == "enum" - || i.to_string() == "trait" - || i.to_string() == "struct" - || i.to_string() == "union" - || i.to_string() == "impl" => +fn generate_the_pin_data( + vis: &Visibility, + ident: &Ident, + generics: &Generics, + fields: &[(bool, &Field)], +) -> TokenStream { + let (impl_generics, ty_generics, whr) = generics.split_for_impl(); + + // For every field, we create an initializing projection function according to its projection + // type. If a field is structurally pinned, then it must be initialized via `PinInit`, if it is + // not structurally pinned, then it can be initialized via `Init`. + // + // The functions are `unsafe` to prevent accidentally calling them. + fn handle_field( + Field { + vis, + ident, + ty, + attrs, + .. + }: &Field, + struct_ident: &Ident, + pinned: bool, + ) -> TokenStream { + let mut attrs = attrs.clone(); + attrs.retain(|a| !a.path().is_ident("pin")); + let ident = ident + .as_ref() + .expect("only structs with named fields are supported"); + let project_ident = format_ident!("__project_{ident}"); + let (init_ty, init_fn, project_ty, project_body, pin_safety) = if pinned { + ( + quote!(PinInit), + quote!(__pinned_init), + quote!(::core::pin::Pin<&'__slot mut #ty>), + // SAFETY: this field is structurally pinned. + quote!(unsafe { ::core::pin::Pin::new_unchecked(slot) }), + quote!( + /// - `slot` will not move until it is dropped, i.e. it will be pinned. + ), + ) + } else { + ( + quote!(Init), + quote!(__init), + quote!(&'__slot mut #ty), + quote!(slot), + quote!(), + ) + }; + let slot_safety = format!( + " `slot` points at the field `{ident}` inside of `{struct_ident}`, which is pinned.", + ); + quote! { + /// # Safety + /// + /// - `slot` is a valid pointer to uninitialized memory. + /// - the caller does not touch `slot` when `Err` is returned, they are only permitted + /// to deallocate. + #pin_safety + #(#attrs)* + #vis unsafe fn #ident<E>( + self, + slot: *mut #ty, + init: impl ::pin_init::#init_ty<#ty, E>, + ) -> ::core::result::Result<(), E> { + // SAFETY: this function has the same safety requirements as the __init function + // called below. + unsafe { ::pin_init::#init_ty::#init_fn(init, slot) } + } + + /// # Safety + /// + #[doc = #slot_safety] + #(#attrs)* + #vis unsafe fn #project_ident<'__slot>( + self, + slot: &'__slot mut #ty, + ) -> #project_ty { + #project_body + } + } + } + + let field_accessors = fields + .iter() + .map(|(pinned, field)| handle_field(field, ident, *pinned)) + .collect::<TokenStream>(); + quote! { + // We declare this struct which will host all of the projection function for our type. It + // will be invariant over all generic parameters which are inherited from the struct. + #[doc(hidden)] + #vis struct __ThePinData #generics + #whr { - errs.extend( - format!( - "::core::compile_error!(\"Cannot use `{i}` inside of struct definition with \ - `#[pin_data]`.\");" - ) - .parse::<TokenStream>() - .unwrap() - .into_iter() - .map(|mut tok| { - tok.set_span(tt.span()); - tok - }), - ); - vec![tt] - } - TokenTree::Ident(i) if i.to_string() == "Self" => struct_name.clone(), - TokenTree::Literal(_) | TokenTree::Punct(_) | TokenTree::Ident(_) => vec![tt], - TokenTree::Group(g) => vec![TokenTree::Group(Group::new( - g.delimiter(), - g.stream() - .into_iter() - .flat_map(|tt| replace_self_and_deny_type_defs(struct_name, tt, errs)) - .collect(), - ))], + __phantom: ::core::marker::PhantomData< + fn(#ident #ty_generics) -> #ident #ty_generics + >, + } + + impl #impl_generics ::core::clone::Clone for __ThePinData #ty_generics + #whr + { + fn clone(&self) -> Self { *self } + } + + impl #impl_generics ::core::marker::Copy for __ThePinData #ty_generics + #whr + {} + + #[allow(dead_code)] // Some functions might never be used and private. + #[expect(clippy::missing_safety_doc)] + impl #impl_generics __ThePinData #ty_generics + #whr + { + #field_accessors + } + + // SAFETY: We have added the correct projection functions above to `__ThePinData` and + // we also use the least restrictive generics possible. + unsafe impl #impl_generics ::pin_init::__internal::HasPinData for #ident #ty_generics + #whr + { + type PinData = __ThePinData #ty_generics; + + unsafe fn __pin_data() -> Self::PinData { + __ThePinData { __phantom: ::core::marker::PhantomData } + } + } + + // SAFETY: TODO + unsafe impl #impl_generics ::pin_init::__internal::PinData for __ThePinData #ty_generics + #whr + { + type Datee = #ident #ty_generics; + } + } +} + +struct SelfReplacer(PathSegment); + +impl VisitMut for SelfReplacer { + fn visit_path_mut(&mut self, i: &mut syn::Path) { + if i.is_ident("Self") { + let span = i.span(); + let seg = &self.0; + *i = parse_quote_spanned!(span=> #seg); + } else { + syn::visit_mut::visit_path_mut(self, i); + } + } + + fn visit_path_segment_mut(&mut self, seg: &mut PathSegment) { + if seg.ident == "Self" { + let span = seg.span(); + let this = &self.0; + *seg = parse_quote_spanned!(span=> #this); + } else { + syn::visit_mut::visit_path_segment_mut(self, seg); + } + } + + fn visit_item_mut(&mut self, _: &mut Item) { + // Do not descend into items, since items reset/change what `Self` refers to. + } +} + +// replace with `.collect()` once MSRV is above 1.79 +fn collect_tuple<A, B>(iter: impl Iterator<Item = (A, B)>) -> (Vec<A>, Vec<B>) { + let mut res_a = vec![]; + let mut res_b = vec![]; + for (a, b) in iter { + res_a.push(a); + res_b.push(b); } + (res_a, res_b) } diff --git a/rust/pin-init/internal/src/pinned_drop.rs b/rust/pin-init/internal/src/pinned_drop.rs index c4ca7a70b726..a20ac314ca82 100644 --- a/rust/pin-init/internal/src/pinned_drop.rs +++ b/rust/pin-init/internal/src/pinned_drop.rs @@ -1,51 +1,61 @@ // SPDX-License-Identifier: Apache-2.0 OR MIT -#[cfg(not(kernel))] -use proc_macro2 as proc_macro; +use proc_macro2::TokenStream; +use quote::quote; +use syn::{parse::Nothing, parse_quote, spanned::Spanned, ImplItem, ItemImpl, Token}; -use proc_macro::{TokenStream, TokenTree}; +use crate::diagnostics::{DiagCtxt, ErrorGuaranteed}; -pub(crate) fn pinned_drop(_args: TokenStream, input: TokenStream) -> TokenStream { - let mut toks = input.into_iter().collect::<Vec<_>>(); - assert!(!toks.is_empty()); - // Ensure that we have an `impl` item. - assert!(matches!(&toks[0], TokenTree::Ident(i) if i.to_string() == "impl")); - // Ensure that we are implementing `PinnedDrop`. - let mut nesting: usize = 0; - let mut pinned_drop_idx = None; - for (i, tt) in toks.iter().enumerate() { - match tt { - TokenTree::Punct(p) if p.as_char() == '<' => { - nesting += 1; +pub(crate) fn pinned_drop( + _args: Nothing, + mut input: ItemImpl, + dcx: &mut DiagCtxt, +) -> Result<TokenStream, ErrorGuaranteed> { + if let Some(unsafety) = input.unsafety { + dcx.error(unsafety, "implementing `PinnedDrop` is safe"); + } + input.unsafety = Some(Token); + match &mut input.trait_ { + Some((not, path, _for)) => { + if let Some(not) = not { + dcx.error(not, "cannot implement `!PinnedDrop`"); } - TokenTree::Punct(p) if p.as_char() == '>' => { - nesting = nesting.checked_sub(1).unwrap(); - continue; + for (seg, expected) in path + .segments + .iter() + .rev() + .zip(["PinnedDrop", "pin_init", ""]) + { + if expected.is_empty() || seg.ident != expected { + dcx.error(seg, "bad import path for `PinnedDrop`"); + } + if !seg.arguments.is_none() { + dcx.error(&seg.arguments, "unexpected arguments for `PinnedDrop` path"); + } } - _ => {} + *path = parse_quote!(::pin_init::PinnedDrop); } - if i >= 1 && nesting == 0 { - // Found the end of the generics, this should be `PinnedDrop`. - assert!( - matches!(tt, TokenTree::Ident(i) if i.to_string() == "PinnedDrop"), - "expected 'PinnedDrop', found: '{tt:?}'" + None => { + let span = input + .impl_token + .span + .join(input.self_ty.span()) + .unwrap_or(input.impl_token.span); + dcx.error( + span, + "expected `impl ... PinnedDrop for ...`, got inherent impl", ); - pinned_drop_idx = Some(i); - break; } } - let idx = pinned_drop_idx - .unwrap_or_else(|| panic!("Expected an `impl` block implementing `PinnedDrop`.")); - // Fully qualify the `PinnedDrop`, as to avoid any tampering. - toks.splice(idx..idx, quote!(::pin_init::)); - // Take the `{}` body and call the declarative macro. - if let Some(TokenTree::Group(last)) = toks.pop() { - let last = last.stream(); - quote!(::pin_init::__pinned_drop! { - @impl_sig(#(#toks)*), - @impl_body(#last), - }) - } else { - TokenStream::from_iter(toks) + for item in &mut input.items { + if let ImplItem::Fn(fn_item) = item { + if fn_item.sig.ident == "drop" { + fn_item + .sig + .inputs + .push(parse_quote!(_: ::pin_init::__internal::OnlyCallFromDrop)); + } + } } + Ok(quote!(#input)) } diff --git a/rust/pin-init/internal/src/zeroable.rs b/rust/pin-init/internal/src/zeroable.rs index e0ed3998445c..05683319b0f7 100644 --- a/rust/pin-init/internal/src/zeroable.rs +++ b/rust/pin-init/internal/src/zeroable.rs @@ -1,101 +1,78 @@ // SPDX-License-Identifier: GPL-2.0 -#[cfg(not(kernel))] -use proc_macro2 as proc_macro; +use proc_macro2::TokenStream; +use quote::quote; +use syn::{parse_quote, Data, DeriveInput, Field, Fields}; -use crate::helpers::{parse_generics, Generics}; -use proc_macro::{TokenStream, TokenTree}; +use crate::{diagnostics::ErrorGuaranteed, DiagCtxt}; -pub(crate) fn parse_zeroable_derive_input( - input: TokenStream, -) -> ( - Vec<TokenTree>, - Vec<TokenTree>, - Vec<TokenTree>, - Option<TokenTree>, -) { - let ( - Generics { - impl_generics, - decl_generics: _, - ty_generics, - }, - mut rest, - ) = parse_generics(input); - // This should be the body of the struct `{...}`. - let last = rest.pop(); - // Now we insert `Zeroable` as a bound for every generic parameter in `impl_generics`. - let mut new_impl_generics = Vec::with_capacity(impl_generics.len()); - // Are we inside of a generic where we want to add `Zeroable`? - let mut in_generic = !impl_generics.is_empty(); - // Have we already inserted `Zeroable`? - let mut inserted = false; - // Level of `<>` nestings. - let mut nested = 0; - for tt in impl_generics { - match &tt { - // If we find a `,`, then we have finished a generic/constant/lifetime parameter. - TokenTree::Punct(p) if nested == 0 && p.as_char() == ',' => { - if in_generic && !inserted { - new_impl_generics.extend(quote! { : ::pin_init::Zeroable }); - } - in_generic = true; - inserted = false; - new_impl_generics.push(tt); - } - // If we find `'`, then we are entering a lifetime. - TokenTree::Punct(p) if nested == 0 && p.as_char() == '\'' => { - in_generic = false; - new_impl_generics.push(tt); - } - TokenTree::Punct(p) if nested == 0 && p.as_char() == ':' => { - new_impl_generics.push(tt); - if in_generic { - new_impl_generics.extend(quote! { ::pin_init::Zeroable + }); - inserted = true; - } - } - TokenTree::Punct(p) if p.as_char() == '<' => { - nested += 1; - new_impl_generics.push(tt); - } - TokenTree::Punct(p) if p.as_char() == '>' => { - assert!(nested > 0); - nested -= 1; - new_impl_generics.push(tt); - } - _ => new_impl_generics.push(tt), +pub(crate) fn derive( + input: DeriveInput, + dcx: &mut DiagCtxt, +) -> Result<TokenStream, ErrorGuaranteed> { + let fields = match input.data { + Data::Struct(data_struct) => data_struct.fields, + Data::Union(data_union) => Fields::Named(data_union.fields), + Data::Enum(data_enum) => { + return Err(dcx.error(data_enum.enum_token, "cannot derive `Zeroable` for an enum")); } + }; + let name = input.ident; + let mut generics = input.generics; + for param in generics.type_params_mut() { + param.bounds.insert(0, parse_quote!(::pin_init::Zeroable)); } - assert_eq!(nested, 0); - if in_generic && !inserted { - new_impl_generics.extend(quote! { : ::pin_init::Zeroable }); - } - (rest, new_impl_generics, ty_generics, last) + let (impl_gen, ty_gen, whr) = generics.split_for_impl(); + let field_type = fields.iter().map(|field| &field.ty); + Ok(quote! { + // SAFETY: Every field type implements `Zeroable` and padding bytes may be zero. + #[automatically_derived] + unsafe impl #impl_gen ::pin_init::Zeroable for #name #ty_gen + #whr + {} + const _: () = { + fn assert_zeroable<T: ?::core::marker::Sized + ::pin_init::Zeroable>() {} + fn ensure_zeroable #impl_gen () + #whr + { + #( + assert_zeroable::<#field_type>(); + )* + } + }; + }) } -pub(crate) fn derive(input: TokenStream) -> TokenStream { - let (rest, new_impl_generics, ty_generics, last) = parse_zeroable_derive_input(input); - quote! { - ::pin_init::__derive_zeroable!( - parse_input: - @sig(#(#rest)*), - @impl_generics(#(#new_impl_generics)*), - @ty_generics(#(#ty_generics)*), - @body(#last), - ); +pub(crate) fn maybe_derive( + input: DeriveInput, + dcx: &mut DiagCtxt, +) -> Result<TokenStream, ErrorGuaranteed> { + let fields = match input.data { + Data::Struct(data_struct) => data_struct.fields, + Data::Union(data_union) => Fields::Named(data_union.fields), + Data::Enum(data_enum) => { + return Err(dcx.error(data_enum.enum_token, "cannot derive `Zeroable` for an enum")); + } + }; + let name = input.ident; + let mut generics = input.generics; + for param in generics.type_params_mut() { + param.bounds.insert(0, parse_quote!(::pin_init::Zeroable)); } -} - -pub(crate) fn maybe_derive(input: TokenStream) -> TokenStream { - let (rest, new_impl_generics, ty_generics, last) = parse_zeroable_derive_input(input); - quote! { - ::pin_init::__maybe_derive_zeroable!( - parse_input: - @sig(#(#rest)*), - @impl_generics(#(#new_impl_generics)*), - @ty_generics(#(#ty_generics)*), - @body(#last), - ); + for Field { ty, .. } in fields { + generics + .make_where_clause() + .predicates + // the `for<'__dummy>` HRTB makes this not error without the `trivial_bounds` + // feature <https://github.com/rust-lang/rust/issues/48214#issuecomment-2557829956>. + .push(parse_quote!(#ty: for<'__dummy> ::pin_init::Zeroable)); } + let (impl_gen, ty_gen, whr) = generics.split_for_impl(); + Ok(quote! { + // SAFETY: Every field type implements `Zeroable` and padding bytes may be zero. + #[automatically_derived] + unsafe impl #impl_gen ::pin_init::Zeroable for #name #ty_gen + #whr + {} + }) } diff --git a/rust/pin-init/src/lib.rs b/rust/pin-init/src/lib.rs index 8dc9dd5ac6fd..49945fc07f25 100644 --- a/rust/pin-init/src/lib.rs +++ b/rust/pin-init/src/lib.rs @@ -146,7 +146,7 @@ //! //! impl DriverData { //! fn new() -> impl PinInit<Self, Error> { -//! try_pin_init!(Self { +//! pin_init!(Self { //! status <- CMutex::new(0), //! buffer: Box::init(pin_init::init_zeroed())?, //! }? Error) @@ -290,10 +290,13 @@ use core::{ ptr::{self, NonNull}, }; +// This is used by doc-tests -- the proc-macros expand to `::pin_init::...` and without this the +// doc-tests wouldn't have an extern crate named `pin_init`. +#[allow(unused_extern_crates)] +extern crate self as pin_init; + #[doc(hidden)] pub mod __internal; -#[doc(hidden)] -pub mod macros; #[cfg(any(feature = "std", feature = "alloc"))] mod alloc; @@ -528,7 +531,7 @@ macro_rules! stack_pin_init { /// x: u32, /// } /// -/// stack_try_pin_init!(let foo: Foo = try_pin_init!(Foo { +/// stack_try_pin_init!(let foo: Foo = pin_init!(Foo { /// a <- CMutex::new(42), /// b: Box::try_new(Bar { /// x: 64, @@ -555,7 +558,7 @@ macro_rules! stack_pin_init { /// x: u32, /// } /// -/// stack_try_pin_init!(let foo: Foo =? try_pin_init!(Foo { +/// stack_try_pin_init!(let foo: Foo =? pin_init!(Foo { /// a <- CMutex::new(42), /// b: Box::try_new(Bar { /// x: 64, @@ -584,10 +587,10 @@ macro_rules! stack_try_pin_init { }; } -/// Construct an in-place, pinned initializer for `struct`s. +/// Construct an in-place, fallible pinned initializer for `struct`s. /// -/// This macro defaults the error to [`Infallible`]. If you need a different error, then use -/// [`try_pin_init!`]. +/// The error type defaults to [`Infallible`]; if you need a different one, write `? Error` at the +/// end, after the struct initializer. /// /// The syntax is almost identical to that of a normal `struct` initializer: /// @@ -776,81 +779,12 @@ macro_rules! stack_try_pin_init { /// ``` /// /// [`NonNull<Self>`]: core::ptr::NonNull -// For a detailed example of how this macro works, see the module documentation of the hidden -// module `macros` inside of `macros.rs`. -#[macro_export] -macro_rules! pin_init { - ($(&$this:ident in)? $t:ident $(::<$($generics:ty),* $(,)?>)? { - $($fields:tt)* - }) => { - $crate::try_pin_init!($(&$this in)? $t $(::<$($generics),*>)? { - $($fields)* - }? ::core::convert::Infallible) - }; -} - -/// Construct an in-place, fallible pinned initializer for `struct`s. -/// -/// If the initialization can complete without error (or [`Infallible`]), then use [`pin_init!`]. -/// -/// You can use the `?` operator or use `return Err(err)` inside the initializer to stop -/// initialization and return the error. -/// -/// IMPORTANT: if you have `unsafe` code inside of the initializer you have to ensure that when -/// initialization fails, the memory can be safely deallocated without any further modifications. -/// -/// The syntax is identical to [`pin_init!`] with the following exception: you must append `? $type` -/// after the `struct` initializer to specify the error type you want to use. -/// -/// # Examples -/// -/// ```rust -/// # #![feature(allocator_api)] -/// # #[path = "../examples/error.rs"] mod error; use error::Error; -/// use pin_init::{pin_data, try_pin_init, PinInit, InPlaceInit, init_zeroed}; -/// -/// #[pin_data] -/// struct BigBuf { -/// big: Box<[u8; 1024 * 1024 * 1024]>, -/// small: [u8; 1024 * 1024], -/// ptr: *mut u8, -/// } -/// -/// impl BigBuf { -/// fn new() -> impl PinInit<Self, Error> { -/// try_pin_init!(Self { -/// big: Box::init(init_zeroed())?, -/// small: [0; 1024 * 1024], -/// ptr: core::ptr::null_mut(), -/// }? Error) -/// } -/// } -/// # let _ = Box::pin_init(BigBuf::new()); -/// ``` -// For a detailed example of how this macro works, see the module documentation of the hidden -// module `macros` inside of `macros.rs`. -#[macro_export] -macro_rules! try_pin_init { - ($(&$this:ident in)? $t:ident $(::<$($generics:ty),* $(,)?>)? { - $($fields:tt)* - }? $err:ty) => { - $crate::__init_internal!( - @this($($this)?), - @typ($t $(::<$($generics),*>)? ), - @fields($($fields)*), - @error($err), - @data(PinData, use_data), - @has_data(HasPinData, __pin_data), - @construct_closure(pin_init_from_closure), - @munch_fields($($fields)*), - ) - } -} +pub use pin_init_internal::pin_init; -/// Construct an in-place initializer for `struct`s. +/// Construct an in-place, fallible initializer for `struct`s. /// -/// This macro defaults the error to [`Infallible`]. If you need a different error, then use -/// [`try_init!`]. +/// This macro defaults the error to [`Infallible`]; if you need a different one, write `? Error` +/// at the end, after the struct initializer. /// /// The syntax is identical to [`pin_init!`] and its safety caveats also apply: /// - `unsafe` code must guarantee either full initialization or return an error and allow @@ -883,74 +817,7 @@ macro_rules! try_pin_init { /// } /// # let _ = Box::init(BigBuf::new()); /// ``` -// For a detailed example of how this macro works, see the module documentation of the hidden -// module `macros` inside of `macros.rs`. -#[macro_export] -macro_rules! init { - ($(&$this:ident in)? $t:ident $(::<$($generics:ty),* $(,)?>)? { - $($fields:tt)* - }) => { - $crate::try_init!($(&$this in)? $t $(::<$($generics),*>)? { - $($fields)* - }? ::core::convert::Infallible) - } -} - -/// Construct an in-place fallible initializer for `struct`s. -/// -/// If the initialization can complete without error (or [`Infallible`]), then use -/// [`init!`]. -/// -/// The syntax is identical to [`try_pin_init!`]. You need to specify a custom error -/// via `? $type` after the `struct` initializer. -/// The safety caveats from [`try_pin_init!`] also apply: -/// - `unsafe` code must guarantee either full initialization or return an error and allow -/// deallocation of the memory. -/// - the fields are initialized in the order given in the initializer. -/// - no references to fields are allowed to be created inside of the initializer. -/// -/// # Examples -/// -/// ```rust -/// # #![feature(allocator_api)] -/// # use core::alloc::AllocError; -/// # use pin_init::InPlaceInit; -/// use pin_init::{try_init, Init, init_zeroed}; -/// -/// struct BigBuf { -/// big: Box<[u8; 1024 * 1024 * 1024]>, -/// small: [u8; 1024 * 1024], -/// } -/// -/// impl BigBuf { -/// fn new() -> impl Init<Self, AllocError> { -/// try_init!(Self { -/// big: Box::init(init_zeroed())?, -/// small: [0; 1024 * 1024], -/// }? AllocError) -/// } -/// } -/// # let _ = Box::init(BigBuf::new()); -/// ``` -// For a detailed example of how this macro works, see the module documentation of the hidden -// module `macros` inside of `macros.rs`. -#[macro_export] -macro_rules! try_init { - ($(&$this:ident in)? $t:ident $(::<$($generics:ty),* $(,)?>)? { - $($fields:tt)* - }? $err:ty) => { - $crate::__init_internal!( - @this($($this)?), - @typ($t $(::<$($generics),*>)?), - @fields($($fields)*), - @error($err), - @data(InitData, /*no use_data*/), - @has_data(HasInitData, __init_data), - @construct_closure(init_from_closure), - @munch_fields($($fields)*), - ) - } -} +pub use pin_init_internal::init; /// Asserts that a field on a struct using `#[pin_data]` is marked with `#[pin]` ie. that it is /// structurally pinned. @@ -1410,14 +1277,14 @@ where /// fn init_foo() -> impl PinInit<Foo, Error> { /// pin_init_scope(|| { /// let bar = lookup_bar()?; -/// Ok(try_pin_init!(Foo { a: bar.a.into(), b: bar.b }? Error)) +/// Ok(pin_init!(Foo { a: bar.a.into(), b: bar.b }? Error)) /// }) /// } /// ``` /// /// This initializer will first execute `lookup_bar()`, match on it, if it returned an error, the /// initializer itself will fail with that error. If it returned `Ok`, then it will run the -/// initializer returned by the [`try_pin_init!`] invocation. +/// initializer returned by the [`pin_init!`] invocation. pub fn pin_init_scope<T, E, F, I>(make_init: F) -> impl PinInit<T, E> where F: FnOnce() -> Result<I, E>, @@ -1453,14 +1320,14 @@ where /// fn init_foo() -> impl Init<Foo, Error> { /// init_scope(|| { /// let bar = lookup_bar()?; -/// Ok(try_init!(Foo { a: bar.a.into(), b: bar.b }? Error)) +/// Ok(init!(Foo { a: bar.a.into(), b: bar.b }? Error)) /// }) /// } /// ``` /// /// This initializer will first execute `lookup_bar()`, match on it, if it returned an error, the /// initializer itself will fail with that error. If it returned `Ok`, then it will run the -/// initializer returned by the [`try_init!`] invocation. +/// initializer returned by the [`init!`] invocation. pub fn init_scope<T, E, F, I>(make_init: F) -> impl Init<T, E> where F: FnOnce() -> Result<I, E>, @@ -1536,6 +1403,33 @@ pub trait InPlaceWrite<T> { fn write_pin_init<E>(self, init: impl PinInit<T, E>) -> Result<Pin<Self::Initialized>, E>; } +impl<T> InPlaceWrite<T> for &'static mut MaybeUninit<T> { + type Initialized = &'static mut T; + + fn write_init<E>(self, init: impl Init<T, E>) -> Result<Self::Initialized, E> { + let slot = self.as_mut_ptr(); + + // SAFETY: `slot` is a valid pointer to uninitialized memory. + unsafe { init.__init(slot)? }; + + // SAFETY: The above call initialized the memory. + unsafe { Ok(self.assume_init_mut()) } + } + + fn write_pin_init<E>(self, init: impl PinInit<T, E>) -> Result<Pin<Self::Initialized>, E> { + let slot = self.as_mut_ptr(); + + // SAFETY: `slot` is a valid pointer to uninitialized memory. + // + // The `'static` borrow guarantees the data will not be + // moved/invalidated until it gets dropped (which is never). + unsafe { init.__pinned_init(slot)? }; + + // SAFETY: The above call initialized the memory. + Ok(Pin::static_mut(unsafe { self.assume_init_mut() })) + } +} + /// Trait facilitating pinned destruction. /// /// Use [`pinned_drop`] to implement this trait safely: diff --git a/rust/pin-init/src/macros.rs b/rust/pin-init/src/macros.rs deleted file mode 100644 index 682c61a587a0..000000000000 --- a/rust/pin-init/src/macros.rs +++ /dev/null @@ -1,1677 +0,0 @@ -// SPDX-License-Identifier: Apache-2.0 OR MIT - -//! This module provides the macros that actually implement the proc-macros `pin_data` and -//! `pinned_drop`. It also contains `__init_internal`, the implementation of the -//! `{try_}{pin_}init!` macros. -//! -//! These macros should never be called directly, since they expect their input to be -//! in a certain format which is internal. If used incorrectly, these macros can lead to UB even in -//! safe code! Use the public facing macros instead. -//! -//! This architecture has been chosen because the kernel does not yet have access to `syn` which -//! would make matters a lot easier for implementing these as proc-macros. -//! -//! Since this library and the kernel implementation should diverge as little as possible, the same -//! approach has been taken here. -//! -//! # Macro expansion example -//! -//! This section is intended for readers trying to understand the macros in this module and the -//! `[try_][pin_]init!` macros from `lib.rs`. -//! -//! We will look at the following example: -//! -//! ```rust,ignore -//! #[pin_data] -//! #[repr(C)] -//! struct Bar<T> { -//! #[pin] -//! t: T, -//! pub x: usize, -//! } -//! -//! impl<T> Bar<T> { -//! fn new(t: T) -> impl PinInit<Self> { -//! pin_init!(Self { t, x: 0 }) -//! } -//! } -//! -//! #[pin_data(PinnedDrop)] -//! struct Foo { -//! a: usize, -//! #[pin] -//! b: Bar<u32>, -//! } -//! -//! #[pinned_drop] -//! impl PinnedDrop for Foo { -//! fn drop(self: Pin<&mut Self>) { -//! println!("{self:p} is getting dropped."); -//! } -//! } -//! -//! let a = 42; -//! let initializer = pin_init!(Foo { -//! a, -//! b <- Bar::new(36), -//! }); -//! ``` -//! -//! This example includes the most common and important features of the pin-init API. -//! -//! Below you can find individual section about the different macro invocations. Here are some -//! general things we need to take into account when designing macros: -//! - use global paths, similarly to file paths, these start with the separator: `::core::panic!()` -//! this ensures that the correct item is used, since users could define their own `mod core {}` -//! and then their own `panic!` inside to execute arbitrary code inside of our macro. -//! - macro `unsafe` hygiene: we need to ensure that we do not expand arbitrary, user-supplied -//! expressions inside of an `unsafe` block in the macro, because this would allow users to do -//! `unsafe` operations without an associated `unsafe` block. -//! -//! ## `#[pin_data]` on `Bar` -//! -//! This macro is used to specify which fields are structurally pinned and which fields are not. It -//! is placed on the struct definition and allows `#[pin]` to be placed on the fields. -//! -//! Here is the definition of `Bar` from our example: -//! -//! ```rust,ignore -//! #[pin_data] -//! #[repr(C)] -//! struct Bar<T> { -//! #[pin] -//! t: T, -//! pub x: usize, -//! } -//! ``` -//! -//! This expands to the following code: -//! -//! ```rust,ignore -//! // Firstly the normal definition of the struct, attributes are preserved: -//! #[repr(C)] -//! struct Bar<T> { -//! t: T, -//! pub x: usize, -//! } -//! // Then an anonymous constant is defined, this is because we do not want any code to access the -//! // types that we define inside: -//! const _: () = { -//! // We define the pin-data carrying struct, it is a ZST and needs to have the same generics, -//! // since we need to implement access functions for each field and thus need to know its -//! // type. -//! struct __ThePinData<T> { -//! __phantom: ::core::marker::PhantomData<fn(Bar<T>) -> Bar<T>>, -//! } -//! // We implement `Copy` for the pin-data struct, since all functions it defines will take -//! // `self` by value. -//! impl<T> ::core::clone::Clone for __ThePinData<T> { -//! fn clone(&self) -> Self { -//! *self -//! } -//! } -//! impl<T> ::core::marker::Copy for __ThePinData<T> {} -//! // For every field of `Bar`, the pin-data struct will define a function with the same name -//! // and accessor (`pub` or `pub(crate)` etc.). This function will take a pointer to the -//! // field (`slot`) and a `PinInit` or `Init` depending on the projection kind of the field -//! // (if pinning is structural for the field, then `PinInit` otherwise `Init`). -//! #[allow(dead_code)] -//! impl<T> __ThePinData<T> { -//! unsafe fn t<E>( -//! self, -//! slot: *mut T, -//! // Since `t` is `#[pin]`, this is `PinInit`. -//! init: impl ::pin_init::PinInit<T, E>, -//! ) -> ::core::result::Result<(), E> { -//! unsafe { ::pin_init::PinInit::__pinned_init(init, slot) } -//! } -//! pub unsafe fn x<E>( -//! self, -//! slot: *mut usize, -//! // Since `x` is not `#[pin]`, this is `Init`. -//! init: impl ::pin_init::Init<usize, E>, -//! ) -> ::core::result::Result<(), E> { -//! unsafe { ::pin_init::Init::__init(init, slot) } -//! } -//! } -//! // Implement the internal `HasPinData` trait that associates `Bar` with the pin-data struct -//! // that we constructed above. -//! unsafe impl<T> ::pin_init::__internal::HasPinData for Bar<T> { -//! type PinData = __ThePinData<T>; -//! unsafe fn __pin_data() -> Self::PinData { -//! __ThePinData { -//! __phantom: ::core::marker::PhantomData, -//! } -//! } -//! } -//! // Implement the internal `PinData` trait that marks the pin-data struct as a pin-data -//! // struct. This is important to ensure that no user can implement a rogue `__pin_data` -//! // function without using `unsafe`. -//! unsafe impl<T> ::pin_init::__internal::PinData for __ThePinData<T> { -//! type Datee = Bar<T>; -//! } -//! // Now we only want to implement `Unpin` for `Bar` when every structurally pinned field is -//! // `Unpin`. In other words, whether `Bar` is `Unpin` only depends on structurally pinned -//! // fields (those marked with `#[pin]`). These fields will be listed in this struct, in our -//! // case no such fields exist, hence this is almost empty. The two phantomdata fields exist -//! // for two reasons: -//! // - `__phantom`: every generic must be used, since we cannot really know which generics -//! // are used, we declare all and then use everything here once. -//! // - `__phantom_pin`: uses the `'__pin` lifetime and ensures that this struct is invariant -//! // over it. The lifetime is needed to work around the limitation that trait bounds must -//! // not be trivial, e.g. the user has a `#[pin] PhantomPinned` field -- this is -//! // unconditionally `!Unpin` and results in an error. The lifetime tricks the compiler -//! // into accepting these bounds regardless. -//! #[allow(dead_code)] -//! struct __Unpin<'__pin, T> { -//! __phantom_pin: ::core::marker::PhantomData<fn(&'__pin ()) -> &'__pin ()>, -//! __phantom: ::core::marker::PhantomData<fn(Bar<T>) -> Bar<T>>, -//! // Our only `#[pin]` field is `t`. -//! t: T, -//! } -//! #[doc(hidden)] -//! impl<'__pin, T> ::core::marker::Unpin for Bar<T> -//! where -//! __Unpin<'__pin, T>: ::core::marker::Unpin, -//! {} -//! // Now we need to ensure that `Bar` does not implement `Drop`, since that would give users -//! // access to `&mut self` inside of `drop` even if the struct was pinned. This could lead to -//! // UB with only safe code, so we disallow this by giving a trait implementation error using -//! // a direct impl and a blanket implementation. -//! trait MustNotImplDrop {} -//! // Normally `Drop` bounds do not have the correct semantics, but for this purpose they do -//! // (normally people want to know if a type has any kind of drop glue at all, here we want -//! // to know if it has any kind of custom drop glue, which is exactly what this bound does). -//! #[expect(drop_bounds)] -//! impl<T: ::core::ops::Drop> MustNotImplDrop for T {} -//! impl<T> MustNotImplDrop for Bar<T> {} -//! // Here comes a convenience check, if one implemented `PinnedDrop`, but forgot to add it to -//! // `#[pin_data]`, then this will error with the same mechanic as above, this is not needed -//! // for safety, but a good sanity check, since no normal code calls `PinnedDrop::drop`. -//! #[expect(non_camel_case_types)] -//! trait UselessPinnedDropImpl_you_need_to_specify_PinnedDrop {} -//! impl< -//! T: ::pin_init::PinnedDrop, -//! > UselessPinnedDropImpl_you_need_to_specify_PinnedDrop for T {} -//! impl<T> UselessPinnedDropImpl_you_need_to_specify_PinnedDrop for Bar<T> {} -//! }; -//! ``` -//! -//! ## `pin_init!` in `impl Bar` -//! -//! This macro creates an pin-initializer for the given struct. It requires that the struct is -//! annotated by `#[pin_data]`. -//! -//! Here is the impl on `Bar` defining the new function: -//! -//! ```rust,ignore -//! impl<T> Bar<T> { -//! fn new(t: T) -> impl PinInit<Self> { -//! pin_init!(Self { t, x: 0 }) -//! } -//! } -//! ``` -//! -//! This expands to the following code: -//! -//! ```rust,ignore -//! impl<T> Bar<T> { -//! fn new(t: T) -> impl PinInit<Self> { -//! { -//! // We do not want to allow arbitrary returns, so we declare this type as the `Ok` -//! // return type and shadow it later when we insert the arbitrary user code. That way -//! // there will be no possibility of returning without `unsafe`. -//! struct __InitOk; -//! // Get the data about fields from the supplied type. -//! // - the function is unsafe, hence the unsafe block -//! // - we `use` the `HasPinData` trait in the block, it is only available in that -//! // scope. -//! let data = unsafe { -//! use ::pin_init::__internal::HasPinData; -//! Self::__pin_data() -//! }; -//! // Ensure that `data` really is of type `PinData` and help with type inference: -//! let init = ::pin_init::__internal::PinData::make_closure::< -//! _, -//! __InitOk, -//! ::core::convert::Infallible, -//! >(data, move |slot| { -//! { -//! // Shadow the structure so it cannot be used to return early. If a user -//! // tries to write `return Ok(__InitOk)`, then they get a type error, -//! // since that will refer to this struct instead of the one defined -//! // above. -//! struct __InitOk; -//! // This is the expansion of `t,`, which is syntactic sugar for `t: t,`. -//! { -//! unsafe { ::core::ptr::write(::core::addr_of_mut!((*slot).t), t) }; -//! } -//! // Since initialization could fail later (not in this case, since the -//! // error type is `Infallible`) we will need to drop this field if there -//! // is an error later. This `DropGuard` will drop the field when it gets -//! // dropped and has not yet been forgotten. -//! let __t_guard = unsafe { -//! ::pin_init::__internal::DropGuard::new(::core::addr_of_mut!((*slot).t)) -//! }; -//! // Expansion of `x: 0,`: -//! // Since this can be an arbitrary expression we cannot place it inside -//! // of the `unsafe` block, so we bind it here. -//! { -//! let x = 0; -//! unsafe { ::core::ptr::write(::core::addr_of_mut!((*slot).x), x) }; -//! } -//! // We again create a `DropGuard`. -//! let __x_guard = unsafe { -//! ::pin_init::__internal::DropGuard::new(::core::addr_of_mut!((*slot).x)) -//! }; -//! // Since initialization has successfully completed, we can now forget -//! // the guards. This is not `mem::forget`, since we only have -//! // `&DropGuard`. -//! ::core::mem::forget(__x_guard); -//! ::core::mem::forget(__t_guard); -//! // Here we use the type checker to ensure that every field has been -//! // initialized exactly once, since this is `if false` it will never get -//! // executed, but still type-checked. -//! // Additionally we abuse `slot` to automatically infer the correct type -//! // for the struct. This is also another check that every field is -//! // accessible from this scope. -//! #[allow(unreachable_code, clippy::diverging_sub_expression)] -//! let _ = || { -//! unsafe { -//! ::core::ptr::write( -//! slot, -//! Self { -//! // We only care about typecheck finding every field -//! // here, the expression does not matter, just conjure -//! // one using `panic!()`: -//! t: ::core::panic!(), -//! x: ::core::panic!(), -//! }, -//! ); -//! }; -//! }; -//! } -//! // We leave the scope above and gain access to the previously shadowed -//! // `__InitOk` that we need to return. -//! Ok(__InitOk) -//! }); -//! // Change the return type from `__InitOk` to `()`. -//! let init = move | -//! slot, -//! | -> ::core::result::Result<(), ::core::convert::Infallible> { -//! init(slot).map(|__InitOk| ()) -//! }; -//! // Construct the initializer. -//! let init = unsafe { -//! ::pin_init::pin_init_from_closure::< -//! _, -//! ::core::convert::Infallible, -//! >(init) -//! }; -//! init -//! } -//! } -//! } -//! ``` -//! -//! ## `#[pin_data]` on `Foo` -//! -//! Since we already took a look at `#[pin_data]` on `Bar`, this section will only explain the -//! differences/new things in the expansion of the `Foo` definition: -//! -//! ```rust,ignore -//! #[pin_data(PinnedDrop)] -//! struct Foo { -//! a: usize, -//! #[pin] -//! b: Bar<u32>, -//! } -//! ``` -//! -//! This expands to the following code: -//! -//! ```rust,ignore -//! struct Foo { -//! a: usize, -//! b: Bar<u32>, -//! } -//! const _: () = { -//! struct __ThePinData { -//! __phantom: ::core::marker::PhantomData<fn(Foo) -> Foo>, -//! } -//! impl ::core::clone::Clone for __ThePinData { -//! fn clone(&self) -> Self { -//! *self -//! } -//! } -//! impl ::core::marker::Copy for __ThePinData {} -//! #[allow(dead_code)] -//! impl __ThePinData { -//! unsafe fn b<E>( -//! self, -//! slot: *mut Bar<u32>, -//! init: impl ::pin_init::PinInit<Bar<u32>, E>, -//! ) -> ::core::result::Result<(), E> { -//! unsafe { ::pin_init::PinInit::__pinned_init(init, slot) } -//! } -//! unsafe fn a<E>( -//! self, -//! slot: *mut usize, -//! init: impl ::pin_init::Init<usize, E>, -//! ) -> ::core::result::Result<(), E> { -//! unsafe { ::pin_init::Init::__init(init, slot) } -//! } -//! } -//! unsafe impl ::pin_init::__internal::HasPinData for Foo { -//! type PinData = __ThePinData; -//! unsafe fn __pin_data() -> Self::PinData { -//! __ThePinData { -//! __phantom: ::core::marker::PhantomData, -//! } -//! } -//! } -//! unsafe impl ::pin_init::__internal::PinData for __ThePinData { -//! type Datee = Foo; -//! } -//! #[allow(dead_code)] -//! struct __Unpin<'__pin> { -//! __phantom_pin: ::core::marker::PhantomData<fn(&'__pin ()) -> &'__pin ()>, -//! __phantom: ::core::marker::PhantomData<fn(Foo) -> Foo>, -//! b: Bar<u32>, -//! } -//! #[doc(hidden)] -//! impl<'__pin> ::core::marker::Unpin for Foo -//! where -//! __Unpin<'__pin>: ::core::marker::Unpin, -//! {} -//! // Since we specified `PinnedDrop` as the argument to `#[pin_data]`, we expect `Foo` to -//! // implement `PinnedDrop`. Thus we do not need to prevent `Drop` implementations like -//! // before, instead we implement `Drop` here and delegate to `PinnedDrop`. -//! impl ::core::ops::Drop for Foo { -//! fn drop(&mut self) { -//! // Since we are getting dropped, no one else has a reference to `self` and thus we -//! // can assume that we never move. -//! let pinned = unsafe { ::core::pin::Pin::new_unchecked(self) }; -//! // Create the unsafe token that proves that we are inside of a destructor, this -//! // type is only allowed to be created in a destructor. -//! let token = unsafe { ::pin_init::__internal::OnlyCallFromDrop::new() }; -//! ::pin_init::PinnedDrop::drop(pinned, token); -//! } -//! } -//! }; -//! ``` -//! -//! ## `#[pinned_drop]` on `impl PinnedDrop for Foo` -//! -//! This macro is used to implement the `PinnedDrop` trait, since that trait is `unsafe` and has an -//! extra parameter that should not be used at all. The macro hides that parameter. -//! -//! Here is the `PinnedDrop` impl for `Foo`: -//! -//! ```rust,ignore -//! #[pinned_drop] -//! impl PinnedDrop for Foo { -//! fn drop(self: Pin<&mut Self>) { -//! println!("{self:p} is getting dropped."); -//! } -//! } -//! ``` -//! -//! This expands to the following code: -//! -//! ```rust,ignore -//! // `unsafe`, full path and the token parameter are added, everything else stays the same. -//! unsafe impl ::pin_init::PinnedDrop for Foo { -//! fn drop(self: Pin<&mut Self>, _: ::pin_init::__internal::OnlyCallFromDrop) { -//! println!("{self:p} is getting dropped."); -//! } -//! } -//! ``` -//! -//! ## `pin_init!` on `Foo` -//! -//! Since we already took a look at `pin_init!` on `Bar`, this section will only show the expansion -//! of `pin_init!` on `Foo`: -//! -//! ```rust,ignore -//! let a = 42; -//! let initializer = pin_init!(Foo { -//! a, -//! b <- Bar::new(36), -//! }); -//! ``` -//! -//! This expands to the following code: -//! -//! ```rust,ignore -//! let a = 42; -//! let initializer = { -//! struct __InitOk; -//! let data = unsafe { -//! use ::pin_init::__internal::HasPinData; -//! Foo::__pin_data() -//! }; -//! let init = ::pin_init::__internal::PinData::make_closure::< -//! _, -//! __InitOk, -//! ::core::convert::Infallible, -//! >(data, move |slot| { -//! { -//! struct __InitOk; -//! { -//! unsafe { ::core::ptr::write(::core::addr_of_mut!((*slot).a), a) }; -//! } -//! let __a_guard = unsafe { -//! ::pin_init::__internal::DropGuard::new(::core::addr_of_mut!((*slot).a)) -//! }; -//! let init = Bar::new(36); -//! unsafe { data.b(::core::addr_of_mut!((*slot).b), b)? }; -//! let __b_guard = unsafe { -//! ::pin_init::__internal::DropGuard::new(::core::addr_of_mut!((*slot).b)) -//! }; -//! ::core::mem::forget(__b_guard); -//! ::core::mem::forget(__a_guard); -//! #[allow(unreachable_code, clippy::diverging_sub_expression)] -//! let _ = || { -//! unsafe { -//! ::core::ptr::write( -//! slot, -//! Foo { -//! a: ::core::panic!(), -//! b: ::core::panic!(), -//! }, -//! ); -//! }; -//! }; -//! } -//! Ok(__InitOk) -//! }); -//! let init = move | -//! slot, -//! | -> ::core::result::Result<(), ::core::convert::Infallible> { -//! init(slot).map(|__InitOk| ()) -//! }; -//! let init = unsafe { -//! ::pin_init::pin_init_from_closure::<_, ::core::convert::Infallible>(init) -//! }; -//! init -//! }; -//! ``` - -#[cfg(kernel)] -pub use ::macros::paste; -#[cfg(not(kernel))] -pub use ::paste::paste; - -/// Creates a `unsafe impl<...> PinnedDrop for $type` block. -/// -/// See [`PinnedDrop`] for more information. -/// -/// [`PinnedDrop`]: crate::PinnedDrop -#[doc(hidden)] -#[macro_export] -macro_rules! __pinned_drop { - ( - @impl_sig($($impl_sig:tt)*), - @impl_body( - $(#[$($attr:tt)*])* - fn drop($($sig:tt)*) { - $($inner:tt)* - } - ), - ) => { - // SAFETY: TODO. - unsafe $($impl_sig)* { - // Inherit all attributes and the type/ident tokens for the signature. - $(#[$($attr)*])* - fn drop($($sig)*, _: $crate::__internal::OnlyCallFromDrop) { - $($inner)* - } - } - } -} - -/// This macro first parses the struct definition such that it separates pinned and not pinned -/// fields. Afterwards it declares the struct and implement the `PinData` trait safely. -#[doc(hidden)] -#[macro_export] -macro_rules! __pin_data { - // Proc-macro entry point, this is supplied by the proc-macro pre-parsing. - (parse_input: - @args($($pinned_drop:ident)?), - @sig( - $(#[$($struct_attr:tt)*])* - $vis:vis struct $name:ident - $(where $($whr:tt)*)? - ), - @impl_generics($($impl_generics:tt)*), - @ty_generics($($ty_generics:tt)*), - @decl_generics($($decl_generics:tt)*), - @body({ $($fields:tt)* }), - ) => { - // We now use token munching to iterate through all of the fields. While doing this we - // identify fields marked with `#[pin]`, these fields are the 'pinned fields'. The user - // wants these to be structurally pinned. The rest of the fields are the - // 'not pinned fields'. Additionally we collect all fields, since we need them in the right - // order to declare the struct. - // - // In this call we also put some explaining comments for the parameters. - $crate::__pin_data!(find_pinned_fields: - // Attributes on the struct itself, these will just be propagated to be put onto the - // struct definition. - @struct_attrs($(#[$($struct_attr)*])*), - // The visibility of the struct. - @vis($vis), - // The name of the struct. - @name($name), - // The 'impl generics', the generics that will need to be specified on the struct inside - // of an `impl<$ty_generics>` block. - @impl_generics($($impl_generics)*), - // The 'ty generics', the generics that will need to be specified on the impl blocks. - @ty_generics($($ty_generics)*), - // The 'decl generics', the generics that need to be specified on the struct - // definition. - @decl_generics($($decl_generics)*), - // The where clause of any impl block and the declaration. - @where($($($whr)*)?), - // The remaining fields tokens that need to be processed. - // We add a `,` at the end to ensure correct parsing. - @fields_munch($($fields)* ,), - // The pinned fields. - @pinned(), - // The not pinned fields. - @not_pinned(), - // All fields. - @fields(), - // The accumulator containing all attributes already parsed. - @accum(), - // Contains `yes` or `` to indicate if `#[pin]` was found on the current field. - @is_pinned(), - // The proc-macro argument, this should be `PinnedDrop` or ``. - @pinned_drop($($pinned_drop)?), - ); - }; - (find_pinned_fields: - @struct_attrs($($struct_attrs:tt)*), - @vis($vis:vis), - @name($name:ident), - @impl_generics($($impl_generics:tt)*), - @ty_generics($($ty_generics:tt)*), - @decl_generics($($decl_generics:tt)*), - @where($($whr:tt)*), - // We found a PhantomPinned field, this should generally be pinned! - @fields_munch($field:ident : $($($(::)?core::)?marker::)?PhantomPinned, $($rest:tt)*), - @pinned($($pinned:tt)*), - @not_pinned($($not_pinned:tt)*), - @fields($($fields:tt)*), - @accum($($accum:tt)*), - // This field is not pinned. - @is_pinned(), - @pinned_drop($($pinned_drop:ident)?), - ) => { - ::core::compile_error!(concat!( - "The field `", - stringify!($field), - "` of type `PhantomPinned` only has an effect, if it has the `#[pin]` attribute.", - )); - $crate::__pin_data!(find_pinned_fields: - @struct_attrs($($struct_attrs)*), - @vis($vis), - @name($name), - @impl_generics($($impl_generics)*), - @ty_generics($($ty_generics)*), - @decl_generics($($decl_generics)*), - @where($($whr)*), - @fields_munch($($rest)*), - @pinned($($pinned)* $($accum)* $field: ::core::marker::PhantomPinned,), - @not_pinned($($not_pinned)*), - @fields($($fields)* $($accum)* $field: ::core::marker::PhantomPinned,), - @accum(), - @is_pinned(), - @pinned_drop($($pinned_drop)?), - ); - }; - (find_pinned_fields: - @struct_attrs($($struct_attrs:tt)*), - @vis($vis:vis), - @name($name:ident), - @impl_generics($($impl_generics:tt)*), - @ty_generics($($ty_generics:tt)*), - @decl_generics($($decl_generics:tt)*), - @where($($whr:tt)*), - // We reached the field declaration. - @fields_munch($field:ident : $type:ty, $($rest:tt)*), - @pinned($($pinned:tt)*), - @not_pinned($($not_pinned:tt)*), - @fields($($fields:tt)*), - @accum($($accum:tt)*), - // This field is pinned. - @is_pinned(yes), - @pinned_drop($($pinned_drop:ident)?), - ) => { - $crate::__pin_data!(find_pinned_fields: - @struct_attrs($($struct_attrs)*), - @vis($vis), - @name($name), - @impl_generics($($impl_generics)*), - @ty_generics($($ty_generics)*), - @decl_generics($($decl_generics)*), - @where($($whr)*), - @fields_munch($($rest)*), - @pinned($($pinned)* $($accum)* $field: $type,), - @not_pinned($($not_pinned)*), - @fields($($fields)* $($accum)* $field: $type,), - @accum(), - @is_pinned(), - @pinned_drop($($pinned_drop)?), - ); - }; - (find_pinned_fields: - @struct_attrs($($struct_attrs:tt)*), - @vis($vis:vis), - @name($name:ident), - @impl_generics($($impl_generics:tt)*), - @ty_generics($($ty_generics:tt)*), - @decl_generics($($decl_generics:tt)*), - @where($($whr:tt)*), - // We reached the field declaration. - @fields_munch($field:ident : $type:ty, $($rest:tt)*), - @pinned($($pinned:tt)*), - @not_pinned($($not_pinned:tt)*), - @fields($($fields:tt)*), - @accum($($accum:tt)*), - // This field is not pinned. - @is_pinned(), - @pinned_drop($($pinned_drop:ident)?), - ) => { - $crate::__pin_data!(find_pinned_fields: - @struct_attrs($($struct_attrs)*), - @vis($vis), - @name($name), - @impl_generics($($impl_generics)*), - @ty_generics($($ty_generics)*), - @decl_generics($($decl_generics)*), - @where($($whr)*), - @fields_munch($($rest)*), - @pinned($($pinned)*), - @not_pinned($($not_pinned)* $($accum)* $field: $type,), - @fields($($fields)* $($accum)* $field: $type,), - @accum(), - @is_pinned(), - @pinned_drop($($pinned_drop)?), - ); - }; - (find_pinned_fields: - @struct_attrs($($struct_attrs:tt)*), - @vis($vis:vis), - @name($name:ident), - @impl_generics($($impl_generics:tt)*), - @ty_generics($($ty_generics:tt)*), - @decl_generics($($decl_generics:tt)*), - @where($($whr:tt)*), - // We found the `#[pin]` attr. - @fields_munch(#[pin] $($rest:tt)*), - @pinned($($pinned:tt)*), - @not_pinned($($not_pinned:tt)*), - @fields($($fields:tt)*), - @accum($($accum:tt)*), - @is_pinned($($is_pinned:ident)?), - @pinned_drop($($pinned_drop:ident)?), - ) => { - $crate::__pin_data!(find_pinned_fields: - @struct_attrs($($struct_attrs)*), - @vis($vis), - @name($name), - @impl_generics($($impl_generics)*), - @ty_generics($($ty_generics)*), - @decl_generics($($decl_generics)*), - @where($($whr)*), - @fields_munch($($rest)*), - // We do not include `#[pin]` in the list of attributes, since it is not actually an - // attribute that is defined somewhere. - @pinned($($pinned)*), - @not_pinned($($not_pinned)*), - @fields($($fields)*), - @accum($($accum)*), - // Set this to `yes`. - @is_pinned(yes), - @pinned_drop($($pinned_drop)?), - ); - }; - (find_pinned_fields: - @struct_attrs($($struct_attrs:tt)*), - @vis($vis:vis), - @name($name:ident), - @impl_generics($($impl_generics:tt)*), - @ty_generics($($ty_generics:tt)*), - @decl_generics($($decl_generics:tt)*), - @where($($whr:tt)*), - // We reached the field declaration with visibility, for simplicity we only munch the - // visibility and put it into `$accum`. - @fields_munch($fvis:vis $field:ident $($rest:tt)*), - @pinned($($pinned:tt)*), - @not_pinned($($not_pinned:tt)*), - @fields($($fields:tt)*), - @accum($($accum:tt)*), - @is_pinned($($is_pinned:ident)?), - @pinned_drop($($pinned_drop:ident)?), - ) => { - $crate::__pin_data!(find_pinned_fields: - @struct_attrs($($struct_attrs)*), - @vis($vis), - @name($name), - @impl_generics($($impl_generics)*), - @ty_generics($($ty_generics)*), - @decl_generics($($decl_generics)*), - @where($($whr)*), - @fields_munch($field $($rest)*), - @pinned($($pinned)*), - @not_pinned($($not_pinned)*), - @fields($($fields)*), - @accum($($accum)* $fvis), - @is_pinned($($is_pinned)?), - @pinned_drop($($pinned_drop)?), - ); - }; - (find_pinned_fields: - @struct_attrs($($struct_attrs:tt)*), - @vis($vis:vis), - @name($name:ident), - @impl_generics($($impl_generics:tt)*), - @ty_generics($($ty_generics:tt)*), - @decl_generics($($decl_generics:tt)*), - @where($($whr:tt)*), - // Some other attribute, just put it into `$accum`. - @fields_munch(#[$($attr:tt)*] $($rest:tt)*), - @pinned($($pinned:tt)*), - @not_pinned($($not_pinned:tt)*), - @fields($($fields:tt)*), - @accum($($accum:tt)*), - @is_pinned($($is_pinned:ident)?), - @pinned_drop($($pinned_drop:ident)?), - ) => { - $crate::__pin_data!(find_pinned_fields: - @struct_attrs($($struct_attrs)*), - @vis($vis), - @name($name), - @impl_generics($($impl_generics)*), - @ty_generics($($ty_generics)*), - @decl_generics($($decl_generics)*), - @where($($whr)*), - @fields_munch($($rest)*), - @pinned($($pinned)*), - @not_pinned($($not_pinned)*), - @fields($($fields)*), - @accum($($accum)* #[$($attr)*]), - @is_pinned($($is_pinned)?), - @pinned_drop($($pinned_drop)?), - ); - }; - (find_pinned_fields: - @struct_attrs($($struct_attrs:tt)*), - @vis($vis:vis), - @name($name:ident), - @impl_generics($($impl_generics:tt)*), - @ty_generics($($ty_generics:tt)*), - @decl_generics($($decl_generics:tt)*), - @where($($whr:tt)*), - // We reached the end of the fields, plus an optional additional comma, since we added one - // before and the user is also allowed to put a trailing comma. - @fields_munch($(,)?), - @pinned($($pinned:tt)*), - @not_pinned($($not_pinned:tt)*), - @fields($($fields:tt)*), - @accum(), - @is_pinned(), - @pinned_drop($($pinned_drop:ident)?), - ) => { - // Declare the struct with all fields in the correct order. - $($struct_attrs)* - $vis struct $name <$($decl_generics)*> - where $($whr)* - { - $($fields)* - } - - $crate::__pin_data!(make_pin_projections: - @vis($vis), - @name($name), - @impl_generics($($impl_generics)*), - @ty_generics($($ty_generics)*), - @decl_generics($($decl_generics)*), - @where($($whr)*), - @pinned($($pinned)*), - @not_pinned($($not_pinned)*), - ); - - // We put the rest into this const item, because it then will not be accessible to anything - // outside. - const _: () = { - // We declare this struct which will host all of the projection function for our type. - // it will be invariant over all generic parameters which are inherited from the - // struct. - $vis struct __ThePinData<$($impl_generics)*> - where $($whr)* - { - __phantom: ::core::marker::PhantomData< - fn($name<$($ty_generics)*>) -> $name<$($ty_generics)*> - >, - } - - impl<$($impl_generics)*> ::core::clone::Clone for __ThePinData<$($ty_generics)*> - where $($whr)* - { - fn clone(&self) -> Self { *self } - } - - impl<$($impl_generics)*> ::core::marker::Copy for __ThePinData<$($ty_generics)*> - where $($whr)* - {} - - // Make all projection functions. - $crate::__pin_data!(make_pin_data: - @pin_data(__ThePinData), - @impl_generics($($impl_generics)*), - @ty_generics($($ty_generics)*), - @where($($whr)*), - @pinned($($pinned)*), - @not_pinned($($not_pinned)*), - ); - - // SAFETY: We have added the correct projection functions above to `__ThePinData` and - // we also use the least restrictive generics possible. - unsafe impl<$($impl_generics)*> - $crate::__internal::HasPinData for $name<$($ty_generics)*> - where $($whr)* - { - type PinData = __ThePinData<$($ty_generics)*>; - - unsafe fn __pin_data() -> Self::PinData { - __ThePinData { __phantom: ::core::marker::PhantomData } - } - } - - // SAFETY: TODO. - unsafe impl<$($impl_generics)*> - $crate::__internal::PinData for __ThePinData<$($ty_generics)*> - where $($whr)* - { - type Datee = $name<$($ty_generics)*>; - } - - // This struct will be used for the unpin analysis. Since only structurally pinned - // fields are relevant whether the struct should implement `Unpin`. - #[allow(dead_code)] - struct __Unpin <'__pin, $($impl_generics)*> - where $($whr)* - { - __phantom_pin: ::core::marker::PhantomData<fn(&'__pin ()) -> &'__pin ()>, - __phantom: ::core::marker::PhantomData< - fn($name<$($ty_generics)*>) -> $name<$($ty_generics)*> - >, - // Only the pinned fields. - $($pinned)* - } - - #[doc(hidden)] - impl<'__pin, $($impl_generics)*> ::core::marker::Unpin for $name<$($ty_generics)*> - where - __Unpin<'__pin, $($ty_generics)*>: ::core::marker::Unpin, - $($whr)* - {} - - // We need to disallow normal `Drop` implementation, the exact behavior depends on - // whether `PinnedDrop` was specified as the parameter. - $crate::__pin_data!(drop_prevention: - @name($name), - @impl_generics($($impl_generics)*), - @ty_generics($($ty_generics)*), - @where($($whr)*), - @pinned_drop($($pinned_drop)?), - ); - }; - }; - // When no `PinnedDrop` was specified, then we have to prevent implementing drop. - (drop_prevention: - @name($name:ident), - @impl_generics($($impl_generics:tt)*), - @ty_generics($($ty_generics:tt)*), - @where($($whr:tt)*), - @pinned_drop(), - ) => { - // We prevent this by creating a trait that will be implemented for all types implementing - // `Drop`. Additionally we will implement this trait for the struct leading to a conflict, - // if it also implements `Drop` - trait MustNotImplDrop {} - #[expect(drop_bounds)] - impl<T: ::core::ops::Drop> MustNotImplDrop for T {} - impl<$($impl_generics)*> MustNotImplDrop for $name<$($ty_generics)*> - where $($whr)* {} - // We also take care to prevent users from writing a useless `PinnedDrop` implementation. - // They might implement `PinnedDrop` correctly for the struct, but forget to give - // `PinnedDrop` as the parameter to `#[pin_data]`. - #[expect(non_camel_case_types)] - trait UselessPinnedDropImpl_you_need_to_specify_PinnedDrop {} - impl<T: $crate::PinnedDrop> - UselessPinnedDropImpl_you_need_to_specify_PinnedDrop for T {} - impl<$($impl_generics)*> - UselessPinnedDropImpl_you_need_to_specify_PinnedDrop for $name<$($ty_generics)*> - where $($whr)* {} - }; - // When `PinnedDrop` was specified we just implement `Drop` and delegate. - (drop_prevention: - @name($name:ident), - @impl_generics($($impl_generics:tt)*), - @ty_generics($($ty_generics:tt)*), - @where($($whr:tt)*), - @pinned_drop(PinnedDrop), - ) => { - impl<$($impl_generics)*> ::core::ops::Drop for $name<$($ty_generics)*> - where $($whr)* - { - fn drop(&mut self) { - // SAFETY: Since this is a destructor, `self` will not move after this function - // terminates, since it is inaccessible. - let pinned = unsafe { ::core::pin::Pin::new_unchecked(self) }; - // SAFETY: Since this is a drop function, we can create this token to call the - // pinned destructor of this type. - let token = unsafe { $crate::__internal::OnlyCallFromDrop::new() }; - $crate::PinnedDrop::drop(pinned, token); - } - } - }; - // If some other parameter was specified, we emit a readable error. - (drop_prevention: - @name($name:ident), - @impl_generics($($impl_generics:tt)*), - @ty_generics($($ty_generics:tt)*), - @where($($whr:tt)*), - @pinned_drop($($rest:tt)*), - ) => { - compile_error!( - "Wrong parameters to `#[pin_data]`, expected nothing or `PinnedDrop`, got '{}'.", - stringify!($($rest)*), - ); - }; - (make_pin_projections: - @vis($vis:vis), - @name($name:ident), - @impl_generics($($impl_generics:tt)*), - @ty_generics($($ty_generics:tt)*), - @decl_generics($($decl_generics:tt)*), - @where($($whr:tt)*), - @pinned($($(#[$($p_attr:tt)*])* $pvis:vis $p_field:ident : $p_type:ty),* $(,)?), - @not_pinned($($(#[$($attr:tt)*])* $fvis:vis $field:ident : $type:ty),* $(,)?), - ) => { - $crate::macros::paste! { - #[doc(hidden)] - $vis struct [< $name Projection >] <'__pin, $($decl_generics)*> { - $($(#[$($p_attr)*])* $pvis $p_field : ::core::pin::Pin<&'__pin mut $p_type>,)* - $($(#[$($attr)*])* $fvis $field : &'__pin mut $type,)* - ___pin_phantom_data: ::core::marker::PhantomData<&'__pin mut ()>, - } - - impl<$($impl_generics)*> $name<$($ty_generics)*> - where $($whr)* - { - /// Pin-projects all fields of `Self`. - /// - /// These fields are structurally pinned: - $(#[doc = ::core::concat!(" - `", ::core::stringify!($p_field), "`")])* - /// - /// These fields are **not** structurally pinned: - $(#[doc = ::core::concat!(" - `", ::core::stringify!($field), "`")])* - #[inline] - $vis fn project<'__pin>( - self: ::core::pin::Pin<&'__pin mut Self>, - ) -> [< $name Projection >] <'__pin, $($ty_generics)*> { - // SAFETY: we only give access to `&mut` for fields not structurally pinned. - let this = unsafe { ::core::pin::Pin::get_unchecked_mut(self) }; - [< $name Projection >] { - $( - // SAFETY: `$p_field` is structurally pinned. - $(#[$($p_attr)*])* - $p_field : unsafe { ::core::pin::Pin::new_unchecked(&mut this.$p_field) }, - )* - $( - $(#[$($attr)*])* - $field : &mut this.$field, - )* - ___pin_phantom_data: ::core::marker::PhantomData, - } - } - } - } - }; - (make_pin_data: - @pin_data($pin_data:ident), - @impl_generics($($impl_generics:tt)*), - @ty_generics($($ty_generics:tt)*), - @where($($whr:tt)*), - @pinned($($(#[$($p_attr:tt)*])* $pvis:vis $p_field:ident : $p_type:ty),* $(,)?), - @not_pinned($($(#[$($attr:tt)*])* $fvis:vis $field:ident : $type:ty),* $(,)?), - ) => { - $crate::macros::paste! { - // For every field, we create a projection function according to its projection type. If a - // field is structurally pinned, then it must be initialized via `PinInit`, if it is not - // structurally pinned, then it can be initialized via `Init`. - // - // The functions are `unsafe` to prevent accidentally calling them. - #[allow(dead_code)] - #[expect(clippy::missing_safety_doc)] - impl<$($impl_generics)*> $pin_data<$($ty_generics)*> - where $($whr)* - { - $( - $(#[$($p_attr)*])* - $pvis unsafe fn $p_field<E>( - self, - slot: *mut $p_type, - init: impl $crate::PinInit<$p_type, E>, - ) -> ::core::result::Result<(), E> { - // SAFETY: TODO. - unsafe { $crate::PinInit::__pinned_init(init, slot) } - } - - $(#[$($p_attr)*])* - $pvis unsafe fn [<__project_ $p_field>]<'__slot>( - self, - slot: &'__slot mut $p_type, - ) -> ::core::pin::Pin<&'__slot mut $p_type> { - ::core::pin::Pin::new_unchecked(slot) - } - )* - $( - $(#[$($attr)*])* - $fvis unsafe fn $field<E>( - self, - slot: *mut $type, - init: impl $crate::Init<$type, E>, - ) -> ::core::result::Result<(), E> { - // SAFETY: TODO. - unsafe { $crate::Init::__init(init, slot) } - } - - $(#[$($attr)*])* - $fvis unsafe fn [<__project_ $field>]<'__slot>( - self, - slot: &'__slot mut $type, - ) -> &'__slot mut $type { - slot - } - )* - } - } - }; -} - -/// The internal init macro. Do not call manually! -/// -/// This is called by the `{try_}{pin_}init!` macros with various inputs. -/// -/// This macro has multiple internal call configurations, these are always the very first ident: -/// - nothing: this is the base case and called by the `{try_}{pin_}init!` macros. -/// - `with_update_parsed`: when the `..Zeroable::init_zeroed()` syntax has been handled. -/// - `init_slot`: recursively creates the code that initializes all fields in `slot`. -/// - `make_initializer`: recursively create the struct initializer that guarantees that every -/// field has been initialized exactly once. -#[doc(hidden)] -#[macro_export] -macro_rules! __init_internal { - ( - @this($($this:ident)?), - @typ($t:path), - @fields($($fields:tt)*), - @error($err:ty), - // Either `PinData` or `InitData`, `$use_data` should only be present in the `PinData` - // case. - @data($data:ident, $($use_data:ident)?), - // `HasPinData` or `HasInitData`. - @has_data($has_data:ident, $get_data:ident), - // `pin_init_from_closure` or `init_from_closure`. - @construct_closure($construct_closure:ident), - @munch_fields(), - ) => { - $crate::__init_internal!(with_update_parsed: - @this($($this)?), - @typ($t), - @fields($($fields)*), - @error($err), - @data($data, $($use_data)?), - @has_data($has_data, $get_data), - @construct_closure($construct_closure), - @init_zeroed(), // Nothing means default behavior. - ) - }; - ( - @this($($this:ident)?), - @typ($t:path), - @fields($($fields:tt)*), - @error($err:ty), - // Either `PinData` or `InitData`, `$use_data` should only be present in the `PinData` - // case. - @data($data:ident, $($use_data:ident)?), - // `HasPinData` or `HasInitData`. - @has_data($has_data:ident, $get_data:ident), - // `pin_init_from_closure` or `init_from_closure`. - @construct_closure($construct_closure:ident), - @munch_fields(..Zeroable::init_zeroed()), - ) => { - $crate::__init_internal!(with_update_parsed: - @this($($this)?), - @typ($t), - @fields($($fields)*), - @error($err), - @data($data, $($use_data)?), - @has_data($has_data, $get_data), - @construct_closure($construct_closure), - @init_zeroed(()), // `()` means zero all fields not mentioned. - ) - }; - ( - @this($($this:ident)?), - @typ($t:path), - @fields($($fields:tt)*), - @error($err:ty), - // Either `PinData` or `InitData`, `$use_data` should only be present in the `PinData` - // case. - @data($data:ident, $($use_data:ident)?), - // `HasPinData` or `HasInitData`. - @has_data($has_data:ident, $get_data:ident), - // `pin_init_from_closure` or `init_from_closure`. - @construct_closure($construct_closure:ident), - @munch_fields($ignore:tt $($rest:tt)*), - ) => { - $crate::__init_internal!( - @this($($this)?), - @typ($t), - @fields($($fields)*), - @error($err), - @data($data, $($use_data)?), - @has_data($has_data, $get_data), - @construct_closure($construct_closure), - @munch_fields($($rest)*), - ) - }; - (with_update_parsed: - @this($($this:ident)?), - @typ($t:path), - @fields($($fields:tt)*), - @error($err:ty), - // Either `PinData` or `InitData`, `$use_data` should only be present in the `PinData` - // case. - @data($data:ident, $($use_data:ident)?), - // `HasPinData` or `HasInitData`. - @has_data($has_data:ident, $get_data:ident), - // `pin_init_from_closure` or `init_from_closure`. - @construct_closure($construct_closure:ident), - @init_zeroed($($init_zeroed:expr)?), - ) => {{ - // We do not want to allow arbitrary returns, so we declare this type as the `Ok` return - // type and shadow it later when we insert the arbitrary user code. That way there will be - // no possibility of returning without `unsafe`. - struct __InitOk; - // Get the data about fields from the supplied type. - // - // SAFETY: TODO. - let data = unsafe { - use $crate::__internal::$has_data; - // Here we abuse `paste!` to retokenize `$t`. Declarative macros have some internal - // information that is associated to already parsed fragments, so a path fragment - // cannot be used in this position. Doing the retokenization results in valid rust - // code. - $crate::macros::paste!($t::$get_data()) - }; - // Ensure that `data` really is of type `$data` and help with type inference: - let init = $crate::__internal::$data::make_closure::<_, __InitOk, $err>( - data, - move |slot| { - { - // Shadow the structure so it cannot be used to return early. - struct __InitOk; - // If `$init_zeroed` is present we should zero the slot now and not emit an - // error when fields are missing (since they will be zeroed). We also have to - // check that the type actually implements `Zeroable`. - $({ - fn assert_zeroable<T: $crate::Zeroable>(_: *mut T) {} - // Ensure that the struct is indeed `Zeroable`. - assert_zeroable(slot); - // SAFETY: The type implements `Zeroable` by the check above. - unsafe { ::core::ptr::write_bytes(slot, 0, 1) }; - $init_zeroed // This will be `()` if set. - })? - // Create the `this` so it can be referenced by the user inside of the - // expressions creating the individual fields. - $(let $this = unsafe { ::core::ptr::NonNull::new_unchecked(slot) };)? - // Initialize every field. - $crate::__init_internal!(init_slot($($use_data)?): - @data(data), - @slot(slot), - @guards(), - @munch_fields($($fields)*,), - ); - // We use unreachable code to ensure that all fields have been mentioned exactly - // once, this struct initializer will still be type-checked and complain with a - // very natural error message if a field is forgotten/mentioned more than once. - #[allow(unreachable_code, clippy::diverging_sub_expression)] - let _ = || { - $crate::__init_internal!(make_initializer: - @slot(slot), - @type_name($t), - @munch_fields($($fields)*,), - @acc(), - ); - }; - } - Ok(__InitOk) - } - ); - let init = move |slot| -> ::core::result::Result<(), $err> { - init(slot).map(|__InitOk| ()) - }; - // SAFETY: TODO. - let init = unsafe { $crate::$construct_closure::<_, $err>(init) }; - init - }}; - (init_slot($($use_data:ident)?): - @data($data:ident), - @slot($slot:ident), - @guards($($guards:ident,)*), - @munch_fields($(..Zeroable::init_zeroed())? $(,)?), - ) => { - // Endpoint of munching, no fields are left. If execution reaches this point, all fields - // have been initialized. Therefore we can now dismiss the guards by forgetting them. - $(::core::mem::forget($guards);)* - }; - (init_slot($($use_data:ident)?): - @data($data:ident), - @slot($slot:ident), - @guards($($guards:ident,)*), - // arbitrary code block - @munch_fields(_: { $($code:tt)* }, $($rest:tt)*), - ) => { - { $($code)* } - $crate::__init_internal!(init_slot($($use_data)?): - @data($data), - @slot($slot), - @guards($($guards,)*), - @munch_fields($($rest)*), - ); - }; - (init_slot($use_data:ident): // `use_data` is present, so we use the `data` to init fields. - @data($data:ident), - @slot($slot:ident), - @guards($($guards:ident,)*), - // In-place initialization syntax. - @munch_fields($field:ident <- $val:expr, $($rest:tt)*), - ) => { - let init = $val; - // Call the initializer. - // - // SAFETY: `slot` is valid, because we are inside of an initializer closure, we - // return when an error/panic occurs. - // We also use the `data` to require the correct trait (`Init` or `PinInit`) for `$field`. - unsafe { $data.$field(::core::ptr::addr_of_mut!((*$slot).$field), init)? }; - // SAFETY: - // - the project function does the correct field projection, - // - the field has been initialized, - // - the reference is only valid until the end of the initializer. - #[allow(unused_variables)] - let $field = $crate::macros::paste!(unsafe { $data.[< __project_ $field >](&mut (*$slot).$field) }); - - // Create the drop guard: - // - // We rely on macro hygiene to make it impossible for users to access this local variable. - // We use `paste!` to create new hygiene for `$field`. - $crate::macros::paste! { - // SAFETY: We forget the guard later when initialization has succeeded. - let [< __ $field _guard >] = unsafe { - $crate::__internal::DropGuard::new(::core::ptr::addr_of_mut!((*$slot).$field)) - }; - - $crate::__init_internal!(init_slot($use_data): - @data($data), - @slot($slot), - @guards([< __ $field _guard >], $($guards,)*), - @munch_fields($($rest)*), - ); - } - }; - (init_slot(): // No `use_data`, so we use `Init::__init` directly. - @data($data:ident), - @slot($slot:ident), - @guards($($guards:ident,)*), - // In-place initialization syntax. - @munch_fields($field:ident <- $val:expr, $($rest:tt)*), - ) => { - let init = $val; - // Call the initializer. - // - // SAFETY: `slot` is valid, because we are inside of an initializer closure, we - // return when an error/panic occurs. - unsafe { $crate::Init::__init(init, ::core::ptr::addr_of_mut!((*$slot).$field))? }; - - // SAFETY: - // - the field is not structurally pinned, since the line above must compile, - // - the field has been initialized, - // - the reference is only valid until the end of the initializer. - #[allow(unused_variables)] - let $field = unsafe { &mut (*$slot).$field }; - - // Create the drop guard: - // - // We rely on macro hygiene to make it impossible for users to access this local variable. - // We use `paste!` to create new hygiene for `$field`. - $crate::macros::paste! { - // SAFETY: We forget the guard later when initialization has succeeded. - let [< __ $field _guard >] = unsafe { - $crate::__internal::DropGuard::new(::core::ptr::addr_of_mut!((*$slot).$field)) - }; - - $crate::__init_internal!(init_slot(): - @data($data), - @slot($slot), - @guards([< __ $field _guard >], $($guards,)*), - @munch_fields($($rest)*), - ); - } - }; - (init_slot(): // No `use_data`, so all fields are not structurally pinned - @data($data:ident), - @slot($slot:ident), - @guards($($guards:ident,)*), - // Init by-value. - @munch_fields($field:ident $(: $val:expr)?, $($rest:tt)*), - ) => { - { - $(let $field = $val;)? - // Initialize the field. - // - // SAFETY: The memory at `slot` is uninitialized. - unsafe { ::core::ptr::write(::core::ptr::addr_of_mut!((*$slot).$field), $field) }; - } - - #[allow(unused_variables)] - // SAFETY: - // - the field is not structurally pinned, since no `use_data` was required to create this - // initializer, - // - the field has been initialized, - // - the reference is only valid until the end of the initializer. - let $field = unsafe { &mut (*$slot).$field }; - - // Create the drop guard: - // - // We rely on macro hygiene to make it impossible for users to access this local variable. - // We use `paste!` to create new hygiene for `$field`. - $crate::macros::paste! { - // SAFETY: We forget the guard later when initialization has succeeded. - let [< __ $field _guard >] = unsafe { - $crate::__internal::DropGuard::new(::core::ptr::addr_of_mut!((*$slot).$field)) - }; - - $crate::__init_internal!(init_slot(): - @data($data), - @slot($slot), - @guards([< __ $field _guard >], $($guards,)*), - @munch_fields($($rest)*), - ); - } - }; - (init_slot($use_data:ident): - @data($data:ident), - @slot($slot:ident), - @guards($($guards:ident,)*), - // Init by-value. - @munch_fields($field:ident $(: $val:expr)?, $($rest:tt)*), - ) => { - { - $(let $field = $val;)? - // Initialize the field. - // - // SAFETY: The memory at `slot` is uninitialized. - unsafe { ::core::ptr::write(::core::ptr::addr_of_mut!((*$slot).$field), $field) }; - } - // SAFETY: - // - the project function does the correct field projection, - // - the field has been initialized, - // - the reference is only valid until the end of the initializer. - #[allow(unused_variables)] - let $field = $crate::macros::paste!(unsafe { $data.[< __project_ $field >](&mut (*$slot).$field) }); - - // Create the drop guard: - // - // We rely on macro hygiene to make it impossible for users to access this local variable. - // We use `paste!` to create new hygiene for `$field`. - $crate::macros::paste! { - // SAFETY: We forget the guard later when initialization has succeeded. - let [< __ $field _guard >] = unsafe { - $crate::__internal::DropGuard::new(::core::ptr::addr_of_mut!((*$slot).$field)) - }; - - $crate::__init_internal!(init_slot($use_data): - @data($data), - @slot($slot), - @guards([< __ $field _guard >], $($guards,)*), - @munch_fields($($rest)*), - ); - } - }; - (make_initializer: - @slot($slot:ident), - @type_name($t:path), - @munch_fields(_: { $($code:tt)* }, $($rest:tt)*), - @acc($($acc:tt)*), - ) => { - // code blocks are ignored for the initializer check - $crate::__init_internal!(make_initializer: - @slot($slot), - @type_name($t), - @munch_fields($($rest)*), - @acc($($acc)*), - ); - }; - (make_initializer: - @slot($slot:ident), - @type_name($t:path), - @munch_fields(..Zeroable::init_zeroed() $(,)?), - @acc($($acc:tt)*), - ) => { - // Endpoint, nothing more to munch, create the initializer. Since the users specified - // `..Zeroable::init_zeroed()`, the slot will already have been zeroed and all field that have - // not been overwritten are thus zero and initialized. We still check that all fields are - // actually accessible by using the struct update syntax ourselves. - // We are inside of a closure that is never executed and thus we can abuse `slot` to - // get the correct type inference here: - #[allow(unused_assignments)] - unsafe { - let mut zeroed = ::core::mem::zeroed(); - // We have to use type inference here to make zeroed have the correct type. This does - // not get executed, so it has no effect. - ::core::ptr::write($slot, zeroed); - zeroed = ::core::mem::zeroed(); - // Here we abuse `paste!` to retokenize `$t`. Declarative macros have some internal - // information that is associated to already parsed fragments, so a path fragment - // cannot be used in this position. Doing the retokenization results in valid rust - // code. - $crate::macros::paste!( - ::core::ptr::write($slot, $t { - $($acc)* - ..zeroed - }); - ); - } - }; - (make_initializer: - @slot($slot:ident), - @type_name($t:path), - @munch_fields($(,)?), - @acc($($acc:tt)*), - ) => { - // Endpoint, nothing more to munch, create the initializer. - // Since we are in the closure that is never called, this will never get executed. - // We abuse `slot` to get the correct type inference here: - // - // SAFETY: TODO. - unsafe { - // Here we abuse `paste!` to retokenize `$t`. Declarative macros have some internal - // information that is associated to already parsed fragments, so a path fragment - // cannot be used in this position. Doing the retokenization results in valid rust - // code. - $crate::macros::paste!( - ::core::ptr::write($slot, $t { - $($acc)* - }); - ); - } - }; - (make_initializer: - @slot($slot:ident), - @type_name($t:path), - @munch_fields($field:ident <- $val:expr, $($rest:tt)*), - @acc($($acc:tt)*), - ) => { - $crate::__init_internal!(make_initializer: - @slot($slot), - @type_name($t), - @munch_fields($($rest)*), - @acc($($acc)* $field: ::core::panic!(),), - ); - }; - (make_initializer: - @slot($slot:ident), - @type_name($t:path), - @munch_fields($field:ident $(: $val:expr)?, $($rest:tt)*), - @acc($($acc:tt)*), - ) => { - $crate::__init_internal!(make_initializer: - @slot($slot), - @type_name($t), - @munch_fields($($rest)*), - @acc($($acc)* $field: ::core::panic!(),), - ); - }; -} - -#[doc(hidden)] -#[macro_export] -macro_rules! __derive_zeroable { - (parse_input: - @sig( - $(#[$($struct_attr:tt)*])* - $vis:vis struct $name:ident - $(where $($whr:tt)*)? - ), - @impl_generics($($impl_generics:tt)*), - @ty_generics($($ty_generics:tt)*), - @body({ - $( - $(#[$($field_attr:tt)*])* - $field_vis:vis $field:ident : $field_ty:ty - ),* $(,)? - }), - ) => { - // SAFETY: Every field type implements `Zeroable` and padding bytes may be zero. - #[automatically_derived] - unsafe impl<$($impl_generics)*> $crate::Zeroable for $name<$($ty_generics)*> - where - $($($whr)*)? - {} - const _: () = { - fn assert_zeroable<T: ?::core::marker::Sized + $crate::Zeroable>() {} - fn ensure_zeroable<$($impl_generics)*>() - where $($($whr)*)? - { - $(assert_zeroable::<$field_ty>();)* - } - }; - }; - (parse_input: - @sig( - $(#[$($struct_attr:tt)*])* - $vis:vis union $name:ident - $(where $($whr:tt)*)? - ), - @impl_generics($($impl_generics:tt)*), - @ty_generics($($ty_generics:tt)*), - @body({ - $( - $(#[$($field_attr:tt)*])* - $field_vis:vis $field:ident : $field_ty:ty - ),* $(,)? - }), - ) => { - // SAFETY: Every field type implements `Zeroable` and padding bytes may be zero. - #[automatically_derived] - unsafe impl<$($impl_generics)*> $crate::Zeroable for $name<$($ty_generics)*> - where - $($($whr)*)? - {} - const _: () = { - fn assert_zeroable<T: ?::core::marker::Sized + $crate::Zeroable>() {} - fn ensure_zeroable<$($impl_generics)*>() - where $($($whr)*)? - { - $(assert_zeroable::<$field_ty>();)* - } - }; - }; -} - -#[doc(hidden)] -#[macro_export] -macro_rules! __maybe_derive_zeroable { - (parse_input: - @sig( - $(#[$($struct_attr:tt)*])* - $vis:vis struct $name:ident - $(where $($whr:tt)*)? - ), - @impl_generics($($impl_generics:tt)*), - @ty_generics($($ty_generics:tt)*), - @body({ - $( - $(#[$($field_attr:tt)*])* - $field_vis:vis $field:ident : $field_ty:ty - ),* $(,)? - }), - ) => { - // SAFETY: Every field type implements `Zeroable` and padding bytes may be zero. - #[automatically_derived] - unsafe impl<$($impl_generics)*> $crate::Zeroable for $name<$($ty_generics)*> - where - $( - // the `for<'__dummy>` HRTB makes this not error without the `trivial_bounds` - // feature <https://github.com/rust-lang/rust/issues/48214#issuecomment-2557829956>. - $field_ty: for<'__dummy> $crate::Zeroable, - )* - $($($whr)*)? - {} - }; - (parse_input: - @sig( - $(#[$($struct_attr:tt)*])* - $vis:vis union $name:ident - $(where $($whr:tt)*)? - ), - @impl_generics($($impl_generics:tt)*), - @ty_generics($($ty_generics:tt)*), - @body({ - $( - $(#[$($field_attr:tt)*])* - $field_vis:vis $field:ident : $field_ty:ty - ),* $(,)? - }), - ) => { - // SAFETY: Every field type implements `Zeroable` and padding bytes may be zero. - #[automatically_derived] - unsafe impl<$($impl_generics)*> $crate::Zeroable for $name<$($ty_generics)*> - where - $( - // the `for<'__dummy>` HRTB makes this not error without the `trivial_bounds` - // feature <https://github.com/rust-lang/rust/issues/48214#issuecomment-2557829956>. - $field_ty: for<'__dummy> $crate::Zeroable, - )* - $($($whr)*)? - {} - }; -} |
