summaryrefslogtreecommitdiff
path: root/rust/pin-init/internal
diff options
context:
space:
mode:
authorBenno Lossin <lossin@kernel.org>2026-01-16 11:54:24 +0100
committerBenno Lossin <lossin@kernel.org>2026-01-17 10:51:42 +0100
commit4883830e9784bdf6223fe0e5f1ea36d4a4ab4fef (patch)
tree293be1e715075c20816e02a7f6b58f87350b710f /rust/pin-init/internal
parentdae5466c4aa5b43a6cda4282bf9ff8e6b42ece0e (diff)
rust: pin-init: rewrite the initializer macros using `syn`
Rewrite the initializer macros `[pin_]init!` using `syn`. No functional changes intended aside from improved error messages on syntactic and semantical errors. For example if one forgets to use `<-` with an initializer (and instead uses `:`): impl Bar { fn new() -> impl PinInit<Self> { ... } } impl Foo { fn new() -> impl PinInit<Self> { pin_init!(Self { bar: Bar::new() }) } } Then the declarative macro would report: error[E0308]: mismatched types --> tests/ui/compile-fail/init/colon_instead_of_arrow.rs:21:9 | 14 | fn new() -> impl PinInit<Self> { | ------------------ the found opaque type ... 21 | pin_init!(Self { bar: Bar::new() }) | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ | | | expected `Bar`, found opaque type | arguments to this function are incorrect | = note: expected struct `Bar` found opaque type `impl pin_init::PinInit<Bar>` note: function defined here --> $RUST/core/src/ptr/mod.rs | | pub const unsafe fn write<T>(dst: *mut T, src: T) { | ^^^^^ = note: this error originates in the macro `$crate::__init_internal` which comes from the expansion of the macro `pin_init` (in Nightly builds, run with -Z macro-backtrace for more info) And the new error is: error[E0308]: mismatched types --> tests/ui/compile-fail/init/colon_instead_of_arrow.rs:21:31 | 14 | fn new() -> impl PinInit<Self> { | ------------------ the found opaque type ... 21 | pin_init!(Self { bar: Bar::new() }) | --- ^^^^^^^^^^ expected `Bar`, found opaque type | | | arguments to this function are incorrect | = note: expected struct `Bar` found opaque type `impl pin_init::PinInit<Bar>` note: function defined here --> $RUST/core/src/ptr/mod.rs | | pub const unsafe fn write<T>(dst: *mut T, src: T) { | ^^^^^ Importantly, this error gives much more accurate span locations, pointing to the offending field, rather than the entire macro invocation. Tested-by: Andreas Hindborg <a.hindborg@kernel.org> Reviewed-by: Gary Guo <gary@garyguo.net> Signed-off-by: Benno Lossin <lossin@kernel.org>
Diffstat (limited to 'rust/pin-init/internal')
-rw-r--r--rust/pin-init/internal/src/init.rs445
-rw-r--r--rust/pin-init/internal/src/lib.rs13
2 files changed, 458 insertions, 0 deletions
diff --git a/rust/pin-init/internal/src/init.rs b/rust/pin-init/internal/src/init.rs
new file mode 100644
index 000000000000..b0eb66224341
--- /dev/null
+++ b/rust/pin-init/internal/src/init.rs
@@ -0,0 +1,445 @@
+// SPDX-License-Identifier: Apache-2.0 OR MIT
+
+use proc_macro2::{Span, TokenStream};
+use quote::{format_ident, quote, quote_spanned};
+use syn::{
+ braced,
+ parse::{End, Parse},
+ parse_quote,
+ punctuated::Punctuated,
+ spanned::Spanned,
+ token, Block, Expr, ExprCall, ExprPath, Ident, Path, Token, Type,
+};
+
+use crate::diagnostics::{DiagCtxt, ErrorGuaranteed};
+
+pub(crate) struct Initializer {
+ this: Option<This>,
+ path: Path,
+ brace_token: token::Brace,
+ fields: Punctuated<InitializerField, Token![,]>,
+ rest: Option<(Token![..], Expr)>,
+ error: Option<(Token![?], Type)>,
+}
+
+struct This {
+ _and_token: Token![&],
+ ident: Ident,
+ _in_token: Token![in],
+}
+
+enum InitializerField {
+ Value {
+ ident: Ident,
+ value: Option<(Token![:], Expr)>,
+ },
+ Init {
+ ident: Ident,
+ _left_arrow_token: Token![<-],
+ value: Expr,
+ },
+ Code {
+ _underscore_token: Token![_],
+ _colon_token: Token![:],
+ block: Block,
+ },
+}
+
+impl InitializerField {
+ fn ident(&self) -> Option<&Ident> {
+ match self {
+ Self::Value { ident, .. } | Self::Init { ident, .. } => Some(ident),
+ Self::Code { .. } => None,
+ }
+ }
+}
+
+pub(crate) fn expand(
+ Initializer {
+ this,
+ path,
+ brace_token,
+ fields,
+ rest,
+ error,
+ }: Initializer,
+ default_error: Option<&'static str>,
+ pinned: bool,
+ dcx: &mut DiagCtxt,
+) -> Result<TokenStream, ErrorGuaranteed> {
+ let error = error.map_or_else(
+ || {
+ if let Some(default_error) = default_error {
+ syn::parse_str(default_error).unwrap()
+ } else {
+ dcx.error(brace_token.span.close(), "expected `? <type>` after `}`");
+ parse_quote!(::core::convert::Infallible)
+ }
+ },
+ |(_, err)| err,
+ );
+ let slot = format_ident!("slot");
+ let (has_data_trait, data_trait, get_data, init_from_closure) = if pinned {
+ (
+ format_ident!("HasPinData"),
+ format_ident!("PinData"),
+ format_ident!("__pin_data"),
+ format_ident!("pin_init_from_closure"),
+ )
+ } else {
+ (
+ format_ident!("HasInitData"),
+ format_ident!("InitData"),
+ format_ident!("__init_data"),
+ format_ident!("init_from_closure"),
+ )
+ };
+ let init_kind = get_init_kind(rest, dcx);
+ let zeroable_check = match init_kind {
+ InitKind::Normal => quote!(),
+ InitKind::Zeroing => quote! {
+ // The user specified `..Zeroable::zeroed()` at the end of the list of fields.
+ // Therefore we check if the struct implements `Zeroable` and then zero the memory.
+ // This allows us to also remove the check that all fields are present (since we
+ // already set the memory to zero and that is a valid bit pattern).
+ fn assert_zeroable<T: ?::core::marker::Sized>(_: *mut T)
+ where T: ::pin_init::Zeroable
+ {}
+ // Ensure that the struct is indeed `Zeroable`.
+ assert_zeroable(#slot);
+ // SAFETY: The type implements `Zeroable` by the check above.
+ unsafe { ::core::ptr::write_bytes(#slot, 0, 1) };
+ },
+ };
+ let this = match this {
+ None => quote!(),
+ Some(This { ident, .. }) => quote! {
+ // Create the `this` so it can be referenced by the user inside of the
+ // expressions creating the individual fields.
+ let #ident = unsafe { ::core::ptr::NonNull::new_unchecked(slot) };
+ },
+ };
+ // `mixed_site` ensures that the data is not accessible to the user-controlled code.
+ let data = Ident::new("__data", Span::mixed_site());
+ let init_fields = init_fields(&fields, pinned, &data, &slot);
+ let field_check = make_field_check(&fields, init_kind, &path);
+ Ok(quote! {{
+ // We do not want to allow arbitrary returns, so we declare this type as the `Ok` return
+ // type and shadow it later when we insert the arbitrary user code. That way there will be
+ // no possibility of returning without `unsafe`.
+ struct __InitOk;
+
+ // Get the data about fields from the supplied type.
+ // SAFETY: TODO
+ let #data = unsafe {
+ use ::pin_init::__internal::#has_data_trait;
+ // Can't use `<#path as #has_data_trait>::#get_data`, since the user is able to omit
+ // generics (which need to be present with that syntax).
+ #path::#get_data()
+ };
+ // Ensure that `#data` really is of type `#data` and help with type inference:
+ let init = ::pin_init::__internal::#data_trait::make_closure::<_, __InitOk, #error>(
+ #data,
+ move |slot| {
+ {
+ // Shadow the structure so it cannot be used to return early.
+ struct __InitOk;
+ #zeroable_check
+ #this
+ #init_fields
+ #field_check
+ }
+ Ok(__InitOk)
+ }
+ );
+ let init = move |slot| -> ::core::result::Result<(), #error> {
+ init(slot).map(|__InitOk| ())
+ };
+ // SAFETY: TODO
+ let init = unsafe { ::pin_init::#init_from_closure::<_, #error>(init) };
+ init
+ }})
+}
+
+enum InitKind {
+ Normal,
+ Zeroing,
+}
+
+fn get_init_kind(rest: Option<(Token![..], Expr)>, dcx: &mut DiagCtxt) -> InitKind {
+ let Some((dotdot, expr)) = rest else {
+ return InitKind::Normal;
+ };
+ match &expr {
+ Expr::Call(ExprCall { func, args, .. }) if args.is_empty() => match &**func {
+ Expr::Path(ExprPath {
+ attrs,
+ qself: None,
+ path:
+ Path {
+ leading_colon: None,
+ segments,
+ },
+ }) if attrs.is_empty()
+ && segments.len() == 2
+ && segments[0].ident == "Zeroable"
+ && segments[0].arguments.is_none()
+ && segments[1].ident == "init_zeroed"
+ && segments[1].arguments.is_none() =>
+ {
+ return InitKind::Zeroing;
+ }
+ _ => {}
+ },
+ _ => {}
+ }
+ dcx.error(
+ dotdot.span().join(expr.span()).unwrap_or(expr.span()),
+ "expected nothing or `..Zeroable::init_zeroed()`.",
+ );
+ InitKind::Normal
+}
+
+/// Generate the code that initializes the fields of the struct using the initializers in `field`.
+fn init_fields(
+ fields: &Punctuated<InitializerField, Token![,]>,
+ pinned: bool,
+ data: &Ident,
+ slot: &Ident,
+) -> TokenStream {
+ let mut guards = vec![];
+ let mut res = TokenStream::new();
+ for field in fields {
+ let init = match field {
+ InitializerField::Value { ident, value } => {
+ let mut value_ident = ident.clone();
+ let value_prep = value.as_ref().map(|value| &value.1).map(|value| {
+ // Setting the span of `value_ident` to `value`'s span improves error messages
+ // when the type of `value` is wrong.
+ value_ident.set_span(value.span());
+ quote!(let #value_ident = #value;)
+ });
+ // Again span for better diagnostics
+ let write = quote_spanned!(ident.span()=> ::core::ptr::write);
+ let accessor = if pinned {
+ let project_ident = format_ident!("__project_{ident}");
+ quote! {
+ // SAFETY: TODO
+ unsafe { #data.#project_ident(&mut (*#slot).#ident) }
+ }
+ } else {
+ quote! {
+ // SAFETY: TODO
+ unsafe { &mut (*#slot).#ident }
+ }
+ };
+ quote! {
+ {
+ #value_prep
+ // SAFETY: TODO
+ unsafe { #write(::core::ptr::addr_of_mut!((*#slot).#ident), #value_ident) };
+ }
+ #[allow(unused_variables)]
+ let #ident = #accessor;
+ }
+ }
+ InitializerField::Init { ident, value, .. } => {
+ // Again span for better diagnostics
+ let init = format_ident!("init", span = value.span());
+ if pinned {
+ let project_ident = format_ident!("__project_{ident}");
+ quote! {
+ {
+ let #init = #value;
+ // SAFETY:
+ // - `slot` is valid, because we are inside of an initializer closure, we
+ // return when an error/panic occurs.
+ // - We also use `#data` to require the correct trait (`Init` or `PinInit`)
+ // for `#ident`.
+ unsafe { #data.#ident(::core::ptr::addr_of_mut!((*#slot).#ident), #init)? };
+ }
+ // SAFETY: TODO
+ #[allow(unused_variables)]
+ let #ident = unsafe { #data.#project_ident(&mut (*#slot).#ident) };
+ }
+ } else {
+ quote! {
+ {
+ let #init = #value;
+ // SAFETY: `slot` is valid, because we are inside of an initializer
+ // closure, we return when an error/panic occurs.
+ unsafe {
+ ::pin_init::Init::__init(
+ #init,
+ ::core::ptr::addr_of_mut!((*#slot).#ident),
+ )?
+ };
+ }
+ // SAFETY: TODO
+ #[allow(unused_variables)]
+ let #ident = unsafe { &mut (*#slot).#ident };
+ }
+ }
+ }
+ InitializerField::Code { block: value, .. } => quote!(#[allow(unused_braces)] #value),
+ };
+ res.extend(init);
+ if let Some(ident) = field.ident() {
+ // `mixed_site` ensures that the guard is not accessible to the user-controlled code.
+ let guard = format_ident!("__{ident}_guard", span = Span::mixed_site());
+ guards.push(guard.clone());
+ res.extend(quote! {
+ // Create the drop guard:
+ //
+ // We rely on macro hygiene to make it impossible for users to access this local
+ // variable.
+ // SAFETY: We forget the guard later when initialization has succeeded.
+ let #guard = unsafe {
+ ::pin_init::__internal::DropGuard::new(
+ ::core::ptr::addr_of_mut!((*slot).#ident)
+ )
+ };
+ });
+ }
+ }
+ quote! {
+ #res
+ // If execution reaches this point, all fields have been initialized. Therefore we can now
+ // dismiss the guards by forgetting them.
+ #(::core::mem::forget(#guards);)*
+ }
+}
+
+/// Generate the check for ensuring that every field has been initialized.
+fn make_field_check(
+ fields: &Punctuated<InitializerField, Token![,]>,
+ init_kind: InitKind,
+ path: &Path,
+) -> TokenStream {
+ let fields = fields.iter().filter_map(|f| f.ident());
+ match init_kind {
+ InitKind::Normal => quote! {
+ // We use unreachable code to ensure that all fields have been mentioned exactly once,
+ // this struct initializer will still be type-checked and complain with a very natural
+ // error message if a field is forgotten/mentioned more than once.
+ #[allow(unreachable_code, clippy::diverging_sub_expression)]
+ // SAFETY: this code is never executed.
+ let _ = || unsafe {
+ ::core::ptr::write(slot, #path {
+ #(
+ #fields: ::core::panic!(),
+ )*
+ })
+ };
+ },
+ InitKind::Zeroing => quote! {
+ // We use unreachable code to ensure that all fields have been mentioned at most once.
+ // Since the user specified `..Zeroable::zeroed()` at the end, all missing fields will
+ // be zeroed. This struct initializer will still be type-checked and complain with a
+ // very natural error message if a field is mentioned more than once, or doesn't exist.
+ #[allow(unreachable_code, clippy::diverging_sub_expression, unused_assignments)]
+ // SAFETY: this code is never executed.
+ let _ = || unsafe {
+ let mut zeroed = ::core::mem::zeroed();
+ // We have to use type inference here to make zeroed have the correct type. This
+ // does not get executed, so it has no effect.
+ ::core::ptr::write(slot, zeroed);
+ zeroed = ::core::mem::zeroed();
+ ::core::ptr::write(slot, #path {
+ #(
+ #fields: ::core::panic!(),
+ )*
+ ..zeroed
+ })
+ };
+ },
+ }
+}
+
+impl Parse for Initializer {
+ fn parse(input: syn::parse::ParseStream<'_>) -> syn::Result<Self> {
+ let this = input.peek(Token![&]).then(|| input.parse()).transpose()?;
+ let path = input.parse()?;
+ let content;
+ let brace_token = braced!(content in input);
+ let mut fields = Punctuated::new();
+ loop {
+ let lh = content.lookahead1();
+ if lh.peek(End) || lh.peek(Token![..]) {
+ break;
+ } else if lh.peek(Ident) || lh.peek(Token![_]) {
+ fields.push_value(content.parse()?);
+ let lh = content.lookahead1();
+ if lh.peek(End) {
+ break;
+ } else if lh.peek(Token![,]) {
+ fields.push_punct(content.parse()?);
+ } else {
+ return Err(lh.error());
+ }
+ } else {
+ return Err(lh.error());
+ }
+ }
+ let rest = content
+ .peek(Token![..])
+ .then(|| Ok::<_, syn::Error>((content.parse()?, content.parse()?)))
+ .transpose()?;
+ let error = input
+ .peek(Token![?])
+ .then(|| Ok::<_, syn::Error>((input.parse()?, input.parse()?)))
+ .transpose()?;
+ Ok(Self {
+ this,
+ path,
+ brace_token,
+ fields,
+ rest,
+ error,
+ })
+ }
+}
+
+impl Parse for This {
+ fn parse(input: syn::parse::ParseStream<'_>) -> syn::Result<Self> {
+ Ok(Self {
+ _and_token: input.parse()?,
+ ident: input.parse()?,
+ _in_token: input.parse()?,
+ })
+ }
+}
+
+impl Parse for InitializerField {
+ fn parse(input: syn::parse::ParseStream<'_>) -> syn::Result<Self> {
+ let lh = input.lookahead1();
+ if lh.peek(Token![_]) {
+ Ok(Self::Code {
+ _underscore_token: input.parse()?,
+ _colon_token: input.parse()?,
+ block: input.parse()?,
+ })
+ } else if lh.peek(Ident) {
+ let ident = input.parse()?;
+ let lh = input.lookahead1();
+ if lh.peek(Token![<-]) {
+ Ok(Self::Init {
+ ident,
+ _left_arrow_token: input.parse()?,
+ value: input.parse()?,
+ })
+ } else if lh.peek(Token![:]) {
+ Ok(Self::Value {
+ ident,
+ value: Some((input.parse()?, input.parse()?)),
+ })
+ } else if lh.peek(Token![,]) || lh.peek(End) {
+ Ok(Self::Value { ident, value: None })
+ } else {
+ Err(lh.error())
+ }
+ } else {
+ Err(lh.error())
+ }
+ }
+}
diff --git a/rust/pin-init/internal/src/lib.rs b/rust/pin-init/internal/src/lib.rs
index 56dc306e04a9..08372c8f65f0 100644
--- a/rust/pin-init/internal/src/lib.rs
+++ b/rust/pin-init/internal/src/lib.rs
@@ -16,6 +16,7 @@ use syn::parse_macro_input;
use crate::diagnostics::DiagCtxt;
mod diagnostics;
+mod init;
mod pin_data;
mod pinned_drop;
mod zeroable;
@@ -45,3 +46,15 @@ pub fn maybe_derive_zeroable(input: TokenStream) -> TokenStream {
let input = parse_macro_input!(input);
DiagCtxt::with(|dcx| zeroable::maybe_derive(input, dcx)).into()
}
+#[proc_macro]
+pub fn init(input: TokenStream) -> TokenStream {
+ let input = parse_macro_input!(input);
+ DiagCtxt::with(|dcx| init::expand(input, Some("::core::convert::Infallible"), false, dcx))
+ .into()
+}
+
+#[proc_macro]
+pub fn pin_init(input: TokenStream) -> TokenStream {
+ let input = parse_macro_input!(input);
+ DiagCtxt::with(|dcx| init::expand(input, Some("::core::convert::Infallible"), true, dcx)).into()
+}