summaryrefslogtreecommitdiff
path: root/rust/pin-init/internal
diff options
context:
space:
mode:
Diffstat (limited to 'rust/pin-init/internal')
-rw-r--r--rust/pin-init/internal/src/diagnostics.rs30
-rw-r--r--rust/pin-init/internal/src/helpers.rs152
-rw-r--r--rust/pin-init/internal/src/init.rs548
-rw-r--r--rust/pin-init/internal/src/lib.rs48
-rw-r--r--rust/pin-init/internal/src/pin_data.rs615
-rw-r--r--rust/pin-init/internal/src/pinned_drop.rs88
-rw-r--r--rust/pin-init/internal/src/zeroable.rs157
7 files changed, 1219 insertions, 419 deletions
diff --git a/rust/pin-init/internal/src/diagnostics.rs b/rust/pin-init/internal/src/diagnostics.rs
new file mode 100644
index 000000000000..3bdb477c2f2b
--- /dev/null
+++ b/rust/pin-init/internal/src/diagnostics.rs
@@ -0,0 +1,30 @@
+// SPDX-License-Identifier: Apache-2.0 OR MIT
+
+use std::fmt::Display;
+
+use proc_macro2::TokenStream;
+use syn::{spanned::Spanned, Error};
+
+pub(crate) struct DiagCtxt(TokenStream);
+pub(crate) struct ErrorGuaranteed(());
+
+impl DiagCtxt {
+ pub(crate) fn error(&mut self, span: impl Spanned, msg: impl Display) -> ErrorGuaranteed {
+ let error = Error::new(span.span(), msg);
+ self.0.extend(error.into_compile_error());
+ ErrorGuaranteed(())
+ }
+
+ pub(crate) fn with(
+ fun: impl FnOnce(&mut DiagCtxt) -> Result<TokenStream, ErrorGuaranteed>,
+ ) -> TokenStream {
+ let mut dcx = Self(TokenStream::new());
+ match fun(&mut dcx) {
+ Ok(mut stream) => {
+ stream.extend(dcx.0);
+ stream
+ }
+ Err(ErrorGuaranteed(())) => dcx.0,
+ }
+ }
+}
diff --git a/rust/pin-init/internal/src/helpers.rs b/rust/pin-init/internal/src/helpers.rs
deleted file mode 100644
index 236f989a50f2..000000000000
--- a/rust/pin-init/internal/src/helpers.rs
+++ /dev/null
@@ -1,152 +0,0 @@
-// SPDX-License-Identifier: Apache-2.0 OR MIT
-
-#[cfg(not(kernel))]
-use proc_macro2 as proc_macro;
-
-use proc_macro::{TokenStream, TokenTree};
-
-/// Parsed generics.
-///
-/// See the field documentation for an explanation what each of the fields represents.
-///
-/// # Examples
-///
-/// ```rust,ignore
-/// # let input = todo!();
-/// let (Generics { decl_generics, impl_generics, ty_generics }, rest) = parse_generics(input);
-/// quote! {
-/// struct Foo<$($decl_generics)*> {
-/// // ...
-/// }
-///
-/// impl<$impl_generics> Foo<$ty_generics> {
-/// fn foo() {
-/// // ...
-/// }
-/// }
-/// }
-/// ```
-pub(crate) struct Generics {
- /// The generics with bounds and default values (e.g. `T: Clone, const N: usize = 0`).
- ///
- /// Use this on type definitions e.g. `struct Foo<$decl_generics> ...` (or `union`/`enum`).
- pub(crate) decl_generics: Vec<TokenTree>,
- /// The generics with bounds (e.g. `T: Clone, const N: usize`).
- ///
- /// Use this on `impl` blocks e.g. `impl<$impl_generics> Trait for ...`.
- pub(crate) impl_generics: Vec<TokenTree>,
- /// The generics without bounds and without default values (e.g. `T, N`).
- ///
- /// Use this when you use the type that is declared with these generics e.g.
- /// `Foo<$ty_generics>`.
- pub(crate) ty_generics: Vec<TokenTree>,
-}
-
-/// Parses the given `TokenStream` into `Generics` and the rest.
-///
-/// The generics are not present in the rest, but a where clause might remain.
-pub(crate) fn parse_generics(input: TokenStream) -> (Generics, Vec<TokenTree>) {
- // The generics with bounds and default values.
- let mut decl_generics = vec![];
- // `impl_generics`, the declared generics with their bounds.
- let mut impl_generics = vec![];
- // Only the names of the generics, without any bounds.
- let mut ty_generics = vec![];
- // Tokens not related to the generics e.g. the `where` token and definition.
- let mut rest = vec![];
- // The current level of `<`.
- let mut nesting = 0;
- let mut toks = input.into_iter();
- // If we are at the beginning of a generic parameter.
- let mut at_start = true;
- let mut skip_until_comma = false;
- while let Some(tt) = toks.next() {
- if nesting == 1 && matches!(&tt, TokenTree::Punct(p) if p.as_char() == '>') {
- // Found the end of the generics.
- break;
- } else if nesting >= 1 {
- decl_generics.push(tt.clone());
- }
- match tt.clone() {
- TokenTree::Punct(p) if p.as_char() == '<' => {
- if nesting >= 1 && !skip_until_comma {
- // This is inside of the generics and part of some bound.
- impl_generics.push(tt);
- }
- nesting += 1;
- }
- TokenTree::Punct(p) if p.as_char() == '>' => {
- // This is a parsing error, so we just end it here.
- if nesting == 0 {
- break;
- } else {
- nesting -= 1;
- if nesting >= 1 && !skip_until_comma {
- // We are still inside of the generics and part of some bound.
- impl_generics.push(tt);
- }
- }
- }
- TokenTree::Punct(p) if skip_until_comma && p.as_char() == ',' => {
- if nesting == 1 {
- impl_generics.push(tt.clone());
- impl_generics.push(tt);
- skip_until_comma = false;
- }
- }
- _ if !skip_until_comma => {
- match nesting {
- // If we haven't entered the generics yet, we still want to keep these tokens.
- 0 => rest.push(tt),
- 1 => {
- // Here depending on the token, it might be a generic variable name.
- match tt.clone() {
- TokenTree::Ident(i) if at_start && i.to_string() == "const" => {
- let Some(name) = toks.next() else {
- // Parsing error.
- break;
- };
- impl_generics.push(tt);
- impl_generics.push(name.clone());
- ty_generics.push(name.clone());
- decl_generics.push(name);
- at_start = false;
- }
- TokenTree::Ident(_) if at_start => {
- impl_generics.push(tt.clone());
- ty_generics.push(tt);
- at_start = false;
- }
- TokenTree::Punct(p) if p.as_char() == ',' => {
- impl_generics.push(tt.clone());
- ty_generics.push(tt);
- at_start = true;
- }
- // Lifetimes begin with `'`.
- TokenTree::Punct(p) if p.as_char() == '\'' && at_start => {
- impl_generics.push(tt.clone());
- ty_generics.push(tt);
- }
- // Generics can have default values, we skip these.
- TokenTree::Punct(p) if p.as_char() == '=' => {
- skip_until_comma = true;
- }
- _ => impl_generics.push(tt),
- }
- }
- _ => impl_generics.push(tt),
- }
- }
- _ => {}
- }
- }
- rest.extend(toks);
- (
- Generics {
- impl_generics,
- decl_generics,
- ty_generics,
- },
- rest,
- )
-}
diff --git a/rust/pin-init/internal/src/init.rs b/rust/pin-init/internal/src/init.rs
new file mode 100644
index 000000000000..42936f915a07
--- /dev/null
+++ b/rust/pin-init/internal/src/init.rs
@@ -0,0 +1,548 @@
+// SPDX-License-Identifier: Apache-2.0 OR MIT
+
+use proc_macro2::{Span, TokenStream};
+use quote::{format_ident, quote, quote_spanned};
+use syn::{
+ braced,
+ parse::{End, Parse},
+ parse_quote,
+ punctuated::Punctuated,
+ spanned::Spanned,
+ token, Attribute, Block, Expr, ExprCall, ExprPath, Ident, Path, Token, Type,
+};
+
+use crate::diagnostics::{DiagCtxt, ErrorGuaranteed};
+
+pub(crate) struct Initializer {
+ attrs: Vec<InitializerAttribute>,
+ this: Option<This>,
+ path: Path,
+ brace_token: token::Brace,
+ fields: Punctuated<InitializerField, Token![,]>,
+ rest: Option<(Token![..], Expr)>,
+ error: Option<(Token![?], Type)>,
+}
+
+struct This {
+ _and_token: Token![&],
+ ident: Ident,
+ _in_token: Token![in],
+}
+
+struct InitializerField {
+ attrs: Vec<Attribute>,
+ kind: InitializerKind,
+}
+
+enum InitializerKind {
+ Value {
+ ident: Ident,
+ value: Option<(Token![:], Expr)>,
+ },
+ Init {
+ ident: Ident,
+ _left_arrow_token: Token![<-],
+ value: Expr,
+ },
+ Code {
+ _underscore_token: Token![_],
+ _colon_token: Token![:],
+ block: Block,
+ },
+}
+
+impl InitializerKind {
+ fn ident(&self) -> Option<&Ident> {
+ match self {
+ Self::Value { ident, .. } | Self::Init { ident, .. } => Some(ident),
+ Self::Code { .. } => None,
+ }
+ }
+}
+
+enum InitializerAttribute {
+ DefaultError(DefaultErrorAttribute),
+ DisableInitializedFieldAccess,
+}
+
+struct DefaultErrorAttribute {
+ ty: Box<Type>,
+}
+
+pub(crate) fn expand(
+ Initializer {
+ attrs,
+ this,
+ path,
+ brace_token,
+ fields,
+ rest,
+ error,
+ }: Initializer,
+ default_error: Option<&'static str>,
+ pinned: bool,
+ dcx: &mut DiagCtxt,
+) -> Result<TokenStream, ErrorGuaranteed> {
+ let error = error.map_or_else(
+ || {
+ if let Some(default_error) = attrs.iter().fold(None, |acc, attr| {
+ if let InitializerAttribute::DefaultError(DefaultErrorAttribute { ty }) = attr {
+ Some(ty.clone())
+ } else {
+ acc
+ }
+ }) {
+ default_error
+ } else if let Some(default_error) = default_error {
+ syn::parse_str(default_error).unwrap()
+ } else {
+ dcx.error(brace_token.span.close(), "expected `? <type>` after `}`");
+ parse_quote!(::core::convert::Infallible)
+ }
+ },
+ |(_, err)| Box::new(err),
+ );
+ let slot = format_ident!("slot");
+ let (has_data_trait, data_trait, get_data, init_from_closure) = if pinned {
+ (
+ format_ident!("HasPinData"),
+ format_ident!("PinData"),
+ format_ident!("__pin_data"),
+ format_ident!("pin_init_from_closure"),
+ )
+ } else {
+ (
+ format_ident!("HasInitData"),
+ format_ident!("InitData"),
+ format_ident!("__init_data"),
+ format_ident!("init_from_closure"),
+ )
+ };
+ let init_kind = get_init_kind(rest, dcx);
+ let zeroable_check = match init_kind {
+ InitKind::Normal => quote!(),
+ InitKind::Zeroing => quote! {
+ // The user specified `..Zeroable::zeroed()` at the end of the list of fields.
+ // Therefore we check if the struct implements `Zeroable` and then zero the memory.
+ // This allows us to also remove the check that all fields are present (since we
+ // already set the memory to zero and that is a valid bit pattern).
+ fn assert_zeroable<T: ?::core::marker::Sized>(_: *mut T)
+ where T: ::pin_init::Zeroable
+ {}
+ // Ensure that the struct is indeed `Zeroable`.
+ assert_zeroable(#slot);
+ // SAFETY: The type implements `Zeroable` by the check above.
+ unsafe { ::core::ptr::write_bytes(#slot, 0, 1) };
+ },
+ };
+ let this = match this {
+ None => quote!(),
+ Some(This { ident, .. }) => quote! {
+ // Create the `this` so it can be referenced by the user inside of the
+ // expressions creating the individual fields.
+ let #ident = unsafe { ::core::ptr::NonNull::new_unchecked(slot) };
+ },
+ };
+ // `mixed_site` ensures that the data is not accessible to the user-controlled code.
+ let data = Ident::new("__data", Span::mixed_site());
+ let init_fields = init_fields(
+ &fields,
+ pinned,
+ !attrs
+ .iter()
+ .any(|attr| matches!(attr, InitializerAttribute::DisableInitializedFieldAccess)),
+ &data,
+ &slot,
+ );
+ let field_check = make_field_check(&fields, init_kind, &path);
+ Ok(quote! {{
+ // We do not want to allow arbitrary returns, so we declare this type as the `Ok` return
+ // type and shadow it later when we insert the arbitrary user code. That way there will be
+ // no possibility of returning without `unsafe`.
+ struct __InitOk;
+
+ // Get the data about fields from the supplied type.
+ // SAFETY: TODO
+ let #data = unsafe {
+ use ::pin_init::__internal::#has_data_trait;
+ // Can't use `<#path as #has_data_trait>::#get_data`, since the user is able to omit
+ // generics (which need to be present with that syntax).
+ #path::#get_data()
+ };
+ // Ensure that `#data` really is of type `#data` and help with type inference:
+ let init = ::pin_init::__internal::#data_trait::make_closure::<_, __InitOk, #error>(
+ #data,
+ move |slot| {
+ {
+ // Shadow the structure so it cannot be used to return early.
+ struct __InitOk;
+ #zeroable_check
+ #this
+ #init_fields
+ #field_check
+ }
+ Ok(__InitOk)
+ }
+ );
+ let init = move |slot| -> ::core::result::Result<(), #error> {
+ init(slot).map(|__InitOk| ())
+ };
+ // SAFETY: TODO
+ let init = unsafe { ::pin_init::#init_from_closure::<_, #error>(init) };
+ init
+ }})
+}
+
+enum InitKind {
+ Normal,
+ Zeroing,
+}
+
+fn get_init_kind(rest: Option<(Token![..], Expr)>, dcx: &mut DiagCtxt) -> InitKind {
+ let Some((dotdot, expr)) = rest else {
+ return InitKind::Normal;
+ };
+ match &expr {
+ Expr::Call(ExprCall { func, args, .. }) if args.is_empty() => match &**func {
+ Expr::Path(ExprPath {
+ attrs,
+ qself: None,
+ path:
+ Path {
+ leading_colon: None,
+ segments,
+ },
+ }) if attrs.is_empty()
+ && segments.len() == 2
+ && segments[0].ident == "Zeroable"
+ && segments[0].arguments.is_none()
+ && segments[1].ident == "init_zeroed"
+ && segments[1].arguments.is_none() =>
+ {
+ return InitKind::Zeroing;
+ }
+ _ => {}
+ },
+ _ => {}
+ }
+ dcx.error(
+ dotdot.span().join(expr.span()).unwrap_or(expr.span()),
+ "expected nothing or `..Zeroable::init_zeroed()`.",
+ );
+ InitKind::Normal
+}
+
+/// Generate the code that initializes the fields of the struct using the initializers in `field`.
+fn init_fields(
+ fields: &Punctuated<InitializerField, Token![,]>,
+ pinned: bool,
+ generate_initialized_accessors: bool,
+ data: &Ident,
+ slot: &Ident,
+) -> TokenStream {
+ let mut guards = vec![];
+ let mut guard_attrs = vec![];
+ let mut res = TokenStream::new();
+ for InitializerField { attrs, kind } in fields {
+ let cfgs = {
+ let mut cfgs = attrs.clone();
+ cfgs.retain(|attr| attr.path().is_ident("cfg"));
+ cfgs
+ };
+ let init = match kind {
+ InitializerKind::Value { ident, value } => {
+ let mut value_ident = ident.clone();
+ let value_prep = value.as_ref().map(|value| &value.1).map(|value| {
+ // Setting the span of `value_ident` to `value`'s span improves error messages
+ // when the type of `value` is wrong.
+ value_ident.set_span(value.span());
+ quote!(let #value_ident = #value;)
+ });
+ // Again span for better diagnostics
+ let write = quote_spanned!(ident.span()=> ::core::ptr::write);
+ let accessor = if pinned {
+ let project_ident = format_ident!("__project_{ident}");
+ quote! {
+ // SAFETY: TODO
+ unsafe { #data.#project_ident(&mut (*#slot).#ident) }
+ }
+ } else {
+ quote! {
+ // SAFETY: TODO
+ unsafe { &mut (*#slot).#ident }
+ }
+ };
+ let accessor = generate_initialized_accessors.then(|| {
+ quote! {
+ #(#cfgs)*
+ #[allow(unused_variables)]
+ let #ident = #accessor;
+ }
+ });
+ quote! {
+ #(#attrs)*
+ {
+ #value_prep
+ // SAFETY: TODO
+ unsafe { #write(::core::ptr::addr_of_mut!((*#slot).#ident), #value_ident) };
+ }
+ #accessor
+ }
+ }
+ InitializerKind::Init { ident, value, .. } => {
+ // Again span for better diagnostics
+ let init = format_ident!("init", span = value.span());
+ let (value_init, accessor) = if pinned {
+ let project_ident = format_ident!("__project_{ident}");
+ (
+ quote! {
+ // SAFETY:
+ // - `slot` is valid, because we are inside of an initializer closure, we
+ // return when an error/panic occurs.
+ // - We also use `#data` to require the correct trait (`Init` or `PinInit`)
+ // for `#ident`.
+ unsafe { #data.#ident(::core::ptr::addr_of_mut!((*#slot).#ident), #init)? };
+ },
+ quote! {
+ // SAFETY: TODO
+ unsafe { #data.#project_ident(&mut (*#slot).#ident) }
+ },
+ )
+ } else {
+ (
+ quote! {
+ // SAFETY: `slot` is valid, because we are inside of an initializer
+ // closure, we return when an error/panic occurs.
+ unsafe {
+ ::pin_init::Init::__init(
+ #init,
+ ::core::ptr::addr_of_mut!((*#slot).#ident),
+ )?
+ };
+ },
+ quote! {
+ // SAFETY: TODO
+ unsafe { &mut (*#slot).#ident }
+ },
+ )
+ };
+ let accessor = generate_initialized_accessors.then(|| {
+ quote! {
+ #(#cfgs)*
+ #[allow(unused_variables)]
+ let #ident = #accessor;
+ }
+ });
+ quote! {
+ #(#attrs)*
+ {
+ let #init = #value;
+ #value_init
+ }
+ #accessor
+ }
+ }
+ InitializerKind::Code { block: value, .. } => quote! {
+ #(#attrs)*
+ #[allow(unused_braces)]
+ #value
+ },
+ };
+ res.extend(init);
+ if let Some(ident) = kind.ident() {
+ // `mixed_site` ensures that the guard is not accessible to the user-controlled code.
+ let guard = format_ident!("__{ident}_guard", span = Span::mixed_site());
+ res.extend(quote! {
+ #(#cfgs)*
+ // Create the drop guard:
+ //
+ // We rely on macro hygiene to make it impossible for users to access this local
+ // variable.
+ // SAFETY: We forget the guard later when initialization has succeeded.
+ let #guard = unsafe {
+ ::pin_init::__internal::DropGuard::new(
+ ::core::ptr::addr_of_mut!((*slot).#ident)
+ )
+ };
+ });
+ guards.push(guard);
+ guard_attrs.push(cfgs);
+ }
+ }
+ quote! {
+ #res
+ // If execution reaches this point, all fields have been initialized. Therefore we can now
+ // dismiss the guards by forgetting them.
+ #(
+ #(#guard_attrs)*
+ ::core::mem::forget(#guards);
+ )*
+ }
+}
+
+/// Generate the check for ensuring that every field has been initialized.
+fn make_field_check(
+ fields: &Punctuated<InitializerField, Token![,]>,
+ init_kind: InitKind,
+ path: &Path,
+) -> TokenStream {
+ let field_attrs = fields
+ .iter()
+ .filter_map(|f| f.kind.ident().map(|_| &f.attrs));
+ let field_name = fields.iter().filter_map(|f| f.kind.ident());
+ match init_kind {
+ InitKind::Normal => quote! {
+ // We use unreachable code to ensure that all fields have been mentioned exactly once,
+ // this struct initializer will still be type-checked and complain with a very natural
+ // error message if a field is forgotten/mentioned more than once.
+ #[allow(unreachable_code, clippy::diverging_sub_expression)]
+ // SAFETY: this code is never executed.
+ let _ = || unsafe {
+ ::core::ptr::write(slot, #path {
+ #(
+ #(#field_attrs)*
+ #field_name: ::core::panic!(),
+ )*
+ })
+ };
+ },
+ InitKind::Zeroing => quote! {
+ // We use unreachable code to ensure that all fields have been mentioned at most once.
+ // Since the user specified `..Zeroable::zeroed()` at the end, all missing fields will
+ // be zeroed. This struct initializer will still be type-checked and complain with a
+ // very natural error message if a field is mentioned more than once, or doesn't exist.
+ #[allow(unreachable_code, clippy::diverging_sub_expression, unused_assignments)]
+ // SAFETY: this code is never executed.
+ let _ = || unsafe {
+ ::core::ptr::write(slot, #path {
+ #(
+ #(#field_attrs)*
+ #field_name: ::core::panic!(),
+ )*
+ ..::core::mem::zeroed()
+ })
+ };
+ },
+ }
+}
+
+impl Parse for Initializer {
+ fn parse(input: syn::parse::ParseStream<'_>) -> syn::Result<Self> {
+ let attrs = input.call(Attribute::parse_outer)?;
+ let this = input.peek(Token![&]).then(|| input.parse()).transpose()?;
+ let path = input.parse()?;
+ let content;
+ let brace_token = braced!(content in input);
+ let mut fields = Punctuated::new();
+ loop {
+ let lh = content.lookahead1();
+ if lh.peek(End) || lh.peek(Token![..]) {
+ break;
+ } else if lh.peek(Ident) || lh.peek(Token![_]) || lh.peek(Token![#]) {
+ fields.push_value(content.parse()?);
+ let lh = content.lookahead1();
+ if lh.peek(End) {
+ break;
+ } else if lh.peek(Token![,]) {
+ fields.push_punct(content.parse()?);
+ } else {
+ return Err(lh.error());
+ }
+ } else {
+ return Err(lh.error());
+ }
+ }
+ let rest = content
+ .peek(Token![..])
+ .then(|| Ok::<_, syn::Error>((content.parse()?, content.parse()?)))
+ .transpose()?;
+ let error = input
+ .peek(Token![?])
+ .then(|| Ok::<_, syn::Error>((input.parse()?, input.parse()?)))
+ .transpose()?;
+ let attrs = attrs
+ .into_iter()
+ .map(|a| {
+ if a.path().is_ident("default_error") {
+ a.parse_args::<DefaultErrorAttribute>()
+ .map(InitializerAttribute::DefaultError)
+ } else if a.path().is_ident("disable_initialized_field_access") {
+ a.meta
+ .require_path_only()
+ .map(|_| InitializerAttribute::DisableInitializedFieldAccess)
+ } else {
+ Err(syn::Error::new_spanned(a, "unknown initializer attribute"))
+ }
+ })
+ .collect::<Result<Vec<_>, _>>()?;
+ Ok(Self {
+ attrs,
+ this,
+ path,
+ brace_token,
+ fields,
+ rest,
+ error,
+ })
+ }
+}
+
+impl Parse for DefaultErrorAttribute {
+ fn parse(input: syn::parse::ParseStream<'_>) -> syn::Result<Self> {
+ Ok(Self { ty: input.parse()? })
+ }
+}
+
+impl Parse for This {
+ fn parse(input: syn::parse::ParseStream<'_>) -> syn::Result<Self> {
+ Ok(Self {
+ _and_token: input.parse()?,
+ ident: input.parse()?,
+ _in_token: input.parse()?,
+ })
+ }
+}
+
+impl Parse for InitializerField {
+ fn parse(input: syn::parse::ParseStream<'_>) -> syn::Result<Self> {
+ let attrs = input.call(Attribute::parse_outer)?;
+ Ok(Self {
+ attrs,
+ kind: input.parse()?,
+ })
+ }
+}
+
+impl Parse for InitializerKind {
+ fn parse(input: syn::parse::ParseStream<'_>) -> syn::Result<Self> {
+ let lh = input.lookahead1();
+ if lh.peek(Token![_]) {
+ Ok(Self::Code {
+ _underscore_token: input.parse()?,
+ _colon_token: input.parse()?,
+ block: input.parse()?,
+ })
+ } else if lh.peek(Ident) {
+ let ident = input.parse()?;
+ let lh = input.lookahead1();
+ if lh.peek(Token![<-]) {
+ Ok(Self::Init {
+ ident,
+ _left_arrow_token: input.parse()?,
+ value: input.parse()?,
+ })
+ } else if lh.peek(Token![:]) {
+ Ok(Self::Value {
+ ident,
+ value: Some((input.parse()?, input.parse()?)),
+ })
+ } else if lh.peek(Token![,]) || lh.peek(End) {
+ Ok(Self::Value { ident, value: None })
+ } else {
+ Err(lh.error())
+ }
+ } else {
+ Err(lh.error())
+ }
+ }
+}
diff --git a/rust/pin-init/internal/src/lib.rs b/rust/pin-init/internal/src/lib.rs
index 297b0129a5bf..08372c8f65f0 100644
--- a/rust/pin-init/internal/src/lib.rs
+++ b/rust/pin-init/internal/src/lib.rs
@@ -7,48 +7,54 @@
//! `pin-init` proc macros.
#![cfg_attr(not(RUSTC_LINT_REASONS_IS_STABLE), feature(lint_reasons))]
-// Allow `.into()` to convert
-// - `proc_macro2::TokenStream` into `proc_macro::TokenStream` in the user-space version.
-// - `proc_macro::TokenStream` into `proc_macro::TokenStream` in the kernel version.
-// Clippy warns on this conversion, but it's required by the user-space version.
-//
-// Remove once we have `proc_macro2` in the kernel.
-#![allow(clippy::useless_conversion)]
// Documentation is done in the pin-init crate instead.
#![allow(missing_docs)]
use proc_macro::TokenStream;
+use syn::parse_macro_input;
-#[cfg(kernel)]
-#[path = "../../../macros/quote.rs"]
-#[macro_use]
-#[cfg_attr(not(kernel), rustfmt::skip)]
-mod quote;
-#[cfg(not(kernel))]
-#[macro_use]
-extern crate quote;
+use crate::diagnostics::DiagCtxt;
-mod helpers;
+mod diagnostics;
+mod init;
mod pin_data;
mod pinned_drop;
mod zeroable;
#[proc_macro_attribute]
-pub fn pin_data(inner: TokenStream, item: TokenStream) -> TokenStream {
- pin_data::pin_data(inner.into(), item.into()).into()
+pub fn pin_data(args: TokenStream, input: TokenStream) -> TokenStream {
+ let args = parse_macro_input!(args);
+ let input = parse_macro_input!(input);
+ DiagCtxt::with(|dcx| pin_data::pin_data(args, input, dcx)).into()
}
#[proc_macro_attribute]
pub fn pinned_drop(args: TokenStream, input: TokenStream) -> TokenStream {
- pinned_drop::pinned_drop(args.into(), input.into()).into()
+ let args = parse_macro_input!(args);
+ let input = parse_macro_input!(input);
+ DiagCtxt::with(|dcx| pinned_drop::pinned_drop(args, input, dcx)).into()
}
#[proc_macro_derive(Zeroable)]
pub fn derive_zeroable(input: TokenStream) -> TokenStream {
- zeroable::derive(input.into()).into()
+ let input = parse_macro_input!(input);
+ DiagCtxt::with(|dcx| zeroable::derive(input, dcx)).into()
}
#[proc_macro_derive(MaybeZeroable)]
pub fn maybe_derive_zeroable(input: TokenStream) -> TokenStream {
- zeroable::maybe_derive(input.into()).into()
+ let input = parse_macro_input!(input);
+ DiagCtxt::with(|dcx| zeroable::maybe_derive(input, dcx)).into()
+}
+#[proc_macro]
+pub fn init(input: TokenStream) -> TokenStream {
+ let input = parse_macro_input!(input);
+ DiagCtxt::with(|dcx| init::expand(input, Some("::core::convert::Infallible"), false, dcx))
+ .into()
+}
+
+#[proc_macro]
+pub fn pin_init(input: TokenStream) -> TokenStream {
+ let input = parse_macro_input!(input);
+ DiagCtxt::with(|dcx| init::expand(input, Some("::core::convert::Infallible"), true, dcx)).into()
}
diff --git a/rust/pin-init/internal/src/pin_data.rs b/rust/pin-init/internal/src/pin_data.rs
index 87d4a7eb1d35..7d871236b49c 100644
--- a/rust/pin-init/internal/src/pin_data.rs
+++ b/rust/pin-init/internal/src/pin_data.rs
@@ -1,132 +1,513 @@
// SPDX-License-Identifier: Apache-2.0 OR MIT
-#[cfg(not(kernel))]
-use proc_macro2 as proc_macro;
+use proc_macro2::TokenStream;
+use quote::{format_ident, quote};
+use syn::{
+ parse::{End, Nothing, Parse},
+ parse_quote, parse_quote_spanned,
+ spanned::Spanned,
+ visit_mut::VisitMut,
+ Field, Generics, Ident, Item, PathSegment, Type, TypePath, Visibility, WhereClause,
+};
-use crate::helpers::{parse_generics, Generics};
-use proc_macro::{Group, Punct, Spacing, TokenStream, TokenTree};
+use crate::diagnostics::{DiagCtxt, ErrorGuaranteed};
-pub(crate) fn pin_data(args: TokenStream, input: TokenStream) -> TokenStream {
- // This proc-macro only does some pre-parsing and then delegates the actual parsing to
- // `pin_init::__pin_data!`.
+pub(crate) mod kw {
+ syn::custom_keyword!(PinnedDrop);
+}
+
+pub(crate) enum Args {
+ Nothing(Nothing),
+ #[allow(dead_code)]
+ PinnedDrop(kw::PinnedDrop),
+}
+
+impl Parse for Args {
+ fn parse(input: syn::parse::ParseStream<'_>) -> syn::Result<Self> {
+ let lh = input.lookahead1();
+ if lh.peek(End) {
+ input.parse().map(Self::Nothing)
+ } else if lh.peek(kw::PinnedDrop) {
+ input.parse().map(Self::PinnedDrop)
+ } else {
+ Err(lh.error())
+ }
+ }
+}
+
+pub(crate) fn pin_data(
+ args: Args,
+ input: Item,
+ dcx: &mut DiagCtxt,
+) -> Result<TokenStream, ErrorGuaranteed> {
+ let mut struct_ = match input {
+ Item::Struct(struct_) => struct_,
+ Item::Enum(enum_) => {
+ return Err(dcx.error(
+ enum_.enum_token,
+ "`#[pin_data]` only supports structs for now",
+ ));
+ }
+ Item::Union(union) => {
+ return Err(dcx.error(
+ union.union_token,
+ "`#[pin_data]` only supports structs for now",
+ ));
+ }
+ rest => {
+ return Err(dcx.error(
+ rest,
+ "`#[pin_data]` can only be applied to struct, enum and union definitions",
+ ));
+ }
+ };
+
+ // The generics might contain the `Self` type. Since this macro will define a new type with the
+ // same generics and bounds, this poses a problem: `Self` will refer to the new type as opposed
+ // to this struct definition. Therefore we have to replace `Self` with the concrete name.
+ let mut replacer = {
+ let name = &struct_.ident;
+ let (_, ty_generics, _) = struct_.generics.split_for_impl();
+ SelfReplacer(parse_quote!(#name #ty_generics))
+ };
+ replacer.visit_generics_mut(&mut struct_.generics);
+ replacer.visit_fields_mut(&mut struct_.fields);
+
+ let fields: Vec<(bool, &Field)> = struct_
+ .fields
+ .iter_mut()
+ .map(|field| {
+ let len = field.attrs.len();
+ field.attrs.retain(|a| !a.path().is_ident("pin"));
+ (len != field.attrs.len(), &*field)
+ })
+ .collect();
+
+ for (pinned, field) in &fields {
+ if !pinned && is_phantom_pinned(&field.ty) {
+ dcx.error(
+ field,
+ format!(
+ "The field `{}` of type `PhantomPinned` only has an effect \
+ if it has the `#[pin]` attribute",
+ field.ident.as_ref().unwrap(),
+ ),
+ );
+ }
+ }
+
+ let unpin_impl = generate_unpin_impl(&struct_.ident, &struct_.generics, &fields);
+ let drop_impl = generate_drop_impl(&struct_.ident, &struct_.generics, args);
+ let projections =
+ generate_projections(&struct_.vis, &struct_.ident, &struct_.generics, &fields);
+ let the_pin_data =
+ generate_the_pin_data(&struct_.vis, &struct_.ident, &struct_.generics, &fields);
+
+ Ok(quote! {
+ #struct_
+ #projections
+ // We put the rest into this const item, because it then will not be accessible to anything
+ // outside.
+ const _: () = {
+ #the_pin_data
+ #unpin_impl
+ #drop_impl
+ };
+ })
+}
+
+fn is_phantom_pinned(ty: &Type) -> bool {
+ match ty {
+ Type::Path(TypePath { qself: None, path }) => {
+ // Cannot possibly refer to `PhantomPinned` (except alias, but that's on the user).
+ if path.segments.len() > 3 {
+ return false;
+ }
+ // If there is a `::`, then the path needs to be `::core::marker::PhantomPinned` or
+ // `::std::marker::PhantomPinned`.
+ if path.leading_colon.is_some() && path.segments.len() != 3 {
+ return false;
+ }
+ let expected: Vec<&[&str]> = vec![&["PhantomPinned"], &["marker"], &["core", "std"]];
+ for (actual, expected) in path.segments.iter().rev().zip(expected) {
+ if !actual.arguments.is_empty() || expected.iter().all(|e| actual.ident != e) {
+ return false;
+ }
+ }
+ true
+ }
+ _ => false,
+ }
+}
+fn generate_unpin_impl(
+ ident: &Ident,
+ generics: &Generics,
+ fields: &[(bool, &Field)],
+) -> TokenStream {
+ let (_, ty_generics, _) = generics.split_for_impl();
+ let mut generics_with_pin_lt = generics.clone();
+ generics_with_pin_lt.params.insert(0, parse_quote!('__pin));
+ generics_with_pin_lt.make_where_clause();
let (
- Generics {
- impl_generics,
- decl_generics,
- ty_generics,
- },
- rest,
- ) = parse_generics(input);
- // The struct definition might contain the `Self` type. Since `__pin_data!` will define a new
- // type with the same generics and bounds, this poses a problem, since `Self` will refer to the
- // new type as opposed to this struct definition. Therefore we have to replace `Self` with the
- // concrete name.
-
- // Errors that occur when replacing `Self` with `struct_name`.
- let mut errs = TokenStream::new();
- // The name of the struct with ty_generics.
- let struct_name = rest
- .iter()
- .skip_while(|tt| !matches!(tt, TokenTree::Ident(i) if i.to_string() == "struct"))
- .nth(1)
- .and_then(|tt| match tt {
- TokenTree::Ident(_) => {
- let tt = tt.clone();
- let mut res = vec![tt];
- if !ty_generics.is_empty() {
- // We add this, so it is maximally compatible with e.g. `Self::CONST` which
- // will be replaced by `StructName::<$generics>::CONST`.
- res.push(TokenTree::Punct(Punct::new(':', Spacing::Joint)));
- res.push(TokenTree::Punct(Punct::new(':', Spacing::Alone)));
- res.push(TokenTree::Punct(Punct::new('<', Spacing::Alone)));
- res.extend(ty_generics.iter().cloned());
- res.push(TokenTree::Punct(Punct::new('>', Spacing::Alone)));
+ impl_generics_with_pin_lt,
+ ty_generics_with_pin_lt,
+ Some(WhereClause {
+ where_token,
+ predicates,
+ }),
+ ) = generics_with_pin_lt.split_for_impl()
+ else {
+ unreachable!()
+ };
+ let pinned_fields = fields.iter().filter_map(|(b, f)| b.then_some(f));
+ quote! {
+ // This struct will be used for the unpin analysis. It is needed, because only structurally
+ // pinned fields are relevant whether the struct should implement `Unpin`.
+ #[allow(dead_code)] // The fields below are never used.
+ struct __Unpin #generics_with_pin_lt
+ #where_token
+ #predicates
+ {
+ __phantom_pin: ::core::marker::PhantomData<fn(&'__pin ()) -> &'__pin ()>,
+ __phantom: ::core::marker::PhantomData<
+ fn(#ident #ty_generics) -> #ident #ty_generics
+ >,
+ #(#pinned_fields),*
+ }
+
+ #[doc(hidden)]
+ impl #impl_generics_with_pin_lt ::core::marker::Unpin for #ident #ty_generics
+ #where_token
+ __Unpin #ty_generics_with_pin_lt: ::core::marker::Unpin,
+ #predicates
+ {}
+ }
+}
+
+fn generate_drop_impl(ident: &Ident, generics: &Generics, args: Args) -> TokenStream {
+ let (impl_generics, ty_generics, whr) = generics.split_for_impl();
+ let has_pinned_drop = matches!(args, Args::PinnedDrop(_));
+ // We need to disallow normal `Drop` implementation, the exact behavior depends on whether
+ // `PinnedDrop` was specified in `args`.
+ if has_pinned_drop {
+ // When `PinnedDrop` was specified we just implement `Drop` and delegate.
+ quote! {
+ impl #impl_generics ::core::ops::Drop for #ident #ty_generics
+ #whr
+ {
+ fn drop(&mut self) {
+ // SAFETY: Since this is a destructor, `self` will not move after this function
+ // terminates, since it is inaccessible.
+ let pinned = unsafe { ::core::pin::Pin::new_unchecked(self) };
+ // SAFETY: Since this is a drop function, we can create this token to call the
+ // pinned destructor of this type.
+ let token = unsafe { ::pin_init::__internal::OnlyCallFromDrop::new() };
+ ::pin_init::PinnedDrop::drop(pinned, token);
}
- Some(res)
}
- _ => None,
- })
- .unwrap_or_else(|| {
- // If we did not find the name of the struct then we will use `Self` as the replacement
- // and add a compile error to ensure it does not compile.
- errs.extend(
- "::core::compile_error!(\"Could not locate type name.\");"
- .parse::<TokenStream>()
- .unwrap(),
- );
- "Self".parse::<TokenStream>().unwrap().into_iter().collect()
- });
- let impl_generics = impl_generics
- .into_iter()
- .flat_map(|tt| replace_self_and_deny_type_defs(&struct_name, tt, &mut errs))
- .collect::<Vec<_>>();
- let mut rest = rest
- .into_iter()
- .flat_map(|tt| {
- // We ignore top level `struct` tokens, since they would emit a compile error.
- if matches!(&tt, TokenTree::Ident(i) if i.to_string() == "struct") {
- vec![tt]
+ }
+ } else {
+ // When no `PinnedDrop` was specified, then we have to prevent implementing drop.
+ quote! {
+ // We prevent this by creating a trait that will be implemented for all types implementing
+ // `Drop`. Additionally we will implement this trait for the struct leading to a conflict,
+ // if it also implements `Drop`
+ trait MustNotImplDrop {}
+ #[expect(drop_bounds)]
+ impl<T: ::core::ops::Drop + ?::core::marker::Sized> MustNotImplDrop for T {}
+ impl #impl_generics MustNotImplDrop for #ident #ty_generics
+ #whr
+ {}
+ // We also take care to prevent users from writing a useless `PinnedDrop` implementation.
+ // They might implement `PinnedDrop` correctly for the struct, but forget to give
+ // `PinnedDrop` as the parameter to `#[pin_data]`.
+ #[expect(non_camel_case_types)]
+ trait UselessPinnedDropImpl_you_need_to_specify_PinnedDrop {}
+ impl<T: ::pin_init::PinnedDrop + ?::core::marker::Sized>
+ UselessPinnedDropImpl_you_need_to_specify_PinnedDrop for T {}
+ impl #impl_generics
+ UselessPinnedDropImpl_you_need_to_specify_PinnedDrop for #ident #ty_generics
+ #whr
+ {}
+ }
+ }
+}
+
+fn generate_projections(
+ vis: &Visibility,
+ ident: &Ident,
+ generics: &Generics,
+ fields: &[(bool, &Field)],
+) -> TokenStream {
+ let (impl_generics, ty_generics, _) = generics.split_for_impl();
+ let mut generics_with_pin_lt = generics.clone();
+ generics_with_pin_lt.params.insert(0, parse_quote!('__pin));
+ let (_, ty_generics_with_pin_lt, whr) = generics_with_pin_lt.split_for_impl();
+ let projection = format_ident!("{ident}Projection");
+ let this = format_ident!("this");
+
+ let (fields_decl, fields_proj) = collect_tuple(fields.iter().map(
+ |(
+ pinned,
+ Field {
+ vis,
+ ident,
+ ty,
+ attrs,
+ ..
+ },
+ )| {
+ let mut attrs = attrs.clone();
+ attrs.retain(|a| !a.path().is_ident("pin"));
+ let mut no_doc_attrs = attrs.clone();
+ no_doc_attrs.retain(|a| !a.path().is_ident("doc"));
+ let ident = ident
+ .as_ref()
+ .expect("only structs with named fields are supported");
+ if *pinned {
+ (
+ quote!(
+ #(#attrs)*
+ #vis #ident: ::core::pin::Pin<&'__pin mut #ty>,
+ ),
+ quote!(
+ #(#no_doc_attrs)*
+ // SAFETY: this field is structurally pinned.
+ #ident: unsafe { ::core::pin::Pin::new_unchecked(&mut #this.#ident) },
+ ),
+ )
} else {
- replace_self_and_deny_type_defs(&struct_name, tt, &mut errs)
+ (
+ quote!(
+ #(#attrs)*
+ #vis #ident: &'__pin mut #ty,
+ ),
+ quote!(
+ #(#no_doc_attrs)*
+ #ident: &mut #this.#ident,
+ ),
+ )
}
- })
- .collect::<Vec<_>>();
- // This should be the body of the struct `{...}`.
- let last = rest.pop();
- let mut quoted = quote!(::pin_init::__pin_data! {
- parse_input:
- @args(#args),
- @sig(#(#rest)*),
- @impl_generics(#(#impl_generics)*),
- @ty_generics(#(#ty_generics)*),
- @decl_generics(#(#decl_generics)*),
- @body(#last),
- });
- quoted.extend(errs);
- quoted
+ },
+ ));
+ let structurally_pinned_fields_docs = fields
+ .iter()
+ .filter_map(|(pinned, field)| pinned.then_some(field))
+ .map(|Field { ident, .. }| format!(" - `{}`", ident.as_ref().unwrap()));
+ let not_structurally_pinned_fields_docs = fields
+ .iter()
+ .filter_map(|(pinned, field)| (!pinned).then_some(field))
+ .map(|Field { ident, .. }| format!(" - `{}`", ident.as_ref().unwrap()));
+ let docs = format!(" Pin-projections of [`{ident}`]");
+ quote! {
+ #[doc = #docs]
+ #[allow(dead_code)]
+ #[doc(hidden)]
+ #vis struct #projection #generics_with_pin_lt {
+ #(#fields_decl)*
+ ___pin_phantom_data: ::core::marker::PhantomData<&'__pin mut ()>,
+ }
+
+ impl #impl_generics #ident #ty_generics
+ #whr
+ {
+ /// Pin-projects all fields of `Self`.
+ ///
+ /// These fields are structurally pinned:
+ #(#[doc = #structurally_pinned_fields_docs])*
+ ///
+ /// These fields are **not** structurally pinned:
+ #(#[doc = #not_structurally_pinned_fields_docs])*
+ #[inline]
+ #vis fn project<'__pin>(
+ self: ::core::pin::Pin<&'__pin mut Self>,
+ ) -> #projection #ty_generics_with_pin_lt {
+ // SAFETY: we only give access to `&mut` for fields not structurally pinned.
+ let #this = unsafe { ::core::pin::Pin::get_unchecked_mut(self) };
+ #projection {
+ #(#fields_proj)*
+ ___pin_phantom_data: ::core::marker::PhantomData,
+ }
+ }
+ }
+ }
}
-/// Replaces `Self` with `struct_name` and errors on `enum`, `trait`, `struct` `union` and `impl`
-/// keywords.
-///
-/// The error is appended to `errs` to allow normal parsing to continue.
-fn replace_self_and_deny_type_defs(
- struct_name: &Vec<TokenTree>,
- tt: TokenTree,
- errs: &mut TokenStream,
-) -> Vec<TokenTree> {
- match tt {
- TokenTree::Ident(ref i)
- if i.to_string() == "enum"
- || i.to_string() == "trait"
- || i.to_string() == "struct"
- || i.to_string() == "union"
- || i.to_string() == "impl" =>
+fn generate_the_pin_data(
+ vis: &Visibility,
+ ident: &Ident,
+ generics: &Generics,
+ fields: &[(bool, &Field)],
+) -> TokenStream {
+ let (impl_generics, ty_generics, whr) = generics.split_for_impl();
+
+ // For every field, we create an initializing projection function according to its projection
+ // type. If a field is structurally pinned, then it must be initialized via `PinInit`, if it is
+ // not structurally pinned, then it can be initialized via `Init`.
+ //
+ // The functions are `unsafe` to prevent accidentally calling them.
+ fn handle_field(
+ Field {
+ vis,
+ ident,
+ ty,
+ attrs,
+ ..
+ }: &Field,
+ struct_ident: &Ident,
+ pinned: bool,
+ ) -> TokenStream {
+ let mut attrs = attrs.clone();
+ attrs.retain(|a| !a.path().is_ident("pin"));
+ let ident = ident
+ .as_ref()
+ .expect("only structs with named fields are supported");
+ let project_ident = format_ident!("__project_{ident}");
+ let (init_ty, init_fn, project_ty, project_body, pin_safety) = if pinned {
+ (
+ quote!(PinInit),
+ quote!(__pinned_init),
+ quote!(::core::pin::Pin<&'__slot mut #ty>),
+ // SAFETY: this field is structurally pinned.
+ quote!(unsafe { ::core::pin::Pin::new_unchecked(slot) }),
+ quote!(
+ /// - `slot` will not move until it is dropped, i.e. it will be pinned.
+ ),
+ )
+ } else {
+ (
+ quote!(Init),
+ quote!(__init),
+ quote!(&'__slot mut #ty),
+ quote!(slot),
+ quote!(),
+ )
+ };
+ let slot_safety = format!(
+ " `slot` points at the field `{ident}` inside of `{struct_ident}`, which is pinned.",
+ );
+ quote! {
+ /// # Safety
+ ///
+ /// - `slot` is a valid pointer to uninitialized memory.
+ /// - the caller does not touch `slot` when `Err` is returned, they are only permitted
+ /// to deallocate.
+ #pin_safety
+ #(#attrs)*
+ #vis unsafe fn #ident<E>(
+ self,
+ slot: *mut #ty,
+ init: impl ::pin_init::#init_ty<#ty, E>,
+ ) -> ::core::result::Result<(), E> {
+ // SAFETY: this function has the same safety requirements as the __init function
+ // called below.
+ unsafe { ::pin_init::#init_ty::#init_fn(init, slot) }
+ }
+
+ /// # Safety
+ ///
+ #[doc = #slot_safety]
+ #(#attrs)*
+ #vis unsafe fn #project_ident<'__slot>(
+ self,
+ slot: &'__slot mut #ty,
+ ) -> #project_ty {
+ #project_body
+ }
+ }
+ }
+
+ let field_accessors = fields
+ .iter()
+ .map(|(pinned, field)| handle_field(field, ident, *pinned))
+ .collect::<TokenStream>();
+ quote! {
+ // We declare this struct which will host all of the projection function for our type. It
+ // will be invariant over all generic parameters which are inherited from the struct.
+ #[doc(hidden)]
+ #vis struct __ThePinData #generics
+ #whr
{
- errs.extend(
- format!(
- "::core::compile_error!(\"Cannot use `{i}` inside of struct definition with \
- `#[pin_data]`.\");"
- )
- .parse::<TokenStream>()
- .unwrap()
- .into_iter()
- .map(|mut tok| {
- tok.set_span(tt.span());
- tok
- }),
- );
- vec![tt]
- }
- TokenTree::Ident(i) if i.to_string() == "Self" => struct_name.clone(),
- TokenTree::Literal(_) | TokenTree::Punct(_) | TokenTree::Ident(_) => vec![tt],
- TokenTree::Group(g) => vec![TokenTree::Group(Group::new(
- g.delimiter(),
- g.stream()
- .into_iter()
- .flat_map(|tt| replace_self_and_deny_type_defs(struct_name, tt, errs))
- .collect(),
- ))],
+ __phantom: ::core::marker::PhantomData<
+ fn(#ident #ty_generics) -> #ident #ty_generics
+ >,
+ }
+
+ impl #impl_generics ::core::clone::Clone for __ThePinData #ty_generics
+ #whr
+ {
+ fn clone(&self) -> Self { *self }
+ }
+
+ impl #impl_generics ::core::marker::Copy for __ThePinData #ty_generics
+ #whr
+ {}
+
+ #[allow(dead_code)] // Some functions might never be used and private.
+ #[expect(clippy::missing_safety_doc)]
+ impl #impl_generics __ThePinData #ty_generics
+ #whr
+ {
+ #field_accessors
+ }
+
+ // SAFETY: We have added the correct projection functions above to `__ThePinData` and
+ // we also use the least restrictive generics possible.
+ unsafe impl #impl_generics ::pin_init::__internal::HasPinData for #ident #ty_generics
+ #whr
+ {
+ type PinData = __ThePinData #ty_generics;
+
+ unsafe fn __pin_data() -> Self::PinData {
+ __ThePinData { __phantom: ::core::marker::PhantomData }
+ }
+ }
+
+ // SAFETY: TODO
+ unsafe impl #impl_generics ::pin_init::__internal::PinData for __ThePinData #ty_generics
+ #whr
+ {
+ type Datee = #ident #ty_generics;
+ }
+ }
+}
+
+struct SelfReplacer(PathSegment);
+
+impl VisitMut for SelfReplacer {
+ fn visit_path_mut(&mut self, i: &mut syn::Path) {
+ if i.is_ident("Self") {
+ let span = i.span();
+ let seg = &self.0;
+ *i = parse_quote_spanned!(span=> #seg);
+ } else {
+ syn::visit_mut::visit_path_mut(self, i);
+ }
+ }
+
+ fn visit_path_segment_mut(&mut self, seg: &mut PathSegment) {
+ if seg.ident == "Self" {
+ let span = seg.span();
+ let this = &self.0;
+ *seg = parse_quote_spanned!(span=> #this);
+ } else {
+ syn::visit_mut::visit_path_segment_mut(self, seg);
+ }
+ }
+
+ fn visit_item_mut(&mut self, _: &mut Item) {
+ // Do not descend into items, since items reset/change what `Self` refers to.
+ }
+}
+
+// replace with `.collect()` once MSRV is above 1.79
+fn collect_tuple<A, B>(iter: impl Iterator<Item = (A, B)>) -> (Vec<A>, Vec<B>) {
+ let mut res_a = vec![];
+ let mut res_b = vec![];
+ for (a, b) in iter {
+ res_a.push(a);
+ res_b.push(b);
}
+ (res_a, res_b)
}
diff --git a/rust/pin-init/internal/src/pinned_drop.rs b/rust/pin-init/internal/src/pinned_drop.rs
index c4ca7a70b726..a20ac314ca82 100644
--- a/rust/pin-init/internal/src/pinned_drop.rs
+++ b/rust/pin-init/internal/src/pinned_drop.rs
@@ -1,51 +1,61 @@
// SPDX-License-Identifier: Apache-2.0 OR MIT
-#[cfg(not(kernel))]
-use proc_macro2 as proc_macro;
+use proc_macro2::TokenStream;
+use quote::quote;
+use syn::{parse::Nothing, parse_quote, spanned::Spanned, ImplItem, ItemImpl, Token};
-use proc_macro::{TokenStream, TokenTree};
+use crate::diagnostics::{DiagCtxt, ErrorGuaranteed};
-pub(crate) fn pinned_drop(_args: TokenStream, input: TokenStream) -> TokenStream {
- let mut toks = input.into_iter().collect::<Vec<_>>();
- assert!(!toks.is_empty());
- // Ensure that we have an `impl` item.
- assert!(matches!(&toks[0], TokenTree::Ident(i) if i.to_string() == "impl"));
- // Ensure that we are implementing `PinnedDrop`.
- let mut nesting: usize = 0;
- let mut pinned_drop_idx = None;
- for (i, tt) in toks.iter().enumerate() {
- match tt {
- TokenTree::Punct(p) if p.as_char() == '<' => {
- nesting += 1;
+pub(crate) fn pinned_drop(
+ _args: Nothing,
+ mut input: ItemImpl,
+ dcx: &mut DiagCtxt,
+) -> Result<TokenStream, ErrorGuaranteed> {
+ if let Some(unsafety) = input.unsafety {
+ dcx.error(unsafety, "implementing `PinnedDrop` is safe");
+ }
+ input.unsafety = Some(Token![unsafe](input.impl_token.span));
+ match &mut input.trait_ {
+ Some((not, path, _for)) => {
+ if let Some(not) = not {
+ dcx.error(not, "cannot implement `!PinnedDrop`");
}
- TokenTree::Punct(p) if p.as_char() == '>' => {
- nesting = nesting.checked_sub(1).unwrap();
- continue;
+ for (seg, expected) in path
+ .segments
+ .iter()
+ .rev()
+ .zip(["PinnedDrop", "pin_init", ""])
+ {
+ if expected.is_empty() || seg.ident != expected {
+ dcx.error(seg, "bad import path for `PinnedDrop`");
+ }
+ if !seg.arguments.is_none() {
+ dcx.error(&seg.arguments, "unexpected arguments for `PinnedDrop` path");
+ }
}
- _ => {}
+ *path = parse_quote!(::pin_init::PinnedDrop);
}
- if i >= 1 && nesting == 0 {
- // Found the end of the generics, this should be `PinnedDrop`.
- assert!(
- matches!(tt, TokenTree::Ident(i) if i.to_string() == "PinnedDrop"),
- "expected 'PinnedDrop', found: '{tt:?}'"
+ None => {
+ let span = input
+ .impl_token
+ .span
+ .join(input.self_ty.span())
+ .unwrap_or(input.impl_token.span);
+ dcx.error(
+ span,
+ "expected `impl ... PinnedDrop for ...`, got inherent impl",
);
- pinned_drop_idx = Some(i);
- break;
}
}
- let idx = pinned_drop_idx
- .unwrap_or_else(|| panic!("Expected an `impl` block implementing `PinnedDrop`."));
- // Fully qualify the `PinnedDrop`, as to avoid any tampering.
- toks.splice(idx..idx, quote!(::pin_init::));
- // Take the `{}` body and call the declarative macro.
- if let Some(TokenTree::Group(last)) = toks.pop() {
- let last = last.stream();
- quote!(::pin_init::__pinned_drop! {
- @impl_sig(#(#toks)*),
- @impl_body(#last),
- })
- } else {
- TokenStream::from_iter(toks)
+ for item in &mut input.items {
+ if let ImplItem::Fn(fn_item) = item {
+ if fn_item.sig.ident == "drop" {
+ fn_item
+ .sig
+ .inputs
+ .push(parse_quote!(_: ::pin_init::__internal::OnlyCallFromDrop));
+ }
+ }
}
+ Ok(quote!(#input))
}
diff --git a/rust/pin-init/internal/src/zeroable.rs b/rust/pin-init/internal/src/zeroable.rs
index e0ed3998445c..05683319b0f7 100644
--- a/rust/pin-init/internal/src/zeroable.rs
+++ b/rust/pin-init/internal/src/zeroable.rs
@@ -1,101 +1,78 @@
// SPDX-License-Identifier: GPL-2.0
-#[cfg(not(kernel))]
-use proc_macro2 as proc_macro;
+use proc_macro2::TokenStream;
+use quote::quote;
+use syn::{parse_quote, Data, DeriveInput, Field, Fields};
-use crate::helpers::{parse_generics, Generics};
-use proc_macro::{TokenStream, TokenTree};
+use crate::{diagnostics::ErrorGuaranteed, DiagCtxt};
-pub(crate) fn parse_zeroable_derive_input(
- input: TokenStream,
-) -> (
- Vec<TokenTree>,
- Vec<TokenTree>,
- Vec<TokenTree>,
- Option<TokenTree>,
-) {
- let (
- Generics {
- impl_generics,
- decl_generics: _,
- ty_generics,
- },
- mut rest,
- ) = parse_generics(input);
- // This should be the body of the struct `{...}`.
- let last = rest.pop();
- // Now we insert `Zeroable` as a bound for every generic parameter in `impl_generics`.
- let mut new_impl_generics = Vec::with_capacity(impl_generics.len());
- // Are we inside of a generic where we want to add `Zeroable`?
- let mut in_generic = !impl_generics.is_empty();
- // Have we already inserted `Zeroable`?
- let mut inserted = false;
- // Level of `<>` nestings.
- let mut nested = 0;
- for tt in impl_generics {
- match &tt {
- // If we find a `,`, then we have finished a generic/constant/lifetime parameter.
- TokenTree::Punct(p) if nested == 0 && p.as_char() == ',' => {
- if in_generic && !inserted {
- new_impl_generics.extend(quote! { : ::pin_init::Zeroable });
- }
- in_generic = true;
- inserted = false;
- new_impl_generics.push(tt);
- }
- // If we find `'`, then we are entering a lifetime.
- TokenTree::Punct(p) if nested == 0 && p.as_char() == '\'' => {
- in_generic = false;
- new_impl_generics.push(tt);
- }
- TokenTree::Punct(p) if nested == 0 && p.as_char() == ':' => {
- new_impl_generics.push(tt);
- if in_generic {
- new_impl_generics.extend(quote! { ::pin_init::Zeroable + });
- inserted = true;
- }
- }
- TokenTree::Punct(p) if p.as_char() == '<' => {
- nested += 1;
- new_impl_generics.push(tt);
- }
- TokenTree::Punct(p) if p.as_char() == '>' => {
- assert!(nested > 0);
- nested -= 1;
- new_impl_generics.push(tt);
- }
- _ => new_impl_generics.push(tt),
+pub(crate) fn derive(
+ input: DeriveInput,
+ dcx: &mut DiagCtxt,
+) -> Result<TokenStream, ErrorGuaranteed> {
+ let fields = match input.data {
+ Data::Struct(data_struct) => data_struct.fields,
+ Data::Union(data_union) => Fields::Named(data_union.fields),
+ Data::Enum(data_enum) => {
+ return Err(dcx.error(data_enum.enum_token, "cannot derive `Zeroable` for an enum"));
}
+ };
+ let name = input.ident;
+ let mut generics = input.generics;
+ for param in generics.type_params_mut() {
+ param.bounds.insert(0, parse_quote!(::pin_init::Zeroable));
}
- assert_eq!(nested, 0);
- if in_generic && !inserted {
- new_impl_generics.extend(quote! { : ::pin_init::Zeroable });
- }
- (rest, new_impl_generics, ty_generics, last)
+ let (impl_gen, ty_gen, whr) = generics.split_for_impl();
+ let field_type = fields.iter().map(|field| &field.ty);
+ Ok(quote! {
+ // SAFETY: Every field type implements `Zeroable` and padding bytes may be zero.
+ #[automatically_derived]
+ unsafe impl #impl_gen ::pin_init::Zeroable for #name #ty_gen
+ #whr
+ {}
+ const _: () = {
+ fn assert_zeroable<T: ?::core::marker::Sized + ::pin_init::Zeroable>() {}
+ fn ensure_zeroable #impl_gen ()
+ #whr
+ {
+ #(
+ assert_zeroable::<#field_type>();
+ )*
+ }
+ };
+ })
}
-pub(crate) fn derive(input: TokenStream) -> TokenStream {
- let (rest, new_impl_generics, ty_generics, last) = parse_zeroable_derive_input(input);
- quote! {
- ::pin_init::__derive_zeroable!(
- parse_input:
- @sig(#(#rest)*),
- @impl_generics(#(#new_impl_generics)*),
- @ty_generics(#(#ty_generics)*),
- @body(#last),
- );
+pub(crate) fn maybe_derive(
+ input: DeriveInput,
+ dcx: &mut DiagCtxt,
+) -> Result<TokenStream, ErrorGuaranteed> {
+ let fields = match input.data {
+ Data::Struct(data_struct) => data_struct.fields,
+ Data::Union(data_union) => Fields::Named(data_union.fields),
+ Data::Enum(data_enum) => {
+ return Err(dcx.error(data_enum.enum_token, "cannot derive `Zeroable` for an enum"));
+ }
+ };
+ let name = input.ident;
+ let mut generics = input.generics;
+ for param in generics.type_params_mut() {
+ param.bounds.insert(0, parse_quote!(::pin_init::Zeroable));
}
-}
-
-pub(crate) fn maybe_derive(input: TokenStream) -> TokenStream {
- let (rest, new_impl_generics, ty_generics, last) = parse_zeroable_derive_input(input);
- quote! {
- ::pin_init::__maybe_derive_zeroable!(
- parse_input:
- @sig(#(#rest)*),
- @impl_generics(#(#new_impl_generics)*),
- @ty_generics(#(#ty_generics)*),
- @body(#last),
- );
+ for Field { ty, .. } in fields {
+ generics
+ .make_where_clause()
+ .predicates
+ // the `for<'__dummy>` HRTB makes this not error without the `trivial_bounds`
+ // feature <https://github.com/rust-lang/rust/issues/48214#issuecomment-2557829956>.
+ .push(parse_quote!(#ty: for<'__dummy> ::pin_init::Zeroable));
}
+ let (impl_gen, ty_gen, whr) = generics.split_for_impl();
+ Ok(quote! {
+ // SAFETY: Every field type implements `Zeroable` and padding bytes may be zero.
+ #[automatically_derived]
+ unsafe impl #impl_gen ::pin_init::Zeroable for #name #ty_gen
+ #whr
+ {}
+ })
}