summaryrefslogtreecommitdiff
path: root/rust/kernel/device.rs
diff options
context:
space:
mode:
Diffstat (limited to 'rust/kernel/device.rs')
-rw-r--r--rust/kernel/device.rs84
1 files changed, 81 insertions, 3 deletions
diff --git a/rust/kernel/device.rs b/rust/kernel/device.rs
index 106aa57a6385..1a307be953c2 100644
--- a/rust/kernel/device.rs
+++ b/rust/kernel/device.rs
@@ -10,13 +10,16 @@ use crate::{
sync::aref::ARef,
types::{ForeignOwnable, Opaque},
};
-use core::{marker::PhantomData, ptr};
+use core::{any::TypeId, marker::PhantomData, ptr};
#[cfg(CONFIG_PRINTK)]
use crate::c_str;
pub mod property;
+// Assert that we can `read()` / `write()` a `TypeId` instance from / into `struct driver_type`.
+static_assert!(core::mem::size_of::<bindings::driver_type>() >= core::mem::size_of::<TypeId>());
+
/// The core representation of a device in the kernel's driver model.
///
/// This structure represents the Rust abstraction for a C `struct device`. A [`Device`] can either
@@ -198,12 +201,29 @@ impl Device {
}
impl Device<CoreInternal> {
+ fn set_type_id<T: 'static>(&self) {
+ // SAFETY: By the type invariants, `self.as_raw()` is a valid pointer to a `struct device`.
+ let private = unsafe { (*self.as_raw()).p };
+
+ // SAFETY: For a bound device (implied by the `CoreInternal` device context), `private` is
+ // guaranteed to be a valid pointer to a `struct device_private`.
+ let driver_type = unsafe { &raw mut (*private).driver_type };
+
+ // SAFETY: `driver_type` is valid for (unaligned) writes of a `TypeId`.
+ unsafe {
+ driver_type
+ .cast::<TypeId>()
+ .write_unaligned(TypeId::of::<T>())
+ };
+ }
+
/// Store a pointer to the bound driver's private data.
pub fn set_drvdata<T: 'static>(&self, data: impl PinInit<T, Error>) -> Result {
let data = KBox::pin_init(data, GFP_KERNEL)?;
// SAFETY: By the type invariants, `self.as_raw()` is a valid pointer to a `struct device`.
unsafe { bindings::dev_set_drvdata(self.as_raw(), data.into_foreign().cast()) };
+ self.set_type_id::<T>();
Ok(())
}
@@ -219,6 +239,9 @@ impl Device<CoreInternal> {
// SAFETY: By the type invariants, `self.as_raw()` is a valid pointer to a `struct device`.
let ptr = unsafe { bindings::dev_get_drvdata(self.as_raw()) };
+ // SAFETY: By the type invariants, `self.as_raw()` is a valid pointer to a `struct device`.
+ unsafe { bindings::dev_set_drvdata(self.as_raw(), core::ptr::null_mut()) };
+
// SAFETY:
// - By the safety requirements of this function, `ptr` comes from a previous call to
// `into_foreign()`.
@@ -235,7 +258,23 @@ impl Device<CoreInternal> {
/// [`Device::drvdata_obtain`].
/// - The type `T` must match the type of the `ForeignOwnable` previously stored by
/// [`Device::set_drvdata`].
- pub unsafe fn drvdata_borrow<T: ForeignOwnable>(&self) -> T::Borrowed<'_> {
+ pub unsafe fn drvdata_borrow<T: 'static>(&self) -> Pin<&T> {
+ // SAFETY: `drvdata_unchecked()` has the exact same safety requirements as the ones
+ // required by this method.
+ unsafe { self.drvdata_unchecked() }
+ }
+}
+
+impl Device<Bound> {
+ /// Borrow the driver's private data bound to this [`Device`].
+ ///
+ /// # Safety
+ ///
+ /// - Must only be called after a preceding call to [`Device::set_drvdata`] and before
+ /// [`Device::drvdata_obtain`].
+ /// - The type `T` must match the type of the `ForeignOwnable` previously stored by
+ /// [`Device::set_drvdata`].
+ unsafe fn drvdata_unchecked<T: 'static>(&self) -> Pin<&T> {
// SAFETY: By the type invariants, `self.as_raw()` is a valid pointer to a `struct device`.
let ptr = unsafe { bindings::dev_get_drvdata(self.as_raw()) };
@@ -244,7 +283,46 @@ impl Device<CoreInternal> {
// `into_foreign()`.
// - `dev_get_drvdata()` guarantees to return the same pointer given to `dev_set_drvdata()`
// in `into_foreign()`.
- unsafe { T::borrow(ptr.cast()) }
+ unsafe { Pin::<KBox<T>>::borrow(ptr.cast()) }
+ }
+
+ fn match_type_id<T: 'static>(&self) -> Result {
+ // SAFETY: By the type invariants, `self.as_raw()` is a valid pointer to a `struct device`.
+ let private = unsafe { (*self.as_raw()).p };
+
+ // SAFETY: For a bound device, `private` is guaranteed to be a valid pointer to a
+ // `struct device_private`.
+ let driver_type = unsafe { &raw mut (*private).driver_type };
+
+ // SAFETY:
+ // - `driver_type` is valid for (unaligned) reads of a `TypeId`.
+ // - A bound device guarantees that `driver_type` contains a valid `TypeId` value.
+ let type_id = unsafe { driver_type.cast::<TypeId>().read_unaligned() };
+
+ if type_id != TypeId::of::<T>() {
+ return Err(EINVAL);
+ }
+
+ Ok(())
+ }
+
+ /// Access a driver's private data.
+ ///
+ /// Returns a pinned reference to the driver's private data or [`EINVAL`] if it doesn't match
+ /// the asserted type `T`.
+ pub fn drvdata<T: 'static>(&self) -> Result<Pin<&T>> {
+ // SAFETY: By the type invariants, `self.as_raw()` is a valid pointer to a `struct device`.
+ if unsafe { bindings::dev_get_drvdata(self.as_raw()) }.is_null() {
+ return Err(ENOENT);
+ }
+
+ self.match_type_id::<T>()?;
+
+ // SAFETY:
+ // - The above check of `dev_get_drvdata()` guarantees that we are called after
+ // `set_drvdata()` and before `drvdata_obtain()`.
+ // - We've just checked that the type of the driver's private data is in fact `T`.
+ Ok(unsafe { self.drvdata_unchecked() })
}
}