diff --git a/rust/kernel/drm/gpuvm.rs b/rust/kernel/drm/gpuvm.rs
index ed52253fa5c18d5476ceebde46a2e78c1a98e4eb..1376fefe24ff0625cfb4c4e36b26a25a3c22f46e 100644
--- a/rust/kernel/drm/gpuvm.rs
+++ b/rust/kernel/drm/gpuvm.rs
@@ -50,6 +50,7 @@ pub trait DriverGpuVm: Sized {
     fn step_remap(
         self: &mut UpdatingGpuVm<'_, Self>,
         op: &mut OpReMap<Self>,
+        vm_bo: &GpuVmBo<Self>,
         ctx: &mut Self::StepContext,
     ) -> Result;
 }
@@ -207,17 +208,6 @@ impl<T: DriverGpuVm> GpuVa<T> {
     pub fn range(&self) -> u64 {
         self.gpuva.va.range
     }
-    pub fn vm_bo(&self) -> ARef<GpuVmBo<T>> {
-        // SAFETY: Container invariant is guaranteed for ops structs created for our types.
-        let p =
-            unsafe { crate::container_of!(self.gpuva.vm_bo, GpuVmBo<T>, bo) as *mut GpuVmBo<T> };
-
-        // SAFETY: We incref and wrap in an ARef, so the reference count is consistent
-        unsafe {
-            bindings::drm_gpuvm_bo_get(self.gpuva.vm_bo);
-            ARef::from_raw(NonNull::new_unchecked(p))
-        }
-    }
     pub fn offset(&self) -> u64 {
         self.gpuva.gem.offset
     }
@@ -349,10 +339,27 @@ pub(super) unsafe extern "C" fn step_remap_callback<T: DriverGpuVm>(
     // guaranteed to outlive this function.
     let ctx = unsafe { &mut *(_priv as *mut StepContext<'_, T>) };
 
-    from_result(|| {
-        UpdatingGpuVm(ctx.gpuvm).step_remap(remap, ctx.ctx)?;
-        Ok(0)
-    })
+    let p_vm_bo = remap.unmap().va().unwrap().gpuva.vm_bo;
+
+    let res = {
+        // SAFETY: vm_bo pointer must be valid and non-null by the step_remap invariants.
+        // Since we grab a ref, this reference's lifetime is until the decref.
+        let vm_bo_ref = unsafe {
+            bindings::drm_gpuvm_bo_get(p_vm_bo);
+            &*(crate::container_of!(p_vm_bo, GpuVmBo<T>, bo) as *mut GpuVmBo<T>)
+        };
+
+        from_result(|| {
+            UpdatingGpuVm(ctx.gpuvm).step_remap(remap, vm_bo_ref, ctx.ctx)?;
+            Ok(0)
+        })
+    };
+
+    // SAFETY: We incremented the refcount above, and the Rust reference we took is
+    // no longer in scope.
+    unsafe { bindings::drm_gpuvm_bo_put(p_vm_bo) };
+
+    res
 }
 pub(super) unsafe extern "C" fn step_unmap_callback<T: DriverGpuVm>(
     op: *mut bindings::drm_gpuva_op,