diff --git a/drivers/gpu/drm/xe/xe_vm.c b/drivers/gpu/drm/xe/xe_vm.c
index 62d1ef8867a84351ae7444d63113d8867dfbb0c5..3d4c8f342e215ed39263ba5c4c01079072dfcbbd 100644
--- a/drivers/gpu/drm/xe/xe_vm.c
+++ b/drivers/gpu/drm/xe/xe_vm.c
@@ -1577,6 +1577,16 @@ void xe_vm_close_and_put(struct xe_vm *vm)
 		xe->usm.num_vm_in_fault_mode--;
 	else if (!(vm->flags & XE_VM_FLAG_MIGRATION))
 		xe->usm.num_vm_in_non_fault_mode--;
+
+	if (vm->usm.asid) {
+		void *lookup;
+
+		xe_assert(xe, xe->info.has_asid);
+		xe_assert(xe, !(vm->flags & XE_VM_FLAG_MIGRATION));
+
+		lookup = xa_erase(&xe->usm.asid_to_vm, vm->usm.asid);
+		xe_assert(xe, lookup == vm);
+	}
 	mutex_unlock(&xe->usm.lock);
 
 	for_each_tile(tile, xe, id)
@@ -1592,24 +1602,15 @@ static void vm_destroy_work_func(struct work_struct *w)
 	struct xe_device *xe = vm->xe;
 	struct xe_tile *tile;
 	u8 id;
-	void *lookup;
 
 	/* xe_vm_close_and_put was not called? */
 	xe_assert(xe, !vm->size);
 
 	mutex_destroy(&vm->snap_mutex);
 
-	if (!(vm->flags & XE_VM_FLAG_MIGRATION)) {
+	if (!(vm->flags & XE_VM_FLAG_MIGRATION))
 		xe_device_mem_access_put(xe);
 
-		if (xe->info.has_asid && vm->usm.asid) {
-			mutex_lock(&xe->usm.lock);
-			lookup = xa_erase(&xe->usm.asid_to_vm, vm->usm.asid);
-			xe_assert(xe, lookup == vm);
-			mutex_unlock(&xe->usm.lock);
-		}
-	}
-
 	for_each_tile(tile, xe, id)
 		XE_WARN_ON(vm->pt_root[id]);