From 4656b43c8e623d8b40ff253e4722623f1e342f0e Mon Sep 17 00:00:00 2001
From: Lyude Paul <lyude@redhat.com>
Date: Tue, 16 Jul 2024 17:31:37 -0400
Subject: [PATCH] squash! rust: hrtimer: Add Timer::cancel()

* I totally forgot that we actually need to both add a callback for
  handling cleaning up a timer that didn't execute as a result of being
  cancelled
---
 rust/kernel/hrtimer.rs | 48 ++++++++++++++++++++++++++++++++++++------
 1 file changed, 42 insertions(+), 6 deletions(-)

diff --git a/rust/kernel/hrtimer.rs b/rust/kernel/hrtimer.rs
index c93dcc224330a..b0d7c25e8ccbd 100644
--- a/rust/kernel/hrtimer.rs
+++ b/rust/kernel/hrtimer.rs
@@ -55,7 +55,7 @@
 //!
 //! C header: [`include/linux/hrtimer.h`](srctree/include/linux/hrtimer.h)
 
-use core::{marker::PhantomData, pin::Pin};
+use core::{marker::PhantomData, pin::Pin, ptr::{self, NonNull}};
 
 use crate::{init::PinInit, prelude::*, sync::Arc, types::Opaque};
 
@@ -110,11 +110,28 @@ impl<T: TimerCallback> Timer<T> {
     /// Cancel a timer and wait for its callback to finish executing
     ///
     /// Returns `true` if the timer was active.
-    pub fn cancel(&self) -> bool {
-        // SAFETY: By struct invariant `self.timer` was initialized by
-        // `hrtimer_init` so by C API contract it is safe to call
-        // `hrtimer_cancel`.
-        unsafe { bindings::hrtimer_cancel(self.timer.get()) != 0 }
+    pub fn cancel(&self) -> bool
+    where
+        T: HasTimer<T>
+    {
+        let timer_ptr = self.timer.get();
+
+        // SAFETY: By struct invariant `self.timer` was initialized by `hrtimer_init` so by C API
+        // contract it is safe to call `hrtimer_cancel`.
+        let cancelled = unsafe { bindings::hrtimer_cancel(timer_ptr) != 0 };
+
+        // If the callback cancelled before execution, cleanup its resources
+        if cancelled {
+            // SAFETY: `timer` is valid for as long as our interface is exposed
+            let data_ptr = unsafe {
+                NonNull::new_unchecked(T::timer_container_of(timer_ptr.cast()))
+            };
+
+            // SAFETY: We're in the `hrtimer` crate
+            unsafe { <T::Receiver as RawTimer>::cleanup(data_ptr) };
+        }
+
+        cancelled
     }
 
     /// Return the current time from the base timer for this timer
@@ -153,6 +170,16 @@ impl<T> PinnedDrop for Timer<T> {
 pub trait RawTimer: Sync {
     /// Schedule the timer after `expires` time units
     fn schedule(self, expires: u64);
+
+    /// Cleanup resources for the timer if it was cancelled before executing
+    ///
+    /// # Safety
+    ///
+    /// This should only be called by the [`kernel::hrtimer`] crate
+    unsafe fn cleanup<T>(data: ptr::NonNull<T>)
+    where
+        Self: Sized + RawTimerCallback,
+        T: TimerCallback<Receiver = Self>;
 }
 
 /// Implemented by structs that contain timer nodes.
@@ -283,6 +310,15 @@ where
             );
         }
     }
+
+    unsafe fn cleanup<D>(data: ptr::NonNull<D>)
+    where
+        Self: Sized + RawTimerCallback,
+        D: TimerCallback<Receiver = Self>
+    {
+        // SAFETY: The caller guarantees that `self` points to a valid instance of `Arc<T>`
+        drop(unsafe { Arc::from_raw(data.as_ptr()) })
+    }
 }
 
 impl<T> kernel::hrtimer::RawTimerCallback for Arc<T>
-- 
GitLab