diff options
| author | Benno Lossin <lossin@kernel.org> | 2026-01-16 11:54:24 +0100 |
|---|---|---|
| committer | Benno Lossin <lossin@kernel.org> | 2026-01-17 10:51:42 +0100 |
| commit | 4883830e9784bdf6223fe0e5f1ea36d4a4ab4fef (patch) | |
| tree | 293be1e715075c20816e02a7f6b58f87350b710f /rust/pin-init/internal | |
| parent | dae5466c4aa5b43a6cda4282bf9ff8e6b42ece0e (diff) | |
rust: pin-init: rewrite the initializer macros using `syn`
Rewrite the initializer macros `[pin_]init!` using `syn`. No functional
changes intended aside from improved error messages on syntactic and
semantical errors. For example if one forgets to use `<-` with an
initializer (and instead uses `:`):
impl Bar {
fn new() -> impl PinInit<Self> { ... }
}
impl Foo {
fn new() -> impl PinInit<Self> {
pin_init!(Self { bar: Bar::new() })
}
}
Then the declarative macro would report:
error[E0308]: mismatched types
--> tests/ui/compile-fail/init/colon_instead_of_arrow.rs:21:9
|
14 | fn new() -> impl PinInit<Self> {
| ------------------ the found opaque type
...
21 | pin_init!(Self { bar: Bar::new() })
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
| |
| expected `Bar`, found opaque type
| arguments to this function are incorrect
|
= note: expected struct `Bar`
found opaque type `impl pin_init::PinInit<Bar>`
note: function defined here
--> $RUST/core/src/ptr/mod.rs
|
| pub const unsafe fn write<T>(dst: *mut T, src: T) {
| ^^^^^
= note: this error originates in the macro `$crate::__init_internal` which comes from the expansion of the macro `pin_init` (in Nightly builds, run with -Z macro-backtrace for more info)
And the new error is:
error[E0308]: mismatched types
--> tests/ui/compile-fail/init/colon_instead_of_arrow.rs:21:31
|
14 | fn new() -> impl PinInit<Self> {
| ------------------ the found opaque type
...
21 | pin_init!(Self { bar: Bar::new() })
| --- ^^^^^^^^^^ expected `Bar`, found opaque type
| |
| arguments to this function are incorrect
|
= note: expected struct `Bar`
found opaque type `impl pin_init::PinInit<Bar>`
note: function defined here
--> $RUST/core/src/ptr/mod.rs
|
| pub const unsafe fn write<T>(dst: *mut T, src: T) {
| ^^^^^
Importantly, this error gives much more accurate span locations,
pointing to the offending field, rather than the entire macro
invocation.
Tested-by: Andreas Hindborg <a.hindborg@kernel.org>
Reviewed-by: Gary Guo <gary@garyguo.net>
Signed-off-by: Benno Lossin <lossin@kernel.org>
Diffstat (limited to 'rust/pin-init/internal')
| -rw-r--r-- | rust/pin-init/internal/src/init.rs | 445 | ||||
| -rw-r--r-- | rust/pin-init/internal/src/lib.rs | 13 |
2 files changed, 458 insertions, 0 deletions
diff --git a/rust/pin-init/internal/src/init.rs b/rust/pin-init/internal/src/init.rs new file mode 100644 index 000000000000..b0eb66224341 --- /dev/null +++ b/rust/pin-init/internal/src/init.rs @@ -0,0 +1,445 @@ +// SPDX-License-Identifier: Apache-2.0 OR MIT + +use proc_macro2::{Span, TokenStream}; +use quote::{format_ident, quote, quote_spanned}; +use syn::{ + braced, + parse::{End, Parse}, + parse_quote, + punctuated::Punctuated, + spanned::Spanned, + token, Block, Expr, ExprCall, ExprPath, Ident, Path, Token, Type, +}; + +use crate::diagnostics::{DiagCtxt, ErrorGuaranteed}; + +pub(crate) struct Initializer { + this: Option<This>, + path: Path, + brace_token: token::Brace, + fields: Punctuated<InitializerField, Token![,]>, + rest: Option<(Token![..], Expr)>, + error: Option<(Token![?], Type)>, +} + +struct This { + _and_token: Token![&], + ident: Ident, + _in_token: Token![in], +} + +enum InitializerField { + Value { + ident: Ident, + value: Option<(Token![:], Expr)>, + }, + Init { + ident: Ident, + _left_arrow_token: Token![<-], + value: Expr, + }, + Code { + _underscore_token: Token![_], + _colon_token: Token![:], + block: Block, + }, +} + +impl InitializerField { + fn ident(&self) -> Option<&Ident> { + match self { + Self::Value { ident, .. } | Self::Init { ident, .. } => Some(ident), + Self::Code { .. } => None, + } + } +} + +pub(crate) fn expand( + Initializer { + this, + path, + brace_token, + fields, + rest, + error, + }: Initializer, + default_error: Option<&'static str>, + pinned: bool, + dcx: &mut DiagCtxt, +) -> Result<TokenStream, ErrorGuaranteed> { + let error = error.map_or_else( + || { + if let Some(default_error) = default_error { + syn::parse_str(default_error).unwrap() + } else { + dcx.error(brace_token.span.close(), "expected `? <type>` after `}`"); + parse_quote!(::core::convert::Infallible) + } + }, + |(_, err)| err, + ); + let slot = format_ident!("slot"); + let (has_data_trait, data_trait, get_data, init_from_closure) = if pinned { + ( + format_ident!("HasPinData"), + format_ident!("PinData"), + format_ident!("__pin_data"), + format_ident!("pin_init_from_closure"), + ) + } else { + ( + format_ident!("HasInitData"), + format_ident!("InitData"), + format_ident!("__init_data"), + format_ident!("init_from_closure"), + ) + }; + let init_kind = get_init_kind(rest, dcx); + let zeroable_check = match init_kind { + InitKind::Normal => quote!(), + InitKind::Zeroing => quote! { + // The user specified `..Zeroable::zeroed()` at the end of the list of fields. + // Therefore we check if the struct implements `Zeroable` and then zero the memory. + // This allows us to also remove the check that all fields are present (since we + // already set the memory to zero and that is a valid bit pattern). + fn assert_zeroable<T: ?::core::marker::Sized>(_: *mut T) + where T: ::pin_init::Zeroable + {} + // Ensure that the struct is indeed `Zeroable`. + assert_zeroable(#slot); + // SAFETY: The type implements `Zeroable` by the check above. + unsafe { ::core::ptr::write_bytes(#slot, 0, 1) }; + }, + }; + let this = match this { + None => quote!(), + Some(This { ident, .. }) => quote! { + // Create the `this` so it can be referenced by the user inside of the + // expressions creating the individual fields. + let #ident = unsafe { ::core::ptr::NonNull::new_unchecked(slot) }; + }, + }; + // `mixed_site` ensures that the data is not accessible to the user-controlled code. + let data = Ident::new("__data", Span::mixed_site()); + let init_fields = init_fields(&fields, pinned, &data, &slot); + let field_check = make_field_check(&fields, init_kind, &path); + Ok(quote! {{ + // We do not want to allow arbitrary returns, so we declare this type as the `Ok` return + // type and shadow it later when we insert the arbitrary user code. That way there will be + // no possibility of returning without `unsafe`. + struct __InitOk; + + // Get the data about fields from the supplied type. + // SAFETY: TODO + let #data = unsafe { + use ::pin_init::__internal::#has_data_trait; + // Can't use `<#path as #has_data_trait>::#get_data`, since the user is able to omit + // generics (which need to be present with that syntax). + #path::#get_data() + }; + // Ensure that `#data` really is of type `#data` and help with type inference: + let init = ::pin_init::__internal::#data_trait::make_closure::<_, __InitOk, #error>( + #data, + move |slot| { + { + // Shadow the structure so it cannot be used to return early. + struct __InitOk; + #zeroable_check + #this + #init_fields + #field_check + } + Ok(__InitOk) + } + ); + let init = move |slot| -> ::core::result::Result<(), #error> { + init(slot).map(|__InitOk| ()) + }; + // SAFETY: TODO + let init = unsafe { ::pin_init::#init_from_closure::<_, #error>(init) }; + init + }}) +} + +enum InitKind { + Normal, + Zeroing, +} + +fn get_init_kind(rest: Option<(Token![..], Expr)>, dcx: &mut DiagCtxt) -> InitKind { + let Some((dotdot, expr)) = rest else { + return InitKind::Normal; + }; + match &expr { + Expr::Call(ExprCall { func, args, .. }) if args.is_empty() => match &**func { + Expr::Path(ExprPath { + attrs, + qself: None, + path: + Path { + leading_colon: None, + segments, + }, + }) if attrs.is_empty() + && segments.len() == 2 + && segments[0].ident == "Zeroable" + && segments[0].arguments.is_none() + && segments[1].ident == "init_zeroed" + && segments[1].arguments.is_none() => + { + return InitKind::Zeroing; + } + _ => {} + }, + _ => {} + } + dcx.error( + dotdot.span().join(expr.span()).unwrap_or(expr.span()), + "expected nothing or `..Zeroable::init_zeroed()`.", + ); + InitKind::Normal +} + +/// Generate the code that initializes the fields of the struct using the initializers in `field`. +fn init_fields( + fields: &Punctuated<InitializerField, Token![,]>, + pinned: bool, + data: &Ident, + slot: &Ident, +) -> TokenStream { + let mut guards = vec![]; + let mut res = TokenStream::new(); + for field in fields { + let init = match field { + InitializerField::Value { ident, value } => { + let mut value_ident = ident.clone(); + let value_prep = value.as_ref().map(|value| &value.1).map(|value| { + // Setting the span of `value_ident` to `value`'s span improves error messages + // when the type of `value` is wrong. + value_ident.set_span(value.span()); + quote!(let #value_ident = #value;) + }); + // Again span for better diagnostics + let write = quote_spanned!(ident.span()=> ::core::ptr::write); + let accessor = if pinned { + let project_ident = format_ident!("__project_{ident}"); + quote! { + // SAFETY: TODO + unsafe { #data.#project_ident(&mut (*#slot).#ident) } + } + } else { + quote! { + // SAFETY: TODO + unsafe { &mut (*#slot).#ident } + } + }; + quote! { + { + #value_prep + // SAFETY: TODO + unsafe { #write(::core::ptr::addr_of_mut!((*#slot).#ident), #value_ident) }; + } + #[allow(unused_variables)] + let #ident = #accessor; + } + } + InitializerField::Init { ident, value, .. } => { + // Again span for better diagnostics + let init = format_ident!("init", span = value.span()); + if pinned { + let project_ident = format_ident!("__project_{ident}"); + quote! { + { + let #init = #value; + // SAFETY: + // - `slot` is valid, because we are inside of an initializer closure, we + // return when an error/panic occurs. + // - We also use `#data` to require the correct trait (`Init` or `PinInit`) + // for `#ident`. + unsafe { #data.#ident(::core::ptr::addr_of_mut!((*#slot).#ident), #init)? }; + } + // SAFETY: TODO + #[allow(unused_variables)] + let #ident = unsafe { #data.#project_ident(&mut (*#slot).#ident) }; + } + } else { + quote! { + { + let #init = #value; + // SAFETY: `slot` is valid, because we are inside of an initializer + // closure, we return when an error/panic occurs. + unsafe { + ::pin_init::Init::__init( + #init, + ::core::ptr::addr_of_mut!((*#slot).#ident), + )? + }; + } + // SAFETY: TODO + #[allow(unused_variables)] + let #ident = unsafe { &mut (*#slot).#ident }; + } + } + } + InitializerField::Code { block: value, .. } => quote!(#[allow(unused_braces)] #value), + }; + res.extend(init); + if let Some(ident) = field.ident() { + // `mixed_site` ensures that the guard is not accessible to the user-controlled code. + let guard = format_ident!("__{ident}_guard", span = Span::mixed_site()); + guards.push(guard.clone()); + res.extend(quote! { + // Create the drop guard: + // + // We rely on macro hygiene to make it impossible for users to access this local + // variable. + // SAFETY: We forget the guard later when initialization has succeeded. + let #guard = unsafe { + ::pin_init::__internal::DropGuard::new( + ::core::ptr::addr_of_mut!((*slot).#ident) + ) + }; + }); + } + } + quote! { + #res + // If execution reaches this point, all fields have been initialized. Therefore we can now + // dismiss the guards by forgetting them. + #(::core::mem::forget(#guards);)* + } +} + +/// Generate the check for ensuring that every field has been initialized. +fn make_field_check( + fields: &Punctuated<InitializerField, Token![,]>, + init_kind: InitKind, + path: &Path, +) -> TokenStream { + let fields = fields.iter().filter_map(|f| f.ident()); + match init_kind { + InitKind::Normal => quote! { + // We use unreachable code to ensure that all fields have been mentioned exactly once, + // this struct initializer will still be type-checked and complain with a very natural + // error message if a field is forgotten/mentioned more than once. + #[allow(unreachable_code, clippy::diverging_sub_expression)] + // SAFETY: this code is never executed. + let _ = || unsafe { + ::core::ptr::write(slot, #path { + #( + #fields: ::core::panic!(), + )* + }) + }; + }, + InitKind::Zeroing => quote! { + // We use unreachable code to ensure that all fields have been mentioned at most once. + // Since the user specified `..Zeroable::zeroed()` at the end, all missing fields will + // be zeroed. This struct initializer will still be type-checked and complain with a + // very natural error message if a field is mentioned more than once, or doesn't exist. + #[allow(unreachable_code, clippy::diverging_sub_expression, unused_assignments)] + // SAFETY: this code is never executed. + let _ = || unsafe { + let mut zeroed = ::core::mem::zeroed(); + // We have to use type inference here to make zeroed have the correct type. This + // does not get executed, so it has no effect. + ::core::ptr::write(slot, zeroed); + zeroed = ::core::mem::zeroed(); + ::core::ptr::write(slot, #path { + #( + #fields: ::core::panic!(), + )* + ..zeroed + }) + }; + }, + } +} + +impl Parse for Initializer { + fn parse(input: syn::parse::ParseStream<'_>) -> syn::Result<Self> { + let this = input.peek(Token![&]).then(|| input.parse()).transpose()?; + let path = input.parse()?; + let content; + let brace_token = braced!(content in input); + let mut fields = Punctuated::new(); + loop { + let lh = content.lookahead1(); + if lh.peek(End) || lh.peek(Token![..]) { + break; + } else if lh.peek(Ident) || lh.peek(Token![_]) { + fields.push_value(content.parse()?); + let lh = content.lookahead1(); + if lh.peek(End) { + break; + } else if lh.peek(Token![,]) { + fields.push_punct(content.parse()?); + } else { + return Err(lh.error()); + } + } else { + return Err(lh.error()); + } + } + let rest = content + .peek(Token![..]) + .then(|| Ok::<_, syn::Error>((content.parse()?, content.parse()?))) + .transpose()?; + let error = input + .peek(Token![?]) + .then(|| Ok::<_, syn::Error>((input.parse()?, input.parse()?))) + .transpose()?; + Ok(Self { + this, + path, + brace_token, + fields, + rest, + error, + }) + } +} + +impl Parse for This { + fn parse(input: syn::parse::ParseStream<'_>) -> syn::Result<Self> { + Ok(Self { + _and_token: input.parse()?, + ident: input.parse()?, + _in_token: input.parse()?, + }) + } +} + +impl Parse for InitializerField { + fn parse(input: syn::parse::ParseStream<'_>) -> syn::Result<Self> { + let lh = input.lookahead1(); + if lh.peek(Token![_]) { + Ok(Self::Code { + _underscore_token: input.parse()?, + _colon_token: input.parse()?, + block: input.parse()?, + }) + } else if lh.peek(Ident) { + let ident = input.parse()?; + let lh = input.lookahead1(); + if lh.peek(Token![<-]) { + Ok(Self::Init { + ident, + _left_arrow_token: input.parse()?, + value: input.parse()?, + }) + } else if lh.peek(Token![:]) { + Ok(Self::Value { + ident, + value: Some((input.parse()?, input.parse()?)), + }) + } else if lh.peek(Token![,]) || lh.peek(End) { + Ok(Self::Value { ident, value: None }) + } else { + Err(lh.error()) + } + } else { + Err(lh.error()) + } + } +} diff --git a/rust/pin-init/internal/src/lib.rs b/rust/pin-init/internal/src/lib.rs index 56dc306e04a9..08372c8f65f0 100644 --- a/rust/pin-init/internal/src/lib.rs +++ b/rust/pin-init/internal/src/lib.rs @@ -16,6 +16,7 @@ use syn::parse_macro_input; use crate::diagnostics::DiagCtxt; mod diagnostics; +mod init; mod pin_data; mod pinned_drop; mod zeroable; @@ -45,3 +46,15 @@ pub fn maybe_derive_zeroable(input: TokenStream) -> TokenStream { let input = parse_macro_input!(input); DiagCtxt::with(|dcx| zeroable::maybe_derive(input, dcx)).into() } +#[proc_macro] +pub fn init(input: TokenStream) -> TokenStream { + let input = parse_macro_input!(input); + DiagCtxt::with(|dcx| init::expand(input, Some("::core::convert::Infallible"), false, dcx)) + .into() +} + +#[proc_macro] +pub fn pin_init(input: TokenStream) -> TokenStream { + let input = parse_macro_input!(input); + DiagCtxt::with(|dcx| init::expand(input, Some("::core::convert::Infallible"), true, dcx)).into() +} |
