diff --git a/mm/ksm.c b/mm/ksm.c
index 0805221e1d4c05fc7829da0b1514b31444609395..dd02780c387f02b3176d088fdcc1bb9d6db487e8 100644
--- a/mm/ksm.c
+++ b/mm/ksm.c
@@ -39,6 +39,7 @@
 #include <linux/freezer.h>
 #include <linux/oom.h>
 #include <linux/numa.h>
+#include <linux/pagewalk.h>
 
 #include <asm/tlbflush.h>
 #include "internal.h"
@@ -419,6 +420,39 @@ static inline bool ksm_test_exit(struct mm_struct *mm)
 	return atomic_read(&mm->mm_users) == 0;
 }
 
+static int break_ksm_pmd_entry(pmd_t *pmd, unsigned long addr, unsigned long next,
+			struct mm_walk *walk)
+{
+	struct page *page = NULL;
+	spinlock_t *ptl;
+	pte_t *pte;
+	int ret;
+
+	if (pmd_leaf(*pmd) || !pmd_present(*pmd))
+		return 0;
+
+	pte = pte_offset_map_lock(walk->mm, pmd, addr, &ptl);
+	if (pte_present(*pte)) {
+		page = vm_normal_page(walk->vma, addr, *pte);
+	} else if (!pte_none(*pte)) {
+		swp_entry_t entry = pte_to_swp_entry(*pte);
+
+		/*
+		 * As KSM pages remain KSM pages until freed, no need to wait
+		 * here for migration to end.
+		 */
+		if (is_migration_entry(entry))
+			page = pfn_swap_entry_to_page(entry);
+	}
+	ret = page && PageKsm(page);
+	pte_unmap_unlock(pte, ptl);
+	return ret;
+}
+
+static const struct mm_walk_ops break_ksm_ops = {
+	.pmd_entry = break_ksm_pmd_entry,
+};
+
 /*
  * We use break_ksm to break COW on a ksm page by triggering unsharing,
  * such that the ksm page will get replaced by an exclusive anonymous page.
@@ -434,21 +468,16 @@ static inline bool ksm_test_exit(struct mm_struct *mm)
  */
 static int break_ksm(struct vm_area_struct *vma, unsigned long addr)
 {
-	struct page *page;
 	vm_fault_t ret = 0;
 
 	do {
-		bool ksm_page = false;
+		int ksm_page;
 
 		cond_resched();
-		page = follow_page(vma, addr,
-				FOLL_GET | FOLL_MIGRATION | FOLL_REMOTE);
-		if (IS_ERR_OR_NULL(page))
-			break;
-		if (PageKsm(page))
-			ksm_page = true;
-		put_page(page);
-
+		ksm_page = walk_page_range_vma(vma, addr, addr + 1,
+					       &break_ksm_ops, NULL);
+		if (WARN_ON_ONCE(ksm_page < 0))
+			return ksm_page;
 		if (!ksm_page)
 			return 0;
 		ret = handle_mm_fault(vma, addr,