diff --git a/include/linux/khugepaged.h b/include/linux/khugepaged.h
index 082d1d2a5216977262d3a1817b30a9131efe873a..bc45ea1efbf7973362c6cb99b433ecba4f95612a 100644
--- a/include/linux/khugepaged.h
+++ b/include/linux/khugepaged.h
@@ -15,6 +15,14 @@ extern int __khugepaged_enter(struct mm_struct *mm);
 extern void __khugepaged_exit(struct mm_struct *mm);
 extern int khugepaged_enter_vma_merge(struct vm_area_struct *vma,
 				      unsigned long vm_flags);
+#ifdef CONFIG_SHMEM
+extern void collapse_pte_mapped_thp(struct mm_struct *mm, unsigned long addr);
+#else
+static inline void collapse_pte_mapped_thp(struct mm_struct *mm,
+					   unsigned long addr)
+{
+}
+#endif
 
 #define khugepaged_enabled()					       \
 	(transparent_hugepage_flags &				       \
@@ -73,6 +81,10 @@ static inline int khugepaged_enter_vma_merge(struct vm_area_struct *vma,
 {
 	return 0;
 }
+static inline void collapse_pte_mapped_thp(struct mm_struct *mm,
+					   unsigned long addr)
+{
+}
 #endif /* CONFIG_TRANSPARENT_HUGEPAGE */
 
 #endif /* _LINUX_KHUGEPAGED_H */
diff --git a/mm/khugepaged.c b/mm/khugepaged.c
index e89430ec5267f941416b1c3818ff10c57624e612..0a1b4b484ac5b4a0eed5e5148f04849d9e09607b 100644
--- a/mm/khugepaged.c
+++ b/mm/khugepaged.c
@@ -77,6 +77,8 @@ static __read_mostly DEFINE_HASHTABLE(mm_slots_hash, MM_SLOTS_HASH_BITS);
 
 static struct kmem_cache *mm_slot_cache __read_mostly;
 
+#define MAX_PTE_MAPPED_THP 8
+
 /**
  * struct mm_slot - hash lookup from mm to mm_slot
  * @hash: hash collision list
@@ -87,6 +89,10 @@ struct mm_slot {
 	struct hlist_node hash;
 	struct list_head mm_node;
 	struct mm_struct *mm;
+
+	/* pte-mapped THP in this mm */
+	int nr_pte_mapped_thp;
+	unsigned long pte_mapped_thp[MAX_PTE_MAPPED_THP];
 };
 
 /**
@@ -1254,6 +1260,159 @@ static void collect_mm_slot(struct mm_slot *mm_slot)
 }
 
 #if defined(CONFIG_SHMEM) && defined(CONFIG_TRANSPARENT_HUGE_PAGECACHE)
+/*
+ * Notify khugepaged that given addr of the mm is pte-mapped THP. Then
+ * khugepaged should try to collapse the page table.
+ */
+static int khugepaged_add_pte_mapped_thp(struct mm_struct *mm,
+					 unsigned long addr)
+{
+	struct mm_slot *mm_slot;
+
+	VM_BUG_ON(addr & ~HPAGE_PMD_MASK);
+
+	spin_lock(&khugepaged_mm_lock);
+	mm_slot = get_mm_slot(mm);
+	if (likely(mm_slot && mm_slot->nr_pte_mapped_thp < MAX_PTE_MAPPED_THP))
+		mm_slot->pte_mapped_thp[mm_slot->nr_pte_mapped_thp++] = addr;
+	spin_unlock(&khugepaged_mm_lock);
+	return 0;
+}
+
+/**
+ * Try to collapse a pte-mapped THP for mm at address haddr.
+ *
+ * This function checks whether all the PTEs in the PMD are pointing to the
+ * right THP. If so, retract the page table so the THP can refault in with
+ * as pmd-mapped.
+ */
+void collapse_pte_mapped_thp(struct mm_struct *mm, unsigned long addr)
+{
+	unsigned long haddr = addr & HPAGE_PMD_MASK;
+	struct vm_area_struct *vma = find_vma(mm, haddr);
+	struct page *hpage = NULL;
+	pte_t *start_pte, *pte;
+	pmd_t *pmd, _pmd;
+	spinlock_t *ptl;
+	int count = 0;
+	int i;
+
+	if (!vma || !vma->vm_file ||
+	    vma->vm_start > haddr || vma->vm_end < haddr + HPAGE_PMD_SIZE)
+		return;
+
+	/*
+	 * This vm_flags may not have VM_HUGEPAGE if the page was not
+	 * collapsed by this mm. But we can still collapse if the page is
+	 * the valid THP. Add extra VM_HUGEPAGE so hugepage_vma_check()
+	 * will not fail the vma for missing VM_HUGEPAGE
+	 */
+	if (!hugepage_vma_check(vma, vma->vm_flags | VM_HUGEPAGE))
+		return;
+
+	pmd = mm_find_pmd(mm, haddr);
+	if (!pmd)
+		return;
+
+	start_pte = pte_offset_map_lock(mm, pmd, haddr, &ptl);
+
+	/* step 1: check all mapped PTEs are to the right huge page */
+	for (i = 0, addr = haddr, pte = start_pte;
+	     i < HPAGE_PMD_NR; i++, addr += PAGE_SIZE, pte++) {
+		struct page *page;
+
+		/* empty pte, skip */
+		if (pte_none(*pte))
+			continue;
+
+		/* page swapped out, abort */
+		if (!pte_present(*pte))
+			goto abort;
+
+		page = vm_normal_page(vma, addr, *pte);
+
+		if (!page || !PageCompound(page))
+			goto abort;
+
+		if (!hpage) {
+			hpage = compound_head(page);
+			/*
+			 * The mapping of the THP should not change.
+			 *
+			 * Note that uprobe, debugger, or MAP_PRIVATE may
+			 * change the page table, but the new page will
+			 * not pass PageCompound() check.
+			 */
+			if (WARN_ON(hpage->mapping != vma->vm_file->f_mapping))
+				goto abort;
+		}
+
+		/*
+		 * Confirm the page maps to the correct subpage.
+		 *
+		 * Note that uprobe, debugger, or MAP_PRIVATE may change
+		 * the page table, but the new page will not pass
+		 * PageCompound() check.
+		 */
+		if (WARN_ON(hpage + i != page))
+			goto abort;
+		count++;
+	}
+
+	/* step 2: adjust rmap */
+	for (i = 0, addr = haddr, pte = start_pte;
+	     i < HPAGE_PMD_NR; i++, addr += PAGE_SIZE, pte++) {
+		struct page *page;
+
+		if (pte_none(*pte))
+			continue;
+		page = vm_normal_page(vma, addr, *pte);
+		page_remove_rmap(page, false);
+	}
+
+	pte_unmap_unlock(start_pte, ptl);
+
+	/* step 3: set proper refcount and mm_counters. */
+	if (hpage) {
+		page_ref_sub(hpage, count);
+		add_mm_counter(vma->vm_mm, mm_counter_file(hpage), -count);
+	}
+
+	/* step 4: collapse pmd */
+	ptl = pmd_lock(vma->vm_mm, pmd);
+	_pmd = pmdp_collapse_flush(vma, addr, pmd);
+	spin_unlock(ptl);
+	mm_dec_nr_ptes(mm);
+	pte_free(mm, pmd_pgtable(_pmd));
+	return;
+
+abort:
+	pte_unmap_unlock(start_pte, ptl);
+}
+
+static int khugepaged_collapse_pte_mapped_thps(struct mm_slot *mm_slot)
+{
+	struct mm_struct *mm = mm_slot->mm;
+	int i;
+
+	if (likely(mm_slot->nr_pte_mapped_thp == 0))
+		return 0;
+
+	if (!down_write_trylock(&mm->mmap_sem))
+		return -EBUSY;
+
+	if (unlikely(khugepaged_test_exit(mm)))
+		goto out;
+
+	for (i = 0; i < mm_slot->nr_pte_mapped_thp; i++)
+		collapse_pte_mapped_thp(mm, mm_slot->pte_mapped_thp[i]);
+
+out:
+	mm_slot->nr_pte_mapped_thp = 0;
+	up_write(&mm->mmap_sem);
+	return 0;
+}
+
 static void retract_page_tables(struct address_space *mapping, pgoff_t pgoff)
 {
 	struct vm_area_struct *vma;
@@ -1262,7 +1421,22 @@ static void retract_page_tables(struct address_space *mapping, pgoff_t pgoff)
 
 	i_mmap_lock_write(mapping);
 	vma_interval_tree_foreach(vma, &mapping->i_mmap, pgoff, pgoff) {
-		/* probably overkill */
+		/*
+		 * Check vma->anon_vma to exclude MAP_PRIVATE mappings that
+		 * got written to. These VMAs are likely not worth investing
+		 * down_write(mmap_sem) as PMD-mapping is likely to be split
+		 * later.
+		 *
+		 * Not that vma->anon_vma check is racy: it can be set up after
+		 * the check but before we took mmap_sem by the fault path.
+		 * But page lock would prevent establishing any new ptes of the
+		 * page, so we are safe.
+		 *
+		 * An alternative would be drop the check, but check that page
+		 * table is clear before calling pmdp_collapse_flush() under
+		 * ptl. It has higher chance to recover THP for the VMA, but
+		 * has higher cost too.
+		 */
 		if (vma->anon_vma)
 			continue;
 		addr = vma->vm_start + ((pgoff - vma->vm_pgoff) << PAGE_SHIFT);
@@ -1275,9 +1449,10 @@ static void retract_page_tables(struct address_space *mapping, pgoff_t pgoff)
 			continue;
 		/*
 		 * We need exclusive mmap_sem to retract page table.
-		 * If trylock fails we would end up with pte-mapped THP after
-		 * re-fault. Not ideal, but it's more important to not disturb
-		 * the system too much.
+		 *
+		 * We use trylock due to lock inversion: we need to acquire
+		 * mmap_sem while holding page lock. Fault path does it in
+		 * reverse order. Trylock is a way to avoid deadlock.
 		 */
 		if (down_write_trylock(&vma->vm_mm->mmap_sem)) {
 			spinlock_t *ptl = pmd_lock(vma->vm_mm, pmd);
@@ -1287,6 +1462,9 @@ static void retract_page_tables(struct address_space *mapping, pgoff_t pgoff)
 			up_write(&vma->vm_mm->mmap_sem);
 			mm_dec_nr_ptes(vma->vm_mm);
 			pte_free(vma->vm_mm, pmd_pgtable(_pmd));
+		} else {
+			/* Try again later */
+			khugepaged_add_pte_mapped_thp(vma->vm_mm, addr);
 		}
 	}
 	i_mmap_unlock_write(mapping);
@@ -1709,6 +1887,11 @@ static void khugepaged_scan_file(struct mm_struct *mm,
 {
 	BUILD_BUG();
 }
+
+static int khugepaged_collapse_pte_mapped_thps(struct mm_slot *mm_slot)
+{
+	return 0;
+}
 #endif
 
 static unsigned int khugepaged_scan_mm_slot(unsigned int pages,
@@ -1733,6 +1916,7 @@ static unsigned int khugepaged_scan_mm_slot(unsigned int pages,
 		khugepaged_scan.mm_slot = mm_slot;
 	}
 	spin_unlock(&khugepaged_mm_lock);
+	khugepaged_collapse_pte_mapped_thps(mm_slot);
 
 	mm = mm_slot->mm;
 	/*