From 4656b43c8e623d8b40ff253e4722623f1e342f0e Mon Sep 17 00:00:00 2001 From: Lyude Paul <lyude@redhat.com> Date: Tue, 16 Jul 2024 17:31:37 -0400 Subject: [PATCH] squash! rust: hrtimer: Add Timer::cancel() * I totally forgot that we actually need to both add a callback for handling cleaning up a timer that didn't execute as a result of being cancelled --- rust/kernel/hrtimer.rs | 48 ++++++++++++++++++++++++++++++++++++------ 1 file changed, 42 insertions(+), 6 deletions(-) diff --git a/rust/kernel/hrtimer.rs b/rust/kernel/hrtimer.rs index c93dcc224330a..b0d7c25e8ccbd 100644 --- a/rust/kernel/hrtimer.rs +++ b/rust/kernel/hrtimer.rs @@ -55,7 +55,7 @@ //! //! C header: [`include/linux/hrtimer.h`](srctree/include/linux/hrtimer.h) -use core::{marker::PhantomData, pin::Pin}; +use core::{marker::PhantomData, pin::Pin, ptr::{self, NonNull}}; use crate::{init::PinInit, prelude::*, sync::Arc, types::Opaque}; @@ -110,11 +110,28 @@ impl<T: TimerCallback> Timer<T> { /// Cancel a timer and wait for its callback to finish executing /// /// Returns `true` if the timer was active. - pub fn cancel(&self) -> bool { - // SAFETY: By struct invariant `self.timer` was initialized by - // `hrtimer_init` so by C API contract it is safe to call - // `hrtimer_cancel`. - unsafe { bindings::hrtimer_cancel(self.timer.get()) != 0 } + pub fn cancel(&self) -> bool + where + T: HasTimer<T> + { + let timer_ptr = self.timer.get(); + + // SAFETY: By struct invariant `self.timer` was initialized by `hrtimer_init` so by C API + // contract it is safe to call `hrtimer_cancel`. + let cancelled = unsafe { bindings::hrtimer_cancel(timer_ptr) != 0 }; + + // If the callback cancelled before execution, cleanup its resources + if cancelled { + // SAFETY: `timer` is valid for as long as our interface is exposed + let data_ptr = unsafe { + NonNull::new_unchecked(T::timer_container_of(timer_ptr.cast())) + }; + + // SAFETY: We're in the `hrtimer` crate + unsafe { <T::Receiver as RawTimer>::cleanup(data_ptr) }; + } + + cancelled } /// Return the current time from the base timer for this timer @@ -153,6 +170,16 @@ impl<T> PinnedDrop for Timer<T> { pub trait RawTimer: Sync { /// Schedule the timer after `expires` time units fn schedule(self, expires: u64); + + /// Cleanup resources for the timer if it was cancelled before executing + /// + /// # Safety + /// + /// This should only be called by the [`kernel::hrtimer`] crate + unsafe fn cleanup<T>(data: ptr::NonNull<T>) + where + Self: Sized + RawTimerCallback, + T: TimerCallback<Receiver = Self>; } /// Implemented by structs that contain timer nodes. @@ -283,6 +310,15 @@ where ); } } + + unsafe fn cleanup<D>(data: ptr::NonNull<D>) + where + Self: Sized + RawTimerCallback, + D: TimerCallback<Receiver = Self> + { + // SAFETY: The caller guarantees that `self` points to a valid instance of `Arc<T>` + drop(unsafe { Arc::from_raw(data.as_ptr()) }) + } } impl<T> kernel::hrtimer::RawTimerCallback for Arc<T> -- GitLab