diff --git a/arch/x86/kvm/mmu/mmu.c b/arch/x86/kvm/mmu/mmu.c
index b2c1c4eb600708e7633fee42926c5c4697940bfb..32c6d4b33d032704d38060c14ffddba9d801c3f2 100644
--- a/arch/x86/kvm/mmu/mmu.c
+++ b/arch/x86/kvm/mmu/mmu.c
@@ -2353,7 +2353,7 @@ static bool __kvm_mmu_prepare_zap_page(struct kvm *kvm,
 		 * treats invalid shadow pages as being obsolete.
 		 */
 		if (!is_obsolete_sp(kvm, sp))
-			kvm_reload_remote_mmus(kvm);
+			kvm_make_all_cpus_request(kvm, KVM_REQ_MMU_RELOAD);
 	}
 
 	if (sp->lpage_disallowed)
@@ -5639,11 +5639,11 @@ static void kvm_mmu_zap_all_fast(struct kvm *kvm)
 	 */
 	kvm->arch.mmu_valid_gen = kvm->arch.mmu_valid_gen ? 0 : 1;
 
-	/* In order to ensure all threads see this change when
-	 * handling the MMU reload signal, this must happen in the
-	 * same critical section as kvm_reload_remote_mmus, and
-	 * before kvm_zap_obsolete_pages as kvm_zap_obsolete_pages
-	 * could drop the MMU lock and yield.
+	/*
+	 * In order to ensure all vCPUs drop their soon-to-be invalid roots,
+	 * invalidating TDP MMU roots must be done while holding mmu_lock for
+	 * write and in the same critical section as making the reload request,
+	 * e.g. before kvm_zap_obsolete_pages() could drop mmu_lock and yield.
 	 */
 	if (is_tdp_mmu_enabled(kvm))
 		kvm_tdp_mmu_invalidate_all_roots(kvm);
@@ -5656,7 +5656,7 @@ static void kvm_mmu_zap_all_fast(struct kvm *kvm)
 	 * Note: we need to do this under the protection of mmu_lock,
 	 * otherwise, vcpu would purge shadow page but miss tlb flush.
 	 */
-	kvm_reload_remote_mmus(kvm);
+	kvm_make_all_cpus_request(kvm, KVM_REQ_MMU_RELOAD);
 
 	kvm_zap_obsolete_pages(kvm);
 
diff --git a/include/linux/kvm_host.h b/include/linux/kvm_host.h
index f11039944c08ffd7d5bfe6675f95c8b9f2d3dd86..0aeb47cffd4373a156f5140b6e476a3a1fa3634e 100644
--- a/include/linux/kvm_host.h
+++ b/include/linux/kvm_host.h
@@ -1325,7 +1325,6 @@ int kvm_vcpu_yield_to(struct kvm_vcpu *target);
 void kvm_vcpu_on_spin(struct kvm_vcpu *vcpu, bool usermode_vcpu_not_eligible);
 
 void kvm_flush_remote_tlbs(struct kvm *kvm);
-void kvm_reload_remote_mmus(struct kvm *kvm);
 
 #ifdef KVM_ARCH_NR_OBJS_PER_MEMORY_CACHE
 int kvm_mmu_topup_memory_cache(struct kvm_mmu_memory_cache *mc, int min);
diff --git a/virt/kvm/kvm_main.c b/virt/kvm/kvm_main.c
index cdf1fa3c60ae2da3b666828376ed256004419950..64eb99444688f6f8e1567e105d8f0f5ee65df2e4 100644
--- a/virt/kvm/kvm_main.c
+++ b/virt/kvm/kvm_main.c
@@ -354,11 +354,6 @@ void kvm_flush_remote_tlbs(struct kvm *kvm)
 EXPORT_SYMBOL_GPL(kvm_flush_remote_tlbs);
 #endif
 
-void kvm_reload_remote_mmus(struct kvm *kvm)
-{
-	kvm_make_all_cpus_request(kvm, KVM_REQ_MMU_RELOAD);
-}
-
 #ifdef KVM_ARCH_NR_OBJS_PER_MEMORY_CACHE
 static inline void *mmu_memory_cache_alloc_obj(struct kvm_mmu_memory_cache *mc,
 					       gfp_t gfp_flags)