diff options
Diffstat (limited to 'rust/kernel')
63 files changed, 3662 insertions, 880 deletions
diff --git a/rust/kernel/Kconfig.test b/rust/kernel/Kconfig.test new file mode 100644 index 000000000000..e6a5c7a795f0 --- /dev/null +++ b/rust/kernel/Kconfig.test @@ -0,0 +1,86 @@ +# SPDX-License-Identifier: GPL-2.0-only +menuconfig RUST_KUNIT_TESTS + bool "Rust KUnit tests" + depends on KUNIT && RUST + default KUNIT_ALL_TESTS + help + This menu collects all options for Rust KUnit tests. + See Documentation/rust/testing.rst for how to protect + unit tests with these options. + + Say Y here to enable Rust KUnit tests. + + If unsure, say N. + +if RUST_KUNIT_TESTS +config RUST_ALLOCATOR_KUNIT_TEST + bool "KUnit tests for Rust allocator API" if !KUNIT_ALL_TESTS + default KUNIT_ALL_TESTS + help + This option enables KUnit tests for the Rust allocator API. + These are only for development and testing, not for regular + kernel use cases. + + If unsure, say N. + +config RUST_KVEC_KUNIT_TEST + bool "KUnit tests for Rust KVec API" if !KUNIT_ALL_TESTS + default KUNIT_ALL_TESTS + help + This option enables KUnit tests for the Rust KVec API. + These are only for development and testing, not for + regular kernel use cases. + + If unsure, say N. + +config RUST_BITMAP_KUNIT_TEST + bool "KUnit tests for Rust bitmap API" if !KUNIT_ALL_TESTS + default KUNIT_ALL_TESTS + help + This option enables KUnit tests for the Rust bitmap API. + These are only for development and testing, not for regular + kernel use cases. + + If unsure, say N. + +config RUST_KUNIT_SELFTEST + bool "KUnit selftests for Rust" if !KUNIT_ALL_TESTS + default KUNIT_ALL_TESTS + help + This option enables KUnit selftests. These are only + for development and testing, not for regular kernel + use cases. + + If unsure, say N. + +config RUST_STR_KUNIT_TEST + bool "KUnit tests for Rust strings API" if !KUNIT_ALL_TESTS + default KUNIT_ALL_TESTS + help + This option enables KUnit tests for the Rust strings API. + These are only for development and testing, not for regular + kernel use cases. + + If unsure, say N. + +config RUST_ATOMICS_KUNIT_TEST + bool "KUnit tests for Rust atomics API" if !KUNIT_ALL_TESTS + default KUNIT_ALL_TESTS + help + This option enables KUnit tests for the Rust atomics API. + These are only for development and testing, not for regular + kernel use cases. + + If unsure, say N. + +config RUST_BITFIELD_KUNIT_TEST + bool "KUnit tests for the Rust `bitfield!` macro" if !KUNIT_ALL_TESTS + default KUNIT_ALL_TESTS + help + This option enables KUnit tests for the Rust `bitfield!` macro. + These are only for development and testing, not for regular + kernel use cases. + + If unsure, say N. + +endif diff --git a/rust/kernel/alloc.rs b/rust/kernel/alloc.rs index e38720349dcf..21067bde6860 100644 --- a/rust/kernel/alloc.rs +++ b/rust/kernel/alloc.rs @@ -22,8 +22,12 @@ pub use self::kvec::Vec; #[derive(Copy, Clone, PartialEq, Eq, Debug)] pub struct AllocError; -use crate::error::{code::EINVAL, Result}; -use core::{alloc::Layout, ptr::NonNull}; +use crate::prelude::*; + +use core::{ + alloc::Layout, + ptr::NonNull, // +}; /// Flags to be used when allocating memory. /// diff --git a/rust/kernel/alloc/allocator.rs b/rust/kernel/alloc/allocator.rs index 63bfb91b3671..cd4203f27aed 100644 --- a/rust/kernel/alloc/allocator.rs +++ b/rust/kernel/alloc/allocator.rs @@ -8,14 +8,25 @@ //! //! Reference: <https://docs.kernel.org/core-api/memory-allocation.html> -use super::Flags; -use core::alloc::Layout; -use core::ptr; -use core::ptr::NonNull; - -use crate::alloc::{AllocError, Allocator, NumaNode}; -use crate::bindings; -use crate::page; +use super::{ + AllocError, + Allocator, + Flags, + NumaNode, // +}; + +use crate::{ + bindings, + page, // +}; + +use core::{ + alloc::Layout, + ptr::{ + self, + NonNull, // + }, // +}; const ARCH_KMALLOC_MINALIGN: usize = bindings::ARCH_KMALLOC_MINALIGN; @@ -163,8 +174,11 @@ impl Vmalloc { /// # Examples /// /// ``` - /// # use core::ptr::{NonNull, from_mut}; - /// # use kernel::{page, prelude::*}; + /// # use core::ptr::{ + /// # from_mut, + /// # NonNull, // + /// # }; + /// # use kernel::page; /// use kernel::alloc::allocator::Vmalloc; /// /// let mut vbox = VBox::<[u8; page::PAGE_SIZE]>::new_uninit(GFP_KERNEL)?; @@ -251,6 +265,7 @@ unsafe impl Allocator for KVmalloc { } } +#[cfg(CONFIG_RUST_ALLOCATOR_KUNIT_TEST)] #[macros::kunit_tests(rust_allocator)] mod tests { use super::*; diff --git a/rust/kernel/alloc/allocator/iter.rs b/rust/kernel/alloc/allocator/iter.rs index e0a70b7a744a..02fda3ea5cae 100644 --- a/rust/kernel/alloc/allocator/iter.rs +++ b/rust/kernel/alloc/allocator/iter.rs @@ -1,9 +1,13 @@ // SPDX-License-Identifier: GPL-2.0 use super::Vmalloc; + use crate::page; -use core::marker::PhantomData; -use core::ptr::NonNull; + +use core::{ + marker::PhantomData, + ptr::NonNull, // +}; /// An [`Iterator`] of [`page::BorrowedPage`] items owned by a [`Vmalloc`] allocation. /// diff --git a/rust/kernel/alloc/kbox.rs b/rust/kernel/alloc/kbox.rs index bd6da02c7ab8..35d1e015848d 100644 --- a/rust/kernel/alloc/kbox.rs +++ b/rust/kernel/alloc/kbox.rs @@ -3,24 +3,47 @@ //! Implementation of [`Box`]. #[allow(unused_imports)] // Used in doc comments. -use super::allocator::{KVmalloc, Kmalloc, Vmalloc, VmallocPageIter}; -use super::{AllocError, Allocator, Flags, NumaNode}; -use core::alloc::Layout; -use core::borrow::{Borrow, BorrowMut}; -use core::marker::PhantomData; -use core::mem::ManuallyDrop; -use core::mem::MaybeUninit; -use core::ops::{Deref, DerefMut}; -use core::pin::Pin; -use core::ptr::NonNull; -use core::result::Result; - -use crate::ffi::c_void; -use crate::fmt; -use crate::init::InPlaceInit; -use crate::page::AsPageIter; -use crate::types::ForeignOwnable; -use pin_init::{InPlaceWrite, Init, PinInit, ZeroableOption}; +use super::allocator::{ + KVmalloc, + Kmalloc, + Vmalloc, + VmallocPageIter, // +}; + +use super::{ + AllocError, + Allocator, + Flags, + NumaNode, // +}; + +use crate::{ + fmt, + page::AsPageIter, + prelude::*, + types::ForeignOwnable, // +}; + +use core::{ + alloc::Layout, + borrow::{ + Borrow, + BorrowMut, // + }, + marker::PhantomData, + mem::{ + ManuallyDrop, + MaybeUninit, // + }, + ops::{ + Deref, + DerefMut, // + }, + ptr::NonNull, + result::Result, // +}; + +use pin_init::ZeroableOption; /// The kernel's [`Box`] type -- a heap allocation for a single value of type `T`. /// @@ -256,6 +279,27 @@ where Ok(Box(ptr.cast(), PhantomData)) } + /// Creates a new zero-initialized `Box<T, A>`. + /// + /// New memory is allocated with `A` and the [`__GFP_ZERO`] flag. The allocation may fail, in + /// which case an error is returned. For ZSTs no memory is allocated. + /// + /// # Examples + /// + /// ``` + /// let b = KBox::<[u8; 128]>::zeroed(GFP_KERNEL)?; + /// assert_eq!(*b, [0; 128]); + /// # Ok::<(), Error>(()) + /// ``` + pub fn zeroed(flags: Flags) -> Result<Self, AllocError> + where + T: Zeroable, + { + // SAFETY: `__GFP_ZERO` guarantees the memory is zeroed; `T: Zeroable` guarantees that + // all-zeroes is a valid bit pattern for `T`. + Ok(unsafe { Self::new_uninit(flags | __GFP_ZERO)?.assume_init() }) + } + /// Constructs a new `Pin<Box<T, A>>`. If `T` does not implement [`Unpin`], then `x` will be /// pinned in memory and can't be moved. #[inline] @@ -274,7 +318,10 @@ where /// # Examples /// /// ``` - /// use kernel::sync::{new_spinlock, SpinLock}; + /// use kernel::sync::{ + /// new_spinlock, + /// SpinLock, // + /// }; /// /// struct Inner { /// a: u32, @@ -411,6 +458,7 @@ where { type Initialized = Box<T, A>; + #[inline] fn write_init<E>(mut self, init: impl Init<T, E>) -> Result<Self::Initialized, E> { let slot = self.as_mut_ptr(); // SAFETY: When init errors/panics, slot will get deallocated but not dropped, @@ -420,6 +468,7 @@ where Ok(unsafe { Box::assume_init(self) }) } + #[inline] fn write_pin_init<E>(mut self, init: impl PinInit<T, E>) -> Result<Pin<Self::Initialized>, E> { let slot = self.as_mut_ptr(); // SAFETY: When init errors/panics, slot will get deallocated but not dropped, @@ -455,7 +504,7 @@ where // SAFETY: The pointer returned by `into_foreign` comes from a well aligned // pointer to `T` allocated by `A`. -unsafe impl<T: 'static, A> ForeignOwnable for Box<T, A> +unsafe impl<T, A> ForeignOwnable for Box<T, A> where A: Allocator, { @@ -465,8 +514,14 @@ where core::mem::align_of::<T>() }; - type Borrowed<'a> = &'a T; - type BorrowedMut<'a> = &'a mut T; + type Borrowed<'a> + = &'a T + where + Self: 'a; + type BorrowedMut<'a> + = &'a mut T + where + Self: 'a; fn into_foreign(self) -> *mut c_void { Box::into_raw(self).cast() @@ -494,13 +549,19 @@ where // SAFETY: The pointer returned by `into_foreign` comes from a well aligned // pointer to `T` allocated by `A`. -unsafe impl<T: 'static, A> ForeignOwnable for Pin<Box<T, A>> +unsafe impl<T, A> ForeignOwnable for Pin<Box<T, A>> where A: Allocator, { const FOREIGN_ALIGN: usize = <Box<T, A> as ForeignOwnable>::FOREIGN_ALIGN; - type Borrowed<'a> = Pin<&'a T>; - type BorrowedMut<'a> = Pin<&'a mut T>; + type Borrowed<'a> + = Pin<&'a T> + where + Self: 'a; + type BorrowedMut<'a> + = Pin<&'a mut T> + where + Self: 'a; fn into_foreign(self) -> *mut c_void { // SAFETY: We are still treating the box as pinned. @@ -567,7 +628,6 @@ where /// /// ``` /// # use core::borrow::Borrow; -/// # use kernel::alloc::KBox; /// struct Foo<B: Borrow<u32>>(B); /// /// // Owned instance. @@ -595,7 +655,6 @@ where /// /// ``` /// # use core::borrow::BorrowMut; -/// # use kernel::alloc::KBox; /// struct Foo<B: BorrowMut<u32>>(B); /// /// // Owned instance. @@ -660,9 +719,13 @@ where /// # Examples /// /// ``` -/// # use kernel::prelude::*; -/// use kernel::alloc::allocator::VmallocPageIter; -/// use kernel::page::{AsPageIter, PAGE_SIZE}; +/// use kernel::{ +/// alloc::allocator::VmallocPageIter, +/// page::{ +/// AsPageIter, +/// PAGE_SIZE, // +/// }, // +/// }; /// /// let mut vbox = VBox::new((), GFP_KERNEL)?; /// diff --git a/rust/kernel/alloc/kvec.rs b/rust/kernel/alloc/kvec.rs index 6438385e4322..f7af62835aa8 100644 --- a/rust/kernel/alloc/kvec.rs +++ b/rust/kernel/alloc/kvec.rs @@ -3,29 +3,52 @@ //! Implementation of [`Vec`]. use super::{ - allocator::{KVmalloc, Kmalloc, Vmalloc, VmallocPageIter}, + allocator::{ + KVmalloc, + Kmalloc, + Vmalloc, + VmallocPageIter, // + }, layout::ArrayLayout, - AllocError, Allocator, Box, Flags, NumaNode, + AllocError, + Allocator, + Box, + Flags, + NumaNode, // }; + use crate::{ fmt, page::{ AsPageIter, PAGE_SIZE, // - }, + }, // }; + use core::{ - borrow::{Borrow, BorrowMut}, + borrow::{ + Borrow, + BorrowMut, // + }, marker::PhantomData, - mem::{ManuallyDrop, MaybeUninit}, - ops::Deref, - ops::DerefMut, - ops::Index, - ops::IndexMut, - ptr, - ptr::NonNull, - slice, - slice::SliceIndex, + mem::{ + ManuallyDrop, + MaybeUninit, // + }, + ops::{ + Deref, + DerefMut, + Index, + IndexMut, // + }, + ptr::{ + self, + NonNull, // + }, + slice::{ + self, + SliceIndex, // + }, // }; mod errors; @@ -614,7 +637,7 @@ where /// /// v.reserve(10, GFP_KERNEL)?; /// let cap = v.capacity(); - /// assert!(cap >= 10); + /// assert!(cap >= v.len() + 10); /// /// v.reserve(10, GFP_KERNEL)?; /// let new_cap = v.capacity(); @@ -849,6 +872,24 @@ impl<T> Vec<T, KVmalloc> { impl<T: Clone, A: Allocator> Vec<T, A> { /// Extend the vector by `n` clones of `value`. + /// + /// # Examples + /// + /// ``` + /// let mut v = KVec::new(); + /// v.push(1, GFP_KERNEL)?; + /// + /// v.extend_with(3, 5, GFP_KERNEL)?; + /// assert_eq!(&v, &[1, 5, 5, 5]); + /// + /// v.extend_with(2, 8, GFP_KERNEL)?; + /// assert_eq!(&v, &[1, 5, 5, 5, 8, 8]); + /// + /// v.extend_with(0, 3, GFP_KERNEL)?; + /// assert_eq!(&v, &[1, 5, 5, 5, 8, 8]); + /// + /// # Ok::<(), Error>(()) + /// ``` pub fn extend_with(&mut self, n: usize, value: T, flags: Flags) -> Result<(), AllocError> { if n == 0 { return Ok(()); @@ -866,7 +907,7 @@ impl<T: Clone, A: Allocator> Vec<T, A> { spare[n - 1].write(value); // SAFETY: - // - `self.len() + n < self.capacity()` due to the call to reserve above, + // - `self.len() + n <= self.capacity()` due to the call to reserve above, // - the loop and the line above initialized the next `n` elements. unsafe { self.inc_len(n) }; @@ -1146,9 +1187,13 @@ where /// # Examples /// /// ``` -/// # use kernel::prelude::*; -/// use kernel::alloc::allocator::VmallocPageIter; -/// use kernel::page::{AsPageIter, PAGE_SIZE}; +/// use kernel::{ +/// alloc::allocator::VmallocPageIter, +/// page::{ +/// AsPageIter, +/// PAGE_SIZE, // +/// }, // +/// }; /// /// let mut vec = VVec::<u8>::new(); /// @@ -1463,6 +1508,7 @@ impl<'vec, T> Drop for DrainAll<'vec, T> { } } +#[cfg(CONFIG_RUST_KVEC_KUNIT_TEST)] #[macros::kunit_tests(rust_kvec)] mod tests { use super::*; diff --git a/rust/kernel/alloc/kvec/errors.rs b/rust/kernel/alloc/kvec/errors.rs index 985c5f2c3962..aaca6446516a 100644 --- a/rust/kernel/alloc/kvec/errors.rs +++ b/rust/kernel/alloc/kvec/errors.rs @@ -2,8 +2,10 @@ //! Errors for the [`Vec`] type. -use kernel::fmt; -use kernel::prelude::*; +use crate::{ + fmt, + prelude::*, // +}; /// Error type for [`Vec::push_within_capacity`]. pub struct PushError<T>(pub T); diff --git a/rust/kernel/alloc/layout.rs b/rust/kernel/alloc/layout.rs index 9f8be72feb7a..62a459c66baf 100644 --- a/rust/kernel/alloc/layout.rs +++ b/rust/kernel/alloc/layout.rs @@ -4,7 +4,10 @@ //! //! Custom layout types extending or improving [`Layout`]. -use core::{alloc::Layout, marker::PhantomData}; +use core::{ + alloc::Layout, + marker::PhantomData, // +}; /// Error when constructing an [`ArrayLayout`]. pub struct LayoutError; @@ -47,7 +50,10 @@ impl<T> ArrayLayout<T> { /// # Examples /// /// ``` - /// # use kernel::alloc::layout::{ArrayLayout, LayoutError}; + /// # use kernel::alloc::layout::{ + /// # ArrayLayout, + /// # LayoutError, // + /// # }; /// let layout = ArrayLayout::<i32>::new(15)?; /// assert_eq!(layout.len(), 15); /// diff --git a/rust/kernel/auxiliary.rs b/rust/kernel/auxiliary.rs index 93c0db1f6655..c42928d5a239 100644 --- a/rust/kernel/auxiliary.rs +++ b/rust/kernel/auxiliary.rs @@ -12,19 +12,25 @@ use crate::{ RawDeviceId, RawDeviceIdIndex, // }, - devres::Devres, + driver, error::{ from_result, to_result, // }, prelude::*, - types::Opaque, + types::{ + ForLt, + ForeignOwnable, + Opaque, // + }, ThisModule, // }; use core::{ + any::TypeId, marker::PhantomData, mem::offset_of, + pin::Pin, ptr::{ addr_of_mut, NonNull, // @@ -36,18 +42,18 @@ pub struct Adapter<T: Driver>(T); // SAFETY: // - `bindings::auxiliary_driver` is a C type declared as `repr(C)`. -// - `T` is the type of the driver's device private data. +// - `T::Data` is the type of the driver's device private data. // - `struct auxiliary_driver` embeds a `struct device_driver`. // - `DEVICE_DRIVER_OFFSET` is the correct byte offset to the embedded `struct device_driver`. -unsafe impl<T: Driver + 'static> driver::DriverLayout for Adapter<T> { +unsafe impl<T: Driver> driver::DriverLayout for Adapter<T> { type DriverType = bindings::auxiliary_driver; - type DriverData = T; + type DriverData<'bound> = T::Data<'bound>; const DEVICE_DRIVER_OFFSET: usize = core::mem::offset_of!(Self::DriverType, driver); } // SAFETY: A call to `unregister` for a given instance of `DriverType` is guaranteed to be valid if // a preceding call to `register` has been successful. -unsafe impl<T: Driver + 'static> driver::RegistrationOps for Adapter<T> { +unsafe impl<T: Driver> driver::RegistrationOps for Adapter<T> { unsafe fn register( adrv: &Opaque<Self::DriverType>, name: &'static CStr, @@ -73,7 +79,7 @@ unsafe impl<T: Driver + 'static> driver::RegistrationOps for Adapter<T> { } } -impl<T: Driver + 'static> Adapter<T> { +impl<T: Driver> Adapter<T> { extern "C" fn probe_callback( adev: *mut bindings::auxiliary_device, id: *const bindings::auxiliary_device_id, @@ -82,7 +88,7 @@ impl<T: Driver + 'static> Adapter<T> { // `struct auxiliary_device`. // // INVARIANT: `adev` is valid for the duration of `probe_callback()`. - let adev = unsafe { &*adev.cast::<Device<device::CoreInternal>>() }; + let adev = unsafe { &*adev.cast::<Device<device::CoreInternal<'_>>>() }; // SAFETY: `DeviceId` is a `#[repr(transparent)`] wrapper of `struct auxiliary_device_id` // and does not add additional invariants, so it's safe to transmute. @@ -102,12 +108,12 @@ impl<T: Driver + 'static> Adapter<T> { // `struct auxiliary_device`. // // INVARIANT: `adev` is valid for the duration of `remove_callback()`. - let adev = unsafe { &*adev.cast::<Device<device::CoreInternal>>() }; + let adev = unsafe { &*adev.cast::<Device<device::CoreInternal<'_>>>() }; // SAFETY: `remove_callback` is only ever called after a successful call to // `probe_callback`, hence it's guaranteed that `Device::set_drvdata()` has been called - // and stored a `Pin<KBox<T>>`. - let data = unsafe { adev.as_ref().drvdata_borrow::<T>() }; + // and stored a `Pin<KBox<T::Data<'_>>>`. + let data = unsafe { adev.as_ref().drvdata_borrow::<T::Data<'_>>() }; T::unbind(adev, data); } @@ -197,13 +203,19 @@ pub trait Driver { /// type IdInfo: 'static = (); type IdInfo: 'static; + /// The type of the driver's bus device private data. + type Data<'bound>: Send + 'bound; + /// The table of device ids supported by the driver. const ID_TABLE: IdTable<Self::IdInfo>; /// Auxiliary driver probe. /// /// Called when an auxiliary device is matches a corresponding driver. - fn probe(dev: &Device<device::Core>, id_info: &Self::IdInfo) -> impl PinInit<Self, Error>; + fn probe<'bound>( + dev: &'bound Device<device::Core<'_>>, + id_info: &'bound Self::IdInfo, + ) -> impl PinInit<Self::Data<'bound>, Error> + 'bound; /// Auxiliary driver unbind. /// @@ -214,8 +226,8 @@ pub trait Driver { /// `&Device<Core>` or `&Device<Bound>` reference. For instance, drivers may try to perform I/O /// operations to gracefully tear down the device. /// - /// Otherwise, release operations for driver resources should be performed in `Self::drop`. - fn unbind(dev: &Device<device::Core>, this: Pin<&Self>) { + /// Otherwise, release operations for driver resources should be performed in `Drop`. + fn unbind<'bound>(dev: &'bound Device<device::Core<'_>>, this: Pin<&Self::Data<'bound>>) { let _ = (dev, this); } } @@ -257,6 +269,49 @@ impl Device<device::Bound> { // SAFETY: A bound auxiliary device always has a bound parent device. unsafe { parent.as_bound() } } + + /// Returns a pinned reference to the registration data set by the registering (parent) driver. + /// + /// `F` is the [`ForLt`](trait@ForLt) encoding of the data type. The returned + /// reference has its lifetime shortened from `'static` to `&self`'s borrow lifetime via + /// [`ForLt::cast_ref`]. + /// + /// Returns [`EINVAL`] if `F` does not match the type used by the parent driver when calling + /// [`Registration::new()`]. + /// + /// Returns [`ENOENT`] if no registration data has been set, e.g. when the device was + /// registered by a C driver. + pub fn registration_data<F: ForLt + 'static>(&self) -> Result<Pin<&F::Of<'_>>> { + // SAFETY: By the type invariant, `self.as_raw()` is a valid `struct auxiliary_device`. + let ptr = unsafe { (*self.as_raw()).registration_data_rust }; + if ptr.is_null() { + dev_warn!( + self.as_ref(), + "No registration data set; parent is not a Rust driver.\n" + ); + return Err(ENOENT); + } + + // SAFETY: `ptr` is non-null and was set via `into_foreign()` in `Registration::new()`; + // `RegistrationData` is `#[repr(C)]` with `type_id` at offset 0, so reading a `TypeId` + // at the start of the allocation is valid regardless of `F`. + let type_id = unsafe { ptr.cast::<TypeId>().read() }; + if type_id != TypeId::of::<F>() { + return Err(EINVAL); + } + + // SAFETY: The `TypeId` check above confirms that the stored type matches + // `F::Of<'static>`; `ptr` remains valid until `Registration::drop()` calls + // `from_foreign()`. + let wrapper = unsafe { Pin::<KBox<RegistrationData<F::Of<'static>>>>::borrow(ptr) }; + + // SAFETY: `data` is a structurally pinned field of `RegistrationData`. + let pinned: Pin<&F::Of<'_>> = unsafe { wrapper.map_unchecked(|w| &w.data) }; + + // SAFETY: The data was pinned when stored; `cast_ref` only shortens + // the lifetime, so the pinning guarantee is preserved. + Ok(unsafe { Pin::new_unchecked(F::cast_ref(pinned.get_ref())) }) + } } impl Device { @@ -326,87 +381,173 @@ unsafe impl Send for Device {} // (i.e. `Device<Normal>) are thread safe. unsafe impl Sync for Device {} +// SAFETY: Same as `Device<Normal>` -- the underlying `struct auxiliary_device` is the same; +// `Bound` is a zero-sized type-state marker that does not affect thread safety. +unsafe impl Sync for Device<device::Bound> {} + +/// Wrapper that stores a [`TypeId`] alongside the registration data for runtime type checking. +#[repr(C)] +#[pin_data] +struct RegistrationData<T> { + type_id: TypeId, + #[pin] + data: T, +} + /// The registration of an auxiliary device. /// /// This type represents the registration of a [`struct auxiliary_device`]. When its parent device /// is unbound, the corresponding auxiliary device will be unregistered from the system. /// +/// The type parameter `F` is a [`ForLt`](trait@ForLt) encoding of the registration +/// data type. For non-lifetime-parameterized types, use [`ForLt!(T)`](macro@ForLt). +/// The data can be accessed by the auxiliary driver through [`Device::registration_data()`]. +/// /// # Invariants /// -/// `self.0` always holds a valid pointer to an initialized and registered -/// [`struct auxiliary_device`]. -pub struct Registration(NonNull<bindings::auxiliary_device>); +/// `self.adev` always holds a valid pointer to an initialized and registered +/// [`struct auxiliary_device`] whose `registration_data_rust` field points to a +/// valid `Pin<KBox<RegistrationData<F::Of<'static>>>>`. +pub struct Registration<'a, F: ForLt + 'static> { + adev: NonNull<bindings::auxiliary_device>, + _phantom: PhantomData<F::Of<'a>>, +} -impl Registration { - /// Create and register a new auxiliary device. - pub fn new<'a>( +impl<'a, F: ForLt> Registration<'a, F> +where + for<'b> F::Of<'b>: Send + Sync, +{ + /// Create and register a new auxiliary device with the given registration data. + /// + /// The `data` is owned by the registration and can be accessed through the auxiliary device + /// via [`Device::registration_data()`]. + /// + /// # Safety + /// + /// The caller must not `mem::forget()` the returned [`Registration`] or otherwise prevent its + /// [`Drop`] implementation from running, since the registration data may contain borrowed + /// references that become invalid after `'a` ends. + /// + /// If the registration data is `'static`, use the safe [`Registration::new()`] instead. + pub unsafe fn new_with_lt<E>( parent: &'a device::Device<device::Bound>, - name: &'a CStr, + name: &CStr, id: u32, - modname: &'a CStr, - ) -> impl PinInit<Devres<Self>, Error> + 'a { - pin_init::pin_init_scope(move || { - let boxed = KBox::new(Opaque::<bindings::auxiliary_device>::zeroed(), GFP_KERNEL)?; - let adev = boxed.get(); - - // SAFETY: It's safe to set the fields of `struct auxiliary_device` on initialization. - unsafe { - (*adev).dev.parent = parent.as_raw(); - (*adev).dev.release = Some(Device::release); - (*adev).name = name.as_char_ptr(); - (*adev).id = id; - } - - // SAFETY: `adev` is guaranteed to be a valid pointer to a `struct auxiliary_device`, - // which has not been initialized yet. - unsafe { bindings::auxiliary_device_init(adev) }; - - // Now that `adev` is initialized, leak the `Box`; the corresponding memory will be - // freed by `Device::release` when the last reference to the `struct auxiliary_device` - // is dropped. - let _ = KBox::into_raw(boxed); - - // SAFETY: - // - `adev` is guaranteed to be a valid pointer to a `struct auxiliary_device`, which - // has been initialized, - // - `modname.as_char_ptr()` is a NULL terminated string. - let ret = unsafe { bindings::__auxiliary_device_add(adev, modname.as_char_ptr()) }; - if ret != 0 { - // SAFETY: `adev` is guaranteed to be a valid pointer to a - // `struct auxiliary_device`, which has been initialized. - unsafe { bindings::auxiliary_device_uninit(adev) }; - - return Err(Error::from_errno(ret)); - } - - // INVARIANT: The device will remain registered until `auxiliary_device_delete()` is - // called, which happens in `Self::drop()`. - Ok(Devres::new( - parent, - // SAFETY: `adev` is guaranteed to be non-null, since the `KBox` was allocated - // successfully. - Self(unsafe { NonNull::new_unchecked(adev) }), - )) + modname: &CStr, + data: impl PinInit<F::Of<'a>, E>, + ) -> Result<Self> + where + Error: From<E>, + { + let data = KBox::pin_init::<Error>( + try_pin_init!(RegistrationData { + type_id: TypeId::of::<F>(), + data <- data, + }), + GFP_KERNEL, + )?; + + // SAFETY: `'a` is invariant (via `Registration`'s `PhantomData`). Lifetimes do not + // affect layout, so RegistrationData<F::Of<'a>> and RegistrationData<F::Of<'static>> + // have identical representation. + let data: Pin<KBox<RegistrationData<F::Of<'static>>>> = + unsafe { core::mem::transmute(data) }; + + let boxed: KBox<Opaque<bindings::auxiliary_device>> = KBox::zeroed(GFP_KERNEL)?; + let adev = boxed.get(); + + // SAFETY: It's safe to set the fields of `struct auxiliary_device` on initialization. + unsafe { + (*adev).dev.parent = parent.as_raw(); + (*adev).dev.release = Some(Device::release); + (*adev).name = name.as_char_ptr(); + (*adev).id = id; + (*adev).registration_data_rust = data.into_foreign(); + } + + // SAFETY: `adev` is guaranteed to be a valid pointer to a `struct auxiliary_device`, + // which has not been initialized yet. + unsafe { bindings::auxiliary_device_init(adev) }; + + // Now that `adev` is initialized, leak the `Box`; the corresponding memory will be + // freed by `Device::release` when the last reference to the `struct auxiliary_device` + // is dropped. + let _ = KBox::into_raw(boxed); + + // SAFETY: + // - `adev` is guaranteed to be a valid pointer to a `struct auxiliary_device`, which + // has been initialized, + // - `modname.as_char_ptr()` is a NULL terminated string. + let ret = unsafe { bindings::__auxiliary_device_add(adev, modname.as_char_ptr()) }; + if ret != 0 { + // SAFETY: `registration_data` was set above via `into_foreign()`. + drop(unsafe { + Pin::<KBox<RegistrationData<F::Of<'static>>>>::from_foreign( + (*adev).registration_data_rust, + ) + }); + + // SAFETY: `adev` is guaranteed to be a valid pointer to a + // `struct auxiliary_device`, which has been initialized. + unsafe { bindings::auxiliary_device_uninit(adev) }; + + return Err(Error::from_errno(ret)); + } + + // INVARIANT: The device will remain registered until `auxiliary_device_delete()` is + // called, which happens in `Self::drop()`. + Ok(Self { + // SAFETY: `adev` is guaranteed to be non-null, since the `KBox` was allocated + // successfully. + adev: unsafe { NonNull::new_unchecked(adev) }, + _phantom: PhantomData, }) } + + /// Create and register a new auxiliary device with `'static` registration data. + /// + /// Safe variant of [`Registration::new_with_lt()`] for registration data that does not contain + /// borrowed references. + pub fn new<E>( + parent: &'a device::Device<device::Bound>, + name: &CStr, + id: u32, + modname: &CStr, + data: impl PinInit<F::Of<'a>, E>, + ) -> Result<Self> + where + F::Of<'a>: 'static, + Error: From<E>, + { + // SAFETY: `F::Of<'a>: 'static` guarantees the data contains no borrowed references, + // so forgetting the `Registration` cannot cause use-after-free. + unsafe { Self::new_with_lt(parent, name, id, modname, data) } + } } -impl Drop for Registration { +impl<F: ForLt> Drop for Registration<'_, F> { fn drop(&mut self) { - // SAFETY: By the type invariant of `Self`, `self.0.as_ptr()` is a valid registered + // SAFETY: By the type invariant of `Self`, `self.adev.as_ptr()` is a valid registered // `struct auxiliary_device`. - unsafe { bindings::auxiliary_device_delete(self.0.as_ptr()) }; + unsafe { bindings::auxiliary_device_delete(self.adev.as_ptr()) }; + + // SAFETY: `registration_data` was set in `new()` via `into_foreign()`. + drop(unsafe { + Pin::<KBox<RegistrationData<F::Of<'static>>>>::from_foreign( + (*self.adev.as_ptr()).registration_data_rust, + ) + }); // This drops the reference we acquired through `auxiliary_device_init()`. // - // SAFETY: By the type invariant of `Self`, `self.0.as_ptr()` is a valid registered + // SAFETY: By the type invariant of `Self`, `self.adev.as_ptr()` is a valid registered // `struct auxiliary_device`. - unsafe { bindings::auxiliary_device_uninit(self.0.as_ptr()) }; + unsafe { bindings::auxiliary_device_uninit(self.adev.as_ptr()) }; } } // SAFETY: A `Registration` of a `struct auxiliary_device` can be released from any thread. -unsafe impl Send for Registration {} +unsafe impl<F: ForLt> Send for Registration<'_, F> where for<'a> F::Of<'a>: Send {} // SAFETY: `Registration` does not expose any methods or fields that need synchronization. -unsafe impl Sync for Registration {} +unsafe impl<F: ForLt> Sync for Registration<'_, F> where for<'a> F::Of<'a>: Send {} diff --git a/rust/kernel/bitfield.rs b/rust/kernel/bitfield.rs new file mode 100644 index 000000000000..35ede53f2b8e --- /dev/null +++ b/rust/kernel/bitfield.rs @@ -0,0 +1,863 @@ +// SPDX-License-Identifier: GPL-2.0 + +//! Support for defining bitfields as Rust structures. +//! +//! The [`bitfield!`](kernel::bitfield!) macro declares integer types that are split into distinct +//! bit fields of arbitrary length. Each field is typed using [`Bounded`](kernel::num::Bounded) to +//! ensure values are properly validated and to avoid implicit data loss. +//! +//! # Example +//! +//! ```rust +//! use kernel::bitfield; +//! use kernel::num::Bounded; +//! +//! bitfield! { +//! pub struct Rgb(u16) { +//! 15:11 blue; +//! 10:5 green; +//! 4:0 red; +//! } +//! } +//! +//! // Valid value for the `blue` field. +//! let blue = Bounded::<u16, 5>::new::<0x18>(); +//! +//! // Setters can be chained. Values ranges are checked at compile-time. +//! let color = Rgb::zeroed() +//! // Compile-time bounds check of constant value. +//! .with_const_red::<0x10>() +//! .with_const_green::<0x1f>() +//! // A `Bounded` can also be passed. +//! .with_blue(blue); +//! +//! assert_eq!(color.red(), 0x10); +//! assert_eq!(color.green(), 0x1f); +//! assert_eq!(color.blue(), 0x18); +//! assert_eq!( +//! color.into_raw(), +//! (0x18 << Rgb::BLUE_SHIFT) + (0x1f << Rgb::GREEN_SHIFT) + 0x10, +//! ); +//! +//! // Convert to/from the backing storage type. +//! let raw: u16 = color.into(); +//! assert_eq!(Rgb::from(raw), color); +//! ``` +//! +//! # Syntax +//! +//! ```text +//! bitfield! { +//! #[attributes] +//! // Documentation for `Name`. +//! pub struct Name(storage_type) { +//! // `field_1` documentation. +//! hi:lo field_1; +//! // `field_2` documentation. +//! hi:lo field_2 => ConvertedType; +//! // `field_3` documentation. +//! hi:lo field_3 ?=> ConvertedType; +//! ... +//! } +//! } +//! ``` +//! +//! - `storage_type`: The underlying unsigned integer type ([`u8`], [`u16`], [`u32`], [`u64`]). +//! Signed integer storage types are not supported. +//! - `hi:lo`: Bit range (inclusive), where `hi >= lo`. +//! - `=> Type`: Optional infallible conversion (see [below](#infallible-conversion-)). +//! - `?=> Type`: Optional fallible conversion (see [below](#fallible-conversion-)). +//! - Documentation strings and attributes are optional. +//! +//! # Generated code +//! +//! Each field is internally represented as a [`Bounded`] parameterized by its bit width. Field +//! values can either be set/retrieved directly, or converted from/to another type. +//! +//! The use of [`Bounded`] for each field enforces bounds-checking (at build time or runtime) of +//! every value assigned to a field. This ensures that data is never accidentally truncated. +//! +//! The macro generates the bitfield type, [`From`] and [`Into`] implementations for its storage +//! type, as well as [`Debug`] and [`Zeroable`](pin_init::Zeroable) implementations. +//! +//! For each field, it also generates: +//! +//! - `field()`: Getter method for the field value. +//! - `with_field(value)`: Infallible setter; the argument type must fit within the field's width. +//! - `with_const_field::<VALUE>()`: `const` setter; the value is validated at compile time. +//! Usually shorter to use than `with_field` for constant values as it doesn't require +//! constructing a [`Bounded`]. +//! - `try_with_field(value)`: Fallible setter. Returns an error if the value is out of range. +//! - `FIELD_MASK`, `FIELD_SHIFT`, `FIELD_RANGE`: Constants for manual bit manipulation. +//! +//! # Reserved names for field identifiers +//! +//! Field identifiers are used to generate methods and associated constants on the bitfield type. +//! For a field named `field`, the macro may generate methods named `field`, `with_field`, +//! `with_const_field`, `try_with_field`, `__field` and `__with_field`, as well as constants named +//! `FIELD_MASK`, `FIELD_SHIFT` and `FIELD_RANGE`. +//! +//! Therefore, field identifiers must not use names that would collide with generated items for +//! any field in the same bitfield. The following prefixes are thus reserved for field identifiers: +//! +//! - `with_` +//! - `const_` +//! - `try_with_` +//! - `__` +//! +//! The field identifiers `from_raw`, `into_raw`, and `into` are also reserved. +//! +//! In addition, field identifiers should follow Rust `snake_case` conventions, since the associated +//! constants are generated by uppercasing the field name. +//! +//! # Implicit conversions +//! +//! Types that fit entirely within a field's bit width can be used directly with setters. For +//! example, [`bool`] works with single-bit fields, and [`u8`] works with 8-bit fields: +//! +//! ```rust +//! use kernel::bitfield; +//! +//! bitfield! { +//! pub struct Flags(u32) { +//! 15:8 byte_field; +//! 0:0 flag; +//! } +//! } +//! +//! let flags = Flags::zeroed() +//! .with_byte_field(0x42_u8) +//! .with_flag(true); +//! +//! assert_eq!(flags.into_raw(), (0x42 << Flags::BYTE_FIELD_SHIFT) | 1); +//! ``` +//! +//! # Runtime bounds checking +//! +//! When a value is not known at compile time, use `try_with_field()` to check bounds at runtime: +//! +//! ```rust +//! use kernel::bitfield; +//! +//! bitfield! { +//! pub struct Config(u8) { +//! 3:0 nibble; +//! } +//! } +//! +//! fn set_nibble(config: Config, value: u8) -> Result<Config, Error> { +//! // Returns `EOVERFLOW` if `value > 0xf`. +//! config.try_with_nibble(value) +//! } +//! # Ok::<(), Error>(()) +//! ``` +//! +//! # Type conversion +//! +//! Fields can be automatically converted to/from a custom type using `=>` (infallible) or `?=>` +//! (fallible). The custom type must implement the appropriate [`From`] or [`TryFrom`] traits with +//! [`Bounded`]. +//! +//! ## Infallible conversion (`=>`) +//! +//! Use this when all possible bit patterns of a field map to valid values: +//! +//! ```rust +//! use kernel::bitfield; +//! use kernel::num::Bounded; +//! +//! #[derive(Debug, Clone, Copy, PartialEq)] +//! enum Power { +//! Off, +//! On, +//! } +//! +//! impl From<Bounded<u32, 1>> for Power { +//! fn from(v: Bounded<u32, 1>) -> Self { +//! match *v { +//! 0 => Power::Off, +//! _ => Power::On, +//! } +//! } +//! } +//! +//! impl From<Power> for Bounded<u32, 1> { +//! fn from(p: Power) -> Self { +//! (p as u32 != 0).into() +//! } +//! } +//! +//! bitfield! { +//! pub struct Control(u32) { +//! 0:0 power => Power; +//! } +//! } +//! +//! let ctrl = Control::zeroed().with_power(Power::On); +//! assert_eq!(ctrl.power(), Power::On); +//! ``` +//! +//! ## Fallible conversion (`?=>`) +//! +//! Use this when some bit patterns of a field are invalid. The getter returns a [`Result`]: +//! +//! ```rust +//! use kernel::bitfield; +//! use kernel::num::Bounded; +//! +//! #[derive(Debug, Clone, Copy, PartialEq)] +//! enum Mode { +//! Low = 0, +//! High = 1, +//! Auto = 2, +//! // 3 is invalid +//! } +//! +//! impl TryFrom<Bounded<u32, 2>> for Mode { +//! type Error = u32; +//! +//! fn try_from(v: Bounded<u32, 2>) -> Result<Self, u32> { +//! match *v { +//! 0 => Ok(Mode::Low), +//! 1 => Ok(Mode::High), +//! 2 => Ok(Mode::Auto), +//! n => Err(n), +//! } +//! } +//! } +//! +//! impl From<Mode> for Bounded<u32, 2> { +//! fn from(m: Mode) -> Self { +//! match m { +//! Mode::Low => Bounded::<u32, _>::new::<0>(), +//! Mode::High => Bounded::<u32, _>::new::<1>(), +//! Mode::Auto => Bounded::<u32, _>::new::<2>(), +//! } +//! } +//! } +//! +//! bitfield! { +//! pub struct Config(u32) { +//! 1:0 mode ?=> Mode; +//! } +//! } +//! +//! let cfg = Config::zeroed().with_mode(Mode::Auto); +//! assert_eq!(cfg.mode(), Ok(Mode::Auto)); +//! +//! // Invalid bit pattern returns an error. +//! assert_eq!(Config::from(0b11).mode(), Err(3)); +//! ``` +//! +//! # Bits outside of declared fields +//! +//! Bits of the storage type that are not part of any declared field are preserved by the setter +//! methods, and can only be modified through `from_raw` or the [`From`] implementation from the +//! storage type. +//! +//! ```rust +//! use kernel::bitfield; +//! +//! bitfield! { +//! pub struct Sparse(u8) { +//! 7:6 high; +//! // Bits 5:1 are not covered by any field. +//! 0:0 low; +//! } +//! } +//! +//! // Set the gap bits via `from_raw`, then mutate the declared fields. +//! let val = Sparse::from_raw(0b0010_1010) +//! .with_const_high::<0b11>() +//! .with_low(true); +//! +//! // Bits 5:1 are unchanged. +//! assert_eq!(val.into_raw(), 0b1110_1011); +//! ``` +//! +//! # Signed field values +//! +//! Bitfield storage types are unsigned. Since field getter methods return a [`Bounded`] of the +//! storage type, fields are also unsigned by default. +//! +//! If a field needs to encode a signed value, use a custom conversion type with `=>` or `?=>` to +//! perform the sign interpretation explicitly. +//! +//! [`Bounded`]: kernel::num::Bounded + +/// Defines a bitfield struct with bounds-checked accessors for individual bit ranges. +/// +/// See the [`mod@kernel::bitfield`] module for full documentation and examples. +#[macro_export] +macro_rules! bitfield { + // Entry point defining the bitfield struct, its implementations and its field accessors. + ( + $(#[$attr:meta])* $vis:vis struct $name:ident($storage:ty) { $($fields:tt)* } + ) => { + $crate::bitfield!(@core + #[allow(non_camel_case_types)] + $(#[$attr])* $vis $name $storage + ); + $crate::bitfield!(@fields $vis $name $storage { $($fields)* }); + }; + + // All rules below are helpers. + + // Defines the wrapper `$name` type and its conversions from/to the storage type. + (@core $(#[$attr:meta])* $vis:vis $name:ident $storage:ty) => { + $(#[$attr])* + #[repr(transparent)] + #[derive(Clone, Copy, PartialEq, Eq)] + $vis struct $name { + inner: $storage, + } + + #[allow(dead_code)] + impl $name { + /// Creates a bitfield from a raw value. + #[inline(always)] + $vis const fn from_raw(value: $storage) -> Self { + Self{ inner: value } + } + + /// Turns this bitfield into its raw value. + /// + /// This is similar to the [`From`] implementation, but is shorter to invoke in + /// most cases. + #[inline(always)] + $vis const fn into_raw(self) -> $storage { + self.inner + } + } + + // SAFETY: `$storage` is `Zeroable` and `$name` is transparent. + unsafe impl ::pin_init::Zeroable for $name {} + + impl ::core::convert::From<$name> for $storage { + #[inline(always)] + fn from(val: $name) -> $storage { + val.into_raw() + } + } + + impl ::core::convert::From<$storage> for $name { + #[inline(always)] + fn from(val: $storage) -> $name { + Self::from_raw(val) + } + } + }; + + // Definitions requiring knowledge of individual fields: private and public field accessors, + // and `Debug` implementation. + (@fields $vis:vis $name:ident $storage:ty { + $($(#[doc = $doc:expr])* $hi:literal:$lo:literal $field:ident + $(?=> $try_into_type:ty)? + $(=> $into_type:ty)? + ; + )* + } + ) => { + #[allow(dead_code)] + impl $name { + $( + $crate::bitfield!(@private_field_accessors $vis $name $storage : $hi:$lo $field); + $crate::bitfield!( + @public_field_accessors $(#[doc = $doc])* $vis $name $storage : $hi:$lo $field + $(?=> $try_into_type)? + $(=> $into_type)? + ); + )* + } + + $crate::bitfield!(@debug $name { $($field;)* }); + }; + + // Private field accessors working with the exact `Bounded` type for the field. + ( + @private_field_accessors $vis:vis $name:ident $storage:ty : $hi:tt:$lo:tt $field:ident + ) => { + ::kernel::macros::paste!( + $vis const [<$field:upper _RANGE>]: ::core::ops::RangeInclusive<u8> = $lo..=$hi; + $vis const [<$field:upper _MASK>]: $storage = + ((((1 << $hi) - 1) << 1) + 1) - ((1 << $lo) - 1); + $vis const [<$field:upper _SHIFT>]: u32 = $lo; + ); + + ::kernel::macros::paste!( + #[inline(always)] + fn [<__ $field>](self) -> + ::kernel::num::Bounded<$storage, { $hi + 1 - $lo }> { + // Left shift to align the field's MSB with the storage MSB. + const ALIGN_TOP: u32 = $storage::BITS - ($hi + 1); + // Right shift to move the top-aligned field to bit 0 of the storage. + const ALIGN_BOTTOM: u32 = ALIGN_TOP + $lo; + + // Extract the field using two shifts. `Bounded::shr` produces the correctly-sized + // output type. + let val = ::kernel::num::Bounded::<$storage, { $storage::BITS }>::from( + self.inner << ALIGN_TOP + ); + val.shr::<ALIGN_BOTTOM, { $hi + 1 - $lo } >() + } + + #[inline(always)] + const fn [<__with_ $field>]( + mut self, + value: ::kernel::num::Bounded<$storage, { $hi + 1 - $lo }>, + ) -> Self + { + const MASK: $storage = <$name>::[<$field:upper _MASK>]; + const SHIFT: u32 = <$name>::[<$field:upper _SHIFT>]; + + let value = value.get() << SHIFT; + self.inner = (self.inner & !MASK) | value; + + self + } + ); + }; + + // Public accessors for fields infallibly (`=>`) converted to a type. + ( + @public_field_accessors $(#[doc = $doc:expr])* $vis:vis $name:ident $storage:ty : + $hi:literal:$lo:literal $field:ident => $into_type:ty + ) => { + ::kernel::macros::paste!( + + $(#[doc = $doc])* + #[doc = "Returns the value of this field."] + #[inline(always)] + $vis fn $field(self) -> $into_type + { + self.[<__ $field>]().into() + } + + $(#[doc = $doc])* + #[doc = "Sets this field to the given `value`."] + #[inline(always)] + $vis fn [<with_ $field>](self, value: $into_type) -> Self + { + self.[<__with_ $field>](value.into()) + } + + ); + }; + + // Public accessors for fields fallibly (`?=>`) converted to a type. + ( + @public_field_accessors $(#[doc = $doc:expr])* $vis:vis $name:ident $storage:ty : + $hi:tt:$lo:tt $field:ident ?=> $try_into_type:ty + ) => { + ::kernel::macros::paste!( + + $(#[doc = $doc])* + #[doc = "Returns the value of this field."] + #[inline(always)] + $vis fn $field(self) -> + ::core::result::Result< + $try_into_type, + <$try_into_type as ::core::convert::TryFrom< + ::kernel::num::Bounded<$storage, { $hi + 1 - $lo }> + >>::Error + > + { + self.[<__ $field>]().try_into() + } + + $(#[doc = $doc])* + #[doc = "Sets this field to the given `value`."] + #[inline(always)] + $vis fn [<with_ $field>](self, value: $try_into_type) -> Self + { + self.[<__with_ $field>](value.into()) + } + + ); + }; + + // Public accessors for fields not converted to a type. + ( + @public_field_accessors $(#[doc = $doc:expr])* $vis:vis $name:ident $storage:ty : + $hi:tt:$lo:tt $field:ident + ) => { + ::kernel::macros::paste!( + + $(#[doc = $doc])* + #[doc = "Returns the value of this field."] + #[inline(always)] + $vis fn $field(self) -> + ::kernel::num::Bounded<$storage, { $hi + 1 - $lo }> + { + self.[<__ $field>]() + } + + $(#[doc = $doc])* + #[doc = "Sets this field to the compile-time constant `VALUE`."] + #[inline(always)] + $vis const fn [<with_const_ $field>]<const VALUE: $storage>(self) -> Self { + self.[<__with_ $field>]( + ::kernel::num::Bounded::<$storage, { $hi + 1 - $lo }>::new::<VALUE>() + ) + } + + $(#[doc = $doc])* + #[doc = "Sets this field to the given `value`."] + #[inline(always)] + $vis fn [<with_ $field>]<T>( + self, + value: T, + ) -> Self + where T: ::core::convert::Into<::kernel::num::Bounded<$storage, { $hi + 1 - $lo }>>, + { + self.[<__with_ $field>](value.into()) + } + + $(#[doc = $doc])* + #[doc = "Tries to set this field to `value`, returning an error if it is out of range."] + #[inline(always)] + $vis fn [<try_with_ $field>]<T>( + self, + value: T, + ) -> ::kernel::error::Result<Self> + where T: ::kernel::num::TryIntoBounded<$storage, { $hi + 1 - $lo }>, + { + Ok( + self.[<__with_ $field>]( + value.try_into_bounded().ok_or(::kernel::error::code::EOVERFLOW)? + ) + ) + } + + ); + }; + + // `Debug` implementation. + (@debug $name:ident { $($field:ident;)* }) => { + impl ::kernel::fmt::Debug for $name { + #[inline] + fn fmt(&self, f: &mut ::kernel::fmt::Formatter<'_>) -> ::kernel::fmt::Result { + f.debug_struct(stringify!($name)) + .field("<raw>", &::kernel::prelude::fmt!("{:#x}", self.inner)) + $( + .field(stringify!($field), &self.$field()) + )* + .finish() + } + } + }; +} + +#[cfg(CONFIG_RUST_BITFIELD_KUNIT_TEST)] +#[::kernel::macros::kunit_tests(rust_kernel_bitfield)] +mod tests { + use core::convert::TryFrom; + + use pin_init::Zeroable; + + use kernel::num::Bounded; + + // Enum types for testing `=>` and `?=>` conversions. + + #[derive(Debug, Clone, Copy, PartialEq)] + enum MemoryType { + Unmapped = 0, + Normal = 1, + Device = 2, + Reserved = 3, + } + + impl TryFrom<Bounded<u64, 4>> for MemoryType { + type Error = u64; + fn try_from(value: Bounded<u64, 4>) -> Result<Self, Self::Error> { + match value.get() { + 0 => Ok(MemoryType::Unmapped), + 1 => Ok(MemoryType::Normal), + 2 => Ok(MemoryType::Device), + 3 => Ok(MemoryType::Reserved), + _ => Err(value.get()), + } + } + } + + impl From<MemoryType> for Bounded<u64, 4> { + fn from(mt: MemoryType) -> Bounded<u64, 4> { + Bounded::from_expr(mt as u64) + } + } + + #[derive(Debug, Clone, Copy, PartialEq)] + enum Priority { + Low = 0, + Medium = 1, + High = 2, + Critical = 3, + } + + impl From<Bounded<u16, 2>> for Priority { + fn from(value: Bounded<u16, 2>) -> Self { + match value & 0x3 { + 0 => Priority::Low, + 1 => Priority::Medium, + 2 => Priority::High, + _ => Priority::Critical, + } + } + } + + impl From<Priority> for Bounded<u16, 2> { + fn from(p: Priority) -> Bounded<u16, 2> { + Bounded::from_expr(p as u16) + } + } + + bitfield! { + struct TestU64(u64) { + 63:63 field_63; + 61:52 field_61_52; + 51:16 field_51_16; + 15:12 field_15_12 ?=> MemoryType; + 11:9 field_11_9; + 1:1 field_1; + 0:0 field_0; + } + } + + bitfield! { + struct TestU16(u16) { + 15:8 field_15_8; + 7:4 field_7_4; // Partial overlap with `field_5_4`. + 5:4 field_5_4 => Priority; + 3:1 field_3_1; + 0:0 field_0; + } + } + + bitfield! { + struct TestU8(u8) { + 7:0 field_7_0; // Full byte overlap. + 7:4 field_7_4; + 3:2 field_3_2; + 1:1 field_1; + 0:0 field_0; + } + } + + // Single and multi-bit fields basic access. + #[test] + fn test_basic_access() { + // `TestU64`. + let mut val = TestU64::zeroed(); + assert_eq!(val.into_raw(), 0x0); + + val = val.with_field_0(true); + assert!(val.field_0().into_bool()); + assert_eq!(val.into_raw(), 0x1); + + val = val.with_field_1(true); + assert!(val.field_1().into_bool()); + val = val.with_field_1(false); + assert!(!val.field_1().into_bool()); + assert_eq!(val.into_raw(), 0x1); + + val = val.with_const_field_11_9::<0x5>(); + assert_eq!(val.field_11_9(), 0x5); + assert_eq!(val.into_raw(), 0xA01); + + val = val.with_const_field_51_16::<0x123456>(); + assert_eq!(val.field_51_16(), 0x123456); + assert_eq!(val.into_raw(), 0x0012_3456_0A01); + + const MAX_FIELD_51_16: u64 = ::kernel::bits::genmask_u64(0..=35); + val = val.with_const_field_51_16::<{ MAX_FIELD_51_16 }>(); + assert_eq!(val.field_51_16(), MAX_FIELD_51_16); + + val = val.with_const_field_61_52::<0x3FF>(); + assert_eq!(val.field_61_52(), 0x3FF); + + val = val.with_field_63(true); + assert!(val.field_63().into_bool()); + + // `TestU16`. + let mut val = TestU16::zeroed(); + assert_eq!(val.into_raw(), 0x0); + + val = val.with_field_0(true); + assert!(val.field_0().into_bool()); + assert_eq!(val.into_raw(), 0x1); + + val = val.with_const_field_3_1::<0x5>(); + assert_eq!(val.field_3_1(), 0x5); + assert_eq!(val.into_raw(), 0xB); + + val = val.with_const_field_7_4::<0xA>(); + assert_eq!(val.field_7_4(), 0xA); + assert_eq!(val.into_raw(), 0xAB); + + val = val.with_const_field_15_8::<0x42>(); + assert_eq!(val.field_15_8(), 0x42); + assert_eq!(val.into_raw(), 0x42AB); + + // `TestU8`. + let mut val = TestU8::zeroed(); + assert_eq!(val.into_raw(), 0x0); + + val = val.with_field_0(true); + assert!(val.field_0().into_bool()); + assert_eq!(val.into_raw(), 0x1); + + val = val.with_field_1(true); + assert!(val.field_1().into_bool()); + assert_eq!(val.into_raw(), 0x3); + + val = val.with_const_field_3_2::<0x3>(); + assert_eq!(val.field_3_2(), 0x3); + assert_eq!(val.into_raw(), 0xF); + + val = val.with_const_field_7_4::<0xA>(); + assert_eq!(val.field_7_4(), 0xA); + assert_eq!(val.into_raw(), 0xAF); + } + + // `=>` infallible conversion. + #[test] + fn test_infallible_conversion() { + let mut val = TestU16::zeroed(); + + val = val.with_field_5_4(Priority::Low); + assert_eq!(val.field_5_4(), Priority::Low); + assert_eq!(val.into_raw() & 0x30, 0x00); + + val = val.with_field_5_4(Priority::Medium); + assert_eq!(val.field_5_4(), Priority::Medium); + assert_eq!(val.into_raw() & 0x30, 0x10); + + val = val.with_field_5_4(Priority::High); + assert_eq!(val.field_5_4(), Priority::High); + assert_eq!(val.into_raw() & 0x30, 0x20); + + val = val.with_field_5_4(Priority::Critical); + assert_eq!(val.field_5_4(), Priority::Critical); + assert_eq!(val.into_raw() & 0x30, 0x30); + } + + // `?=>` fallible conversion. + #[test] + fn test_fallible_conversion() { + let mut val = TestU64::zeroed(); + + val = val.with_field_15_12(MemoryType::Unmapped); + assert_eq!(val.field_15_12(), Ok(MemoryType::Unmapped)); + val = val.with_field_15_12(MemoryType::Normal); + assert_eq!(val.field_15_12(), Ok(MemoryType::Normal)); + val = val.with_field_15_12(MemoryType::Device); + assert_eq!(val.field_15_12(), Ok(MemoryType::Device)); + val = val.with_field_15_12(MemoryType::Reserved); + assert_eq!(val.field_15_12(), Ok(MemoryType::Reserved)); + + // `field_15_12` is 4 bits wide (0-15); `MemoryType` only covers 0-3, so 4-15 return `Err`. + let raw = (val.into_raw() & !::kernel::bits::genmask_u64(12..=15)) | (0x7 << 12); + assert_eq!(TestU64::from_raw(raw).field_15_12(), Err(0x7)); + } + + // Test that setting an overlapping field affects the overlapped one as expected. + #[test] + fn test_overlapping_fields() { + let mut val = TestU16::zeroed(); + + val = val.with_field_5_4(Priority::High); // High == 2 == 0b10. + assert_eq!(val.field_5_4(), Priority::High); + assert_eq!(val.field_7_4(), 0x2); // Bits 7:6 == 0, bits 5:4 == 0b10. + + val = val.with_const_field_7_4::<0xF>(); + assert_eq!(val.field_7_4(), 0xF); + assert_eq!(val.field_5_4(), Priority::Critical); // Bits 5:4 == 0b11. + + // `field_7_0` should encompass all other fields. + let mut val = TestU8::zeroed() + .with_field_0(true) + .with_field_1(true) + .with_const_field_3_2::<0x3>() + .with_const_field_7_4::<0xA>(); + assert_eq!(val.into_raw(), 0xAF); + + val = val.with_field_7_0(0x55); + assert_eq!(val.field_7_0(), 0x55); + assert!(val.field_0().into_bool()); + assert!(!val.field_1().into_bool()); + assert_eq!(val.field_3_2(), 0x1); + assert_eq!(val.field_7_4(), 0x5); + } + + // Checks that bits not mapped to any field are left untouched. + #[test] + fn test_unallocated_bits() { + let gap_bits = (1u64 << 62) | 0x1FC; + + let set_all_fields = |val: TestU64| { + val.with_field_63(true) + .with_const_field_61_52::<0x155>() + .with_const_field_51_16::<0x123456>() + .with_field_15_12(MemoryType::Device) + .with_const_field_11_9::<0x5>() + .with_field_1(true) + .with_field_0(true) + }; + + // Gap bits to 0. + let val = set_all_fields(TestU64::from_raw(0)); + assert_eq!(val.into_raw() & gap_bits, 0); + + // Gap bits to 1. + let val = set_all_fields(TestU64::from_raw(gap_bits)); + assert_eq!(val.into_raw() & gap_bits, gap_bits); + } + + #[test] + fn test_try_with() { + let val = TestU64::zeroed().try_with_field_51_16(0x123456).unwrap(); + assert_eq!(val.field_51_16(), 0x123456); + + let err = TestU64::zeroed().try_with_field_51_16(u64::MAX); + assert_eq!(err, Err(::kernel::error::code::EOVERFLOW)); + + let val = TestU64::zeroed() + .try_with_field_51_16(0xABCDEF) + .and_then(|p| p.try_with_field_0(1)) + .unwrap(); + assert_eq!(val.field_51_16(), 0xABCDEF); + assert!(val.field_0().into_bool()); + } + + // `from_raw`/`into_raw` and `From`/`Into` round-trips. + #[test] + fn test_raw() { + let raw: u64 = 0xBFF0_0000_3123_3E03; + let val = TestU64::from_raw(raw); + assert_eq!(u64::from(val), raw); + assert!(val.field_0().into_bool()); + assert!(val.field_1().into_bool()); + assert_eq!(val.field_11_9(), 0x7); + assert_eq!(val.field_51_16(), 0x3123); + assert_eq!(val.field_15_12(), Ok(MemoryType::Reserved)); + assert_eq!(val.field_61_52(), 0x3FF); + assert!(val.field_63().into_bool()); + + let raw: u16 = 0x42AB; + let val = TestU16::from_raw(raw); + assert_eq!(u16::from(val), raw); + assert!(val.field_0().into_bool()); + assert_eq!(val.field_3_1(), 0x5); + assert_eq!(val.field_7_4(), 0xA); + assert_eq!(val.field_15_8(), 0x42); + + let raw: u8 = 0xAF; + let val = TestU8::from_raw(raw); + assert_eq!(u8::from(val), raw); + assert!(val.field_0().into_bool()); + assert!(val.field_1().into_bool()); + assert_eq!(val.field_3_2(), 0x3); + assert_eq!(val.field_7_4(), 0xA); + assert_eq!(val.field_7_0(), 0xAF); + } +} diff --git a/rust/kernel/bitmap.rs b/rust/kernel/bitmap.rs index 83d7dea99137..b27e0ec80d64 100644 --- a/rust/kernel/bitmap.rs +++ b/rust/kernel/bitmap.rs @@ -499,9 +499,8 @@ impl Bitmap { } } -use macros::kunit_tests; - -#[kunit_tests(rust_kernel_bitmap)] +#[cfg(CONFIG_RUST_BITMAP_KUNIT_TEST)] +#[macros::kunit_tests(rust_kernel_bitmap)] mod tests { use super::*; use kernel::alloc::flags::GFP_KERNEL; diff --git a/rust/kernel/block/mq/gen_disk.rs b/rust/kernel/block/mq/gen_disk.rs index 912cb805caf5..fc97dd873974 100644 --- a/rust/kernel/block/mq/gen_disk.rs +++ b/rust/kernel/block/mq/gen_disk.rs @@ -150,6 +150,19 @@ impl GenDiskBuilder { // SAFETY: `gendisk` is a valid pointer as we initialized it above unsafe { (*gendisk).fops = &TABLE }; + let cleanup_failure = ScopeGuard::new_with_data((gendisk, data), |(gendisk, data)| { + // SAFETY: `gendisk` came from `__blk_mq_alloc_disk()` above and + // has not been added to the VFS on this cleanup path. + unsafe { bindings::put_disk(gendisk) }; + // SAFETY: `data` came from `into_foreign()` above and has not been + // converted back on this cleanup path. + drop(unsafe { T::QueueData::from_foreign(data) }); + }); + + // The failure guard now owns both pieces of cleanup; the early guard + // must not run on this path anymore. + recover_data.dismiss(); + let mut writer = NullTerminatedFormatter::new( // SAFETY: `gendisk` points to a valid and initialized instance. We // have exclusive access, since the disk is not added to the VFS @@ -172,7 +185,7 @@ impl GenDiskBuilder { }, )?; - recover_data.dismiss(); + cleanup_failure.dismiss(); // INVARIANT: `gendisk` was initialized above. // INVARIANT: `gendisk` was added to the VFS via `device_add_disk` above. @@ -215,6 +228,11 @@ impl<T: Operations> Drop for GenDisk<T> { // to the VFS. unsafe { bindings::del_gendisk(self.gendisk) }; + // SAFETY: By type invariant, `self.gendisk` was added to the VFS, so + // `put_disk()` must follow `del_gendisk()` to drop the final gendisk + // reference and trigger the remaining release path. + unsafe { bindings::put_disk(self.gendisk) }; + // SAFETY: `queue.queuedata` was created by `GenDiskBuilder::build` with // a call to `ForeignOwnable::into_foreign` to create `queuedata`. // `ForeignOwnable::from_foreign` is only called here. diff --git a/rust/kernel/block/mq/operations.rs b/rust/kernel/block/mq/operations.rs index 8ad46129a52c..861903e18fbf 100644 --- a/rust/kernel/block/mq/operations.rs +++ b/rust/kernel/block/mq/operations.rs @@ -218,7 +218,7 @@ impl<T: Operations> OperationsVTable<T> { _set: *mut bindings::blk_mq_tag_set, rq: *mut bindings::request, _hctx_idx: crate::ffi::c_uint, - _numa_node: crate::ffi::c_uint, + _numa_node: crate::ffi::c_int, ) -> crate::ffi::c_int { from_result(|| { // SAFETY: By the safety requirements of this function, `rq` points diff --git a/rust/kernel/build_assert.rs b/rust/kernel/build_assert.rs index 2ea2154ec30c..c3acb9b68a65 100644 --- a/rust/kernel/build_assert.rs +++ b/rust/kernel/build_assert.rs @@ -61,15 +61,16 @@ //! undefined symbols and linker errors, it is not developer friendly to debug, so it is recommended //! to avoid it and prefer other two assertions where possible. +#[doc(inline)] pub use crate::{ - build_assert, + build_assert_macro as build_assert, build_error, const_assert, static_assert, // }; #[doc(hidden)] -pub use build_error::build_error; +pub use build_error::build_error as build_error_fn; /// Static assert (i.e. compile-time assert). /// @@ -105,6 +106,7 @@ pub use build_error::build_error; /// static_assert!(f(40) == 42, "f(x) must add 2 to the given input."); /// ``` #[macro_export] +#[doc(hidden)] macro_rules! static_assert { ($condition:expr $(,$arg:literal)?) => { const _: () = ::core::assert!($condition $(,$arg)?); @@ -133,6 +135,7 @@ macro_rules! static_assert { /// } /// ``` #[macro_export] +#[doc(hidden)] macro_rules! const_assert { ($condition:expr $(,$arg:literal)?) => { const { ::core::assert!($condition $(,$arg)?) }; @@ -157,12 +160,13 @@ macro_rules! const_assert { /// // foo(usize::MAX); // Fails to compile. /// ``` #[macro_export] +#[doc(hidden)] macro_rules! build_error { () => {{ - $crate::build_assert::build_error("") + $crate::build_assert::build_error_fn("") }}; ($msg:expr) => {{ - $crate::build_assert::build_error($msg) + $crate::build_assert::build_error_fn($msg) }}; } @@ -200,15 +204,16 @@ macro_rules! build_error { /// const _: () = const_bar(2); /// ``` #[macro_export] -macro_rules! build_assert { +#[doc(hidden)] +macro_rules! build_assert_macro { ($cond:expr $(,)?) => {{ if !$cond { - $crate::build_assert::build_error(concat!("assertion failed: ", stringify!($cond))); + $crate::build_assert::build_error_fn(concat!("assertion failed: ", stringify!($cond))); } }}; ($cond:expr, $msg:expr) => {{ if !$cond { - $crate::build_assert::build_error($msg); + $crate::build_assert::build_error_fn($msg); } }}; } diff --git a/rust/kernel/cpufreq.rs b/rust/kernel/cpufreq.rs index d8d26870bea2..58ac04c650a1 100644 --- a/rust/kernel/cpufreq.rs +++ b/rust/kernel/cpufreq.rs @@ -888,12 +888,13 @@ pub trait Driver { /// /// impl platform::Driver for SampleDriver { /// type IdInfo = (); +/// type Data<'bound> = Self; /// const OF_ID_TABLE: Option<of::IdTable<Self::IdInfo>> = None; /// -/// fn probe( -/// pdev: &platform::Device<Core>, -/// _id_info: Option<&Self::IdInfo>, -/// ) -> impl PinInit<Self, Error> { +/// fn probe<'bound>( +/// pdev: &'bound platform::Device<Core<'_>>, +/// _id_info: Option<&'bound Self::IdInfo>, +/// ) -> impl PinInit<Self, Error> + 'bound { /// cpufreq::Registration::<SampleDriver>::new_foreign_owned(pdev.as_ref())?; /// Ok(Self {}) /// } @@ -1323,7 +1324,7 @@ impl<T: Driver> Registration<T> { // SAFETY: The C API guarantees that `cpu` refers to a valid CPU number. let cpu_id = unsafe { CpuId::from_u32_unchecked(cpu) }; - PolicyCpu::from_cpu(cpu_id).map_or(0, |mut policy| T::get(&mut policy).map_or(0, |f| f)) + PolicyCpu::from_cpu(cpu_id).map_or(0, |mut policy| T::get(&mut policy).unwrap_or(0)) } /// Driver's `update_limit` callback. diff --git a/rust/kernel/device.rs b/rust/kernel/device.rs index 6d5396a43ebe..645afc49a27d 100644 --- a/rust/kernel/device.rs +++ b/rust/kernel/device.rs @@ -15,16 +15,12 @@ use crate::{ }, // }; use core::{ - any::TypeId, marker::PhantomData, ptr, // }; pub mod property; -// Assert that we can `read()` / `write()` a `TypeId` instance from / into `struct driver_type`. -static_assert!(core::mem::size_of::<bindings::driver_type>() >= core::mem::size_of::<TypeId>()); - /// The core representation of a device in the kernel's driver model. /// /// This structure represents the Rust abstraction for a C `struct device`. A [`Device`] can either @@ -205,30 +201,13 @@ impl Device { } } -impl Device<CoreInternal> { - fn set_type_id<T: 'static>(&self) { - // SAFETY: By the type invariants, `self.as_raw()` is a valid pointer to a `struct device`. - let private = unsafe { (*self.as_raw()).p }; - - // SAFETY: For a bound device (implied by the `CoreInternal` device context), `private` is - // guaranteed to be a valid pointer to a `struct device_private`. - let driver_type = unsafe { &raw mut (*private).driver_type }; - - // SAFETY: `driver_type` is valid for (unaligned) writes of a `TypeId`. - unsafe { - driver_type - .cast::<TypeId>() - .write_unaligned(TypeId::of::<T>()) - }; - } - +impl<'a> Device<CoreInternal<'a>> { /// Store a pointer to the bound driver's private data. - pub fn set_drvdata<T: 'static>(&self, data: impl PinInit<T, Error>) -> Result { + pub fn set_drvdata<T>(&self, data: impl PinInit<T, Error>) -> Result { let data = KBox::pin_init(data, GFP_KERNEL)?; // SAFETY: By the type invariants, `self.as_raw()` is a valid pointer to a `struct device`. unsafe { bindings::dev_set_drvdata(self.as_raw(), data.into_foreign().cast()) }; - self.set_type_id::<T>(); Ok(()) } @@ -239,7 +218,7 @@ impl Device<CoreInternal> { /// /// - The type `T` must match the type of the `ForeignOwnable` previously stored by /// [`Device::set_drvdata`]. - pub(crate) unsafe fn drvdata_obtain<T: 'static>(&self) -> Option<Pin<KBox<T>>> { + pub(crate) unsafe fn drvdata_obtain<T>(&self) -> Option<Pin<KBox<T>>> { // SAFETY: By the type invariants, `self.as_raw()` is a valid pointer to a `struct device`. let ptr = unsafe { bindings::dev_get_drvdata(self.as_raw()) }; @@ -265,7 +244,7 @@ impl Device<CoreInternal> { /// device is fully unbound. /// - The type `T` must match the type of the `ForeignOwnable` previously stored by /// [`Device::set_drvdata`]. - pub unsafe fn drvdata_borrow<T: 'static>(&self) -> Pin<&T> { + pub unsafe fn drvdata_borrow<T>(&self) -> Pin<&T> { // SAFETY: `drvdata_unchecked()` has the exact same safety requirements as the ones // required by this method. unsafe { self.drvdata_unchecked() } @@ -281,7 +260,7 @@ impl Device<Bound> { /// the device is fully unbound. /// - The type `T` must match the type of the `ForeignOwnable` previously stored by /// [`Device::set_drvdata`]. - unsafe fn drvdata_unchecked<T: 'static>(&self) -> Pin<&T> { + unsafe fn drvdata_unchecked<T>(&self) -> Pin<&T> { // SAFETY: By the type invariants, `self.as_raw()` is a valid pointer to a `struct device`. let ptr = unsafe { bindings::dev_get_drvdata(self.as_raw()) }; @@ -292,45 +271,6 @@ impl Device<Bound> { // in `into_foreign()`. unsafe { Pin::<KBox<T>>::borrow(ptr.cast()) } } - - fn match_type_id<T: 'static>(&self) -> Result { - // SAFETY: By the type invariants, `self.as_raw()` is a valid pointer to a `struct device`. - let private = unsafe { (*self.as_raw()).p }; - - // SAFETY: For a bound device, `private` is guaranteed to be a valid pointer to a - // `struct device_private`. - let driver_type = unsafe { &raw mut (*private).driver_type }; - - // SAFETY: - // - `driver_type` is valid for (unaligned) reads of a `TypeId`. - // - A bound device guarantees that `driver_type` contains a valid `TypeId` value. - let type_id = unsafe { driver_type.cast::<TypeId>().read_unaligned() }; - - if type_id != TypeId::of::<T>() { - return Err(EINVAL); - } - - Ok(()) - } - - /// Access a driver's private data. - /// - /// Returns a pinned reference to the driver's private data or [`EINVAL`] if it doesn't match - /// the asserted type `T`. - pub fn drvdata<T: 'static>(&self) -> Result<Pin<&T>> { - // SAFETY: By the type invariants, `self.as_raw()` is a valid pointer to a `struct device`. - if unsafe { bindings::dev_get_drvdata(self.as_raw()) }.is_null() { - return Err(ENOENT); - } - - self.match_type_id::<T>()?; - - // SAFETY: - // - The above check of `dev_get_drvdata()` guarantees that we are called after - // `set_drvdata()`. - // - We've just checked that the type of the driver's private data is in fact `T`. - Ok(unsafe { self.drvdata_unchecked() }) - } } impl<Ctx: DeviceContext> Device<Ctx> { @@ -527,6 +467,10 @@ unsafe impl Send for Device {} // synchronization in `struct device`. unsafe impl Sync for Device {} +// SAFETY: Same as `Device<Normal>` -- the underlying `struct device` is the same; `Bound` is a +// zero-sized type-state marker that does not affect thread safety. +unsafe impl Sync for Device<Bound> {} + /// Marker trait for the context or scope of a bus specific device. /// /// [`DeviceContext`] is a marker trait for types representing the context of a bus specific @@ -567,7 +511,7 @@ pub struct Normal; /// callback it appears in. It is intended to be used for synchronization purposes. Bus device /// implementations can implement methods for [`Device<Core>`], such that they can only be called /// from bus callbacks. -pub struct Core; +pub struct Core<'a>(PhantomData<&'a ()>); /// Semantically the same as [`Core`], but reserved for internal usage of the corresponding bus /// abstraction. @@ -578,7 +522,7 @@ pub struct Core; /// /// This context mainly exists to share generic [`Device`] infrastructure that should only be called /// from bus callbacks with bus abstractions, but without making them accessible for drivers. -pub struct CoreInternal; +pub struct CoreInternal<'a>(PhantomData<&'a ()>); /// The [`Bound`] context is the [`DeviceContext`] of a bus specific device when it is guaranteed to /// be bound to a driver. @@ -602,14 +546,14 @@ mod private { pub trait Sealed {} impl Sealed for super::Bound {} - impl Sealed for super::Core {} - impl Sealed for super::CoreInternal {} + impl<'a> Sealed for super::Core<'a> {} + impl<'a> Sealed for super::CoreInternal<'a> {} impl Sealed for super::Normal {} } impl DeviceContext for Bound {} -impl DeviceContext for Core {} -impl DeviceContext for CoreInternal {} +impl<'a> DeviceContext for Core<'a> {} +impl<'a> DeviceContext for CoreInternal<'a> {} impl DeviceContext for Normal {} impl<Ctx: DeviceContext> AsRef<Device<Ctx>> for Device<Ctx> { @@ -659,6 +603,22 @@ pub unsafe trait AsBusDevice<Ctx: DeviceContext>: AsRef<Device<Ctx>> { #[doc(hidden)] #[macro_export] macro_rules! __impl_device_context_deref { + (unsafe { $device:ident, <$lt:lifetime> $src:ty => $dst:ty }) => { + impl<$lt> ::core::ops::Deref for $device<$src> { + type Target = $device<$dst>; + + fn deref(&self) -> &Self::Target { + let ptr: *const Self = self; + + // CAST: `$device<$src>` and `$device<$dst>` transparently wrap the same type by the + // safety requirement of the macro. + let ptr = ptr.cast::<Self::Target>(); + + // SAFETY: `ptr` was derived from `&self`. + unsafe { &*ptr } + } + } + }; (unsafe { $device:ident, $src:ty => $dst:ty }) => { impl ::core::ops::Deref for $device<$src> { type Target = $device<$dst>; @@ -691,14 +651,14 @@ macro_rules! impl_device_context_deref { // `__impl_device_context_deref!`. ::kernel::__impl_device_context_deref!(unsafe { $device, - $crate::device::CoreInternal => $crate::device::Core + <'a> $crate::device::CoreInternal<'a> => $crate::device::Core<'a> }); // SAFETY: This macro has the exact same safety requirement as // `__impl_device_context_deref!`. ::kernel::__impl_device_context_deref!(unsafe { $device, - $crate::device::Core => $crate::device::Bound + <'a> $crate::device::Core<'a> => $crate::device::Bound }); // SAFETY: This macro has the exact same safety requirement as @@ -713,6 +673,13 @@ macro_rules! impl_device_context_deref { #[doc(hidden)] #[macro_export] macro_rules! __impl_device_context_into_aref { + (<$lt:lifetime> $src:ty, $device:tt) => { + impl<$lt> ::core::convert::From<&$device<$src>> for $crate::sync::aref::ARef<$device> { + fn from(dev: &$device<$src>) -> Self { + (&**dev).into() + } + } + }; ($src:ty, $device:tt) => { impl ::core::convert::From<&$device<$src>> for $crate::sync::aref::ARef<$device> { fn from(dev: &$device<$src>) -> Self { @@ -727,8 +694,12 @@ macro_rules! __impl_device_context_into_aref { #[macro_export] macro_rules! impl_device_context_into_aref { ($device:tt) => { - ::kernel::__impl_device_context_into_aref!($crate::device::CoreInternal, $device); - ::kernel::__impl_device_context_into_aref!($crate::device::Core, $device); + ::kernel::__impl_device_context_into_aref!( + <'a> $crate::device::CoreInternal<'a>, $device + ); + ::kernel::__impl_device_context_into_aref!( + <'a> $crate::device::Core<'a>, $device + ); ::kernel::__impl_device_context_into_aref!($crate::device::Bound, $device); }; } diff --git a/rust/kernel/devres.rs b/rust/kernel/devres.rs index 9e5f93aed20c..11ce500e9b76 100644 --- a/rust/kernel/devres.rs +++ b/rust/kernel/devres.rs @@ -122,7 +122,7 @@ struct Inner<T> { /// # Ok(()) /// # } /// ``` -pub struct Devres<T: Send> { +pub struct Devres<T: Send + 'static> { dev: ARef<Device>, inner: Arc<Inner<T>>, } @@ -184,7 +184,7 @@ mod base { } } -impl<T: Send> Devres<T> { +impl<T: Send + 'static> Devres<T> { /// Creates a new [`Devres`] instance of the given `data`. /// /// The `data` encapsulated within the returned `Devres` instance' `data` will be @@ -304,7 +304,7 @@ impl<T: Send> Devres<T> { /// pci, // /// }; /// - /// fn from_core(dev: &pci::Device<Core>, devres: Devres<pci::Bar<0x4>>) -> Result { + /// fn from_core(dev: &pci::Device<Core<'_>>, devres: Devres<pci::Bar<'_, 0x4>>) -> Result { /// let bar = devres.access(dev.as_ref())?; /// /// let _ = bar.read32(0x0); @@ -349,7 +349,7 @@ unsafe impl<T: Send> Send for Devres<T> {} // SAFETY: `Devres` can be shared with any task, if `T: Sync`. unsafe impl<T: Send + Sync> Sync for Devres<T> {} -impl<T: Send> Drop for Devres<T> { +impl<T: Send + 'static> Drop for Devres<T> { fn drop(&mut self) { // SAFETY: When `drop` runs, it is guaranteed that nobody is accessing the revocable data // anymore, hence it is safe not to wait for the grace period to finish. diff --git a/rust/kernel/dma.rs b/rust/kernel/dma.rs index 4995ee5dc689..200def84fb69 100644 --- a/rust/kernel/dma.rs +++ b/rust/kernel/dma.rs @@ -47,7 +47,7 @@ pub type DmaAddress = bindings::dma_addr_t; /// where the underlying bus is DMA capable, such as: #[cfg_attr(CONFIG_PCI, doc = "* [`pci::Device`](kernel::pci::Device)")] /// * [`platform::Device`](::kernel::platform::Device) -pub trait Device: AsRef<device::Device<Core>> { +pub trait Device<'a>: AsRef<device::Device<Core<'a>>> { /// Set up the device's DMA streaming addressing capabilities. /// /// This method is usually called once from `probe()` as soon as the device capabilities are @@ -1152,8 +1152,8 @@ unsafe impl Sync for CoherentHandle {} /// unsafe impl kernel::transmute::AsBytes for MyStruct{}; /// /// # fn test(alloc: &kernel::dma::Coherent<[MyStruct]>) -> Result { -/// let whole = kernel::dma_read!(alloc, [2]?); -/// let field = kernel::dma_read!(alloc, [1]?.field); +/// let whole = kernel::dma_read!(alloc, [try: 2]); +/// let field = kernel::dma_read!(alloc, [panic: 1].field); /// # Ok::<(), Error>(()) } /// ``` #[macro_export] @@ -1189,8 +1189,8 @@ macro_rules! dma_read { /// unsafe impl kernel::transmute::AsBytes for MyStruct{}; /// /// # fn test(alloc: &kernel::dma::Coherent<[MyStruct]>) -> Result { -/// kernel::dma_write!(alloc, [2]?.member, 0xf); -/// kernel::dma_write!(alloc, [1]?, MyStruct { member: 0xf }); +/// kernel::dma_write!(alloc, [try: 2].member, 0xf); +/// kernel::dma_write!(alloc, [panic: 1], MyStruct { member: 0xf }); /// # Ok::<(), Error>(()) } /// ``` #[macro_export] @@ -1207,11 +1207,8 @@ macro_rules! dma_write { (@parse [$dma:expr] [$($proj:tt)*] [.$field:tt $($rest:tt)*]) => { $crate::dma_write!(@parse [$dma] [$($proj)* .$field] [$($rest)*]) }; - (@parse [$dma:expr] [$($proj:tt)*] [[$index:expr]? $($rest:tt)*]) => { - $crate::dma_write!(@parse [$dma] [$($proj)* [$index]?] [$($rest)*]) - }; - (@parse [$dma:expr] [$($proj:tt)*] [[$index:expr] $($rest:tt)*]) => { - $crate::dma_write!(@parse [$dma] [$($proj)* [$index]] [$($rest)*]) + (@parse [$dma:expr] [$($proj:tt)*] [[$flavor:ident: $index:expr] $($rest:tt)*]) => { + $crate::dma_write!(@parse [$dma] [$($proj)* [$flavor: $index]] [$($rest)*]) }; ($dma:expr, $($rest:tt)*) => { $crate::dma_write!(@parse [$dma] [] [$($rest)*]) diff --git a/rust/kernel/driver.rs b/rust/kernel/driver.rs index 36de8098754d..bf5ba0d27553 100644 --- a/rust/kernel/driver.rs +++ b/rust/kernel/driver.rs @@ -13,10 +13,13 @@ //! The main driver interface is defined by a bus specific driver trait. For instance: //! //! ```ignore -//! pub trait Driver: Send { +//! pub trait Driver { //! /// The type holding information about each device ID supported by the driver. //! type IdInfo: 'static; //! +//! /// The type of the driver's bus device private data. +//! type Data<'bound>: Send + 'bound; +//! //! /// The table of OF device ids supported by the driver. //! const OF_ID_TABLE: Option<of::IdTable<Self::IdInfo>> = None; //! @@ -24,10 +27,16 @@ //! const ACPI_ID_TABLE: Option<acpi::IdTable<Self::IdInfo>> = None; //! //! /// Driver probe. -//! fn probe(dev: &Device<device::Core>, id_info: &Self::IdInfo) -> impl PinInit<Self, Error>; +//! fn probe<'bound>( +//! dev: &'bound Device<device::Core<'_>>, +//! id_info: &'bound Self::IdInfo, +//! ) -> impl PinInit<Self::Data<'bound>, Error> + 'bound; //! //! /// Driver unbind (optional). -//! fn unbind(dev: &Device<device::Core>, this: Pin<&Self>) { +//! fn unbind<'bound>( +//! dev: &'bound Device<device::Core<'_>>, +//! this: Pin<&Self::Data<'bound>>, +//! ) { //! let _ = (dev, this); //! } //! } @@ -42,8 +51,9 @@ )] #")] //! -//! The `probe()` callback should return a `impl PinInit<Self, Error>`, i.e. the driver's private -//! data. The bus abstraction should store the pointer in the corresponding bus device. The generic +//! The `probe()` callback should return a +//! `impl PinInit<Self::Data<'bound>, Error>`, i.e. the driver's private data. The bus +//! abstraction should store the pointer in the corresponding bus device. The generic //! [`Device`] infrastructure provides common helpers for this purpose on its //! [`Device<CoreInternal>`] implementation. //! @@ -118,8 +128,8 @@ pub unsafe trait DriverLayout { /// The specific driver type embedding a `struct device_driver`. type DriverType: Default; - /// The type of the driver's device private data. - type DriverData; + /// The type of the driver's bus device private data. + type DriverData<'bound>; /// Byte offset of the embedded `struct device_driver` within `DriverType`. /// @@ -181,20 +191,20 @@ unsafe impl<T: RegistrationOps> Sync for Registration<T> {} // any thread, so `Registration` is `Send`. unsafe impl<T: RegistrationOps> Send for Registration<T> {} -impl<T: RegistrationOps + 'static> Registration<T> { +impl<T: RegistrationOps> Registration<T> { extern "C" fn post_unbind_callback(dev: *mut bindings::device) { // SAFETY: The driver core only ever calls the post unbind callback with a valid pointer to // a `struct device`. // // INVARIANT: `dev` is valid for the duration of the `post_unbind_callback()`. - let dev = unsafe { &*dev.cast::<device::Device<device::CoreInternal>>() }; + let dev = unsafe { &*dev.cast::<device::Device<device::CoreInternal<'_>>>() }; - // `remove()` and all devres callbacks have been completed at this point, hence drop the - // driver's device private data. + // `remove()` has been completed at this point; devres resources are still valid and will + // be released after the driver's bus device private data is dropped. // // SAFETY: By the safety requirements of the `Driver` trait, `T::DriverData` is the - // driver's device private data type. - drop(unsafe { dev.drvdata_obtain::<T::DriverData>() }); + // driver's bus device private data type. + drop(unsafe { dev.drvdata_obtain::<T::DriverData<'_>>() }); } /// Attach generic `struct device_driver` callbacks. @@ -215,7 +225,10 @@ impl<T: RegistrationOps + 'static> Registration<T> { } /// Creates a new instance of the registration object. - pub fn new(name: &'static CStr, module: &'static ThisModule) -> impl PinInit<Self, Error> { + pub fn new(name: &'static CStr, module: &'static ThisModule) -> impl PinInit<Self, Error> + where + T: 'static, + { try_pin_init!(Self { reg <- Opaque::try_ffi_init(|ptr: *mut T::DriverType| { // SAFETY: `try_ffi_init` guarantees that `ptr` is valid for write. @@ -278,6 +291,26 @@ macro_rules! module_driver { } } +// Calling the FFI function directly from the `Adapter` impl may result in it being called +// directly from driver modules. This happens since the Rust compiler will use monomorphisation, so +// it might happen that functions are instantiated within the calling driver module. For now, work +// around this with `#[inline(never)]` helpers. +// +// TODO: Remove once a more generic solution has been implemented. For instance, we may be able to +// leverage `bindgen` to take care of this depending on whether a symbol is (already) exported. +#[inline(never)] +#[allow(clippy::missing_safety_doc)] +#[allow(dead_code)] +#[must_use] +unsafe fn acpi_of_match_device( + adev: *const bindings::acpi_device, + of_match_table: *const bindings::of_device_id, + of_id: *mut *const bindings::of_device_id, +) -> bool { + // SAFETY: Safety requirements are the same as `bindings::acpi_of_match_device`. + unsafe { bindings::acpi_of_match_device(adev, of_match_table, of_id) } +} + /// The bus independent adapter to match a drivers and a devices. /// /// This trait should be implemented by the bus specific adapter, which represents the connection @@ -329,35 +362,63 @@ pub trait Adapter { /// /// If this returns `None`, it means there is no match with an entry in the [`of::IdTable`]. fn of_id_info(dev: &device::Device) -> Option<&'static Self::IdInfo> { - #[cfg(not(CONFIG_OF))] + let table = Self::of_id_table()?; + + #[cfg(not(any(CONFIG_OF, CONFIG_ACPI)))] { - let _ = dev; - None + let _ = (dev, table); } #[cfg(CONFIG_OF)] { - let table = Self::of_id_table()?; - // SAFETY: // - `table` has static lifetime, hence it's valid for read, // - `dev` is guaranteed to be valid while it's alive, and so is `dev.as_raw()`. let raw_id = unsafe { bindings::of_match_device(table.as_ptr(), dev.as_raw()) }; - if raw_id.is_null() { - None - } else { + if !raw_id.is_null() { // SAFETY: `DeviceId` is a `#[repr(transparent)]` wrapper of `struct of_device_id` // and does not add additional invariants, so it's safe to transmute. let id = unsafe { &*raw_id.cast::<of::DeviceId>() }; - Some( - table.info(<of::DeviceId as crate::device_id::RawDeviceIdIndex>::index( - id, - )), - ) + return Some(table.info( + <of::DeviceId as crate::device_id::RawDeviceIdIndex>::index(id), + )); + } + } + + #[cfg(CONFIG_ACPI)] + { + use core::ptr; + use device::property::FwNode; + + let mut raw_id = ptr::null(); + + let fwnode = dev.fwnode().map_or(ptr::null_mut(), FwNode::as_raw); + + // SAFETY: `fwnode` is a pointer to a valid `fwnode_handle`. A null pointer will be + // passed through the function. + let adev = unsafe { bindings::to_acpi_device_node(fwnode) }; + + // SAFETY: + // - `adev` is a valid pointer to `acpi_device` or is null. It is guaranteed to be + // valid as long as `dev` is alive. + // - `table` has static lifetime, hence it's valid for read. + if unsafe { acpi_of_match_device(adev, table.as_ptr(), &raw mut raw_id) } { + // SAFETY: + // - the function returns true, therefore `raw_id` has been set to a pointer to a + // valid `of_device_id`. + // - `DeviceId` is a `#[repr(transparent)]` wrapper of `struct of_device_id` + // and does not add additional invariants, so it's safe to transmute. + let id = unsafe { &*raw_id.cast::<of::DeviceId>() }; + + return Some(table.info( + <of::DeviceId as crate::device_id::RawDeviceIdIndex>::index(id), + )); } } + + None } /// Returns the driver's private data from the matching entry of any of the ID tables, if any. diff --git a/rust/kernel/drm/device.rs b/rust/kernel/drm/device.rs index 403fc35353c7..477cf771fb10 100644 --- a/rust/kernel/drm/device.rs +++ b/rust/kernel/drm/device.rs @@ -6,10 +6,12 @@ use crate::{ alloc::allocator::Kmalloc, - bindings, device, + bindings, + device, drm::{ self, - driver::AllocImpl, // + driver::AllocImpl, + private::Sealed, // }, error::from_err_ptr, prelude::*, @@ -17,16 +19,20 @@ use crate::{ ARef, AlwaysRefCounted, // }, - types::Opaque, + types::{ + NotThreadSafe, + Opaque, // + }, workqueue::{ HasDelayedWork, HasWork, Work, WorkItem, // - }, + }, // }; use core::{ alloc::Layout, + marker::PhantomData, mem, ops::Deref, ptr::{ @@ -66,36 +72,122 @@ macro_rules! drm_legacy_fields { } } -/// A typed DRM device with a specific `drm::Driver` implementation. +/// A trait implemented by all possible contexts a [`Device`] can be used in. +/// +/// Setting up a new [`Device`] is a multi-stage process. Each step of the process that a user +/// interacts with in Rust has a respective [`DeviceContext`] typestate. For example, +/// `Device<T, Registered>` would be a [`Device`] that reached the [`Registered`] [`DeviceContext`]. +/// +/// Each stage of this process is described below: +/// +/// ```text +/// 1 2 3 +/// +--------------+ +------------------+ +-----------------------+ +/// |Device created| → |Device initialized| → |Registered w/ userspace| +/// +--------------+ +------------------+ +-----------------------+ +/// (Uninit) (Registered) +/// ``` +/// +/// 1. The [`Device`] is in the [`Uninit`] context and is not guaranteed to be initialized or +/// registered with userspace. Only a limited subset of DRM core functionality is available. +/// 2. The [`Device`] is guaranteed to be fully initialized, but is not guaranteed to be registered +/// with userspace. All DRM core functionality which doesn't interact with userspace is +/// available. We currently don't have a context for representing this. +/// 3. The [`Device`] is guaranteed to be fully initialized, and is guaranteed to have been +/// registered with userspace at some point - thus putting it in the [`Registered`] context. +/// +/// An important caveat of [`DeviceContext`] which must be kept in mind: when used as a typestate +/// for a reference type, it can only guarantee that a [`Device`] reached a particular stage in the +/// initialization process _at the time the reference was taken_. No guarantee is made in regards to +/// what stage of the process the [`Device`] is currently in. This means for instance that a +/// `&Device<T, Uninit>` may actually be registered with userspace, it just wasn't known to be +/// registered at the time the reference was taken. +pub trait DeviceContext: Sealed + Send + Sync {} + +/// The [`DeviceContext`] of a [`Device`] that was registered with userspace at some point. /// -/// The device is always reference-counted. +/// This represents a [`Device`] which is guaranteed to have been registered with userspace at +/// some point in time. Such a DRM device is guaranteed to have been fully-initialized. +/// +/// Note: A device in this context is not guaranteed to remain registered with userspace for its +/// entire lifetime, as this is impossible to guarantee at compile-time. /// /// # Invariants /// -/// `self.dev` is a valid instance of a `struct device`. -#[repr(C)] -pub struct Device<T: drm::Driver> { - dev: Opaque<bindings::drm_device>, - data: T::Data, +/// A [`Device`] in this [`DeviceContext`] is guaranteed to have been registered with userspace +/// at some point in time. +pub struct Registered; + +impl Sealed for Registered {} +impl DeviceContext for Registered {} + +/// The [`DeviceContext`] of a [`Device`] that may be unregistered and partly uninitialized. +/// +/// A [`Device`] in this context is only guaranteed to be partly initialized, and may or may not +/// be registered with userspace. Thus operations which depend on the [`Device`] being fully +/// initialized, or which depend on the [`Device`] being registered with userspace are not +/// available through this [`DeviceContext`]. +/// +/// A [`Device`] in this context can be used to create a +/// [`Registration`](drm::driver::Registration). +pub struct Uninit; + +impl Sealed for Uninit {} +impl DeviceContext for Uninit {} + +/// A [`Device`] which is known at compile-time to be unregistered with userspace. +/// +/// This type allows performing operations which are only safe to do before userspace registration, +/// and can be used to create a [`Registration`](drm::driver::Registration) once the driver is ready +/// to register the device with userspace. +/// +/// Since DRM device initialization must be single-threaded, this object is not thread-safe. +/// +/// # Invariants +/// +/// The device in `self.0` is guaranteed to be a newly created [`Device`] that has not yet been +/// registered with userspace until this type is dropped. +pub struct UnregisteredDevice<T: drm::Driver>(ARef<Device<T, Uninit>>, NotThreadSafe); + +impl<T: drm::Driver> Deref for UnregisteredDevice<T> { + type Target = Device<T, Uninit>; + + fn deref(&self) -> &Self::Target { + &self.0 + } } -impl<T: drm::Driver> Device<T> { +impl<T: drm::Driver> UnregisteredDevice<T> { + const fn compute_features() -> u32 { + let mut features = drm::driver::FEAT_GEM; + + if T::FEAT_RENDER { + features |= drm::driver::FEAT_RENDER; + } + + features + } + const VTABLE: bindings::drm_driver = drm_legacy_fields! { load: None, open: Some(drm::File::<T::File>::open_callback), postclose: Some(drm::File::<T::File>::postclose_callback), unload: None, - release: Some(Self::release), + release: Some(Device::<T>::release), master_set: None, master_drop: None, debugfs_init: None, - gem_create_object: T::Object::ALLOC_OPS.gem_create_object, - prime_handle_to_fd: T::Object::ALLOC_OPS.prime_handle_to_fd, - prime_fd_to_handle: T::Object::ALLOC_OPS.prime_fd_to_handle, - gem_prime_import: T::Object::ALLOC_OPS.gem_prime_import, - gem_prime_import_sg_table: T::Object::ALLOC_OPS.gem_prime_import_sg_table, - dumb_create: T::Object::ALLOC_OPS.dumb_create, - dumb_map_offset: T::Object::ALLOC_OPS.dumb_map_offset, + + // Ignore the Uninit DeviceContext below. It is only provided because it is required by the + // compiler, and it is not actually used by these functions. + gem_create_object: T::Object::<Uninit>::ALLOC_OPS.gem_create_object, + prime_handle_to_fd: T::Object::<Uninit>::ALLOC_OPS.prime_handle_to_fd, + prime_fd_to_handle: T::Object::<Uninit>::ALLOC_OPS.prime_fd_to_handle, + gem_prime_import: T::Object::<Uninit>::ALLOC_OPS.gem_prime_import, + gem_prime_import_sg_table: T::Object::<Uninit>::ALLOC_OPS.gem_prime_import_sg_table, + dumb_create: T::Object::<Uninit>::ALLOC_OPS.dumb_create, + dumb_map_offset: T::Object::<Uninit>::ALLOC_OPS.dumb_map_offset, + show_fdinfo: None, fbdev_probe: None, @@ -105,7 +197,7 @@ impl<T: drm::Driver> Device<T> { name: crate::str::as_char_ptr_in_const_context(T::INFO.name).cast_mut(), desc: crate::str::as_char_ptr_in_const_context(T::INFO.desc).cast_mut(), - driver_features: drm::driver::FEAT_GEM, + driver_features: Self::compute_features(), ioctls: T::IOCTLS.as_ptr(), num_ioctls: T::IOCTLS.len() as i32, fops: &Self::GEM_FOPS, @@ -113,11 +205,13 @@ impl<T: drm::Driver> Device<T> { const GEM_FOPS: bindings::file_operations = drm::gem::create_fops(); - /// Create a new `drm::Device` for a `drm::Driver`. - pub fn new(dev: &device::Device, data: impl PinInit<T::Data, Error>) -> Result<ARef<Self>> { + /// Create a new `UnregisteredDevice` for a `drm::Driver`. + /// + /// This can be used to create a [`Registration`](kernel::drm::Registration). + pub fn new(dev: &device::Device, data: impl PinInit<T::Data, Error>) -> Result<Self> { // `__drm_dev_alloc` uses `kmalloc()` to allocate memory, hence ensure a `kmalloc()` // compatible `Layout`. - let layout = Kmalloc::aligned_layout(Layout::new::<Self>()); + let layout = Kmalloc::aligned_layout(Layout::new::<Device<T, Uninit>>()); // Use a temporary vtable without a `release` callback until `data` is initialized, so // init failure can release the DRM device without dropping uninitialized fields. @@ -129,12 +223,12 @@ impl<T: drm::Driver> Device<T> { // SAFETY: // - `alloc_vtable` reference remains valid until no longer used, // - `dev` is valid by its type invarants, - let raw_drm: *mut Self = unsafe { + let raw_drm: *mut Device<T, Uninit> = unsafe { bindings::__drm_dev_alloc( dev.as_raw(), &alloc_vtable, layout.size(), - mem::offset_of!(Self, dev), + mem::offset_of!(Device<T, Uninit>, dev), ) } .cast(); @@ -142,7 +236,7 @@ impl<T: drm::Driver> Device<T> { // SAFETY: `raw_drm` is a valid pointer to `Self`, given that `__drm_dev_alloc` was // successful. - let drm_dev = unsafe { Self::into_drm_device(raw_drm) }; + let drm_dev = unsafe { Device::into_drm_device(raw_drm) }; // SAFETY: `raw_drm` is a valid pointer to `Self`. let raw_data = unsafe { ptr::addr_of_mut!((*raw_drm.as_ptr()).data) }; @@ -161,9 +255,39 @@ impl<T: drm::Driver> Device<T> { // SAFETY: The reference count is one, and now we take ownership of that reference as a // `drm::Device`. - Ok(unsafe { ARef::from_raw(raw_drm) }) + // INVARIANT: We just created the device above, but have yet to call `drm_dev_register`. + // `Self` cannot be copied or sent to another thread - ensuring that `drm_dev_register` + // won't be called during its lifetime and that the device is unregistered. + Ok(Self(unsafe { ARef::from_raw(raw_drm) }, NotThreadSafe)) } +} +/// A typed DRM device with a specific [`drm::Driver`] implementation and [`DeviceContext`]. +/// +/// Since DRM devices can be used before being fully initialized and registered with userspace, `C` +/// represents the furthest [`DeviceContext`] we can guarantee that this [`Device`] has reached. +/// +/// Keep in mind: this means that an unregistered device can still have the registration state +/// [`Registered`] as long as it was registered with userspace once in the past, and that the +/// behavior of such a device is still well-defined. Additionally, a device with the registration +/// state [`Uninit`] simply does not have a guaranteed registration state at compile time, and could +/// be either registered or unregistered. Since there is no way to guarantee a long-lived reference +/// to an unregistered device would remain unregistered, we do not provide a [`DeviceContext`] for +/// this. +/// +/// # Invariants +/// +/// * `self.dev` is a valid instance of a `struct device`. +/// * The data layout of `Self` remains the same across all implementations of `C`. +/// * Any invariants for `C` also apply. +#[repr(C)] +pub struct Device<T: drm::Driver, C: DeviceContext = Registered> { + dev: Opaque<bindings::drm_device>, + data: T::Data, + _ctx: PhantomData<C>, +} + +impl<T: drm::Driver, C: DeviceContext> Device<T, C> { pub(crate) fn as_raw(&self) -> *mut bindings::drm_device { self.dev.get() } @@ -189,13 +313,13 @@ impl<T: drm::Driver> Device<T> { /// /// # Safety /// - /// Callers must ensure that `ptr` is valid, non-null, and has a non-zero reference count, - /// i.e. it must be ensured that the reference count of the C `struct drm_device` `ptr` points - /// to can't drop to zero, for the duration of this function call and the entire duration when - /// the returned reference exists. - /// - /// Additionally, callers must ensure that the `struct device`, `ptr` is pointing to, is - /// embedded in `Self`. + /// * Callers must ensure that `ptr` is valid, non-null, and has a non-zero reference count, + /// i.e. it must be ensured that the reference count of the C `struct drm_device` `ptr` points + /// to can't drop to zero, for the duration of this function call and the entire duration when + /// the returned reference exists. + /// * Additionally, callers must ensure that the `struct device`, `ptr` is pointing to, is + /// embedded in `Self`. + /// * Callers promise that any type invariants of `C` will be upheld. #[doc(hidden)] pub unsafe fn from_raw<'a>(ptr: *const bindings::drm_device) -> &'a Self { // SAFETY: By the safety requirements of this function `ptr` is a valid pointer to a @@ -215,9 +339,20 @@ impl<T: drm::Driver> Device<T> { // - `this` is valid for dropping. unsafe { core::ptr::drop_in_place(this) }; } + + /// Change the [`DeviceContext`] for a [`Device`]. + /// + /// # Safety + /// + /// The caller promises that `self` fulfills all of the guarantees provided by the given + /// [`DeviceContext`]. + pub(crate) unsafe fn assume_ctx<NewCtx: DeviceContext>(&self) -> &Device<T, NewCtx> { + // SAFETY: The data layout is identical via our type invariants. + unsafe { mem::transmute(self) } + } } -impl<T: drm::Driver> Deref for Device<T> { +impl<T: drm::Driver, C: DeviceContext> Deref for Device<T, C> { type Target = T::Data; fn deref(&self) -> &Self::Target { @@ -227,7 +362,7 @@ impl<T: drm::Driver> Deref for Device<T> { // SAFETY: DRM device objects are always reference counted and the get/put functions // satisfy the requirements. -unsafe impl<T: drm::Driver> AlwaysRefCounted for Device<T> { +unsafe impl<T: drm::Driver, C: DeviceContext> AlwaysRefCounted for Device<T, C> { fn inc_ref(&self) { // SAFETY: The existence of a shared reference guarantees that the refcount is non-zero. unsafe { bindings::drm_dev_get(self.as_raw()) }; @@ -242,7 +377,7 @@ unsafe impl<T: drm::Driver> AlwaysRefCounted for Device<T> { } } -impl<T: drm::Driver> AsRef<device::Device> for Device<T> { +impl<T: drm::Driver, C: DeviceContext> AsRef<device::Device> for Device<T, C> { fn as_ref(&self) -> &device::Device { // SAFETY: `bindings::drm_device::dev` is valid as long as the DRM device itself is valid, // which is guaranteed by the type invariant. @@ -251,21 +386,22 @@ impl<T: drm::Driver> AsRef<device::Device> for Device<T> { } // SAFETY: A `drm::Device` can be released from any thread. -unsafe impl<T: drm::Driver> Send for Device<T> {} +unsafe impl<T: drm::Driver, C: DeviceContext> Send for Device<T, C> {} // SAFETY: A `drm::Device` can be shared among threads because all immutable methods are protected // by the synchronization in `struct drm_device`. -unsafe impl<T: drm::Driver> Sync for Device<T> {} +unsafe impl<T: drm::Driver, C: DeviceContext> Sync for Device<T, C> {} -impl<T, const ID: u64> WorkItem<ID> for Device<T> +impl<T, C, const ID: u64> WorkItem<ID> for Device<T, C> where T: drm::Driver, - T::Data: WorkItem<ID, Pointer = ARef<Device<T>>>, - T::Data: HasWork<Device<T>, ID>, + T::Data: WorkItem<ID, Pointer = ARef<Self>>, + T::Data: HasWork<Self, ID>, + C: DeviceContext, { - type Pointer = ARef<Device<T>>; + type Pointer = ARef<Self>; - fn run(ptr: ARef<Device<T>>) { + fn run(ptr: ARef<Self>) { T::Data::run(ptr); } } @@ -277,40 +413,42 @@ where // stored inline in `drm::Device`, so the `container_of` call is valid. // // - The two methods are true inverses of each other: given `ptr: *mut -// Device<T>`, `raw_get_work` will return a `*mut Work<Device<T>, ID>` through -// `T::Data::raw_get_work` and given a `ptr: *mut Work<Device<T>, ID>`, -// `work_container_of` will return a `*mut Device<T>` through `container_of`. -unsafe impl<T, const ID: u64> HasWork<Device<T>, ID> for Device<T> +// Device<T, C>`, `raw_get_work` will return a `*mut Work<Device<T, C>, ID>` through +// `T::Data::raw_get_work` and given a `ptr: *mut Work<Device<T, C>, ID>`, +// `work_container_of` will return a `*mut Device<T, C>` through `container_of`. +unsafe impl<T, C, const ID: u64> HasWork<Self, ID> for Device<T, C> where T: drm::Driver, - T::Data: HasWork<Device<T>, ID>, + T::Data: HasWork<Self, ID>, + C: DeviceContext, { - unsafe fn raw_get_work(ptr: *mut Self) -> *mut Work<Device<T>, ID> { - // SAFETY: The caller promises that `ptr` points to a valid `Device<T>`. + unsafe fn raw_get_work(ptr: *mut Self) -> *mut Work<Self, ID> { + // SAFETY: The caller promises that `ptr` points to a valid `Device<T, C>`. let data_ptr = unsafe { &raw mut (*ptr).data }; // SAFETY: `data_ptr` is a valid pointer to `T::Data`. unsafe { T::Data::raw_get_work(data_ptr) } } - unsafe fn work_container_of(ptr: *mut Work<Device<T>, ID>) -> *mut Self { + unsafe fn work_container_of(ptr: *mut Work<Self, ID>) -> *mut Self { // SAFETY: The caller promises that `ptr` points at a `Work` field in // `T::Data`. let data_ptr = unsafe { T::Data::work_container_of(ptr) }; - // SAFETY: `T::Data` is stored as the `data` field in `Device<T>`. + // SAFETY: `T::Data` is stored as the `data` field in `Device<T, C>`. unsafe { crate::container_of!(data_ptr, Self, data) } } } // SAFETY: Our `HasWork<T, ID>` implementation returns a `work_struct` that is // stored in the `work` field of a `delayed_work` with the same access rules as -// the `work_struct` owing to the bound on `T::Data: HasDelayedWork<Device<T>, +// the `work_struct` owing to the bound on `T::Data: HasDelayedWork<Device<T, C>, // ID>`, which requires that `T::Data::raw_get_work` return a `work_struct` that // is inside a `delayed_work`. -unsafe impl<T, const ID: u64> HasDelayedWork<Device<T>, ID> for Device<T> +unsafe impl<T, C, const ID: u64> HasDelayedWork<Self, ID> for Device<T, C> where T: drm::Driver, - T::Data: HasDelayedWork<Device<T>, ID>, + T::Data: HasDelayedWork<Self, ID>, + C: DeviceContext, { } diff --git a/rust/kernel/drm/driver.rs b/rust/kernel/drm/driver.rs index 5233bdebc9fc..25f7e233884d 100644 --- a/rust/kernel/drm/driver.rs +++ b/rust/kernel/drm/driver.rs @@ -13,9 +13,15 @@ use crate::{ prelude::*, sync::aref::ARef, // }; +use core::{ + mem, + ptr::NonNull, // +}; /// Driver use the GEM memory manager. This should be set for all modern drivers. pub(crate) const FEAT_GEM: u32 = bindings::drm_driver_feature_DRIVER_GEM; +/// Driver supports render nodes, i.e.: /dev/dri/renderDXX devices. +pub(crate) const FEAT_RENDER: u32 = bindings::drm_driver_feature_DRIVER_RENDER; /// Information data for a DRM Driver. pub struct DriverInfo { @@ -105,7 +111,7 @@ pub trait Driver { type Data: Sync + Send; /// The type used to manage memory for this driver. - type Object: AllocImpl; + type Object<Ctx: drm::DeviceContext>: AllocImpl; /// The type used to represent a DRM File (client) type File: drm::file::DriverFile; @@ -115,6 +121,16 @@ pub trait Driver { /// IOCTL list. See `kernel::drm::ioctl::declare_drm_ioctls!{}`. const IOCTLS: &'static [drm::ioctl::DrmIoctlDescriptor]; + + /// Sets the `DRIVER_RENDER` feature for this driver. + /// + /// When enabled, the driver exposes `/dev/dri/renderDXX` render nodes to + /// userspace. The render node is an alternate low-priviledge way to access + /// the driver, which is enforced on a per-ioctl level. Userspace processes + /// that open the render node can only invoke ioctls explicitly listed as + /// usable from the render node (i.e. marked DRM_RENDER_ALLOW), whereas + /// userspace processes using the master node can invoke any ioctl. + const FEAT_RENDER: bool = false; } /// The registration type of a `drm::Device`. @@ -123,21 +139,31 @@ pub trait Driver { pub struct Registration<T: Driver>(ARef<drm::Device<T>>); impl<T: Driver> Registration<T> { - fn new(drm: &drm::Device<T>, flags: usize) -> Result<Self> { + fn new(drm: drm::UnregisteredDevice<T>, flags: usize) -> Result<Self> { // SAFETY: `drm.as_raw()` is valid by the invariants of `drm::Device`. to_result(unsafe { bindings::drm_dev_register(drm.as_raw(), flags) })?; - Ok(Self(drm.into())) + // SAFETY: We just called `drm_dev_register` above + let new = NonNull::from(unsafe { drm.assume_ctx() }); + + // Leak the ARef from UnregisteredDevice in preparation for transferring its ownership. + mem::forget(drm); + + // SAFETY: `drm`'s `Drop` constructor was never called, ensuring that there remains at least + // one reference to the device - which we take ownership over here. + let new = unsafe { ARef::from_raw(new) }; + + Ok(Self(new)) } - /// Registers a new [`Device`](drm::Device) with userspace. + /// Registers a new [`UnregisteredDevice`](drm::UnregisteredDevice) with userspace. /// /// Ownership of the [`Registration`] object is passed to [`devres::register`]. - pub fn new_foreign_owned( - drm: &drm::Device<T>, - dev: &device::Device<device::Bound>, + pub fn new_foreign_owned<'a>( + drm: drm::UnregisteredDevice<T>, + dev: &'a device::Device<device::Bound>, flags: usize, - ) -> Result + ) -> Result<&'a drm::Device<T>> where T: 'static, { @@ -146,8 +172,13 @@ impl<T: Driver> Registration<T> { } let reg = Registration::<T>::new(drm, flags)?; + let drm = NonNull::from(reg.device()); + + devres::register(dev, reg, GFP_KERNEL)?; - devres::register(dev, reg, GFP_KERNEL) + // SAFETY: Since `reg` was passed to devres::register(), the device now owns the lifetime + // of the DRM registration - ensuring that this references lives for at least as long as 'a. + Ok(unsafe { drm.as_ref() }) } /// Returns a reference to the `Device` instance for this registration. diff --git a/rust/kernel/drm/gem/mod.rs b/rust/kernel/drm/gem/mod.rs index 01b5bd47a333..c8b66d816871 100644 --- a/rust/kernel/drm/gem/mod.rs +++ b/rust/kernel/drm/gem/mod.rs @@ -8,6 +8,10 @@ use crate::{ bindings, drm::{ self, + device::{ + DeviceContext, + Registered, // + }, driver::{ AllocImpl, AllocOps, // @@ -22,6 +26,7 @@ use crate::{ types::Opaque, }; use core::{ + marker::PhantomData, ops::Deref, ptr::NonNull, // }; @@ -73,6 +78,12 @@ pub(crate) use impl_aref_for_gem_obj; /// [`DriverFile`]: drm::file::DriverFile pub type DriverFile<T> = drm::File<<<T as DriverObject>::Driver as drm::Driver>::File>; +/// A type alias for retrieving the current [`AllocImpl`] for a given [`DriverObject`]. +/// +/// [`Driver`]: drm::Driver +pub type DriverAllocImpl<T, Ctx = Registered> = + <<T as DriverObject>::Driver as drm::Driver>::Object<Ctx>; + /// GEM object functions, which must be implemented by drivers. pub trait DriverObject: Sync + Send + Sized { /// Parent `Driver` for this object. @@ -82,19 +93,19 @@ pub trait DriverObject: Sync + Send + Sized { type Args; /// Create a new driver data object for a GEM object of a given size. - fn new( - dev: &drm::Device<Self::Driver>, + fn new<Ctx: DeviceContext>( + dev: &drm::Device<Self::Driver, Ctx>, size: usize, args: Self::Args, ) -> impl PinInit<Self, Error>; /// Open a new handle to an existing object, associated with a File. - fn open(_obj: &<Self::Driver as drm::Driver>::Object, _file: &DriverFile<Self>) -> Result { + fn open(_obj: &DriverAllocImpl<Self>, _file: &DriverFile<Self>) -> Result { Ok(()) } /// Close a handle to an existing object, associated with a File. - fn close(_obj: &<Self::Driver as drm::Driver>::Object, _file: &DriverFile<Self>) {} + fn close(_obj: &DriverAllocImpl<Self>, _file: &DriverFile<Self>) {} } /// Trait that represents a GEM object subtype @@ -120,9 +131,12 @@ extern "C" fn open_callback<T: DriverObject>( // SAFETY: `open_callback` is only ever called with a valid pointer to a `struct drm_file`. let file = unsafe { DriverFile::<T>::from_raw(raw_file) }; - // SAFETY: `open_callback` is specified in the AllocOps structure for `DriverObject<T>`, - // ensuring that `raw_obj` is contained within a `DriverObject<T>` - let obj = unsafe { <<T::Driver as drm::Driver>::Object as IntoGEMObject>::from_raw(raw_obj) }; + // SAFETY: + // * `open_callback` is specified in the AllocOps structure for `DriverObject`, ensuring that + // `raw_obj` is contained within a `DriverAllocImpl<T>` + // * It is only possible for `open_callback` to be called after device registration, ensuring + // that the object's device is in the `Registered` state. + let obj: &DriverAllocImpl<T> = unsafe { IntoGEMObject::from_raw(raw_obj) }; match T::open(obj, file) { Err(e) => e.to_errno(), @@ -139,12 +153,12 @@ extern "C" fn close_callback<T: DriverObject>( // SAFETY: `close_callback` is specified in the AllocOps structure for `Object<T>`, ensuring // that `raw_obj` is indeed contained within a `Object<T>`. - let obj = unsafe { <<T::Driver as drm::Driver>::Object as IntoGEMObject>::from_raw(raw_obj) }; + let obj: &DriverAllocImpl<T> = unsafe { IntoGEMObject::from_raw(raw_obj) }; T::close(obj, file); } -impl<T: DriverObject> IntoGEMObject for Object<T> { +impl<T: DriverObject, Ctx: DeviceContext> IntoGEMObject for Object<T, Ctx> { fn as_raw(&self) -> *mut bindings::drm_gem_object { self.obj.get() } @@ -152,7 +166,7 @@ impl<T: DriverObject> IntoGEMObject for Object<T> { unsafe fn from_raw<'a>(self_ptr: *mut bindings::drm_gem_object) -> &'a Self { // SAFETY: `obj` is guaranteed to be in an `Object<T>` via the safety contract of this // function - unsafe { &*crate::container_of!(Opaque::cast_from(self_ptr), Object<T>, obj) } + unsafe { &*crate::container_of!(Opaque::cast_from(self_ptr), Object<T, Ctx>, obj) } } } @@ -169,7 +183,7 @@ pub trait BaseObject: IntoGEMObject { fn create_handle<D, F>(&self, file: &drm::File<F>) -> Result<u32> where Self: AllocImpl<Driver = D>, - D: drm::Driver<Object = Self, File = F>, + D: drm::Driver<Object<Registered> = Self, File = F>, F: drm::file::DriverFile<Driver = D>, { let mut handle: u32 = 0; @@ -184,7 +198,7 @@ pub trait BaseObject: IntoGEMObject { fn lookup_handle<D, F>(file: &drm::File<F>, handle: u32) -> Result<ARef<Self>> where Self: AllocImpl<Driver = D>, - D: drm::Driver<Object = Self, File = F>, + D: drm::Driver<Object<Registered> = Self, File = F>, F: drm::file::DriverFile<Driver = D>, { // SAFETY: The arguments are all valid per the type invariants. @@ -236,16 +250,18 @@ impl<T: IntoGEMObject> BaseObjectPrivate for T {} /// /// # Invariants /// -/// - `self.obj` is a valid instance of a `struct drm_gem_object`. +/// * `self.obj` is a valid instance of a `struct drm_gem_object`. +/// * Any type invariants of `Ctx` apply to the parent DRM device for this GEM object. #[repr(C)] #[pin_data] -pub struct Object<T: DriverObject + Send + Sync> { +pub struct Object<T: DriverObject + Send + Sync, Ctx: DeviceContext = Registered> { obj: Opaque<bindings::drm_gem_object>, #[pin] data: T, + _ctx: PhantomData<Ctx>, } -impl<T: DriverObject> Object<T> { +impl<T: DriverObject, Ctx: DeviceContext> Object<T, Ctx> { const OBJECT_FUNCS: bindings::drm_gem_object_funcs = bindings::drm_gem_object_funcs { free: Some(Self::free_callback), open: Some(open_callback::<T>), @@ -265,11 +281,16 @@ impl<T: DriverObject> Object<T> { }; /// Create a new GEM object. - pub fn new(dev: &drm::Device<T::Driver>, size: usize, args: T::Args) -> Result<ARef<Self>> { + pub fn new( + dev: &drm::Device<T::Driver, Ctx>, + size: usize, + args: T::Args, + ) -> Result<ARef<Self>> { let obj: Pin<KBox<Self>> = KBox::pin_init( try_pin_init!(Self { obj: Opaque::new(bindings::drm_gem_object::default()), data <- T::new(dev, size, args), + _ctx: PhantomData, }), GFP_KERNEL, )?; @@ -277,6 +298,8 @@ impl<T: DriverObject> Object<T> { // SAFETY: `obj.as_raw()` is guaranteed to be valid by the initialization above. unsafe { (*obj.as_raw()).funcs = &Self::OBJECT_FUNCS }; + // INVARIANT: `dev` and the GEM object are in the same state at the moment, and upgrading + // the typestate in `dev` will not carry over to the GEM object. if let Err(err) = // SAFETY: The arguments are all valid per the type invariants. to_result(unsafe { @@ -300,13 +323,15 @@ impl<T: DriverObject> Object<T> { } /// Returns the `Device` that owns this GEM object. - pub fn dev(&self) -> &drm::Device<T::Driver> { + pub fn dev(&self) -> &drm::Device<T::Driver, Ctx> { // SAFETY: // - `struct drm_gem_object.dev` is initialized and valid for as long as the GEM // object lives. // - The device we used for creating the gem object is passed as &drm::Device<T::Driver> to // Object::<T>::new(), so we know that `T::Driver` is the right generic parameter to use // here. + // - Any type invariants of `Ctx` are upheld by using the same `Ctx` for the `Device` we + // return. unsafe { drm::Device::from_raw((*self.as_raw()).dev) } } @@ -331,11 +356,16 @@ impl<T: DriverObject> Object<T> { } } -impl_aref_for_gem_obj!(impl<T> for Object<T> where T: DriverObject); +impl_aref_for_gem_obj! { + impl<T, C> for Object<T, C> + where + T: DriverObject, + C: DeviceContext +} -impl<T: DriverObject> super::private::Sealed for Object<T> {} +impl<T: DriverObject, Ctx: DeviceContext> super::private::Sealed for Object<T, Ctx> {} -impl<T: DriverObject> Deref for Object<T> { +impl<T: DriverObject, Ctx: DeviceContext> Deref for Object<T, Ctx> { type Target = T; fn deref(&self) -> &Self::Target { @@ -343,7 +373,7 @@ impl<T: DriverObject> Deref for Object<T> { } } -impl<T: DriverObject> AllocImpl for Object<T> { +impl<T: DriverObject, Ctx: DeviceContext> AllocImpl for Object<T, Ctx> { type Driver = T::Driver; const ALLOC_OPS: AllocOps = AllocOps { diff --git a/rust/kernel/drm/gem/shmem.rs b/rust/kernel/drm/gem/shmem.rs index e1b648920d2f..34af402899a0 100644 --- a/rust/kernel/drm/gem/shmem.rs +++ b/rust/kernel/drm/gem/shmem.rs @@ -12,10 +12,12 @@ use crate::{ container_of, drm::{ - device, driver, gem, - private::Sealed, // + private::Sealed, + Device, + DeviceContext, + Registered, // }, error::to_result, prelude::*, @@ -23,11 +25,12 @@ use crate::{ types::Opaque, // }; use core::{ + marker::PhantomData, ops::{ Deref, DerefMut, // }, - ptr::NonNull, + ptr::NonNull, // }; use gem::{ BaseObjectPrivate, @@ -40,42 +43,49 @@ use gem::{ /// This is used with [`Object::new()`] to control various properties that can only be set when /// initially creating a shmem-backed GEM object. #[derive(Default)] -pub struct ObjectConfig<'a, T: DriverObject> { +pub struct ObjectConfig<'a, T: DriverObject, C: DeviceContext = Registered> { /// Whether to set the write-combine map flag. pub map_wc: bool, /// Reuse the DMA reservation from another GEM object. /// /// The newly created [`Object`] will hold an owned refcount to `parent_resv_obj` if specified. - pub parent_resv_obj: Option<&'a Object<T>>, + pub parent_resv_obj: Option<&'a Object<T, C>>, } /// A shmem-backed GEM object. /// /// # Invariants /// -/// `obj` contains a valid initialized `struct drm_gem_shmem_object` for the lifetime of this -/// object. +/// - `obj` contains a valid initialized `struct drm_gem_shmem_object` for the lifetime of this +/// object. +/// - Any type invariants of `C` apply to the parent DRM device for this GEM object. #[repr(C)] #[pin_data] -pub struct Object<T: DriverObject> { +pub struct Object<T: DriverObject, C: DeviceContext = Registered> { #[pin] obj: Opaque<bindings::drm_gem_shmem_object>, /// Parent object that owns this object's DMA reservation object. - parent_resv_obj: Option<ARef<Object<T>>>, + parent_resv_obj: Option<ARef<Object<T, C>>>, #[pin] inner: T, + _ctx: PhantomData<C>, } -super::impl_aref_for_gem_obj!(impl<T> for Object<T> where T: DriverObject); +super::impl_aref_for_gem_obj! { + impl<T, C> for Object<T, C> + where + T: DriverObject, + C: DeviceContext +} // SAFETY: All GEM objects are thread-safe. -unsafe impl<T: DriverObject> Send for Object<T> {} +unsafe impl<T: DriverObject, C: DeviceContext> Send for Object<T, C> {} // SAFETY: All GEM objects are thread-safe. -unsafe impl<T: DriverObject> Sync for Object<T> {} +unsafe impl<T: DriverObject, C: DeviceContext> Sync for Object<T, C> {} -impl<T: DriverObject> Object<T> { +impl<T: DriverObject, C: DeviceContext> Object<T, C> { /// `drm_gem_object_funcs` vtable suitable for GEM shmem objects. const VTABLE: bindings::drm_gem_object_funcs = bindings::drm_gem_object_funcs { free: Some(Self::free_callback), @@ -106,9 +116,9 @@ impl<T: DriverObject> Object<T> { /// /// Additional config options can be specified using `config`. pub fn new( - dev: &device::Device<T::Driver>, + dev: &Device<T::Driver, C>, size: usize, - config: ObjectConfig<'_, T>, + config: ObjectConfig<'_, T, C>, args: T::Args, ) -> Result<ARef<Self>> { let new: Pin<KBox<Self>> = KBox::try_pin_init( @@ -116,6 +126,7 @@ impl<T: DriverObject> Object<T> { obj <- Opaque::init_zeroed(), parent_resv_obj: config.parent_resv_obj.map(|p| p.into()), inner <- T::new(dev, size, args), + _ctx: PhantomData::<C>, }), GFP_KERNEL, )?; @@ -148,9 +159,9 @@ impl<T: DriverObject> Object<T> { } /// Returns the `Device` that owns this GEM object. - pub fn dev(&self) -> &device::Device<T::Driver> { + pub fn dev(&self) -> &Device<T::Driver, C> { // SAFETY: `dev` will have been initialized in `Self::new()` by `drm_gem_shmem_init()`. - unsafe { device::Device::from_raw((*self.as_raw()).dev) } + unsafe { Device::from_raw((*self.as_raw()).dev) } } extern "C" fn free_callback(obj: *mut bindings::drm_gem_object) { @@ -168,7 +179,7 @@ impl<T: DriverObject> Object<T> { // SAFETY: // - We verified above that `obj` is valid, which makes `this` valid // - This function is set in AllocOps, so we know that `this` is contained within a - // `Object<T>` + // `Object<T, C>` let this = unsafe { container_of!(Opaque::cast_from(this), Self, obj) }.cast_mut(); // SAFETY: We're recovering the Kbox<> we created in gem_create_object() @@ -176,7 +187,7 @@ impl<T: DriverObject> Object<T> { } } -impl<T: DriverObject> Deref for Object<T> { +impl<T: DriverObject, C: DeviceContext> Deref for Object<T, C> { type Target = T; fn deref(&self) -> &Self::Target { @@ -184,15 +195,15 @@ impl<T: DriverObject> Deref for Object<T> { } } -impl<T: DriverObject> DerefMut for Object<T> { +impl<T: DriverObject, C: DeviceContext> DerefMut for Object<T, C> { fn deref_mut(&mut self) -> &mut Self::Target { &mut self.inner } } -impl<T: DriverObject> Sealed for Object<T> {} +impl<T: DriverObject, C: DeviceContext> Sealed for Object<T, C> {} -impl<T: DriverObject> gem::IntoGEMObject for Object<T> { +impl<T: DriverObject, C: DeviceContext> gem::IntoGEMObject for Object<T, C> { fn as_raw(&self) -> *mut bindings::drm_gem_object { // SAFETY: // - Our immutable reference is proof that this is safe to dereference. @@ -200,18 +211,18 @@ impl<T: DriverObject> gem::IntoGEMObject for Object<T> { unsafe { &raw mut (*self.obj.get()).base } } - unsafe fn from_raw<'a>(obj: *mut bindings::drm_gem_object) -> &'a Object<T> { + unsafe fn from_raw<'a>(obj: *mut bindings::drm_gem_object) -> &'a Self { // SAFETY: The safety contract of from_gem_obj() guarantees that `obj` is contained within // `Self` unsafe { let obj = Opaque::cast_from(container_of!(obj, bindings::drm_gem_shmem_object, base)); - &*container_of!(obj, Object<T>, obj) + &*container_of!(obj, Self, obj) } } } -impl<T: DriverObject> driver::AllocImpl for Object<T> { +impl<T: DriverObject, C: DeviceContext> driver::AllocImpl for Object<T, C> { type Driver = T::Driver; const ALLOC_OPS: driver::AllocOps = driver::AllocOps { diff --git a/rust/kernel/drm/gpuvm/mod.rs b/rust/kernel/drm/gpuvm/mod.rs new file mode 100644 index 000000000000..ae58f6f667c1 --- /dev/null +++ b/rust/kernel/drm/gpuvm/mod.rs @@ -0,0 +1,328 @@ +// SPDX-License-Identifier: GPL-2.0 OR MIT + +#![cfg(CONFIG_RUST_DRM_GPUVM)] + +//! DRM GPUVM in immediate mode +//! +//! Rust abstractions for using GPUVM in immediate mode. This is when the GPUVM state is updated +//! during `run_job()`, i.e., in the DMA fence signalling critical path, to ensure that the GPUVM +//! and the GPU's virtual address space has the same state at all times. +//! +//! C header: [`include/drm/drm_gpuvm.h`](srctree/include/drm/drm_gpuvm.h) + +use kernel::{ + alloc::{ + AllocError, + Flags as AllocFlags, // + }, + bindings, + drm, + drm::gem::IntoGEMObject, + error::to_result, + prelude::*, + sync::aref::{ + ARef, + AlwaysRefCounted, // + }, + types::Opaque, // +}; + +use core::{ + cell::UnsafeCell, + marker::PhantomData, + mem::{ + ManuallyDrop, + MaybeUninit, // + }, + ops::{ + Deref, + DerefMut, + Range, // + }, + ptr::{ + self, + NonNull, // + }, // +}; + +mod sm_ops; +pub use self::sm_ops::*; + +mod vm_bo; +pub use self::vm_bo::*; + +mod va; +pub use self::va::*; + +/// A DRM GPU VA manager. +/// +/// This object is refcounted, but the locations of mapped ranges may only be accessed or changed +/// via the special unique handle [`UniqueRefGpuVm`]. +/// +/// # Invariants +/// +/// * Stored in an allocation managed by the refcount in `self.vm`. +/// * Access to `data` and the gpuvm interval tree is controlled via the [`UniqueRefGpuVm`] type. +/// * Does not contain any sparse [`GpuVa<T>`] instances. +#[pin_data] +pub struct GpuVm<T: DriverGpuVm> { + #[pin] + vm: Opaque<bindings::drm_gpuvm>, + /// Accessed only through the [`UniqueRefGpuVm`] reference. + data: UnsafeCell<T>, +} + +// SAFETY: The GPUVM api does not assume that it is tied to a specific thread. The destructor will +// drop the `data` field, which is okay because it is guaranteed `Send` by the `DriverGpuVm` trait. +unsafe impl<T: DriverGpuVm> Send for GpuVm<T> {} +// SAFETY: The GPUVM api is designed to allow &self methods to be called in parallel. +unsafe impl<T: DriverGpuVm> Sync for GpuVm<T> {} + +// SAFETY: By type invariants, the allocation is managed by the refcount in `self.vm`. +unsafe impl<T: DriverGpuVm> AlwaysRefCounted for GpuVm<T> { + fn inc_ref(&self) { + // SAFETY: By type invariants, the allocation is managed by the refcount in `self.vm`. + unsafe { bindings::drm_gpuvm_get(self.vm.get()) }; + } + + unsafe fn dec_ref(obj: NonNull<Self>) { + // SAFETY: By type invariants, the allocation is managed by the refcount in `self.vm`. + unsafe { bindings::drm_gpuvm_put((*obj.as_ptr()).vm.get()) }; + } +} + +impl<T: DriverGpuVm> PartialEq for GpuVm<T> { + #[inline] + fn eq(&self, other: &Self) -> bool { + core::ptr::eq(self.as_raw(), other.as_raw()) + } +} +impl<T: DriverGpuVm> Eq for GpuVm<T> {} + +impl<T: DriverGpuVm> GpuVm<T> { + const fn vtable() -> &'static bindings::drm_gpuvm_ops { + &bindings::drm_gpuvm_ops { + vm_free: Some(Self::vm_free), + op_alloc: None, + op_free: None, + vm_bo_alloc: GpuVmBo::<T>::ALLOC_FN, + vm_bo_free: GpuVmBo::<T>::FREE_FN, + vm_bo_validate: None, + sm_step_map: Some(Self::sm_step_map), + sm_step_unmap: Some(Self::sm_step_unmap), + sm_step_remap: Some(Self::sm_step_remap), + } + } + + /// Creates a GPUVM instance. + #[expect(clippy::new_ret_no_self)] + pub fn new<E>( + name: &'static CStr, + dev: &drm::Device<T::Driver>, + r_obj: &T::Object, + range: Range<u64>, + reserve_range: Range<u64>, + data: T, + ) -> Result<UniqueRefGpuVm<T>, E> + where + E: From<AllocError>, + E: From<core::convert::Infallible>, + { + let obj = KBox::try_pin_init::<E>( + try_pin_init!(Self { + data: UnsafeCell::new(data), + vm <- Opaque::ffi_init(|vm| { + // SAFETY: These arguments are valid. `vm` is valid until refcount drops to + // zero. The `vm` is zeroed before calling this method by `__GFP_ZERO` flag + // below. + unsafe { + bindings::drm_gpuvm_init( + vm, + name.as_char_ptr(), + bindings::drm_gpuvm_flags_DRM_GPUVM_IMMEDIATE_MODE + | bindings::drm_gpuvm_flags_DRM_GPUVM_RESV_PROTECTED, + dev.as_raw(), + r_obj.as_raw(), + range.start, + range.end - range.start, + reserve_range.start, + reserve_range.end - reserve_range.start, + const { Self::vtable() }, + ) + } + }), + }? E), + GFP_KERNEL | __GFP_ZERO, + )?; + // SAFETY: This transfers the initial refcount to the ARef. + let aref = unsafe { + ARef::from_raw(NonNull::new_unchecked(KBox::into_raw( + Pin::into_inner_unchecked(obj), + ))) + }; + // INVARIANT: This reference is unique. + Ok(UniqueRefGpuVm(aref)) + } + + /// Access this [`GpuVm`] from a raw pointer. + /// + /// # Safety + /// + /// The pointer must reference the `struct drm_gpuvm` in a valid [`GpuVm<T>`] that remains + /// valid for at least `'a`. + #[inline] + pub unsafe fn from_raw<'a>(ptr: *mut bindings::drm_gpuvm) -> &'a Self { + // SAFETY: Caller passes a pointer to the `drm_gpuvm` in a `GpuVm<T>`. Caller ensures the + // pointer is valid for 'a. + unsafe { &*kernel::container_of!(Opaque::cast_from(ptr), Self, vm) } + } + + /// Returns a raw pointer to the embedded `struct drm_gpuvm`. + #[inline] + pub fn as_raw(&self) -> *mut bindings::drm_gpuvm { + self.vm.get() + } + + /// The start of the VA space. + #[inline] + pub fn va_start(&self) -> u64 { + // SAFETY: The `mm_start` field is immutable. + unsafe { (*self.as_raw()).mm_start } + } + + /// The length of the GPU's virtual address space. + #[inline] + pub fn va_length(&self) -> u64 { + // SAFETY: The `mm_range` field is immutable. + unsafe { (*self.as_raw()).mm_range } + } + + /// Returns the range of the GPU virtual address space. + #[inline] + pub fn va_range(&self) -> Range<u64> { + let start = self.va_start(); + // OVERFLOW: This reconstructs the Range<u64> passed to the constructor, so it won't fail. + let end = start + self.va_length(); + Range { start, end } + } + + /// Get or create the [`GpuVmBo`] for this gem object. + #[inline] + pub fn obtain( + &self, + obj: &T::Object, + data: impl PinInit<T::VmBoData>, + ) -> Result<ARef<GpuVmBo<T>>, AllocError> { + Ok(GpuVmBoAlloc::new(self, obj, data)?.obtain()) + } + + /// Clean up buffer objects that are no longer used. + #[inline] + pub fn deferred_cleanup(&self) { + // SAFETY: This GPUVM uses immediate mode. + unsafe { bindings::drm_gpuvm_bo_deferred_cleanup(self.as_raw()) } + } + + /// Check if this GEM object is an external object for this GPUVM. + #[inline] + pub fn is_extobj(&self, obj: &T::Object) -> bool { + // SAFETY: We may call this with any GPUVM and GEM object. + unsafe { bindings::drm_gpuvm_is_extobj(self.as_raw(), obj.as_raw()) } + } + + /// Free this GPUVM. + /// + /// # Safety + /// + /// Called when refcount hits zero. + unsafe extern "C" fn vm_free(me: *mut bindings::drm_gpuvm) { + // SAFETY: Caller passes a pointer to the `drm_gpuvm` in a `GpuVm<T>`. + let me = unsafe { kernel::container_of!(Opaque::cast_from(me), Self, vm).cast_mut() }; + // SAFETY: By type invariants we can free it when refcount hits zero. + drop(unsafe { KBox::from_raw(me) }) + } + + #[inline] + fn raw_resv(&self) -> *mut bindings::dma_resv { + // SAFETY: `r_obj` is immutable and valid for duration of GPUVM. + unsafe { (*(*self.as_raw()).r_obj).resv } + } +} + +/// The manager for a GPUVM. +pub trait DriverGpuVm: Sized + Send { + /// Parent `Driver` for this object. + type Driver: drm::Driver<Object = Self::Object>; + + /// The kind of GEM object stored in this GPUVM. + type Object: IntoGEMObject; + + /// Data stored with each [`struct drm_gpuva`](struct@GpuVa). + type VaData; + + /// Data stored with each [`struct drm_gpuvm_bo`](struct@GpuVmBo). + type VmBoData; + + /// The private data passed to callbacks. + type SmContext<'ctx>; + + /// Indicates that a new mapping should be created. + fn sm_step_map<'op, 'ctx>( + &mut self, + op: OpMap<'op, Self>, + context: &mut Self::SmContext<'ctx>, + ) -> Result<OpMapped<'op, Self>, Error>; + + /// Indicates that an existing mapping should be removed. + fn sm_step_unmap<'op, 'ctx>( + &mut self, + op: OpUnmap<'op, Self>, + context: &mut Self::SmContext<'ctx>, + ) -> Result<OpUnmapped<'op, Self>, Error>; + + /// Indicates that an existing mapping should be split up. + fn sm_step_remap<'op, 'ctx>( + &mut self, + op: OpRemap<'op, Self>, + context: &mut Self::SmContext<'ctx>, + ) -> Result<OpRemapped<'op, Self>, Error>; +} + +/// The core of the DRM GPU VA manager. +/// +/// This object is a unique reference to the VM that can access the interval tree and the Rust +/// `data` field. +/// +/// # Invariants +/// +/// Each `GpuVm` instance has at most one `UniqueRefGpuVm` reference. +pub struct UniqueRefGpuVm<T: DriverGpuVm>(ARef<GpuVm<T>>); + +// SAFETY: The GPUVM api is designed to allow &self methods to be called in parallel, and +// concurrent access to `data` is safe due to the `T: Sync` requirement. +unsafe impl<T: DriverGpuVm + Sync> Sync for UniqueRefGpuVm<T> {} + +impl<T: DriverGpuVm> UniqueRefGpuVm<T> { + /// Access the data owned by this `UniqueRefGpuVm` immutably. + #[inline] + pub fn data_ref(&self) -> &T { + // SAFETY: By the type invariants we may access `data`. + unsafe { &*self.0.data.get() } + } + + /// Access the data owned by this `UniqueRefGpuVm` mutably. + #[inline] + pub fn data(&mut self) -> &mut T { + // SAFETY: By the type invariants we may access `data`. + unsafe { &mut *self.0.data.get() } + } +} + +impl<T: DriverGpuVm> Deref for UniqueRefGpuVm<T> { + type Target = GpuVm<T>; + + #[inline] + fn deref(&self) -> &GpuVm<T> { + &self.0 + } +} diff --git a/rust/kernel/drm/gpuvm/sm_ops.rs b/rust/kernel/drm/gpuvm/sm_ops.rs new file mode 100644 index 000000000000..69a8e5ab2821 --- /dev/null +++ b/rust/kernel/drm/gpuvm/sm_ops.rs @@ -0,0 +1,429 @@ +// SPDX-License-Identifier: GPL-2.0 OR MIT + +use super::*; + +/// The actual data that gets threaded through the callbacks. +struct SmData<'a, 'ctx, T: DriverGpuVm> { + gpuvm: &'a mut UniqueRefGpuVm<T>, + user_context: &'a mut T::SmContext<'ctx>, +} + +/// Adds an extra field to `SmData` for `sm_map()` callbacks. +/// +/// # Invariants +/// +/// `self.vm_bo.gpuvm() == self.sm_data.gpuvm`. +#[repr(C)] +struct SmMapData<'a, 'ctx, T: DriverGpuVm> { + sm_data: SmData<'a, 'ctx, T>, + vm_bo: &'a GpuVmBo<T>, +} + +/// The argument for [`UniqueRefGpuVm::sm_map`]. +pub struct OpMapRequest<'a, 'ctx, T: DriverGpuVm> { + /// Address in GPU virtual address space. + pub addr: u64, + /// Length of mapping to create. + pub range: u64, + /// Offset in GEM object. + pub gem_offset: u64, + /// The GEM object to map. + pub vm_bo: &'a GpuVmBo<T>, + /// The user-provided context type. + pub context: &'a mut T::SmContext<'ctx>, +} + +impl<'a, 'ctx, T: DriverGpuVm> OpMapRequest<'a, 'ctx, T> { + fn raw_request(&self) -> bindings::drm_gpuvm_map_req { + bindings::drm_gpuvm_map_req { + map: bindings::drm_gpuva_op_map { + va: bindings::drm_gpuva_op_map__bindgen_ty_1 { + addr: self.addr, + range: self.range, + }, + gem: bindings::drm_gpuva_op_map__bindgen_ty_2 { + offset: self.gem_offset, + obj: self.vm_bo.obj().as_raw(), + }, + }, + } + } +} + +/// Represents an `sm_step_map` operation that has not yet been completed. +pub struct OpMap<'op, T: DriverGpuVm> { + op: &'op bindings::drm_gpuva_op_map, + // Since these abstractions are designed for immediate mode, the VM BO needs to be + // pre-allocated, so we always have it available when we reach this point. + vm_bo: &'op GpuVmBo<T>, + // This ensures that 'op is invariant, so that `OpMap<'long, T>` does not + // coerce to `OpMap<'short, T>`. This ensures that the user can't return + // the wrong `OpMapped` value. + _invariant: PhantomData<*mut &'op mut T>, +} + +impl<'op, T: DriverGpuVm> OpMap<'op, T> { + /// The base address of the new mapping. + pub fn addr(&self) -> u64 { + self.op.va.addr + } + + /// The length of the new mapping. + pub fn length(&self) -> u64 { + self.op.va.range + } + + /// The offset within the [`drm_gem_object`](DriverGpuVm::Object). + pub fn gem_offset(&self) -> u64 { + self.op.gem.offset + } + + /// The [`drm_gem_object`](DriverGpuVm::Object) to map. + pub fn obj(&self) -> &T::Object { + // SAFETY: The `obj` pointer is guaranteed to be valid. + unsafe { <T::Object as IntoGEMObject>::from_raw(self.op.gem.obj) } + } + + /// The [`GpuVmBo`] that the new VA will be associated with. + pub fn vm_bo(&self) -> &GpuVmBo<T> { + self.vm_bo + } + + /// Use the pre-allocated VA to carry out this map operation. + pub fn insert(self, va: GpuVaAlloc<T>, va_data: impl PinInit<T::VaData>) -> OpMapped<'op, T> { + let va = va.prepare(va_data); + // SAFETY: By the type invariants we may access the interval tree. + unsafe { bindings::drm_gpuva_map(self.vm_bo.gpuvm().as_raw(), va, self.op) }; + + let _gpuva_guard = self.vm_bo().lock_gpuva(); + // SAFETY: The va is prepared for insertion, and we hold the GEM lock. + unsafe { bindings::drm_gpuva_link(va, self.vm_bo.as_raw()) }; + + OpMapped { + _invariant: self._invariant, + } + } +} + +/// Represents a completed [`OpMap`] operation. +pub struct OpMapped<'op, T> { + _invariant: PhantomData<*mut &'op mut T>, +} + +/// Represents an `sm_step_unmap` operation that has not yet been completed. +pub struct OpUnmap<'op, T: DriverGpuVm> { + op: &'op bindings::drm_gpuva_op_unmap, + // This ensures that 'op is invariant, so that `OpUnmap<'long, T>` does not + // coerce to `OpUnmap<'short, T>`. This ensures that the user can't return the + // wrong`OpUnmapped` value. + _invariant: PhantomData<*mut &'op mut T>, +} + +impl<'op, T: DriverGpuVm> OpUnmap<'op, T> { + /// Indicates whether this [`GpuVa`] is physically contiguous with the + /// original mapping request. + /// + /// Optionally, if `keep` is set, drivers may keep the actual page table + /// mappings for this `drm_gpuva`, adding the missing page table entries + /// only and update the `drm_gpuvm` accordingly. + pub fn keep(&self) -> bool { + self.op.keep + } + + /// The range being unmapped. + pub fn va(&self) -> &GpuVa<T> { + // SAFETY: This is a valid va. It's not the `kernel_alloc_node` because you can't unmap it, + // and it's not sparse by the `GpuVm<T>` type invariants. + unsafe { GpuVa::<T>::from_raw(self.op.va) } + } + + /// Remove the VA. + pub fn remove(self) -> (OpUnmapped<'op, T>, GpuVaRemoved<T>) { + // SAFETY: The op references a valid drm_gpuva in the GPUVM. + unsafe { bindings::drm_gpuva_unmap(self.op) }; + // SAFETY: The va is no longer in the interval tree so we may unlink it. + unsafe { bindings::drm_gpuva_unlink_defer(self.op.va) }; + + // SAFETY: We just removed this va from the `GpuVm<T>`. + let va = unsafe { GpuVaRemoved::from_raw(self.op.va) }; + + ( + OpUnmapped { + _invariant: self._invariant, + }, + va, + ) + } +} + +/// Represents a completed [`OpUnmap`] operation. +pub struct OpUnmapped<'op, T> { + _invariant: PhantomData<*mut &'op mut T>, +} + +/// Represents an `sm_step_remap` operation that has not yet been completed. +pub struct OpRemap<'op, T: DriverGpuVm> { + op: &'op bindings::drm_gpuva_op_remap, + // This ensures that 'op is invariant, so that `OpRemap<'long, T>` does not + // coerce to `OpRemap<'short, T>`. This ensures that the user can't return the + // wrong`OpRemapped` value. + _invariant: PhantomData<*mut &'op mut T>, +} + +impl<'op, T: DriverGpuVm> OpRemap<'op, T> { + /// The preceding part of a split mapping. + #[inline] + pub fn prev(&self) -> Option<&OpRemapMapData> { + // SAFETY: We checked for null, so the pointer must be valid. + NonNull::new(self.op.prev).map(|ptr| unsafe { OpRemapMapData::from_raw(ptr) }) + } + + /// The subsequent part of a split mapping. + #[inline] + pub fn next(&self) -> Option<&OpRemapMapData> { + // SAFETY: We checked for null, so the pointer must be valid. + NonNull::new(self.op.next).map(|ptr| unsafe { OpRemapMapData::from_raw(ptr) }) + } + + /// Indicates whether the `drm_gpuva` being removed is physically contiguous with the original + /// mapping request. + /// + /// Optionally, if `keep` is set, drivers may keep the actual page table mappings for this + /// `drm_gpuva`, adding the missing page table entries only and update the `drm_gpuvm` + /// accordingly. + #[inline] + pub fn keep(&self) -> bool { + // SAFETY: The unmap pointer is always valid. + unsafe { (*self.op.unmap).keep } + } + + /// The range being unmapped. + #[inline] + pub fn va_to_unmap(&self) -> &GpuVa<T> { + // SAFETY: This is a valid va. It's not the `kernel_alloc_node` because you can't unmap it, + // and it's not sparse by the `GpuVm<T>` type invariants. + unsafe { GpuVa::<T>::from_raw((*self.op.unmap).va) } + } + + /// The [`drm_gem_object`](DriverGpuVm::Object) whose VA is being remapped. + #[inline] + pub fn obj(&self) -> &T::Object { + self.va_to_unmap().obj() + } + + /// The [`GpuVmBo`] that is being remapped. + #[inline] + pub fn vm_bo(&self) -> &GpuVmBo<T> { + self.va_to_unmap().vm_bo() + } + + /// Update the GPUVM to perform the remapping. + pub fn remap( + self, + va_alloc: [GpuVaAlloc<T>; 2], + prev_data: impl PinInit<T::VaData>, + next_data: impl PinInit<T::VaData>, + ) -> (OpRemapped<'op, T>, OpRemapRet<T>) { + let [va1, va2] = va_alloc; + + let mut unused_va = None; + let mut prev_ptr = ptr::null_mut(); + let mut next_ptr = ptr::null_mut(); + if self.prev().is_some() { + prev_ptr = va1.prepare(prev_data); + } else { + unused_va = Some(va1); + } + if self.next().is_some() { + next_ptr = va2.prepare(next_data); + } else { + unused_va = Some(va2); + } + + // SAFETY: the pointers are non-null when required + unsafe { bindings::drm_gpuva_remap(prev_ptr, next_ptr, self.op) }; + + let gpuva_guard = self.vm_bo().lock_gpuva(); + if !prev_ptr.is_null() { + // SAFETY: The prev_ptr is a valid drm_gpuva prepared for insertion. The vm_bo is still + // valid as the not-yet-unlinked gpuva holds a refcount on the vm_bo. + unsafe { bindings::drm_gpuva_link(prev_ptr, self.vm_bo().as_raw()) }; + } + if !next_ptr.is_null() { + // SAFETY: The next_ptr is a valid drm_gpuva prepared for insertion. The vm_bo is still + // valid as the not-yet-unlinked gpuva holds a refcount on the vm_bo. + unsafe { bindings::drm_gpuva_link(next_ptr, self.vm_bo().as_raw()) }; + } + drop(gpuva_guard); + + // SAFETY: The va is no longer in the interval tree so we may unlink it. + unsafe { bindings::drm_gpuva_unlink_defer((*self.op.unmap).va) }; + + ( + OpRemapped { + _invariant: self._invariant, + }, + OpRemapRet { + // SAFETY: We just removed this va from the `GpuVm<T>`. + unmapped_va: unsafe { GpuVaRemoved::from_raw((*self.op.unmap).va) }, + unused_va, + }, + ) + } +} + +/// Part of an [`OpRemap`] that represents a new mapping. +#[repr(transparent)] +pub struct OpRemapMapData(bindings::drm_gpuva_op_map); + +impl OpRemapMapData { + /// # Safety + /// Must reference a valid `drm_gpuva_op_map` for duration of `'a`. + unsafe fn from_raw<'a>(ptr: NonNull<bindings::drm_gpuva_op_map>) -> &'a Self { + // SAFETY: ok per safety requirements + unsafe { ptr.cast().as_ref() } + } + + /// The base address of the new mapping. + pub fn addr(&self) -> u64 { + self.0.va.addr + } + + /// The length of the new mapping. + pub fn length(&self) -> u64 { + self.0.va.range + } + + /// The offset within the [`drm_gem_object`](DriverGpuVm::Object). + pub fn gem_offset(&self) -> u64 { + self.0.gem.offset + } +} + +/// Struct containing objects removed or not used by [`OpRemap::remap`]. +pub struct OpRemapRet<T: DriverGpuVm> { + /// The `drm_gpuva` that was removed. + pub unmapped_va: GpuVaRemoved<T>, + /// If the remap did not split the region into two pieces, then the unused `drm_gpuva` is + /// returned here. + pub unused_va: Option<GpuVaAlloc<T>>, +} + +/// Represents a completed [`OpRemap`] operation. +pub struct OpRemapped<'op, T> { + _invariant: PhantomData<*mut &'op mut T>, +} + +impl<T: DriverGpuVm> UniqueRefGpuVm<T> { + /// Create a mapping, removing or remapping anything that overlaps. + /// + /// Internally calls the [`DriverGpuVm`] callbacks similar to [`Self::sm_unmap`], except that + /// the [`DriverGpuVm::sm_step_map`] is called once to create the requested mapping. + #[inline] + pub fn sm_map(&mut self, req: OpMapRequest<'_, '_, T>) -> Result { + if req.vm_bo.gpuvm() != &**self { + return Err(EINVAL); + } + + let gpuvm = self.as_raw(); + let raw_req = req.raw_request(); + // INVARIANT: Checked above that `vm_bo.gpuvm() == self`. + let mut p = SmMapData { + sm_data: SmData { + gpuvm: self, + user_context: req.context, + }, + vm_bo: req.vm_bo, + }; + // SAFETY: + // * raw_request() creates a valid request. + // * The private data is valid to be interpreted as both SmData and SmMapData since the + // first field of SmMapData is SmData. + to_result(unsafe { + bindings::drm_gpuvm_sm_map(gpuvm, (&raw mut p).cast(), &raw const raw_req) + }) + } + + /// Remove any mappings in the given region. + /// + /// Internally calls [`DriverGpuVm::sm_step_unmap`] for ranges entirely contained within the + /// given range, and [`DriverGpuVm::sm_step_remap`] for ranges that overlap with the range. + #[inline] + pub fn sm_unmap(&mut self, addr: u64, length: u64, context: &mut T::SmContext<'_>) -> Result { + let gpuvm = self.as_raw(); + let mut p = SmData { + gpuvm: self, + user_context: context, + }; + // SAFETY: + // * raw_request() creates a valid request. + // * The private data is a valid SmData. + to_result(unsafe { bindings::drm_gpuvm_sm_unmap(gpuvm, (&raw mut p).cast(), addr, length) }) + } +} + +impl<T: DriverGpuVm> GpuVm<T> { + /// # Safety + /// Must be called from `sm_map` with a pointer to `SmMapData`. + pub(super) unsafe extern "C" fn sm_step_map( + op: *mut bindings::drm_gpuva_op, + p: *mut c_void, + ) -> c_int { + // SAFETY: If we reach `sm_step_map` then we were called from `sm_map` which always passes + // an `SmMapData` as private data. + let p = unsafe { &mut *p.cast::<SmMapData<'_, '_, T>>() }; + let op = OpMap { + // SAFETY: sm_step_map is called with a map operation. + op: unsafe { &(*op).__bindgen_anon_1.map }, + vm_bo: p.vm_bo, + _invariant: PhantomData, + }; + match p + .sm_data + .gpuvm + .data() + .sm_step_map(op, p.sm_data.user_context) + { + Ok(OpMapped { .. }) => 0, + Err(err) => err.to_errno(), + } + } + + /// # Safety + /// Must be called from `sm_map` or `sm_unmap` with a pointer to `SmMapData` or `SmData`. + pub(super) unsafe extern "C" fn sm_step_unmap( + op: *mut bindings::drm_gpuva_op, + p: *mut c_void, + ) -> c_int { + // SAFETY: The caller provides a pointer that can be treated as `SmData`. + let p = unsafe { &mut *p.cast::<SmData<'_, '_, T>>() }; + let op = OpUnmap { + // SAFETY: sm_step_unmap is called with an unmap operation. + op: unsafe { &(*op).__bindgen_anon_1.unmap }, + _invariant: PhantomData, + }; + match p.gpuvm.data().sm_step_unmap(op, p.user_context) { + Ok(OpUnmapped { .. }) => 0, + Err(err) => err.to_errno(), + } + } + + /// # Safety + /// Must be called from `sm_map` or `sm_unmap` with a pointer to `SmMapData` or `SmData`. + pub(super) unsafe extern "C" fn sm_step_remap( + op: *mut bindings::drm_gpuva_op, + p: *mut c_void, + ) -> c_int { + // SAFETY: The caller provides a pointer that can be treated as `SmData`. + let p = unsafe { &mut *p.cast::<SmData<'_, '_, T>>() }; + let op = OpRemap { + // SAFETY: sm_step_remap is called with a remap operation. + op: unsafe { &(*op).__bindgen_anon_1.remap }, + _invariant: PhantomData, + }; + match p.gpuvm.data().sm_step_remap(op, p.user_context) { + Ok(OpRemapped { .. }) => 0, + Err(err) => err.to_errno(), + } + } +} diff --git a/rust/kernel/drm/gpuvm/va.rs b/rust/kernel/drm/gpuvm/va.rs new file mode 100644 index 000000000000..0b09fe44ab39 --- /dev/null +++ b/rust/kernel/drm/gpuvm/va.rs @@ -0,0 +1,168 @@ +// SPDX-License-Identifier: GPL-2.0 OR MIT + +use super::*; + +/// Represents that a range of a GEM object is mapped in this [`GpuVm`] instance. +/// +/// Does not assume that GEM lock is held. +/// +/// # Invariants +/// +/// * This is a valid `drm_gpuva` object that is resident in a [`GpuVm<T>`] instance. +/// * It is associated with a [`GpuVmBo<T>`]. Or in other words, it's not an +/// `gpuvm->kernel_alloc_node` and `DRM_GPUVA_SPARSE` is not set. +/// * The associated [`GpuVmBo<T>`] is part of the GEM list. +#[repr(C)] +#[pin_data] +pub struct GpuVa<T: DriverGpuVm> { + #[pin] + inner: Opaque<bindings::drm_gpuva>, + #[pin] + data: T::VaData, +} + +impl<T: DriverGpuVm> PartialEq for GpuVa<T> { + #[inline] + fn eq(&self, other: &Self) -> bool { + core::ptr::eq(self.as_raw(), other.as_raw()) + } +} +impl<T: DriverGpuVm> Eq for GpuVa<T> {} + +impl<T: DriverGpuVm> GpuVa<T> { + /// Access this [`GpuVa`] from a raw pointer. + /// + /// # Safety + /// + /// * For the duration of `'a`, the pointer must reference a valid `drm_gpuva` associated with + /// a [`GpuVm<T>`]. + /// * It must be associated with a [`GpuVmBo<T>`]. + /// * The associated [`GpuVmBo<T>`] is part of the GEM list. + #[inline] + pub unsafe fn from_raw<'a>(ptr: *mut bindings::drm_gpuva) -> &'a Self { + // CAST: `drm_gpuva` is first field and `repr(C)`. + // SAFETY: The safety requirements match the invariants of `GpuVa`. + unsafe { &*ptr.cast() } + } + + /// Returns a raw pointer to underlying C value. + #[inline] + pub fn as_raw(&self) -> *mut bindings::drm_gpuva { + self.inner.get() + } + + /// Returns the address of this mapping in the GPU virtual address space. + #[inline] + pub fn addr(&self) -> u64 { + // SAFETY: The `va.addr` field of `drm_gpuva` is immutable. + unsafe { (*self.as_raw()).va.addr } + } + + /// Returns the length of this mapping. + #[inline] + pub fn length(&self) -> u64 { + // SAFETY: The `va.range` field of `drm_gpuva` is immutable. + unsafe { (*self.as_raw()).va.range } + } + + /// Returns `addr..addr+length`. + #[inline] + pub fn range(&self) -> Range<u64> { + let addr = self.addr(); + addr..addr + self.length() + } + + /// Returns the offset within the GEM object. + #[inline] + pub fn gem_offset(&self) -> u64 { + // SAFETY: The `gem.offset` field of `drm_gpuva` is immutable. + unsafe { (*self.as_raw()).gem.offset } + } + + /// Returns the GEM object. + #[inline] + pub fn obj(&self) -> &T::Object { + // SAFETY: The `gem.obj` field of `drm_gpuva` is immutable. We know that it's not null + // because this VA is associated with a `GpuVmBo<T>`. + unsafe { <T::Object as IntoGEMObject>::from_raw((*self.as_raw()).gem.obj) } + } + + /// Returns the underlying [`GpuVmBo`] object that backs this [`GpuVa`]. + #[inline] + pub fn vm_bo(&self) -> &GpuVmBo<T> { + // SAFETY: The `vm_bo` field of `drm_gpuva` is immutable. We know that it's not null + // because this VA is associated with a `GpuVmBo<T>`. The BO is in the GEM list by the type + // invariants. + unsafe { GpuVmBo::from_raw((*self.as_raw()).vm_bo) } + } +} + +/// A pre-allocated [`GpuVa`] object. +/// +/// # Invariants +/// +/// The memory is zeroed. +pub struct GpuVaAlloc<T: DriverGpuVm>(KBox<MaybeUninit<GpuVa<T>>>); + +impl<T: DriverGpuVm> GpuVaAlloc<T> { + /// Pre-allocate a [`GpuVa`] object. + pub fn new(flags: AllocFlags) -> Result<GpuVaAlloc<T>, AllocError> { + // INVARIANTS: Memory allocated with __GFP_ZERO. + Ok(GpuVaAlloc(KBox::new_uninit(flags | __GFP_ZERO)?)) + } + + /// Prepare this `drm_gpuva` for insertion into the GPUVM. + #[must_use] + pub(super) fn prepare(mut self, va_data: impl PinInit<T::VaData>) -> *mut bindings::drm_gpuva { + let va_ptr = MaybeUninit::as_mut_ptr(&mut self.0); + // SAFETY: The `data` field is pinned. + let Ok(()) = unsafe { va_data.__pinned_init(&raw mut (*va_ptr).data) }; + KBox::into_raw(self.0).cast() + } +} + +/// A [`GpuVa`] object that has been removed. +/// +/// # Invariants +/// +/// The `drm_gpuva` is not resident in the [`GpuVm`]. +pub struct GpuVaRemoved<T: DriverGpuVm>(KBox<GpuVa<T>>); + +impl<T: DriverGpuVm> GpuVaRemoved<T> { + /// Convert a raw pointer into a [`GpuVaRemoved`]. + /// + /// # Safety + /// + /// * Must have been removed from a [`GpuVm<T>`]. + /// * It must not be a `gpuvm->kernel_alloc_node` va. + pub(super) unsafe fn from_raw(ptr: *mut bindings::drm_gpuva) -> Self { + // SAFETY: Since it used to be a VA in a `GpuVm<T>` and it's not a kernel_alloc_node, this + // pointer references a `GpuVa<T>` with a valid `T::VaData`. Since it has been removed, we + // can take ownership of the allocation. + GpuVaRemoved(unsafe { KBox::from_raw(ptr.cast()) }) + } + + /// Take ownership of the VA data. + pub fn into_inner(self) -> T::VaData + where + T::VaData: Unpin, + { + KBox::into_inner(self.0).data + } +} + +impl<T: DriverGpuVm> Deref for GpuVaRemoved<T> { + type Target = T::VaData; + fn deref(&self) -> &T::VaData { + &self.0.data + } +} + +impl<T: DriverGpuVm> DerefMut for GpuVaRemoved<T> +where + T::VaData: Unpin, +{ + fn deref_mut(&mut self) -> &mut T::VaData { + &mut self.0.data + } +} diff --git a/rust/kernel/drm/gpuvm/vm_bo.rs b/rust/kernel/drm/gpuvm/vm_bo.rs new file mode 100644 index 000000000000..c064ac63897b --- /dev/null +++ b/rust/kernel/drm/gpuvm/vm_bo.rs @@ -0,0 +1,249 @@ +// SPDX-License-Identifier: GPL-2.0 OR MIT + +use super::*; + +/// Represents that a given GEM object has at least one mapping on this [`GpuVm`] instance. +/// +/// Does not assume that GEM lock is held. +/// +/// # Invariants +/// +/// * Allocated with `kmalloc` and refcounted via `inner`. +/// * Is present in the gem list. +#[repr(C)] +#[pin_data] +pub struct GpuVmBo<T: DriverGpuVm> { + #[pin] + inner: Opaque<bindings::drm_gpuvm_bo>, + #[pin] + data: T::VmBoData, +} + +// SAFETY: By type invariants, the allocation is managed by the refcount in `self.inner`. +unsafe impl<T: DriverGpuVm> AlwaysRefCounted for GpuVmBo<T> { + fn inc_ref(&self) { + // SAFETY: By type invariants, the allocation is managed by the refcount in `self.inner`. + unsafe { bindings::drm_gpuvm_bo_get(self.inner.get()) }; + } + + unsafe fn dec_ref(obj: NonNull<Self>) { + // CAST: `drm_gpuvm_bo` is first field of repr(C) struct. + // SAFETY: By type invariants, the allocation is managed by the refcount in `self.inner`. + // This GPUVM instance uses immediate mode, so we may put the refcount using the deferred + // mechanism. + unsafe { bindings::drm_gpuvm_bo_put_deferred(obj.as_ptr().cast()) }; + } +} + +impl<T: DriverGpuVm> PartialEq for GpuVmBo<T> { + #[inline] + fn eq(&self, other: &Self) -> bool { + core::ptr::eq(self.as_raw(), other.as_raw()) + } +} +impl<T: DriverGpuVm> Eq for GpuVmBo<T> {} + +impl<T: DriverGpuVm> GpuVmBo<T> { + /// The function pointer for allocating a GpuVmBo stored in the gpuvm vtable. + /// + /// Allocation is always implemented according to [`Self::vm_bo_alloc`], but it is set to + /// `None` if the default gpuvm behavior is the same as `vm_bo_alloc`. + /// + /// This may be `Some` even if `FREE_FN` is `None`, or vice-versa. + pub(super) const ALLOC_FN: Option<unsafe extern "C" fn() -> *mut bindings::drm_gpuvm_bo> = { + use core::alloc::Layout; + let base = Layout::new::<bindings::drm_gpuvm_bo>(); + let rust = Layout::new::<Self>(); + assert!(base.size() <= rust.size()); + if base.size() != rust.size() || base.align() != rust.align() { + Some(Self::vm_bo_alloc) + } else { + // This causes GPUVM to allocate a `GpuVmBo<T>` with `kzalloc(sizeof(drm_gpuvm_bo))`. + None + } + }; + + /// The function pointer for freeing a GpuVmBo stored in the gpuvm vtable. + /// + /// Freeing is always implemented according to [`Self::vm_bo_free`], but it is set to `None` if + /// the default gpuvm behavior is the same as `vm_bo_free`. + /// + /// This may be `Some` even if `ALLOC_FN` is `None`, or vice-versa. + pub(super) const FREE_FN: Option<unsafe extern "C" fn(*mut bindings::drm_gpuvm_bo)> = { + if core::mem::needs_drop::<Self>() { + Some(Self::vm_bo_free) + } else { + // This causes GPUVM to free a `GpuVmBo<T>` with `kfree`. + None + } + }; + + /// Custom function for allocating a `drm_gpuvm_bo`. + /// + /// # Safety + /// + /// Always safe to call. + unsafe extern "C" fn vm_bo_alloc() -> *mut bindings::drm_gpuvm_bo { + let raw_ptr = KBox::<Self>::new_uninit(GFP_KERNEL | __GFP_ZERO) + .map(KBox::into_raw) + .unwrap_or(ptr::null_mut()); + + // CAST: `drm_gpuvm_bo` is first field of `Self`. + raw_ptr.cast() + } + + /// Custom function for freeing a `drm_gpuvm_bo`. + /// + /// # Safety + /// + /// The pointer must have been allocated with [`GpuVmBo::ALLOC_FN`], and must not be used after + /// this call. + unsafe extern "C" fn vm_bo_free(ptr: *mut bindings::drm_gpuvm_bo) { + // CAST: `drm_gpuvm_bo` is first field of `Self`. + // SAFETY: + // * The ptr was allocated from kmalloc with the layout of `GpuVmBo<T>`. + // * `ptr->inner` has no destructor. + // * `ptr->data` contains a valid `T::VmBoData` that we can drop. + drop(unsafe { KBox::<Self>::from_raw(ptr.cast()) }); + } + + /// Access this [`GpuVmBo`] from a raw pointer. + /// + /// # Safety + /// + /// For the duration of `'a`, the pointer must reference a valid `drm_gpuvm_bo` associated with + /// a [`GpuVm<T>`]. The BO must also be present in the GEM list. + #[inline] + pub(crate) unsafe fn from_raw<'a>(ptr: *mut bindings::drm_gpuvm_bo) -> &'a Self { + // SAFETY: `drm_gpuvm_bo` is first field and `repr(C)`. + unsafe { &*ptr.cast() } + } + + /// Returns a raw pointer to underlying C value. + #[inline] + pub fn as_raw(&self) -> *mut bindings::drm_gpuvm_bo { + self.inner.get() + } + + /// The [`GpuVm`] that this GEM object is mapped in. + #[inline] + pub fn gpuvm(&self) -> &GpuVm<T> { + // SAFETY: The `obj` pointer is guaranteed to be valid. + unsafe { GpuVm::<T>::from_raw((*self.inner.get()).vm) } + } + + /// The [`drm_gem_object`](DriverGpuVm::Object) for these mappings. + #[inline] + pub fn obj(&self) -> &T::Object { + // SAFETY: The `obj` pointer is guaranteed to be valid. + unsafe { <T::Object as IntoGEMObject>::from_raw((*self.inner.get()).obj) } + } + + /// The driver data with this buffer object. + #[inline] + pub fn data(&self) -> &T::VmBoData { + &self.data + } + + pub(super) fn lock_gpuva(&self) -> crate::sync::MutexGuard<'_, ()> { + // SAFETY: The GEM object is valid. + let ptr = unsafe { &raw mut (*self.obj().as_raw()).gpuva.lock }; + // SAFETY: The GEM object is valid, so the mutex is properly initialized. + let mutex = unsafe { crate::sync::Mutex::from_raw(ptr) }; + mutex.lock() + } +} + +/// A pre-allocated [`GpuVmBo`] object. +/// +/// # Invariants +/// +/// Points at a `drm_gpuvm_bo` that contains a valid `T::VmBoData`, has a refcount of one, and is +/// absent from any gem, extobj, or evict lists. +pub(super) struct GpuVmBoAlloc<T: DriverGpuVm>(NonNull<GpuVmBo<T>>); + +impl<T: DriverGpuVm> GpuVmBoAlloc<T> { + /// Create a new pre-allocated [`GpuVmBo`]. + /// + /// It's intentional that the initializer is infallible because `drm_gpuvm_bo_put` will call + /// drop on the data, so we don't have a way to free it when the data is missing. + #[inline] + pub(super) fn new( + gpuvm: &GpuVm<T>, + gem: &T::Object, + value: impl PinInit<T::VmBoData>, + ) -> Result<GpuVmBoAlloc<T>, AllocError> { + // CAST: `GpuVmBoAlloc::vm_bo_alloc` ensures that this memory was allocated with the layout + // of `GpuVmBo<T>`. The type is repr(C), so `container_of` is not required. + // SAFETY: The provided gpuvm and gem ptrs are valid for the duration of this call. + let raw_ptr = unsafe { + bindings::drm_gpuvm_bo_create(gpuvm.as_raw(), gem.as_raw()).cast::<GpuVmBo<T>>() + }; + let ptr = NonNull::new(raw_ptr).ok_or(AllocError)?; + // SAFETY: `ptr->data` is a valid pinned location. + let Ok(()) = unsafe { value.__pinned_init(&raw mut (*raw_ptr).data) }; + // INVARIANTS: We just created the vm_bo so it's absent from lists, and the data is valid + // as we just initialized it. + Ok(GpuVmBoAlloc(ptr)) + } + + /// Returns a raw pointer to underlying C value. + #[inline] + pub(super) fn as_raw(&self) -> *mut bindings::drm_gpuvm_bo { + // SAFETY: The pointer references a valid `drm_gpuvm_bo`. + unsafe { (*self.0.as_ptr()).inner.get() } + } + + /// Look up whether there is an existing [`GpuVmBo`] for this gem object. + /// + /// The caller should not hold the GEM mutex or DMA resv lock. + #[inline] + pub(super) fn obtain(self) -> ARef<GpuVmBo<T>> { + let me = ManuallyDrop::new(self); + // SAFETY: Valid `drm_gpuvm_bo` not already in the lists. We do not access `me` after this + // call. + let ptr = unsafe { bindings::drm_gpuvm_bo_obtain_prealloc(me.as_raw()) }; + + // SAFETY: `drm_gpuvm_bo_obtain_prealloc` always returns a non-null ptr + let nonnull = unsafe { NonNull::new_unchecked(ptr.cast()) }; + + // INVARIANTS: `drm_gpuvm_bo_obtain_prealloc` ensures that the bo is in the GEM list. + // SAFETY: We received one refcount from `drm_gpuvm_bo_obtain_prealloc`. + let ret = unsafe { ARef::<GpuVmBo<T>>::from_raw(nonnull) }; + + // Ensure that external objects are in the extobj list. + // + // Note that we must call `extobj_add` even if `ptr != me` to avoid a race condition where + // we could end up using the extobj before the thread with `ptr == me` calls extobj_add. + if ret.gpuvm().is_extobj(ret.obj()) { + let resv_lock = ret.gpuvm().raw_resv(); + // TODO: Use a proper lock guard here once a dma_resv lock abstraction exists. + // SAFETY: The GPUVM is still alive, so its resv lock is too. + unsafe { bindings::dma_resv_lock(resv_lock, ptr::null_mut()) }; + // SAFETY: We hold the GPUVMs resv lock. + unsafe { bindings::drm_gpuvm_bo_extobj_add(ptr) }; + // SAFETY: We took the lock, so we can unlock it. + unsafe { bindings::dma_resv_unlock(resv_lock) }; + } + + ret + } +} + +impl<T: DriverGpuVm> Deref for GpuVmBoAlloc<T> { + type Target = GpuVmBo<T>; + #[inline] + fn deref(&self) -> &GpuVmBo<T> { + // SAFETY: By the type invariants we may deref while `Self` exists. + unsafe { self.0.as_ref() } + } +} + +impl<T: DriverGpuVm> Drop for GpuVmBoAlloc<T> { + #[inline] + fn drop(&mut self) { + // TODO: Call drm_gpuvm_bo_destroy_not_in_lists() directly. + // SAFETY: It's safe to perform a deferred put in any context. + unsafe { bindings::drm_gpuvm_bo_put_deferred(self.as_raw()) }; + } +} diff --git a/rust/kernel/drm/mod.rs b/rust/kernel/drm/mod.rs index 1b82b6945edf..a66e7166f66b 100644 --- a/rust/kernel/drm/mod.rs +++ b/rust/kernel/drm/mod.rs @@ -6,9 +6,14 @@ pub mod device; pub mod driver; pub mod file; pub mod gem; +pub mod gpuvm; pub mod ioctl; pub use self::device::Device; +pub use self::device::DeviceContext; +pub use self::device::Registered; +pub use self::device::Uninit; +pub use self::device::UnregisteredDevice; pub use self::driver::Driver; pub use self::driver::DriverInfo; pub use self::driver::Registration; diff --git a/rust/kernel/error.rs b/rust/kernel/error.rs index 05cf869ac090..a56ba6309594 100644 --- a/rust/kernel/error.rs +++ b/rust/kernel/error.rs @@ -25,10 +25,8 @@ pub mod code { #[doc = $doc] )* pub const $err: super::Error = - match super::Error::try_from_errno(-(crate::bindings::$err as i32)) { - Some(err) => err, - None => panic!("Invalid errno in `declare_err!`"), - }; + super::Error::try_from_errno(-(crate::bindings::$err as i32)) + .expect("Invalid errno in `declare_err!`"); }; } diff --git a/rust/kernel/fmt.rs b/rust/kernel/fmt.rs index 1e8725eb44ed..73afbc51ba33 100644 --- a/rust/kernel/fmt.rs +++ b/rust/kernel/fmt.rs @@ -4,7 +4,14 @@ //! //! This module is intended to be used in place of `core::fmt` in kernel code. -pub use core::fmt::{Arguments, Debug, Error, Formatter, Result, Write}; +pub use core::fmt::{ + Arguments, + Debug, + Error, + Formatter, + Result, + Write, // +}; /// Internal adapter used to route and allow implementations of formatting traits for foreign types. /// @@ -27,7 +34,15 @@ macro_rules! impl_fmt_adapter_forward { }; } -use core::fmt::{Binary, LowerExp, LowerHex, Octal, Pointer, UpperExp, UpperHex}; +use core::fmt::{ + Binary, + LowerExp, + LowerHex, + Octal, + Pointer, + UpperExp, + UpperHex, // +}; impl_fmt_adapter_forward!(Debug, LowerHex, UpperHex, Octal, Binary, Pointer, LowerExp, UpperExp); /// A copy of [`core::fmt::Display`] that allows us to implement it for foreign types. diff --git a/rust/kernel/i2c.rs b/rust/kernel/i2c.rs index 7b908f0c5a58..624b971ca8b0 100644 --- a/rust/kernel/i2c.rs +++ b/rust/kernel/i2c.rs @@ -93,18 +93,18 @@ pub struct Adapter<T: Driver>(T); // SAFETY: // - `bindings::i2c_driver` is a C type declared as `repr(C)`. -// - `T` is the type of the driver's device private data. +// - `T::Data` is the type of the driver's device private data. // - `struct i2c_driver` embeds a `struct device_driver`. // - `DEVICE_DRIVER_OFFSET` is the correct byte offset to the embedded `struct device_driver`. -unsafe impl<T: Driver + 'static> driver::DriverLayout for Adapter<T> { +unsafe impl<T: Driver> driver::DriverLayout for Adapter<T> { type DriverType = bindings::i2c_driver; - type DriverData = T; + type DriverData<'bound> = T::Data<'bound>; const DEVICE_DRIVER_OFFSET: usize = core::mem::offset_of!(Self::DriverType, driver); } // SAFETY: A call to `unregister` for a given instance of `DriverType` is guaranteed to be valid if // a preceding call to `register` has been successful. -unsafe impl<T: Driver + 'static> driver::RegistrationOps for Adapter<T> { +unsafe impl<T: Driver> driver::RegistrationOps for Adapter<T> { unsafe fn register( idrv: &Opaque<Self::DriverType>, name: &'static CStr, @@ -151,13 +151,13 @@ unsafe impl<T: Driver + 'static> driver::RegistrationOps for Adapter<T> { } } -impl<T: Driver + 'static> Adapter<T> { +impl<T: Driver> Adapter<T> { extern "C" fn probe_callback(idev: *mut bindings::i2c_client) -> kernel::ffi::c_int { // SAFETY: The I2C bus only ever calls the probe callback with a valid pointer to a // `struct i2c_client`. // // INVARIANT: `idev` is valid for the duration of `probe_callback()`. - let idev = unsafe { &*idev.cast::<I2cClient<device::CoreInternal>>() }; + let idev = unsafe { &*idev.cast::<I2cClient<device::CoreInternal<'_>>>() }; let info = Self::i2c_id_info(idev).or_else(|| <Self as driver::Adapter>::id_info(idev.as_ref())); @@ -172,24 +172,24 @@ impl<T: Driver + 'static> Adapter<T> { extern "C" fn remove_callback(idev: *mut bindings::i2c_client) { // SAFETY: `idev` is a valid pointer to a `struct i2c_client`. - let idev = unsafe { &*idev.cast::<I2cClient<device::CoreInternal>>() }; + let idev = unsafe { &*idev.cast::<I2cClient<device::CoreInternal<'_>>>() }; // SAFETY: `remove_callback` is only ever called after a successful call to // `probe_callback`, hence it's guaranteed that `I2cClient::set_drvdata()` has been called - // and stored a `Pin<KBox<T>>`. - let data = unsafe { idev.as_ref().drvdata_borrow::<T>() }; + // and stored a `Pin<KBox<T::Data<'_>>>`. + let data = unsafe { idev.as_ref().drvdata_borrow::<T::Data<'_>>() }; T::unbind(idev, data); } extern "C" fn shutdown_callback(idev: *mut bindings::i2c_client) { // SAFETY: `shutdown_callback` is only ever called for a valid `idev` - let idev = unsafe { &*idev.cast::<I2cClient<device::CoreInternal>>() }; + let idev = unsafe { &*idev.cast::<I2cClient<device::CoreInternal<'_>>>() }; // SAFETY: `shutdown_callback` is only ever called after a successful call to // `probe_callback`, hence it's guaranteed that `Device::set_drvdata()` has been called - // and stored a `Pin<KBox<T>>`. - let data = unsafe { idev.as_ref().drvdata_borrow::<T>() }; + // and stored a `Pin<KBox<T::Data<'_>>>`. + let data = unsafe { idev.as_ref().drvdata_borrow::<T::Data<'_>>() }; T::shutdown(idev, data); } @@ -222,7 +222,7 @@ impl<T: Driver + 'static> Adapter<T> { } } -impl<T: Driver + 'static> driver::Adapter for Adapter<T> { +impl<T: Driver> driver::Adapter for Adapter<T> { type IdInfo = T::IdInfo; fn of_id_table() -> Option<of::IdTable<Self::IdInfo>> { @@ -294,22 +294,26 @@ macro_rules! module_i2c_driver { /// /// impl i2c::Driver for MyDriver { /// type IdInfo = (); +/// type Data<'bound> = Self; /// const I2C_ID_TABLE: Option<i2c::IdTable<Self::IdInfo>> = Some(&I2C_TABLE); /// const OF_ID_TABLE: Option<of::IdTable<Self::IdInfo>> = Some(&OF_TABLE); /// const ACPI_ID_TABLE: Option<acpi::IdTable<Self::IdInfo>> = Some(&ACPI_TABLE); /// -/// fn probe( -/// _idev: &i2c::I2cClient<Core>, -/// _id_info: Option<&Self::IdInfo>, -/// ) -> impl PinInit<Self, Error> { +/// fn probe<'bound>( +/// _idev: &'bound i2c::I2cClient<Core<'_>>, +/// _id_info: Option<&'bound Self::IdInfo>, +/// ) -> impl PinInit<Self::Data<'bound>, Error> + 'bound { /// Err(ENODEV) /// } /// -/// fn shutdown(_idev: &i2c::I2cClient<Core>, this: Pin<&Self>) { +/// fn shutdown<'bound>( +/// _idev: &'bound i2c::I2cClient<Core<'_>>, +/// this: Pin<&Self::Data<'bound>>, +/// ) { /// } /// } ///``` -pub trait Driver: Send { +pub trait Driver { /// The type holding information about each device id supported by the driver. // TODO: Use `associated_type_defaults` once stabilized: // @@ -318,6 +322,9 @@ pub trait Driver: Send { // ``` type IdInfo: 'static; + /// The type of the driver's bus device private data. + type Data<'bound>: Send + 'bound; + /// The table of device ids supported by the driver. const I2C_ID_TABLE: Option<IdTable<Self::IdInfo>> = None; @@ -331,10 +338,10 @@ pub trait Driver: Send { /// /// Called when a new i2c client is added or discovered. /// Implementers should attempt to initialize the client here. - fn probe( - dev: &I2cClient<device::Core>, - id_info: Option<&Self::IdInfo>, - ) -> impl PinInit<Self, Error>; + fn probe<'bound>( + dev: &'bound I2cClient<device::Core<'_>>, + id_info: Option<&'bound Self::IdInfo>, + ) -> impl PinInit<Self::Data<'bound>, Error> + 'bound; /// I2C driver shutdown. /// @@ -346,8 +353,8 @@ pub trait Driver: Send { /// /// This callback is distinct from final resource cleanup, as the driver instance remains valid /// after it returns. Any deallocation or teardown of driver-owned resources should instead be - /// handled in `Self::drop`. - fn shutdown(dev: &I2cClient<device::Core>, this: Pin<&Self>) { + /// handled in `Drop`. + fn shutdown<'bound>(dev: &'bound I2cClient<device::Core<'_>>, this: Pin<&Self::Data<'bound>>) { let _ = (dev, this); } @@ -360,8 +367,8 @@ pub trait Driver: Send { /// `&Device<Core>` or `&Device<Bound>` reference. For instance, drivers may try to perform I/O /// operations to gracefully tear down the device. /// - /// Otherwise, release operations for driver resources should be performed in `Self::drop`. - fn unbind(dev: &I2cClient<device::Core>, this: Pin<&Self>) { + /// Otherwise, release operations for driver resources should be performed in `Drop`. + fn unbind<'bound>(dev: &'bound I2cClient<device::Core<'_>>, this: Pin<&Self::Data<'bound>>) { let _ = (dev, this); } } @@ -405,7 +412,9 @@ impl I2cAdapter { // SAFETY: `adapter` is non-null and points to a live `i2c_adapter`. // `I2cAdapter` is #[repr(transparent)], so this cast is valid. - Ok(unsafe { (&*adapter.as_ptr().cast::<I2cAdapter<device::Normal>>()).into() }) + // `i2c_get_adapter` returned the adapter with an incremented refcount, which we pass to + // the `ARef`. + Ok(unsafe { ARef::from_raw(adapter.cast::<I2cAdapter<device::Normal>>()) }) } } diff --git a/rust/kernel/init.rs b/rust/kernel/init.rs index 7a0d4559d7b5..05a12e869a57 100644 --- a/rust/kernel/init.rs +++ b/rust/kernel/init.rs @@ -151,6 +151,7 @@ pub trait InPlaceInit<T>: Sized { /// type. /// /// If `T: !Unpin` it will not be able to move afterwards. + #[inline] fn pin_init<E>(init: impl PinInit<T, E>, flags: Flags) -> error::Result<Self::PinnedSelf> where Error: From<E>, @@ -168,6 +169,7 @@ pub trait InPlaceInit<T>: Sized { E: From<AllocError>; /// Use the given initializer to in-place initialize a `T`. + #[inline] fn init<E>(init: impl Init<T, E>, flags: Flags) -> error::Result<Self> where Error: From<E>, diff --git a/rust/kernel/io/mem.rs b/rust/kernel/io/mem.rs index 7dc78d547f7a..fc2a3e24f8d5 100644 --- a/rust/kernel/io/mem.rs +++ b/rust/kernel/io/mem.rs @@ -62,33 +62,31 @@ impl<'a> IoRequest<'a> { /// /// impl platform::Driver for SampleDriver { /// # type IdInfo = (); + /// # type Data<'bound> = Self; /// - /// fn probe( - /// pdev: &platform::Device<Core>, - /// info: Option<&Self::IdInfo>, - /// ) -> impl PinInit<Self, Error> { + /// fn probe<'bound>( + /// pdev: &'bound platform::Device<Core<'_>>, + /// info: Option<&'bound Self::IdInfo>, + /// ) -> impl PinInit<Self, Error> + 'bound { /// let offset = 0; // Some offset. /// /// // If the size is known at compile time, use [`Self::iomap_sized`]. /// // /// // No runtime checks will apply when reading and writing. /// let request = pdev.io_request_by_index(0).ok_or(ENODEV)?; - /// let iomem = request.iomap_sized::<42>(); - /// let iomem = KBox::pin_init(iomem, GFP_KERNEL)?; - /// - /// let io = iomem.access(pdev.as_ref())?; + /// let iomem = request.iomap_sized::<42>()?; /// /// // Read and write a 32-bit value at `offset`. - /// let data = io.read32(offset); + /// let data = iomem.read32(offset); /// - /// io.write32(data, offset); + /// iomem.write32(data, offset); /// /// # Ok(SampleDriver) /// } /// } /// ``` - pub fn iomap_sized<const SIZE: usize>(self) -> impl PinInit<Devres<IoMem<SIZE>>, Error> + 'a { - IoMem::new(self) + pub fn iomap_sized<const SIZE: usize>(self) -> Result<IoMem<'a, SIZE>> { + IoMem::ioremap(self.device, self.resource) } /// Same as [`Self::iomap_sized`] but with exclusive access to the @@ -97,10 +95,8 @@ impl<'a> IoRequest<'a> { /// This uses the [`ioremap()`] C API. /// /// [`ioremap()`]: https://docs.kernel.org/driver-api/device-io.html#getting-access-to-the-device - pub fn iomap_exclusive_sized<const SIZE: usize>( - self, - ) -> impl PinInit<Devres<ExclusiveIoMem<SIZE>>, Error> + 'a { - ExclusiveIoMem::new(self) + pub fn iomap_exclusive_sized<const SIZE: usize>(self) -> Result<ExclusiveIoMem<'a, SIZE>> { + ExclusiveIoMem::ioremap(self.device, self.resource) } /// Maps an [`IoRequest`] where the size is not known at compile time, @@ -126,11 +122,12 @@ impl<'a> IoRequest<'a> { /// /// impl platform::Driver for SampleDriver { /// # type IdInfo = (); + /// # type Data<'bound> = Self; /// - /// fn probe( - /// pdev: &platform::Device<Core>, - /// info: Option<&Self::IdInfo>, - /// ) -> impl PinInit<Self, Error> { + /// fn probe<'bound>( + /// pdev: &'bound platform::Device<Core<'_>>, + /// info: Option<&'bound Self::IdInfo>, + /// ) -> impl PinInit<Self, Error> + 'bound { /// let offset = 0; // Some offset. /// /// // Unlike [`Self::iomap_sized`], here the size of the memory region @@ -138,27 +135,24 @@ impl<'a> IoRequest<'a> { /// // family of functions should be used, leading to runtime checks on every /// // access. /// let request = pdev.io_request_by_index(0).ok_or(ENODEV)?; - /// let iomem = request.iomap(); - /// let iomem = KBox::pin_init(iomem, GFP_KERNEL)?; - /// - /// let io = iomem.access(pdev.as_ref())?; + /// let iomem = request.iomap()?; /// - /// let data = io.try_read32(offset)?; + /// let data = iomem.try_read32(offset)?; /// - /// io.try_write32(data, offset)?; + /// iomem.try_write32(data, offset)?; /// /// # Ok(SampleDriver) /// } /// } /// ``` - pub fn iomap(self) -> impl PinInit<Devres<IoMem<0>>, Error> + 'a { - Self::iomap_sized::<0>(self) + pub fn iomap(self) -> Result<IoMem<'a>> { + self.iomap_sized::<0>() } /// Same as [`Self::iomap`] but with exclusive access to the underlying /// region. - pub fn iomap_exclusive(self) -> impl PinInit<Devres<ExclusiveIoMem<0>>, Error> + 'a { - Self::iomap_exclusive_sized::<0>(self) + pub fn iomap_exclusive(self) -> Result<ExclusiveIoMem<'a, 0>> { + self.iomap_exclusive_sized::<0>() } } @@ -167,9 +161,9 @@ impl<'a> IoRequest<'a> { /// # Invariants /// /// - [`ExclusiveIoMem`] has exclusive access to the underlying [`IoMem`]. -pub struct ExclusiveIoMem<const SIZE: usize> { +pub struct ExclusiveIoMem<'a, const SIZE: usize> { /// The underlying `IoMem` instance. - iomem: IoMem<SIZE>, + iomem: IoMem<'a, SIZE>, /// The region abstraction. This represents exclusive access to the /// range represented by the underlying `iomem`. @@ -178,9 +172,9 @@ pub struct ExclusiveIoMem<const SIZE: usize> { _region: Region, } -impl<const SIZE: usize> ExclusiveIoMem<SIZE> { +impl<'a, const SIZE: usize> ExclusiveIoMem<'a, SIZE> { /// Creates a new `ExclusiveIoMem` instance. - fn ioremap(resource: &Resource) -> Result<Self> { + fn ioremap(dev: &'a Device<Bound>, resource: &Resource) -> Result<Self> { let start = resource.start(); let size = resource.size(); let name = resource.name().unwrap_or_default(); @@ -194,26 +188,29 @@ impl<const SIZE: usize> ExclusiveIoMem<SIZE> { ) .ok_or(EBUSY)?; - let iomem = IoMem::ioremap(resource)?; + let iomem = IoMem::ioremap(dev, resource)?; - let iomem = ExclusiveIoMem { + Ok(ExclusiveIoMem { iomem, _region: region, - }; - - Ok(iomem) + }) } - /// Creates a new `ExclusiveIoMem` instance from a previously acquired [`IoRequest`]. - pub fn new<'a>(io_request: IoRequest<'a>) -> impl PinInit<Devres<Self>, Error> + 'a { - let dev = io_request.device; - let res = io_request.resource; - - Devres::new(dev, Self::ioremap(res)) + /// Consume the `ExclusiveIoMem` and register it as a device-managed resource. + /// + /// The returned `Devres<ExclusiveIoMem<'static, SIZE>>` can outlive the original lifetime + /// `'a`. Access to the I/O memory is revoked when the device is unbound. + pub fn into_devres(self) -> Result<Devres<ExclusiveIoMem<'static, SIZE>>> { + // SAFETY: Casting to `'static` is sound because `Devres` guarantees the + // `ExclusiveIoMem` does not actually outlive the device -- access is revoked and the + // resource is released when the device is unbound. + let iomem: ExclusiveIoMem<'static, SIZE> = unsafe { core::mem::transmute(self) }; + let dev = iomem.iomem.dev; + Devres::new(dev, iomem) } } -impl<const SIZE: usize> Deref for ExclusiveIoMem<SIZE> { +impl<const SIZE: usize> Deref for ExclusiveIoMem<'_, SIZE> { type Target = Mmio<SIZE>; fn deref(&self) -> &Self::Target { @@ -230,12 +227,13 @@ impl<const SIZE: usize> Deref for ExclusiveIoMem<SIZE> { /// /// [`IoMem`] always holds an [`MmioRaw`] instance that holds a valid pointer to the /// start of the I/O memory mapped region. -pub struct IoMem<const SIZE: usize = 0> { +pub struct IoMem<'a, const SIZE: usize = 0> { + dev: &'a Device<Bound>, io: MmioRaw<SIZE>, } -impl<const SIZE: usize> IoMem<SIZE> { - fn ioremap(resource: &Resource) -> Result<Self> { +impl<'a, const SIZE: usize> IoMem<'a, SIZE> { + fn ioremap(dev: &'a Device<Bound>, resource: &Resource) -> Result<Self> { // Note: Some ioremap() implementations use types that depend on the CPU // word width rather than the bus address width. // @@ -267,28 +265,33 @@ impl<const SIZE: usize> IoMem<SIZE> { } let io = MmioRaw::new(addr as usize, size)?; - let io = IoMem { io }; - Ok(io) + Ok(IoMem { dev, io }) } - /// Creates a new `IoMem` instance from a previously acquired [`IoRequest`]. - pub fn new<'a>(io_request: IoRequest<'a>) -> impl PinInit<Devres<Self>, Error> + 'a { - let dev = io_request.device; - let res = io_request.resource; - - Devres::new(dev, Self::ioremap(res)) + /// Consume the `IoMem` and register it as a device-managed resource. + /// + /// The returned `Devres<IoMem<'static, SIZE>>` can outlive the original + /// lifetime `'a`. Access to the I/O memory is revoked when the device + /// is unbound. + pub fn into_devres(self) -> Result<Devres<IoMem<'static, SIZE>>> { + // SAFETY: Casting to `'static` is sound because `Devres` guarantees the `IoMem` does not + // actually outlive the device -- access is revoked and the resource is released when the + // device is unbound. + let iomem: IoMem<'static, SIZE> = unsafe { core::mem::transmute(self) }; + let dev = iomem.dev; + Devres::new(dev, iomem) } } -impl<const SIZE: usize> Drop for IoMem<SIZE> { +impl<const SIZE: usize> Drop for IoMem<'_, SIZE> { fn drop(&mut self) { // SAFETY: Safe as by the invariant of `Io`. unsafe { bindings::iounmap(self.io.addr() as *mut c_void) } } } -impl<const SIZE: usize> Deref for IoMem<SIZE> { +impl<const SIZE: usize> Deref for IoMem<'_, SIZE> { type Target = Mmio<SIZE>; fn deref(&self) -> &Self::Target { diff --git a/rust/kernel/io/register.rs b/rust/kernel/io/register.rs index abc49926abfe..f924c7c7c1db 100644 --- a/rust/kernel/io/register.rs +++ b/rust/kernel/io/register.rs @@ -108,9 +108,10 @@ use core::marker::PhantomData; -use crate::io::IoLoc; - -use kernel::build_assert; +use crate::{ + build_assert::build_assert, + io::IoLoc, // +}; /// Trait implemented by all registers. pub trait Register: Sized { @@ -872,7 +873,7 @@ macro_rules! register { @reg $(#[$attr:meta])* $vis:vis $name:ident ($storage:ty) [ $size:expr, stride = $stride:expr ] @ $offset:literal { $($fields:tt)* } ) => { - ::kernel::static_assert!(::core::mem::size_of::<$storage>() <= $stride); + $crate::build_assert::static_assert!(::core::mem::size_of::<$storage>() <= $stride); $crate::register!(@bitfield $(#[$attr])* $vis struct $name($storage) { $($fields)* }); $crate::register!(@io_base $name($storage) @ $offset); @@ -895,7 +896,9 @@ macro_rules! register { @reg $(#[$attr:meta])* $vis:vis $name:ident ($storage:ty) => $alias:ident [ $idx:expr ] { $($fields:tt)* } ) => { - ::kernel::static_assert!($idx < <$alias as $crate::io::register::RegisterArray>::SIZE); + $crate::build_assert::static_assert!( + $idx < <$alias as $crate::io::register::RegisterArray>::SIZE + ); $crate::register!(@bitfield $(#[$attr])* $vis struct $name($storage) { $($fields)* }); $crate::register!( @@ -912,7 +915,7 @@ macro_rules! register { [ $size:expr, stride = $stride:expr ] @ $base:ident + $offset:literal { $($fields:tt)* } ) => { - ::kernel::static_assert!(::core::mem::size_of::<$storage>() <= $stride); + $crate::build_assert::static_assert!(::core::mem::size_of::<$storage>() <= $stride); $crate::register!(@bitfield $(#[$attr])* $vis struct $name($storage) { $($fields)* }); $crate::register!(@io_base $name($storage) @ $offset); @@ -938,7 +941,9 @@ macro_rules! register { @reg $(#[$attr:meta])* $vis:vis $name:ident ($storage:ty) => $base:ident + $alias:ident [ $idx:expr ] { $($fields:tt)* } ) => { - ::kernel::static_assert!($idx < <$alias as $crate::io::register::RegisterArray>::SIZE); + $crate::build_assert::static_assert!( + $idx < <$alias as $crate::io::register::RegisterArray>::SIZE + ); $crate::register!(@bitfield $(#[$attr])* $vis struct $name($storage) { $($fields)* }); $crate::register!( @@ -956,11 +961,10 @@ macro_rules! register { ( @bitfield $(#[$attr:meta])* $vis:vis struct $name:ident($storage:ty) { $($fields:tt)* } ) => { - $crate::register!(@bitfield_core + $crate::bitfield!( #[allow(non_camel_case_types)] - $(#[$attr])* $vis $name $storage + $(#[$attr])* $vis struct $name($storage) { $($fields)* } ); - $crate::register!(@bitfield_fields $vis $name $storage { $($fields)* }); }; // Implementations shared by all registers types. @@ -1016,245 +1020,4 @@ macro_rules! register { impl $crate::io::register::RelativeRegisterArray for $name {} }; - - // Defines the wrapper `$name` type and its conversions from/to the storage type. - (@bitfield_core $(#[$attr:meta])* $vis:vis $name:ident $storage:ty) => { - $(#[$attr])* - #[repr(transparent)] - #[derive(Clone, Copy, PartialEq, Eq)] - $vis struct $name { - inner: $storage, - } - - #[allow(dead_code)] - impl $name { - /// Creates a bitfield from a raw value. - #[inline(always)] - $vis const fn from_raw(value: $storage) -> Self { - Self{ inner: value } - } - - /// Turns this bitfield into its raw value. - /// - /// This is similar to the [`From`] implementation, but is shorter to invoke in - /// most cases. - #[inline(always)] - $vis const fn into_raw(self) -> $storage { - self.inner - } - } - - // SAFETY: `$storage` is `Zeroable` and `$name` is transparent. - unsafe impl ::pin_init::Zeroable for $name {} - - impl ::core::convert::From<$name> for $storage { - #[inline(always)] - fn from(val: $name) -> $storage { - val.into_raw() - } - } - - impl ::core::convert::From<$storage> for $name { - #[inline(always)] - fn from(val: $storage) -> $name { - Self::from_raw(val) - } - } - }; - - // Definitions requiring knowledge of individual fields: private and public field accessors, - // and `Debug` implementation. - (@bitfield_fields $vis:vis $name:ident $storage:ty { - $($(#[doc = $doc:expr])* $hi:literal:$lo:literal $field:ident - $(?=> $try_into_type:ty)? - $(=> $into_type:ty)? - ; - )* - } - ) => { - #[allow(dead_code)] - impl $name { - $( - $crate::register!(@private_field_accessors $vis $name $storage : $hi:$lo $field); - $crate::register!( - @public_field_accessors $(#[doc = $doc])* $vis $name $storage : $hi:$lo $field - $(?=> $try_into_type)? - $(=> $into_type)? - ); - )* - } - - $crate::register!(@debug $name { $($field;)* }); - }; - - // Private field accessors working with the exact `Bounded` type for the field. - ( - @private_field_accessors $vis:vis $name:ident $storage:ty : $hi:tt:$lo:tt $field:ident - ) => { - ::kernel::macros::paste!( - $vis const [<$field:upper _RANGE>]: ::core::ops::RangeInclusive<u8> = $lo..=$hi; - $vis const [<$field:upper _MASK>]: $storage = - ((((1 << $hi) - 1) << 1) + 1) - ((1 << $lo) - 1); - $vis const [<$field:upper _SHIFT>]: u32 = $lo; - ); - - ::kernel::macros::paste!( - fn [<__ $field>](self) -> - ::kernel::num::Bounded<$storage, { $hi + 1 - $lo }> { - // Left shift to align the field's MSB with the storage MSB. - const ALIGN_TOP: u32 = $storage::BITS - ($hi + 1); - // Right shift to move the top-aligned field to bit 0 of the storage. - const ALIGN_BOTTOM: u32 = ALIGN_TOP + $lo; - - // Extract the field using two shifts. `Bounded::shr` produces the correctly-sized - // output type. - let val = ::kernel::num::Bounded::<$storage, { $storage::BITS }>::from( - self.inner << ALIGN_TOP - ); - val.shr::<ALIGN_BOTTOM, { $hi + 1 - $lo } >() - } - - const fn [<__with_ $field>]( - mut self, - value: ::kernel::num::Bounded<$storage, { $hi + 1 - $lo }>, - ) -> Self - { - const MASK: $storage = <$name>::[<$field:upper _MASK>]; - const SHIFT: u32 = <$name>::[<$field:upper _SHIFT>]; - - let value = value.get() << SHIFT; - self.inner = (self.inner & !MASK) | value; - - self - } - ); - }; - - // Public accessors for fields infallibly (`=>`) converted to a type. - ( - @public_field_accessors $(#[doc = $doc:expr])* $vis:vis $name:ident $storage:ty : - $hi:literal:$lo:literal $field:ident => $into_type:ty - ) => { - ::kernel::macros::paste!( - - $(#[doc = $doc])* - #[doc = "Returns the value of this field."] - #[inline(always)] - $vis fn $field(self) -> $into_type - { - self.[<__ $field>]().into() - } - - $(#[doc = $doc])* - #[doc = "Sets this field to the given `value`."] - #[inline(always)] - $vis fn [<with_ $field>](self, value: $into_type) -> Self - { - self.[<__with_ $field>](value.into()) - } - - ); - }; - - // Public accessors for fields fallibly (`?=>`) converted to a type. - ( - @public_field_accessors $(#[doc = $doc:expr])* $vis:vis $name:ident $storage:ty : - $hi:tt:$lo:tt $field:ident ?=> $try_into_type:ty - ) => { - ::kernel::macros::paste!( - - $(#[doc = $doc])* - #[doc = "Returns the value of this field."] - #[inline(always)] - $vis fn $field(self) -> - Result< - $try_into_type, - <$try_into_type as ::core::convert::TryFrom< - ::kernel::num::Bounded<$storage, { $hi + 1 - $lo }> - >>::Error - > - { - self.[<__ $field>]().try_into() - } - - $(#[doc = $doc])* - #[doc = "Sets this field to the given `value`."] - #[inline(always)] - $vis fn [<with_ $field>](self, value: $try_into_type) -> Self - { - self.[<__with_ $field>](value.into()) - } - - ); - }; - - // Public accessors for fields not converted to a type. - ( - @public_field_accessors $(#[doc = $doc:expr])* $vis:vis $name:ident $storage:ty : - $hi:tt:$lo:tt $field:ident - ) => { - ::kernel::macros::paste!( - - $(#[doc = $doc])* - #[doc = "Returns the value of this field."] - #[inline(always)] - $vis fn $field(self) -> - ::kernel::num::Bounded<$storage, { $hi + 1 - $lo }> - { - self.[<__ $field>]() - } - - $(#[doc = $doc])* - #[doc = "Sets this field to the compile-time constant `VALUE`."] - #[inline(always)] - $vis const fn [<with_const_ $field>]<const VALUE: $storage>(self) -> Self { - self.[<__with_ $field>]( - ::kernel::num::Bounded::<$storage, { $hi + 1 - $lo }>::new::<VALUE>() - ) - } - - $(#[doc = $doc])* - #[doc = "Sets this field to the given `value`."] - #[inline(always)] - $vis fn [<with_ $field>]<T>( - self, - value: T, - ) -> Self - where T: Into<::kernel::num::Bounded<$storage, { $hi + 1 - $lo }>>, - { - self.[<__with_ $field>](value.into()) - } - - $(#[doc = $doc])* - #[doc = "Tries to set this field to `value`, returning an error if it is out of range."] - #[inline(always)] - $vis fn [<try_with_ $field>]<T>( - self, - value: T, - ) -> ::kernel::error::Result<Self> - where T: ::kernel::num::TryIntoBounded<$storage, { $hi + 1 - $lo }>, - { - Ok( - self.[<__with_ $field>]( - value.try_into_bounded().ok_or(::kernel::error::code::EOVERFLOW)? - ) - ) - } - - ); - }; - - // `Debug` implementation. - (@debug $name:ident { $($field:ident;)* }) => { - impl ::kernel::fmt::Debug for $name { - fn fmt(&self, f: &mut ::kernel::fmt::Formatter<'_>) -> ::kernel::fmt::Result { - f.debug_struct(stringify!($name)) - .field("<raw>", &::kernel::prelude::fmt!("{:#x}", self.inner)) - $( - .field(stringify!($field), &self.$field()) - )* - .finish() - } - } - }; } diff --git a/rust/kernel/io/resource.rs b/rust/kernel/io/resource.rs index b7ac9faf141d..17b0c174cfc5 100644 --- a/rust/kernel/io/resource.rs +++ b/rust/kernel/io/resource.rs @@ -229,7 +229,7 @@ impl Flags { // Always inline to optimize out error path of `build_assert`. #[inline(always)] const fn new(value: u32) -> Self { - crate::build_assert!(value as u64 <= c_ulong::MAX as u64); + build_assert!(value as u64 <= c_ulong::MAX as u64); Flags(value as c_ulong) } } diff --git a/rust/kernel/ioctl.rs b/rust/kernel/ioctl.rs index 2fc7662339e5..5bb5b48cf949 100644 --- a/rust/kernel/ioctl.rs +++ b/rust/kernel/ioctl.rs @@ -6,7 +6,7 @@ #![expect(non_snake_case)] -use crate::build_assert; +use crate::build_assert::build_assert; /// Build an ioctl number, analogous to the C macro of the same name. #[inline(always)] diff --git a/rust/kernel/kunit.rs b/rust/kernel/kunit.rs index a1edf7491579..cdee5f27bd7f 100644 --- a/rust/kernel/kunit.rs +++ b/rust/kernel/kunit.rs @@ -329,6 +329,7 @@ pub fn in_kunit_test() -> bool { !unsafe { bindings::kunit_get_current_test() }.is_null() } +#[cfg(CONFIG_RUST_KUNIT_SELFTEST)] #[kunit_tests(rust_kernel_kunit)] mod tests { use super::*; diff --git a/rust/kernel/lib.rs b/rust/kernel/lib.rs index b72b2fbe046d..9512af7156df 100644 --- a/rust/kernel/lib.rs +++ b/rust/kernel/lib.rs @@ -44,6 +44,7 @@ pub mod acpi; pub mod alloc; #[cfg(CONFIG_AUXILIARY_BUS)] pub mod auxiliary; +pub mod bitfield; pub mod bitmap; pub mod bits; #[cfg(CONFIG_BLOCK)] diff --git a/rust/kernel/miscdevice.rs b/rust/kernel/miscdevice.rs index c3c2052c9206..83ce50def5ac 100644 --- a/rust/kernel/miscdevice.rs +++ b/rust/kernel/miscdevice.rs @@ -11,16 +11,27 @@ use crate::{ bindings, device::Device, - error::{to_result, Error, Result, VTABLE_DEFAULT_ERROR}, - ffi::{c_int, c_long, c_uint, c_ulong}, - fs::{File, Kiocb}, - iov::{IovIterDest, IovIterSource}, + error::{ + to_result, + VTABLE_DEFAULT_ERROR, // + }, + fs::{ + File, + Kiocb, // + }, + iov::{ + IovIterDest, + IovIterSource, // + }, mm::virt::VmaNew, prelude::*, seq_file::SeqFile, - types::{ForeignOwnable, Opaque}, + types::{ + ForeignOwnable, + Opaque, // + }, }; -use core::{marker::PhantomData, pin::Pin}; +use core::marker::PhantomData; /// Options for creating a misc device. #[derive(Copy, Clone)] diff --git a/rust/kernel/module_param.rs b/rust/kernel/module_param.rs index 6a8a7a875643..6541af218390 100644 --- a/rust/kernel/module_param.rs +++ b/rust/kernel/module_param.rs @@ -62,8 +62,7 @@ where // NOTE: If we start supporting arguments without values, val _is_ allowed // to be null here. if val.is_null() { - // TODO: Use pr_warn_once available. - crate::pr_warn!("Null pointer passed to `module_param::set_param`"); + crate::pr_warn_once!("Null pointer passed to `module_param::set_param`\n"); return EINVAL.to_errno(); } diff --git a/rust/kernel/net/phy/reg.rs b/rust/kernel/net/phy/reg.rs index a7db0064cb7d..80e22c264ea8 100644 --- a/rust/kernel/net/phy/reg.rs +++ b/rust/kernel/net/phy/reg.rs @@ -9,9 +9,11 @@ //! defined in IEEE 802.3. use super::Device; -use crate::build_assert; -use crate::error::*; -use crate::uapi; +use crate::{ + build_assert::build_assert, + error::*, + uapi, // +}; mod private { /// Marker that a trait cannot be implemented outside of this crate diff --git a/rust/kernel/num/bounded.rs b/rust/kernel/num/bounded.rs index f9f90d6ec482..dafe77782d79 100644 --- a/rust/kernel/num/bounded.rs +++ b/rust/kernel/num/bounded.rs @@ -364,7 +364,7 @@ where // Always inline to optimize out error path of `build_assert`. #[inline(always)] pub fn from_expr(expr: T) -> Self { - crate::build_assert!( + crate::build_assert::build_assert!( fits_within(expr, N), "Requested value larger than maximal representable value." ); diff --git a/rust/kernel/opp.rs b/rust/kernel/opp.rs index a760fac28765..62e44676125d 100644 --- a/rust/kernel/opp.rs +++ b/rust/kernel/opp.rs @@ -1042,11 +1042,13 @@ unsafe impl Sync for OPP {} /// SAFETY: The type invariants guarantee that [`OPP`] is always refcounted. unsafe impl AlwaysRefCounted for OPP { + #[inline] fn inc_ref(&self) { // SAFETY: The existence of a shared reference means that the refcount is nonzero. unsafe { bindings::dev_pm_opp_get(self.0.get()) }; } + #[inline] unsafe fn dec_ref(obj: ptr::NonNull<Self>) { // SAFETY: The safety requirements guarantee that the refcount is nonzero. unsafe { bindings::dev_pm_opp_put(obj.cast().as_ptr()) } @@ -1095,6 +1097,7 @@ impl OPP { } /// Returns the frequency of an [`OPP`]. + #[inline] pub fn freq(&self, index: Option<u32>) -> Hertz { let index = index.unwrap_or(0); diff --git a/rust/kernel/page.rs b/rust/kernel/page.rs index adecb200c654..1c0796ea229f 100644 --- a/rust/kernel/page.rs +++ b/rust/kernel/page.rs @@ -3,17 +3,25 @@ //! Kernel page allocation and management. use crate::{ - alloc::{AllocError, Flags}, + alloc::{ + AllocError, + Flags, // + }, bindings, - error::code::*, - error::Result, - uaccess::UserSliceReader, + error::{ + code::*, + Result, // + }, + uaccess::UserSliceReader, // }; use core::{ marker::PhantomData, mem::ManuallyDrop, ops::Deref, - ptr::{self, NonNull}, + ptr::{ + self, + NonNull, // + }, // }; /// A bitwise shift for the page size. @@ -193,6 +201,7 @@ impl Page { } /// Get the node id containing this page. + #[inline] pub fn nid(&self) -> i32 { // SAFETY: Always safe to call with a valid page. unsafe { bindings::page_to_nid(self.as_ptr()) } diff --git a/rust/kernel/pci.rs b/rust/kernel/pci.rs index af74ddff6114..5071cae6543f 100644 --- a/rust/kernel/pci.rs +++ b/rust/kernel/pci.rs @@ -59,18 +59,18 @@ pub struct Adapter<T: Driver>(T); // SAFETY: // - `bindings::pci_driver` is a C type declared as `repr(C)`. -// - `T` is the type of the driver's device private data. +// - `T::Data` is the type of the driver's device private data. // - `struct pci_driver` embeds a `struct device_driver`. // - `DEVICE_DRIVER_OFFSET` is the correct byte offset to the embedded `struct device_driver`. -unsafe impl<T: Driver + 'static> driver::DriverLayout for Adapter<T> { +unsafe impl<T: Driver> driver::DriverLayout for Adapter<T> { type DriverType = bindings::pci_driver; - type DriverData = T; + type DriverData<'bound> = T::Data<'bound>; const DEVICE_DRIVER_OFFSET: usize = core::mem::offset_of!(Self::DriverType, driver); } // SAFETY: A call to `unregister` for a given instance of `DriverType` is guaranteed to be valid if // a preceding call to `register` has been successful. -unsafe impl<T: Driver + 'static> driver::RegistrationOps for Adapter<T> { +unsafe impl<T: Driver> driver::RegistrationOps for Adapter<T> { unsafe fn register( pdrv: &Opaque<Self::DriverType>, name: &'static CStr, @@ -96,7 +96,7 @@ unsafe impl<T: Driver + 'static> driver::RegistrationOps for Adapter<T> { } } -impl<T: Driver + 'static> Adapter<T> { +impl<T: Driver> Adapter<T> { extern "C" fn probe_callback( pdev: *mut bindings::pci_dev, id: *const bindings::pci_device_id, @@ -105,7 +105,7 @@ impl<T: Driver + 'static> Adapter<T> { // `struct pci_dev`. // // INVARIANT: `pdev` is valid for the duration of `probe_callback()`. - let pdev = unsafe { &*pdev.cast::<Device<device::CoreInternal>>() }; + let pdev = unsafe { &*pdev.cast::<Device<device::CoreInternal<'_>>>() }; // SAFETY: `DeviceId` is a `#[repr(transparent)]` wrapper of `struct pci_device_id` and // does not add additional invariants, so it's safe to transmute. @@ -125,12 +125,12 @@ impl<T: Driver + 'static> Adapter<T> { // `struct pci_dev`. // // INVARIANT: `pdev` is valid for the duration of `remove_callback()`. - let pdev = unsafe { &*pdev.cast::<Device<device::CoreInternal>>() }; + let pdev = unsafe { &*pdev.cast::<Device<device::CoreInternal<'_>>>() }; // SAFETY: `remove_callback` is only ever called after a successful call to // `probe_callback`, hence it's guaranteed that `Device::set_drvdata()` has been called - // and stored a `Pin<KBox<T>>`. - let data = unsafe { pdev.as_ref().drvdata_borrow::<T>() }; + // and stored a `Pin<KBox<T::Data<'_>>>`. + let data = unsafe { pdev.as_ref().drvdata_borrow::<T::Data<'_>>() }; T::unbind(pdev, data); } @@ -279,19 +279,20 @@ macro_rules! pci_device_table { /// /// impl pci::Driver for MyDriver { /// type IdInfo = (); +/// type Data<'bound> = Self; /// const ID_TABLE: pci::IdTable<Self::IdInfo> = &PCI_TABLE; /// -/// fn probe( -/// _pdev: &pci::Device<Core>, -/// _id_info: &Self::IdInfo, -/// ) -> impl PinInit<Self, Error> { +/// fn probe<'bound>( +/// _pdev: &'bound pci::Device<Core<'_>>, +/// _id_info: &'bound Self::IdInfo, +/// ) -> impl PinInit<Self::Data<'bound>, Error> + 'bound { /// Err(ENODEV) /// } /// } ///``` /// Drivers must implement this trait in order to get a PCI driver registered. Please refer to the /// `Adapter` documentation for an example. -pub trait Driver: Send { +pub trait Driver { /// The type holding information about each device id supported by the driver. // TODO: Use `associated_type_defaults` once stabilized: // @@ -300,6 +301,9 @@ pub trait Driver: Send { // ``` type IdInfo: 'static; + /// The type of the driver's bus device private data. + type Data<'bound>: Send + 'bound; + /// The table of device ids supported by the driver. const ID_TABLE: IdTable<Self::IdInfo>; @@ -307,7 +311,10 @@ pub trait Driver: Send { /// /// Called when a new pci device is added or discovered. Implementers should /// attempt to initialize the device here. - fn probe(dev: &Device<device::Core>, id_info: &Self::IdInfo) -> impl PinInit<Self, Error>; + fn probe<'bound>( + dev: &'bound Device<device::Core<'_>>, + id_info: &'bound Self::IdInfo, + ) -> impl PinInit<Self::Data<'bound>, Error> + 'bound; /// PCI driver unbind. /// @@ -318,8 +325,8 @@ pub trait Driver: Send { /// `&Device<Core>` or `&Device<Bound>` reference. For instance, drivers may try to perform I/O /// operations to gracefully tear down the device. /// - /// Otherwise, release operations for driver resources should be performed in `Self::drop`. - fn unbind(dev: &Device<device::Core>, this: Pin<&Self>) { + /// Otherwise, release operations for driver resources should be performed in `Drop`. + fn unbind<'bound>(dev: &'bound Device<device::Core<'_>>, this: Pin<&Self::Data<'bound>>) { let _ = (dev, this); } } @@ -354,7 +361,7 @@ impl Device { /// /// ``` /// # use kernel::{device::Core, pci::{self, Vendor}, prelude::*}; - /// fn log_device_info(pdev: &pci::Device<Core>) -> Result { + /// fn log_device_info(pdev: &pci::Device<Core<'_>>) -> Result { /// // Get an instance of `Vendor`. /// let vendor = pdev.vendor_id(); /// dev_info!( @@ -445,7 +452,7 @@ impl Device { } } -impl Device<device::Core> { +impl<'a> Device<device::Core<'a>> { /// Enable memory resources for this device. pub fn enable_device_mem(&self) -> Result { // SAFETY: `self.as_raw` is guaranteed to be a pointer to a valid `struct pci_dev`. @@ -471,7 +478,7 @@ unsafe impl<Ctx: device::DeviceContext> device::AsBusDevice<Ctx> for Device<Ctx> kernel::impl_device_context_deref!(unsafe { Device }); kernel::impl_device_context_into_aref!(Device); -impl crate::dma::Device for Device<device::Core> {} +impl<'a> crate::dma::Device<'a> for Device<device::Core<'a>> {} // SAFETY: Instances of `Device` are always reference-counted. unsafe impl crate::sync::aref::AlwaysRefCounted for Device { @@ -523,3 +530,7 @@ unsafe impl Send for Device {} // SAFETY: `Device` can be shared among threads because all methods of `Device` // (i.e. `Device<Normal>) are thread safe. unsafe impl Sync for Device {} + +// SAFETY: Same as `Device<Normal>` -- the underlying `struct pci_dev` is the same; +// `Bound` is a zero-sized type-state marker that does not affect thread safety. +unsafe impl Sync for Device<device::Bound> {} diff --git a/rust/kernel/pci/id.rs b/rust/kernel/pci/id.rs index 50005d176561..dbaf301666e7 100644 --- a/rust/kernel/pci/id.rs +++ b/rust/kernel/pci/id.rs @@ -19,7 +19,7 @@ use crate::{ /// /// ``` /// # use kernel::{device::Core, pci::{self, Class}, prelude::*}; -/// fn probe_device(pdev: &pci::Device<Core>) -> Result { +/// fn probe_device(pdev: &pci::Device<Core<'_>>) -> Result { /// let pci_class = pdev.pci_class(); /// dev_info!( /// pdev, diff --git a/rust/kernel/pci/io.rs b/rust/kernel/pci/io.rs index ae78676c927f..0461e01aaa20 100644 --- a/rust/kernel/pci/io.rs +++ b/rust/kernel/pci/io.rs @@ -14,8 +14,7 @@ use crate::{ Mmio, MmioRaw, // }, - prelude::*, - sync::aref::ARef, // + prelude::*, // }; use core::{ marker::PhantomData, @@ -146,14 +145,18 @@ impl<'a, S: ConfigSpaceKind> IoKnownSize for ConfigSpace<'a, S> { /// /// `Bar` always holds an `IoRaw` instance that holds a valid pointer to the start of the I/O /// memory mapped PCI BAR and its size. -pub struct Bar<const SIZE: usize = 0> { - pdev: ARef<Device>, +pub struct Bar<'a, const SIZE: usize = 0> { + pdev: &'a Device<device::Bound>, io: MmioRaw<SIZE>, num: i32, } -impl<const SIZE: usize> Bar<SIZE> { - pub(super) fn new(pdev: &Device, num: u32, name: &CStr) -> Result<Self> { +impl<'a, const SIZE: usize> Bar<'a, SIZE> { + pub(super) fn new( + pdev: &'a Device<device::Bound>, + num: u32, + name: &'static CStr, + ) -> Result<Self> { let len = pdev.resource_len(num)?; if len == 0 { return Err(ENOMEM); @@ -196,11 +199,7 @@ impl<const SIZE: usize> Bar<SIZE> { } }; - Ok(Bar { - pdev: pdev.into(), - io, - num, - }) + Ok(Bar { pdev, io, num }) } /// # Safety @@ -219,11 +218,24 @@ impl<const SIZE: usize> Bar<SIZE> { fn release(&self) { // SAFETY: The safety requirements are guaranteed by the type invariant of `self.pdev`. - unsafe { Self::do_release(&self.pdev, self.io.addr(), self.num) }; + unsafe { Self::do_release(self.pdev, self.io.addr(), self.num) }; + } + + /// Consume the `Bar` and register it as a device-managed resource. + /// + /// The returned `Devres<Bar<'static, SIZE>>` can outlive the original lifetime `'a`. Access + /// to the BAR is revoked when the device is unbound. + pub fn into_devres(self) -> Result<Devres<Bar<'static, SIZE>>> { + // SAFETY: Casting to `'static` is sound because `Devres` guarantees the `Bar` does not + // actually outlive the device -- access is revoked and the resource is released when the + // device is unbound. + let bar: Bar<'static, SIZE> = unsafe { core::mem::transmute(self) }; + let pdev = bar.pdev; + Devres::new(pdev.as_ref(), bar) } } -impl Bar { +impl Bar<'_> { #[inline] pub(super) fn index_is_valid(index: u32) -> bool { // A `struct pci_dev` owns an array of resources with at most `PCI_NUM_RESOURCES` entries. @@ -231,13 +243,13 @@ impl Bar { } } -impl<const SIZE: usize> Drop for Bar<SIZE> { +impl<const SIZE: usize> Drop for Bar<'_, SIZE> { fn drop(&mut self) { self.release(); } } -impl<const SIZE: usize> Deref for Bar<SIZE> { +impl<const SIZE: usize> Deref for Bar<'_, SIZE> { type Target = Mmio<SIZE>; fn deref(&self) -> &Self::Target { @@ -252,17 +264,13 @@ impl Device<device::Bound> { pub fn iomap_region_sized<'a, const SIZE: usize>( &'a self, bar: u32, - name: &'a CStr, - ) -> impl PinInit<Devres<Bar<SIZE>>, Error> + 'a { - Devres::new(self.as_ref(), Bar::<SIZE>::new(self, bar, name)) + name: &'static CStr, + ) -> Result<Bar<'a, SIZE>> { + Bar::new(self, bar, name) } /// Maps an entire PCI BAR after performing a region-request on it. - pub fn iomap_region<'a>( - &'a self, - bar: u32, - name: &'a CStr, - ) -> impl PinInit<Devres<Bar>, Error> + 'a { + pub fn iomap_region<'a>(&'a self, bar: u32, name: &'static CStr) -> Result<Bar<'a>> { self.iomap_region_sized::<0>(bar, name) } diff --git a/rust/kernel/platform.rs b/rust/kernel/platform.rs index 8917d4ee499f..9b362e0495d3 100644 --- a/rust/kernel/platform.rs +++ b/rust/kernel/platform.rs @@ -45,18 +45,18 @@ pub struct Adapter<T: Driver>(T); // SAFETY: // - `bindings::platform_driver` is a C type declared as `repr(C)`. -// - `T` is the type of the driver's device private data. +// - `T::Data` is the type of the driver's device private data. // - `struct platform_driver` embeds a `struct device_driver`. // - `DEVICE_DRIVER_OFFSET` is the correct byte offset to the embedded `struct device_driver`. -unsafe impl<T: Driver + 'static> driver::DriverLayout for Adapter<T> { +unsafe impl<T: Driver> driver::DriverLayout for Adapter<T> { type DriverType = bindings::platform_driver; - type DriverData = T; + type DriverData<'bound> = T::Data<'bound>; const DEVICE_DRIVER_OFFSET: usize = core::mem::offset_of!(Self::DriverType, driver); } // SAFETY: A call to `unregister` for a given instance of `DriverType` is guaranteed to be valid if // a preceding call to `register` has been successful. -unsafe impl<T: Driver + 'static> driver::RegistrationOps for Adapter<T> { +unsafe impl<T: Driver> driver::RegistrationOps for Adapter<T> { unsafe fn register( pdrv: &Opaque<Self::DriverType>, name: &'static CStr, @@ -82,7 +82,9 @@ unsafe impl<T: Driver + 'static> driver::RegistrationOps for Adapter<T> { } // SAFETY: `pdrv` is guaranteed to be a valid `DriverType`. - to_result(unsafe { bindings::__platform_driver_register(pdrv.get(), module.0) }) + to_result(unsafe { + bindings::__platform_driver_register(pdrv.get(), module.0, name.as_char_ptr()) + }) } unsafe fn unregister(pdrv: &Opaque<Self::DriverType>) { @@ -91,13 +93,13 @@ unsafe impl<T: Driver + 'static> driver::RegistrationOps for Adapter<T> { } } -impl<T: Driver + 'static> Adapter<T> { +impl<T: Driver> Adapter<T> { extern "C" fn probe_callback(pdev: *mut bindings::platform_device) -> kernel::ffi::c_int { // SAFETY: The platform bus only ever calls the probe callback with a valid pointer to a // `struct platform_device`. // // INVARIANT: `pdev` is valid for the duration of `probe_callback()`. - let pdev = unsafe { &*pdev.cast::<Device<device::CoreInternal>>() }; + let pdev = unsafe { &*pdev.cast::<Device<device::CoreInternal<'_>>>() }; let info = <Self as driver::Adapter>::id_info(pdev.as_ref()); from_result(|| { @@ -113,18 +115,18 @@ impl<T: Driver + 'static> Adapter<T> { // `struct platform_device`. // // INVARIANT: `pdev` is valid for the duration of `remove_callback()`. - let pdev = unsafe { &*pdev.cast::<Device<device::CoreInternal>>() }; + let pdev = unsafe { &*pdev.cast::<Device<device::CoreInternal<'_>>>() }; // SAFETY: `remove_callback` is only ever called after a successful call to // `probe_callback`, hence it's guaranteed that `Device::set_drvdata()` has been called - // and stored a `Pin<KBox<T>>`. - let data = unsafe { pdev.as_ref().drvdata_borrow::<T>() }; + // and stored a `Pin<KBox<T::Data<'_>>>`. + let data = unsafe { pdev.as_ref().drvdata_borrow::<T::Data<'_>>() }; T::unbind(pdev, data); } } -impl<T: Driver + 'static> driver::Adapter for Adapter<T> { +impl<T: Driver> driver::Adapter for Adapter<T> { type IdInfo = T::IdInfo; fn of_id_table() -> Option<of::IdTable<Self::IdInfo>> { @@ -192,18 +194,19 @@ macro_rules! module_platform_driver { /// /// impl platform::Driver for MyDriver { /// type IdInfo = (); +/// type Data<'bound> = Self; /// const OF_ID_TABLE: Option<of::IdTable<Self::IdInfo>> = Some(&OF_TABLE); /// const ACPI_ID_TABLE: Option<acpi::IdTable<Self::IdInfo>> = Some(&ACPI_TABLE); /// -/// fn probe( -/// _pdev: &platform::Device<Core>, -/// _id_info: Option<&Self::IdInfo>, -/// ) -> impl PinInit<Self, Error> { +/// fn probe<'bound>( +/// _pdev: &'bound platform::Device<Core<'_>>, +/// _id_info: Option<&'bound Self::IdInfo>, +/// ) -> impl PinInit<Self::Data<'bound>, Error> + 'bound { /// Err(ENODEV) /// } /// } ///``` -pub trait Driver: Send { +pub trait Driver { /// The type holding driver private data about each device id supported by the driver. // TODO: Use associated_type_defaults once stabilized: // @@ -212,6 +215,9 @@ pub trait Driver: Send { // ``` type IdInfo: 'static; + /// The type of the driver's bus device private data. + type Data<'bound>: Send + 'bound; + /// The table of OF device ids supported by the driver. const OF_ID_TABLE: Option<of::IdTable<Self::IdInfo>> = None; @@ -222,10 +228,10 @@ pub trait Driver: Send { /// /// Called when a new platform device is added or discovered. /// Implementers should attempt to initialize the device here. - fn probe( - dev: &Device<device::Core>, - id_info: Option<&Self::IdInfo>, - ) -> impl PinInit<Self, Error>; + fn probe<'bound>( + dev: &'bound Device<device::Core<'_>>, + id_info: Option<&'bound Self::IdInfo>, + ) -> impl PinInit<Self::Data<'bound>, Error> + 'bound; /// Platform driver unbind. /// @@ -236,8 +242,8 @@ pub trait Driver: Send { /// `&Device<Core>` or `&Device<Bound>` reference. For instance, drivers may try to perform I/O /// operations to gracefully tear down the device. /// - /// Otherwise, release operations for driver resources should be performed in `Self::drop`. - fn unbind(dev: &Device<device::Core>, this: Pin<&Self>) { + /// Otherwise, release operations for driver resources should be performed in `Drop`. + fn unbind<'bound>(dev: &'bound Device<device::Core<'_>>, this: Pin<&Self::Data<'bound>>) { let _ = (dev, this); } } @@ -509,7 +515,7 @@ impl Device<Bound> { kernel::impl_device_context_deref!(unsafe { Device }); kernel::impl_device_context_into_aref!(Device); -impl crate::dma::Device for Device<device::Core> {} +impl<'a> crate::dma::Device<'a> for Device<device::Core<'a>> {} // SAFETY: Instances of `Device` are always reference-counted. unsafe impl crate::sync::aref::AlwaysRefCounted for Device { @@ -561,3 +567,7 @@ unsafe impl Send for Device {} // SAFETY: `Device` can be shared among threads because all methods of `Device` // (i.e. `Device<Normal>) are thread safe. unsafe impl Sync for Device {} + +// SAFETY: Same as `Device<Normal>` -- the underlying `struct platform_device` is the same; +// `Bound` is a zero-sized type-state marker that does not affect thread safety. +unsafe impl Sync for Device<device::Bound> {} diff --git a/rust/kernel/prelude.rs b/rust/kernel/prelude.rs index 44edf72a4a24..ca396f1f78a6 100644 --- a/rust/kernel/prelude.rs +++ b/rust/kernel/prelude.rs @@ -22,6 +22,7 @@ pub use core::{ pin::Pin, // }; +#[doc(no_inline)] pub use ::ffi::{ c_char, c_int, @@ -47,6 +48,7 @@ pub use macros::{ vtable, // }; +#[doc(no_inline)] pub use pin_init::{ init, pin_data, @@ -58,6 +60,19 @@ pub use pin_init::{ Zeroable, // }; +#[doc(no_inline)] +pub use zerocopy::{ + FromBytes, + IntoBytes, // +}; + +#[doc(no_inline)] +pub use zerocopy_derive::{ + FromBytes, + IntoBytes, // +}; + +#[doc(no_inline)] pub use super::{ alloc::{ flags::*, @@ -70,9 +85,12 @@ pub use super::{ VVec, Vec, // }, - build_assert, - build_error, - const_assert, + build_assert::{ + build_assert, + build_error, + const_assert, + static_assert, // + }, current, dev_alert, dev_crit, @@ -96,7 +114,6 @@ pub use super::{ pr_info, pr_notice, pr_warn, - static_assert, str::CStrExt as _, try_init, try_pin_init, diff --git a/rust/kernel/ptr/projection.rs b/rust/kernel/ptr/projection.rs index 140ea8e21617..af72d3b0e2a3 100644 --- a/rust/kernel/ptr/projection.rs +++ b/rust/kernel/ptr/projection.rs @@ -26,14 +26,14 @@ impl From<OutOfBound> for Error { /// /// # Safety /// -/// The implementation of `index` and `get` (if [`Some`] is returned) must ensure that, if provided -/// input pointer `slice` and returned pointer `output`, then: +/// For a given input pointer `slice` and return value `output`, the implementation of `index`, +/// `build_index` and `get` (if [`Some`] is returned) must ensure that: /// - `output` has the same provenance as `slice`; /// - `output.byte_offset_from(slice)` is between 0 to /// `KnownSize::size(slice) - KnownSize::size(output)`. /// -/// This means that if the input pointer is valid, then pointer returned by `get` or `index` is -/// also valid. +/// This means that if the input pointer is valid, then the pointer returned by `get`, `index` +/// or `build_index` is also valid. #[diagnostic::on_unimplemented(message = "`{Self}` cannot be used to index `{T}`")] #[doc(hidden)] pub unsafe trait ProjectIndex<T: ?Sized>: Sized { @@ -42,10 +42,16 @@ pub unsafe trait ProjectIndex<T: ?Sized>: Sized { /// Returns an index-projected pointer, if in bounds. fn get(self, slice: *mut T) -> Option<*mut Self::Output>; + /// Returns an index-projected pointer; panic if out of bounds. + fn index(self, slice: *mut T) -> *mut Self::Output; + /// Returns an index-projected pointer; fail the build if it cannot be proved to be in bounds. #[inline(always)] - fn index(self, slice: *mut T) -> *mut Self::Output { - Self::get(self, slice).unwrap_or_else(|| build_error!()) + fn build_index(self, slice: *mut T) -> *mut Self::Output { + match Self::get(self, slice) { + Some(v) => v, + None => build_error!(), + } } } @@ -67,6 +73,11 @@ where fn index(self, slice: *mut [T; N]) -> *mut Self::Output { <I as ProjectIndex<[T]>>::index(self, slice) } + + #[inline(always)] + fn build_index(self, slice: *mut [T; N]) -> *mut Self::Output { + <I as ProjectIndex<[T]>>::build_index(self, slice) + } } // SAFETY: `get`-returned pointer has the same provenance as `slice` and the offset is checked to @@ -82,6 +93,16 @@ unsafe impl<T> ProjectIndex<[T]> for usize { Some(slice.cast::<T>().wrapping_add(self)) } } + + #[inline(always)] + fn index(self, slice: *mut [T]) -> *mut T { + // Leverage Rust built-in operators for bounds checking. + // SAFETY: All non-null and aligned pointers are valid for ZST read. + let zst_slice = + unsafe { core::slice::from_raw_parts::<()>(core::ptr::dangling(), slice.len()) }; + let () = zst_slice[self]; + slice.cast::<T>().wrapping_add(self) + } } // SAFETY: `get`-returned pointer has the same provenance as `slice` and the offset is checked to @@ -100,6 +121,18 @@ unsafe impl<T> ProjectIndex<[T]> for core::ops::Range<usize> { new_len, )) } + + #[inline(always)] + fn index(self, slice: *mut [T]) -> *mut [T] { + // Leverage Rust built-in operators for bounds checking. + // SAFETY: All non-null and aligned pointers are valid for ZST read. + let zst_slice = + unsafe { core::slice::from_raw_parts::<()>(core::ptr::dangling(), slice.len()) }; + _ = zst_slice[self.clone()]; + + // SAFETY: Bounds checked. + unsafe { self.get(slice).unwrap_unchecked() } + } } // SAFETY: Safety requirement guaranteed by the forwarded impl. @@ -110,6 +143,11 @@ unsafe impl<T> ProjectIndex<[T]> for core::ops::RangeTo<usize> { fn get(self, slice: *mut [T]) -> Option<*mut [T]> { (0..self.end).get(slice) } + + #[inline(always)] + fn index(self, slice: *mut [T]) -> *mut [T] { + (0..self.end).index(slice) + } } // SAFETY: Safety requirement guaranteed by the forwarded impl. @@ -120,6 +158,11 @@ unsafe impl<T> ProjectIndex<[T]> for core::ops::RangeFrom<usize> { fn get(self, slice: *mut [T]) -> Option<*mut [T]> { (self.start..slice.len()).get(slice) } + + #[inline(always)] + fn index(self, slice: *mut [T]) -> *mut [T] { + (self.start..slice.len()).index(slice) + } } // SAFETY: `get` returned the pointer as is, so it always has the same provenance and offset of 0. @@ -130,6 +173,11 @@ unsafe impl<T> ProjectIndex<[T]> for core::ops::RangeFull { fn get(self, slice: *mut [T]) -> Option<*mut [T]> { Some(slice) } + + #[inline(always)] + fn index(self, slice: *mut [T]) -> *mut [T] { + slice + } } /// A helper trait to perform field projection. @@ -207,10 +255,13 @@ unsafe impl<T: Deref> ProjectField<true> for T { /// If a mutable pointer is needed, the macro input can be prefixed with the `mut` keyword, i.e. /// `kernel::ptr::project!(mut ptr, projection)`. By default, a const pointer is created. /// -/// `ptr::project!` macro can perform both fallible indexing and build-time checked indexing. -/// `[index]` form performs build-time bounds checking; if compiler fails to prove `[index]` is in -/// bounds, compilation will fail. `[index]?` can be used to perform runtime bounds checking; -/// `OutOfBound` error is raised via `?` if the index is out of bounds. +/// The `ptr::project!` macro can perform both fallible indexing and build-time checked indexing. +/// The syntax is of the form `[<flavor>: index]` where `flavor` indicates the way of handling +/// index out-of-bounds errors. +/// - `try` will raise an [`OutOfBound`] error (which is convertible to [`ERANGE`]). +/// - `build` will use the [`build_assert!`] mechanism to have the compiler validate the index is +/// in bounds. +/// - `panic` will cause a Rust [`panic!`] if the index goes out of bounds. /// /// # Examples /// @@ -228,17 +279,21 @@ unsafe impl<T: Deref> ProjectField<true> for T { /// } /// ``` /// -/// Index projections are performed with `[index]`: +/// Index projections are performed with `[<flavor>: index]`, where `flavor` is `try`, `build` or +/// `panic`: /// /// ``` /// fn proj(ptr: *const [u8; 32]) -> Result { -/// let field_ptr: *const u8 = kernel::ptr::project!(ptr, [1]); +/// let field_ptr: *const u8 = kernel::ptr::project!(ptr, [build: 1]); /// // The following invocation, if uncommented, would fail the build. /// // -/// // kernel::ptr::project!(ptr, [128]); +/// // kernel::ptr::project!(ptr, [build: 128]); /// /// // This will raise an `OutOfBound` error (which is convertible to `ERANGE`). -/// kernel::ptr::project!(ptr, [128]?); +/// kernel::ptr::project!(ptr, [try: 128]); +/// +/// // This will panic at runtime if executed. +/// kernel::ptr::project!(ptr, [panic: 128]); /// Ok(()) /// } /// ``` @@ -248,7 +303,7 @@ unsafe impl<T: Deref> ProjectField<true> for T { /// ``` /// let ptr: *const [u8; 32] = core::ptr::dangling(); /// let field_ptr: Result<*const u8> = (|| -> Result<_> { -/// Ok(kernel::ptr::project!(ptr, [128]?)) +/// Ok(kernel::ptr::project!(ptr, [try: 128])) /// })(); /// assert!(field_ptr.is_err()); /// ``` @@ -257,7 +312,7 @@ unsafe impl<T: Deref> ProjectField<true> for T { /// /// ``` /// let ptr: *mut [(u8, u16); 32] = core::ptr::dangling_mut(); -/// let field_ptr: *mut u16 = kernel::ptr::project!(mut ptr, [1].1); +/// let field_ptr: *mut u16 = kernel::ptr::project!(mut ptr, [build: 1].1); /// ``` #[macro_export] macro_rules! project_pointer { @@ -280,16 +335,22 @@ macro_rules! project_pointer { $crate::ptr::project!(@gen $ptr, $($rest)*) }; // Fallible index projection. - (@gen $ptr:ident, [$index:expr]? $($rest:tt)*) => { + (@gen $ptr:ident, [try: $index:expr] $($rest:tt)*) => { let $ptr = $crate::ptr::projection::ProjectIndex::get($index, $ptr) .ok_or($crate::ptr::projection::OutOfBound)?; $crate::ptr::project!(@gen $ptr, $($rest)*) }; - // Build-time checked index projection. - (@gen $ptr:ident, [$index:expr] $($rest:tt)*) => { + // Panicking index projection. + (@gen $ptr:ident, [panic: $index:expr] $($rest:tt)*) => { let $ptr = $crate::ptr::projection::ProjectIndex::index($index, $ptr); $crate::ptr::project!(@gen $ptr, $($rest)*) }; + // Build-time checked index projection. + (@gen $ptr:ident, [build: $index:expr] $($rest:tt)*) => { + let $ptr = $crate::ptr::projection::ProjectIndex::build_index($index, $ptr); + $crate::ptr::project!(@gen $ptr, $($rest)*) + }; + (mut $ptr:expr, $($proj:tt)*) => {{ let ptr: *mut _ = $ptr; $crate::ptr::project!(@gen ptr, $($proj)*); diff --git a/rust/kernel/str.rs b/rust/kernel/str.rs index 8311d91549e1..b3caa9a1c898 100644 --- a/rust/kernel/str.rs +++ b/rust/kernel/str.rs @@ -3,14 +3,27 @@ //! String representations. use crate::{ - alloc::{flags::*, AllocError, KVec}, - error::{to_result, Result}, - fmt::{self, Write}, - prelude::*, + alloc::{ + AllocError, + KVec, // + }, + error::{ + to_result, + Result, // + }, + fmt::{ + self, + Write, // + }, + prelude::*, // }; use core::{ marker::PhantomData, - ops::{Deref, DerefMut, Index}, + ops::{ + Deref, + DerefMut, + Index, // + }, // }; pub use crate::prelude::CStr; @@ -415,6 +428,7 @@ macro_rules! c_str { }}; } +#[cfg(CONFIG_RUST_STR_KUNIT_TEST)] #[kunit_tests(rust_kernel_str)] mod tests { use super::*; diff --git a/rust/kernel/sync/arc.rs b/rust/kernel/sync/arc.rs index 18d6c0d62ce0..5ac4961b7cd2 100644 --- a/rust/kernel/sync/arc.rs +++ b/rust/kernel/sync/arc.rs @@ -712,6 +712,7 @@ impl<T> InPlaceInit<T> for UniqueArc<T> { impl<T> InPlaceWrite<T> for UniqueArc<MaybeUninit<T>> { type Initialized = UniqueArc<T>; + #[inline] fn write_init<E>(mut self, init: impl Init<T, E>) -> Result<Self::Initialized, E> { let slot = self.as_mut_ptr(); // SAFETY: When init errors/panics, slot will get deallocated but not dropped, @@ -721,6 +722,7 @@ impl<T> InPlaceWrite<T> for UniqueArc<MaybeUninit<T>> { Ok(unsafe { self.assume_init() }) } + #[inline] fn write_pin_init<E>(mut self, init: impl PinInit<T, E>) -> Result<Pin<Self::Initialized>, E> { let slot = self.as_mut_ptr(); // SAFETY: When init errors/panics, slot will get deallocated but not dropped, @@ -758,6 +760,14 @@ impl<T> UniqueArc<T> { } } +impl<T: ?Sized> UniqueArc<T> { + /// Return a raw pointer to the data in this [`UniqueArc`]. + #[inline] + pub fn as_ptr(this: &Self) -> *const T { + Arc::as_ptr(&this.inner) + } +} + impl<T> UniqueArc<MaybeUninit<T>> { /// Converts a `UniqueArc<MaybeUninit<T>>` into a `UniqueArc<T>` by writing a value into it. pub fn write(mut self, value: T) -> UniqueArc<T> { @@ -782,6 +792,7 @@ impl<T> UniqueArc<MaybeUninit<T>> { } /// Initialize `self` using the given initializer. + #[inline] pub fn init_with<E>(mut self, init: impl Init<T, E>) -> core::result::Result<UniqueArc<T>, E> { // SAFETY: The supplied pointer is valid for initialization. match unsafe { init.__init(self.as_mut_ptr()) } { @@ -792,6 +803,7 @@ impl<T> UniqueArc<MaybeUninit<T>> { } /// Pin-initialize `self` using the given pin-initializer. + #[inline] pub fn pin_init_with<E>( mut self, init: impl PinInit<T, E>, diff --git a/rust/kernel/sync/aref.rs b/rust/kernel/sync/aref.rs index 9989f56d0605..b721b2e00b98 100644 --- a/rust/kernel/sync/aref.rs +++ b/rust/kernel/sync/aref.rs @@ -17,7 +17,12 @@ //! [`Arc`]: crate::sync::Arc //! [`Arc<T>`]: crate::sync::Arc -use core::{marker::PhantomData, mem::ManuallyDrop, ops::Deref, ptr::NonNull}; +use core::{ + marker::PhantomData, + mem::ManuallyDrop, + ops::Deref, + ptr::NonNull, // +}; /// Types that are _always_ reference counted. /// diff --git a/rust/kernel/sync/atomic/internal.rs b/rust/kernel/sync/atomic/internal.rs index ad810c2172ec..9c8a7a203abd 100644 --- a/rust/kernel/sync/atomic/internal.rs +++ b/rust/kernel/sync/atomic/internal.rs @@ -4,8 +4,11 @@ //! //! Provides 1:1 mapping to the C atomic operations. -use crate::bindings; -use crate::macros::paste; +use crate::{ + bindings, + build_assert::static_assert, + macros::paste, // +}; use core::cell::UnsafeCell; use ffi::c_void; @@ -46,7 +49,7 @@ pub trait AtomicImpl: Sized + Copy + private::Sealed { // In the future when a CONFIG_ARCH_SUPPORTS_ATOMIC_RMW=n architecture plans to support Rust, the // load/store helpers that guarantee atomicity against RmW operations (usually via a lock) need to // be added. -crate::static_assert!( +static_assert!( cfg!(CONFIG_ARCH_SUPPORTS_ATOMIC_RMW), "The current implementation of atomic i8/i16/ptr relies on the architecure being \ ARCH_SUPPORTS_ATOMIC_RMW" diff --git a/rust/kernel/sync/atomic/predefine.rs b/rust/kernel/sync/atomic/predefine.rs index 1d53834fcb12..3d63f40791fa 100644 --- a/rust/kernel/sync/atomic/predefine.rs +++ b/rust/kernel/sync/atomic/predefine.rs @@ -2,9 +2,7 @@ //! Pre-defined atomic types -use crate::static_assert; -use core::mem::{align_of, size_of}; -use ffi::c_void; +use crate::prelude::*; // Ensure size and alignment requirements are checked. static_assert!(size_of::<bool>() == size_of::<i8>()); @@ -154,9 +152,8 @@ unsafe impl super::AtomicAdd<usize> for usize { } } -use crate::macros::kunit_tests; - -#[kunit_tests(rust_atomics)] +#[cfg(CONFIG_RUST_ATOMICS_KUNIT_TEST)] +#[macros::kunit_tests(rust_atomics)] mod tests { use super::super::*; diff --git a/rust/kernel/sync/completion.rs b/rust/kernel/sync/completion.rs index c50012a940a3..35ff049ff078 100644 --- a/rust/kernel/sync/completion.rs +++ b/rust/kernel/sync/completion.rs @@ -94,6 +94,7 @@ impl Completion { /// /// This method wakes up all tasks waiting on this completion; after this operation the /// completion is permanently done, i.e. signals all current and future waiters. + #[inline] pub fn complete_all(&self) { // SAFETY: `self.as_raw()` is a pointer to a valid `struct completion`. unsafe { bindings::complete_all(self.as_raw()) }; @@ -105,6 +106,7 @@ impl Completion { /// timeout. /// /// See also [`Completion::complete_all`]. + #[inline] pub fn wait_for_completion(&self) { // SAFETY: `self.as_raw()` is a pointer to a valid `struct completion`. unsafe { bindings::wait_for_completion(self.as_raw()) }; diff --git a/rust/kernel/sync/lock/global.rs b/rust/kernel/sync/lock/global.rs index aecbdc34738f..ec2dd84316fc 100644 --- a/rust/kernel/sync/lock/global.rs +++ b/rust/kernel/sync/lock/global.rs @@ -85,6 +85,7 @@ impl<B: GlobalLockBackend> GlobalLock<B> { } /// Try to lock this global lock. + #[must_use = "if unused, the lock will be immediately unlocked"] #[inline] pub fn try_lock(&'static self) -> Option<GlobalGuard<B>> { Some(GlobalGuard { @@ -96,6 +97,7 @@ impl<B: GlobalLockBackend> GlobalLock<B> { /// A guard for a [`GlobalLock`]. /// /// See [`global_lock!`] for examples. +#[must_use = "the lock unlocks immediately when the guard is unused"] pub struct GlobalGuard<B: GlobalLockBackend> { inner: Guard<'static, B::Item, B::Backend>, } diff --git a/rust/kernel/sync/locked_by.rs b/rust/kernel/sync/locked_by.rs index 61f100a45b35..fb4a1430b3b4 100644 --- a/rust/kernel/sync/locked_by.rs +++ b/rust/kernel/sync/locked_by.rs @@ -3,7 +3,7 @@ //! A wrapper for data protected by a lock that does not wrap it. use super::{lock::Backend, lock::Lock}; -use crate::build_assert; +use crate::build_assert::build_assert; use core::{cell::UnsafeCell, mem::size_of, ptr}; /// Allows access to some data to be serialised by a lock that does not wrap it. diff --git a/rust/kernel/sync/refcount.rs b/rust/kernel/sync/refcount.rs index 6c7ae8b05a0b..23a5d201f343 100644 --- a/rust/kernel/sync/refcount.rs +++ b/rust/kernel/sync/refcount.rs @@ -4,9 +4,11 @@ //! //! C header: [`include/linux/refcount.h`](srctree/include/linux/refcount.h) -use crate::build_assert; -use crate::sync::atomic::Atomic; -use crate::types::Opaque; +use crate::{ + build_assert::build_assert, + sync::atomic::Atomic, + types::Opaque, // +}; /// Atomic reference counter. /// diff --git a/rust/kernel/types.rs b/rust/kernel/types.rs index 4329d3c2c2e5..ac316fd7b538 100644 --- a/rust/kernel/types.rs +++ b/rust/kernel/types.rs @@ -11,6 +11,10 @@ use core::{ }; use pin_init::{PinInit, Wrapper, Zeroable}; +#[doc(hidden)] +pub mod for_lt; +pub use for_lt::ForLt; + /// Used to transfer ownership to and from foreign (non-Rust) languages. /// /// Ownership is transferred from Rust to a foreign language by calling [`Self::into_foreign`] and @@ -27,10 +31,14 @@ pub unsafe trait ForeignOwnable: Sized { const FOREIGN_ALIGN: usize; /// Type used to immutably borrow a value that is currently foreign-owned. - type Borrowed<'a>; + type Borrowed<'a> + where + Self: 'a; /// Type used to mutably borrow a value that is currently foreign-owned. - type BorrowedMut<'a>; + type BorrowedMut<'a> + where + Self: 'a; /// Converts a Rust-owned object to a foreign-owned one. /// diff --git a/rust/kernel/types/for_lt.rs b/rust/kernel/types/for_lt.rs new file mode 100644 index 000000000000..d44323c28e8d --- /dev/null +++ b/rust/kernel/types/for_lt.rs @@ -0,0 +1,122 @@ +// SPDX-License-Identifier: Apache-2.0 OR MIT + +//! Provide implementation and test of the `ForLt` trait and macro. +//! +//! This module is hidden and user should just use `ForLt!` directly. + +use core::marker::PhantomData; + +/// Representation of types generic over a lifetime. +/// +/// The type must be covariant over the generic lifetime, i.e. the lifetime parameter +/// can be soundly shortened. +/// +/// The lifetime involved must be covariant. +/// +/// # Macro +/// +/// It is not recommended to implement this trait directly. `ForLt!` macro is provided to obtain a +/// type that implements this trait. +/// +/// The full syntax is +/// +/// ``` +/// # use kernel::types::ForLt; +/// # fn expect_lt<F: ForLt>() {} +/// # struct TypeThatUse<'a>(&'a ()); +/// # expect_lt::< +/// ForLt!(for<'a> TypeThatUse<'a>) +/// # >(); +/// ``` +/// +/// which gives a type so that `<ForLt!(for<'a> TypeThatUse<'a>) as ForLt>::Of<'b>` +/// is `TypeThatUse<'b>`. +/// +/// You may also use a short-hand syntax which works similar to lifetime elision. +/// The macro also accepts types that do not involve a lifetime at all. +/// +/// ``` +/// # use kernel::types::ForLt; +/// # fn expect_lt<F: ForLt>() {} +/// # struct TypeThatUse<'a>(&'a ()); +/// # expect_lt::< +/// ForLt!(TypeThatUse<'_>) // Equivalent to `ForLt!(for<'a> TypeThatUse<'a>)`. +/// # >(); +/// # expect_lt::< +/// ForLt!(&u32) // Equivalent to `ForLt!(for<'a> &'a u32)`. +/// # >(); +/// # expect_lt::< +/// ForLt!(u32) // Equivalent to `ForLt!(for<'a> u32)`. +/// # >(); +/// ``` +/// +/// The macro will attempt to prove that the type is indeed covariant over the lifetime supplied. +/// When it cannot be syntactically proven, it will emit checks to ask the Rust compiler to prove +/// it. +/// +/// ```ignore,compile_fail +/// # use kernel::types::ForLt; +/// # fn expect_lt<F: ForLt>() {} +/// # expect_lt::< +/// ForLt!(fn(&u32)) // Contravariant, will fail compilation. +/// # >(); +/// ``` +/// +/// There is a limitation if the type refers to generic parameters; if the macro cannot prove the +/// covariance syntactically, the emitted checks will fail the compilation as it needs to refer to +/// the generic parameter but is in a separate item. +/// +/// ``` +/// # use kernel::types::ForLt; +/// fn expect_lt<F: ForLt>() {} +/// # #[allow(clippy::unnecessary_safety_comment, reason = "false positive")] +/// fn generic_fn<T: 'static>() { +/// // Syntactically proven by the macro +/// expect_lt::<ForLt!(&T)>(); +/// // Syntactically proven by the macro +/// expect_lt::<ForLt!(&KBox<T>)>(); +/// // Cannot be syntactically proven, need to check covariance of `KBox` +/// // expect_lt::<ForLt!(&KBox<&T>)>(); +/// } +/// ``` +/// +/// # Safety +/// +/// `Self::Of<'a>` must be covariant over the lifetime `'a`. +pub unsafe trait ForLt { + /// The type parameterized by the lifetime. + type Of<'a>: 'a; + + /// Cast a reference to a shorter lifetime. + #[inline(always)] + fn cast_ref<'r, 'short: 'r, 'long: 'short>(long: &'r Self::Of<'long>) -> &'r Self::Of<'short> { + // SAFETY: This is sound as this trait guarantees covariance. + unsafe { core::mem::transmute(long) } + } +} +pub use macros::ForLt; + +/// This is intended to be an "unsafe-to-refer-to" type. +/// +/// Must only be used by the `ForLt!` macro. +/// +/// `T` is the magic `dyn for<'a> WithLt<'a, TypeThatUse<'a>>` generated by macro. +/// +/// `WF` is a type that the macro can use to assert some specific type is well-formed. +/// +/// `N` is to provide the macro a place to emit arbitrary items, in case it needs to prove +/// additional properties. +#[doc(hidden)] +pub struct UnsafeForLtImpl<T: ?Sized, WF, const N: usize>(PhantomData<(WF, T)>); + +// This is a helper trait for implementation `ForLt` to be able to use HRTB. +#[doc(hidden)] +pub trait WithLt<'a> { + type Of: 'a; +} + +// SAFETY: In `ForLt!` macro, a covariance proof is generated when naming `UnsafeForLtImpl` +// and it will fail to evaluate if the type is not covariant. +unsafe impl<T: ?Sized + for<'a> WithLt<'a>, WF> ForLt for UnsafeForLtImpl<T, WF, 0> { + type Of<'a> = <T as WithLt<'a>>::Of; +} diff --git a/rust/kernel/usb.rs b/rust/kernel/usb.rs index 9c17a672cd27..7aff0c82d0af 100644 --- a/rust/kernel/usb.rs +++ b/rust/kernel/usb.rs @@ -36,18 +36,18 @@ pub struct Adapter<T: Driver>(T); // SAFETY: // - `bindings::usb_driver` is a C type declared as `repr(C)`. -// - `T` is the type of the driver's device private data. +// - `T::Data` is the type of the driver's device private data. // - `struct usb_driver` embeds a `struct device_driver`. // - `DEVICE_DRIVER_OFFSET` is the correct byte offset to the embedded `struct device_driver`. -unsafe impl<T: Driver + 'static> driver::DriverLayout for Adapter<T> { +unsafe impl<T: Driver> driver::DriverLayout for Adapter<T> { type DriverType = bindings::usb_driver; - type DriverData = T; + type DriverData<'bound> = T::Data<'bound>; const DEVICE_DRIVER_OFFSET: usize = core::mem::offset_of!(Self::DriverType, driver); } // SAFETY: A call to `unregister` for a given instance of `DriverType` is guaranteed to be valid if // a preceding call to `register` has been successful. -unsafe impl<T: Driver + 'static> driver::RegistrationOps for Adapter<T> { +unsafe impl<T: Driver> driver::RegistrationOps for Adapter<T> { unsafe fn register( udrv: &Opaque<Self::DriverType>, name: &'static CStr, @@ -73,7 +73,7 @@ unsafe impl<T: Driver + 'static> driver::RegistrationOps for Adapter<T> { } } -impl<T: Driver + 'static> Adapter<T> { +impl<T: Driver> Adapter<T> { extern "C" fn probe_callback( intf: *mut bindings::usb_interface, id: *const bindings::usb_device_id, @@ -82,7 +82,7 @@ impl<T: Driver + 'static> Adapter<T> { // `struct usb_interface` and `struct usb_device_id`. // // INVARIANT: `intf` is valid for the duration of `probe_callback()`. - let intf = unsafe { &*intf.cast::<Interface<device::CoreInternal>>() }; + let intf = unsafe { &*intf.cast::<Interface<device::CoreInternal<'_>>>() }; from_result(|| { // SAFETY: `DeviceId` is a `#[repr(transparent)]` wrapper of `struct usb_device_id` and @@ -92,7 +92,7 @@ impl<T: Driver + 'static> Adapter<T> { let info = T::ID_TABLE.info(id.index()); let data = T::probe(intf, id, info); - let dev: &device::Device<device::CoreInternal> = intf.as_ref(); + let dev: &device::Device<device::CoreInternal<'_>> = intf.as_ref(); dev.set_drvdata(data)?; Ok(0) }) @@ -103,14 +103,14 @@ impl<T: Driver + 'static> Adapter<T> { // `struct usb_interface`. // // INVARIANT: `intf` is valid for the duration of `disconnect_callback()`. - let intf = unsafe { &*intf.cast::<Interface<device::CoreInternal>>() }; + let intf = unsafe { &*intf.cast::<Interface<device::CoreInternal<'_>>>() }; - let dev: &device::Device<device::CoreInternal> = intf.as_ref(); + let dev: &device::Device<device::CoreInternal<'_>> = intf.as_ref(); // SAFETY: `disconnect_callback` is only ever called after a successful call to // `probe_callback`, hence it's guaranteed that `Device::set_drvdata()` has been called - // and stored a `Pin<KBox<T>>`. - let data = unsafe { dev.drvdata_borrow::<T>() }; + // and stored a `Pin<KBox<T::Data<'_>>>`. + let data = unsafe { dev.drvdata_borrow::<T::Data<'_>>() }; T::disconnect(intf, data); } @@ -287,23 +287,31 @@ macro_rules! usb_device_table { /// /// impl usb::Driver for MyDriver { /// type IdInfo = (); +/// type Data<'bound> = Self; /// const ID_TABLE: usb::IdTable<Self::IdInfo> = &USB_TABLE; /// -/// fn probe( -/// _interface: &usb::Interface<Core>, +/// fn probe<'bound>( +/// _interface: &'bound usb::Interface<Core<'_>>, /// _id: &usb::DeviceId, -/// _info: &Self::IdInfo, -/// ) -> impl PinInit<Self, Error> { +/// _info: &'bound Self::IdInfo, +/// ) -> impl PinInit<Self::Data<'bound>, Error> + 'bound { /// Err(ENODEV) /// } /// -/// fn disconnect(_interface: &usb::Interface<Core>, _data: Pin<&Self>) {} +/// fn disconnect<'bound>( +/// _interface: &'bound usb::Interface<Core<'_>>, +/// _data: Pin<&Self::Data<'bound>>, +/// ) { +/// } /// } ///``` pub trait Driver { /// The type holding information about each one of the device ids supported by the driver. type IdInfo: 'static; + /// The type of the driver's bus device private data. + type Data<'bound>: Send + 'bound; + /// The table of device ids supported by the driver. const ID_TABLE: IdTable<Self::IdInfo>; @@ -311,16 +319,19 @@ pub trait Driver { /// /// Called when a new USB interface is bound to this driver. /// Implementers should attempt to initialize the interface here. - fn probe( - interface: &Interface<device::Core>, + fn probe<'bound>( + interface: &'bound Interface<device::Core<'_>>, id: &DeviceId, - id_info: &Self::IdInfo, - ) -> impl PinInit<Self, Error>; + id_info: &'bound Self::IdInfo, + ) -> impl PinInit<Self::Data<'bound>, Error> + 'bound; /// USB driver disconnect. /// /// Called when the USB interface is about to be unbound from this driver. - fn disconnect(interface: &Interface<device::Core>, data: Pin<&Self>); + fn disconnect<'bound>( + interface: &'bound Interface<device::Core<'_>>, + data: Pin<&Self::Data<'bound>>, + ); } /// A USB interface. @@ -464,6 +475,10 @@ unsafe impl Send for Device {} // allow any mutation through a shared reference. unsafe impl Sync for Device {} +// SAFETY: Same as `Device<Normal>` -- the underlying `struct usb_device` is the same; +// `Bound` is a zero-sized type-state marker that does not affect thread safety. +unsafe impl Sync for Device<device::Bound> {} + /// Declares a kernel module that exposes a single USB driver. /// /// # Examples diff --git a/rust/kernel/xarray.rs b/rust/kernel/xarray.rs index 46e5f43223fe..987c9c0c2198 100644 --- a/rust/kernel/xarray.rs +++ b/rust/kernel/xarray.rs @@ -5,10 +5,16 @@ //! C header: [`include/linux/xarray.h`](srctree/include/linux/xarray.h) use crate::{ - alloc, bindings, build_assert, + alloc, + bindings, + build_assert::build_assert, error::{Error, Result}, ffi::c_void, - types::{ForeignOwnable, NotThreadSafe, Opaque}, + types::{ + ForeignOwnable, + NotThreadSafe, + Opaque, // + }, // }; use core::{iter, marker::PhantomData, pin::Pin, ptr::NonNull}; use pin_init::{pin_data, pin_init, pinned_drop, PinInit}; |
