diff --git a/include/linux/khugepaged.h b/include/linux/khugepaged.h
index 384f034ae947ff427eae88c834de1b096c9f3f11..70162d707caf0c3713d03a801a8c5c54b39c64ab 100644
--- a/include/linux/khugepaged.h
+++ b/include/linux/khugepaged.h
@@ -16,11 +16,13 @@ extern void khugepaged_enter_vma(struct vm_area_struct *vma,
 				 unsigned long vm_flags);
 extern void khugepaged_min_free_kbytes_update(void);
 #ifdef CONFIG_SHMEM
-extern void collapse_pte_mapped_thp(struct mm_struct *mm, unsigned long addr);
+extern int collapse_pte_mapped_thp(struct mm_struct *mm, unsigned long addr,
+				   bool install_pmd);
 #else
-static inline void collapse_pte_mapped_thp(struct mm_struct *mm,
-					   unsigned long addr)
+static inline int collapse_pte_mapped_thp(struct mm_struct *mm,
+					  unsigned long addr, bool install_pmd)
 {
+	return 0;
 }
 #endif
 
@@ -46,9 +48,10 @@ static inline void khugepaged_enter_vma(struct vm_area_struct *vma,
 					unsigned long vm_flags)
 {
 }
-static inline void collapse_pte_mapped_thp(struct mm_struct *mm,
-					   unsigned long addr)
+static inline int collapse_pte_mapped_thp(struct mm_struct *mm,
+					  unsigned long addr, bool install_pmd)
 {
+	return 0;
 }
 
 static inline void khugepaged_min_free_kbytes_update(void)
diff --git a/include/trace/events/huge_memory.h b/include/trace/events/huge_memory.h
index fbbb25494d6039aaabebbddf705d67cfdb06a7ab..df33453b70fcf2a645c91b7cbb8a470a15df0183 100644
--- a/include/trace/events/huge_memory.h
+++ b/include/trace/events/huge_memory.h
@@ -11,6 +11,7 @@
 	EM( SCAN_FAIL,			"failed")			\
 	EM( SCAN_SUCCEED,		"succeeded")			\
 	EM( SCAN_PMD_NULL,		"pmd_null")			\
+	EM( SCAN_PMD_NONE,		"pmd_none")			\
 	EM( SCAN_PMD_MAPPED,		"page_pmd_mapped")		\
 	EM( SCAN_EXCEED_NONE_PTE,	"exceed_none_pte")		\
 	EM( SCAN_EXCEED_SWAP_PTE,	"exceed_swap_pte")		\
diff --git a/kernel/events/uprobes.c b/kernel/events/uprobes.c
index e0a9b945e7bc007acc15396a4d187e9d33f560bf..d9e357b7e17c91cfc461bb06b0c3f89a7b649135 100644
--- a/kernel/events/uprobes.c
+++ b/kernel/events/uprobes.c
@@ -555,7 +555,7 @@ int uprobe_write_opcode(struct arch_uprobe *auprobe, struct mm_struct *mm,
 
 	/* try collapse pmd for compound page */
 	if (!ret && orig_page_huge)
-		collapse_pte_mapped_thp(mm, vaddr);
+		collapse_pte_mapped_thp(mm, vaddr, false);
 
 	return ret;
 }
diff --git a/mm/khugepaged.c b/mm/khugepaged.c
index b1e3f83c4eb2d34288d032133ee61f6e0522ff48..3bd6e2a741631bd738bd3c32a8081ed0764e493f 100644
--- a/mm/khugepaged.c
+++ b/mm/khugepaged.c
@@ -29,6 +29,7 @@ enum scan_result {
 	SCAN_FAIL,
 	SCAN_SUCCEED,
 	SCAN_PMD_NULL,
+	SCAN_PMD_NONE,
 	SCAN_PMD_MAPPED,
 	SCAN_EXCEED_NONE_PTE,
 	SCAN_EXCEED_SWAP_PTE,
@@ -821,6 +822,7 @@ static bool hpage_collapse_alloc_page(struct page **hpage, gfp_t gfp, int node)
  */
 
 static int hugepage_vma_revalidate(struct mm_struct *mm, unsigned long address,
+				   bool expect_anon,
 				   struct vm_area_struct **vmap,
 				   struct collapse_control *cc)
 {
@@ -845,8 +847,8 @@ static int hugepage_vma_revalidate(struct mm_struct *mm, unsigned long address,
 	 * hugepage_vma_check may return true for qualified file
 	 * vmas.
 	 */
-	if (!vma->anon_vma || !vma_is_anonymous(vma))
-		return SCAN_VMA_CHECK;
+	if (expect_anon && (!(*vmap)->anon_vma || !vma_is_anonymous(*vmap)))
+		return SCAN_PAGE_ANON;
 	return SCAN_SUCCEED;
 }
 
@@ -866,8 +868,8 @@ static int find_pmd_or_thp_or_none(struct mm_struct *mm,
 	/* See comments in pmd_none_or_trans_huge_or_clear_bad() */
 	barrier();
 #endif
-	if (!pmd_present(pmde))
-		return SCAN_PMD_NULL;
+	if (pmd_none(pmde))
+		return SCAN_PMD_NONE;
 	if (pmd_trans_huge(pmde))
 		return SCAN_PMD_MAPPED;
 	if (pmd_bad(pmde))
@@ -995,7 +997,7 @@ static int collapse_huge_page(struct mm_struct *mm, unsigned long address,
 		goto out_nolock;
 
 	mmap_read_lock(mm);
-	result = hugepage_vma_revalidate(mm, address, &vma, cc);
+	result = hugepage_vma_revalidate(mm, address, true, &vma, cc);
 	if (result != SCAN_SUCCEED) {
 		mmap_read_unlock(mm);
 		goto out_nolock;
@@ -1026,7 +1028,7 @@ static int collapse_huge_page(struct mm_struct *mm, unsigned long address,
 	 * handled by the anon_vma lock + PG_lock.
 	 */
 	mmap_write_lock(mm);
-	result = hugepage_vma_revalidate(mm, address, &vma, cc);
+	result = hugepage_vma_revalidate(mm, address, true, &vma, cc);
 	if (result != SCAN_SUCCEED)
 		goto out_up_write;
 	/* check if the pmd is still valid */
@@ -1320,6 +1322,26 @@ static void collect_mm_slot(struct khugepaged_mm_slot *mm_slot)
 /*
  * Notify khugepaged that given addr of the mm is pte-mapped THP. Then
  * khugepaged should try to collapse the page table.
+ *
+ * Note that following race exists:
+ * (1) khugepaged calls khugepaged_collapse_pte_mapped_thps() for mm_struct A,
+ *     emptying the A's ->pte_mapped_thp[] array.
+ * (2) MADV_COLLAPSE collapses some file extent with target mm_struct B, and
+ *     retract_page_tables() finds a VMA in mm_struct A mapping the same extent
+ *     (at virtual address X) and adds an entry (for X) into mm_struct A's
+ *     ->pte-mapped_thp[] array.
+ * (3) khugepaged calls khugepaged_collapse_scan_file() for mm_struct A at X,
+ *     sees a pte-mapped THP (SCAN_PTE_MAPPED_HUGEPAGE) and adds an entry
+ *     (for X) into mm_struct A's ->pte-mapped_thp[] array.
+ * Thus, it's possible the same address is added multiple times for the same
+ * mm_struct.  Should this happen, we'll simply attempt
+ * collapse_pte_mapped_thp() multiple times for the same address, under the same
+ * exclusive mmap_lock, and assuming the first call is successful, subsequent
+ * attempts will return quickly (without grabbing any additional locks) when
+ * a huge pmd is found in find_pmd_or_thp_or_none().  Since this is a cheap
+ * check, and since this is a rare occurrence, the cost of preventing this
+ * "multiple-add" is thought to be more expensive than just handling it, should
+ * it occur.
  */
 static bool khugepaged_add_pte_mapped_thp(struct mm_struct *mm,
 					  unsigned long addr)
@@ -1341,6 +1363,27 @@ static bool khugepaged_add_pte_mapped_thp(struct mm_struct *mm,
 	return ret;
 }
 
+/* hpage must be locked, and mmap_lock must be held in write */
+static int set_huge_pmd(struct vm_area_struct *vma, unsigned long addr,
+			pmd_t *pmdp, struct page *hpage)
+{
+	struct vm_fault vmf = {
+		.vma = vma,
+		.address = addr,
+		.flags = 0,
+		.pmd = pmdp,
+	};
+
+	VM_BUG_ON(!PageTransHuge(hpage));
+	mmap_assert_write_locked(vma->vm_mm);
+
+	if (do_set_pmd(&vmf, hpage))
+		return SCAN_FAIL;
+
+	get_page(hpage);
+	return SCAN_SUCCEED;
+}
+
 static void collapse_and_free_pmd(struct mm_struct *mm, struct vm_area_struct *vma,
 				  unsigned long addr, pmd_t *pmdp)
 {
@@ -1362,12 +1405,14 @@ static void collapse_and_free_pmd(struct mm_struct *mm, struct vm_area_struct *v
  *
  * @mm: process address space where collapse happens
  * @addr: THP collapse address
+ * @install_pmd: If a huge PMD should be installed
  *
  * This function checks whether all the PTEs in the PMD are pointing to the
  * right THP. If so, retract the page table so the THP can refault in with
- * as pmd-mapped.
+ * as pmd-mapped. Possibly install a huge PMD mapping the THP.
  */
-void collapse_pte_mapped_thp(struct mm_struct *mm, unsigned long addr)
+int collapse_pte_mapped_thp(struct mm_struct *mm, unsigned long addr,
+			    bool install_pmd)
 {
 	unsigned long haddr = addr & HPAGE_PMD_MASK;
 	struct vm_area_struct *vma = vma_lookup(mm, haddr);
@@ -1380,14 +1425,14 @@ void collapse_pte_mapped_thp(struct mm_struct *mm, unsigned long addr)
 
 	mmap_assert_write_locked(mm);
 
-	/* Fast check before locking page if not PMD mapping PTE table */
+	/* Fast check before locking page if already PMD-mapped */
 	result = find_pmd_or_thp_or_none(mm, haddr, &pmd);
-	if (result != SCAN_SUCCEED)
-		return;
+	if (result == SCAN_PMD_MAPPED)
+		return result;
 
 	if (!vma || !vma->vm_file ||
 	    !range_in_vma(vma, haddr, haddr + HPAGE_PMD_SIZE))
-		return;
+		return SCAN_VMA_CHECK;
 
 	/*
 	 * If we are here, we've succeeded in replacing all the native pages
@@ -1397,27 +1442,43 @@ void collapse_pte_mapped_thp(struct mm_struct *mm, unsigned long addr)
 	 * analogously elide sysfs THP settings here.
 	 */
 	if (!hugepage_vma_check(vma, vma->vm_flags, false, false, false))
-		return;
+		return SCAN_VMA_CHECK;
 
 	/* Keep pmd pgtable for uffd-wp; see comment in retract_page_tables() */
 	if (userfaultfd_wp(vma))
-		return;
+		return SCAN_PTE_UFFD_WP;
 
 	hpage = find_lock_page(vma->vm_file->f_mapping,
 			       linear_page_index(vma, haddr));
 	if (!hpage)
-		return;
+		return SCAN_PAGE_NULL;
 
-	if (!PageHead(hpage))
+	if (!PageHead(hpage)) {
+		result = SCAN_FAIL;
 		goto drop_hpage;
+	}
 
-	if (compound_order(hpage) != HPAGE_PMD_ORDER)
+	if (compound_order(hpage) != HPAGE_PMD_ORDER) {
+		result = SCAN_PAGE_COMPOUND;
 		goto drop_hpage;
+	}
 
-	if (find_pmd_or_thp_or_none(mm, haddr, &pmd) != SCAN_SUCCEED)
+	switch (result) {
+	case SCAN_SUCCEED:
+		break;
+	case SCAN_PMD_NONE:
+		/*
+		 * In MADV_COLLAPSE path, possible race with khugepaged where
+		 * all pte entries have been removed and pmd cleared.  If so,
+		 * skip all the pte checks and just update the pmd mapping.
+		 */
+		goto maybe_install_pmd;
+	default:
 		goto drop_hpage;
+	}
 
 	start_pte = pte_offset_map_lock(mm, pmd, haddr, &ptl);
+	result = SCAN_FAIL;
 
 	/* step 1: check all mapped PTEs are to the right huge page */
 	for (i = 0, addr = haddr, pte = start_pte;
@@ -1429,8 +1490,10 @@ void collapse_pte_mapped_thp(struct mm_struct *mm, unsigned long addr)
 			continue;
 
 		/* page swapped out, abort */
-		if (!pte_present(*pte))
+		if (!pte_present(*pte)) {
+			result = SCAN_PTE_NON_PRESENT;
 			goto abort;
+		}
 
 		page = vm_normal_page(vma, addr, *pte);
 		if (WARN_ON_ONCE(page && is_zone_device_page(page)))
@@ -1465,12 +1528,19 @@ void collapse_pte_mapped_thp(struct mm_struct *mm, unsigned long addr)
 		add_mm_counter(vma->vm_mm, mm_counter_file(hpage), -count);
 	}
 
-	/* step 4: collapse pmd */
+	/* step 4: remove pte entries */
 	collapse_and_free_pmd(mm, vma, haddr, pmd);
+
+maybe_install_pmd:
+	/* step 5: install pmd entry */
+	result = install_pmd
+			? set_huge_pmd(vma, haddr, pmd, hpage)
+			: SCAN_SUCCEED;
+
 drop_hpage:
 	unlock_page(hpage);
 	put_page(hpage);
-	return;
+	return result;
 
 abort:
 	pte_unmap_unlock(start_pte, ptl);
@@ -1493,22 +1563,29 @@ static void khugepaged_collapse_pte_mapped_thps(struct khugepaged_mm_slot *mm_sl
 		goto out;
 
 	for (i = 0; i < mm_slot->nr_pte_mapped_thp; i++)
-		collapse_pte_mapped_thp(mm, mm_slot->pte_mapped_thp[i]);
+		collapse_pte_mapped_thp(mm, mm_slot->pte_mapped_thp[i], false);
 
 out:
 	mm_slot->nr_pte_mapped_thp = 0;
 	mmap_write_unlock(mm);
 }
 
-static void retract_page_tables(struct address_space *mapping, pgoff_t pgoff)
+static int retract_page_tables(struct address_space *mapping, pgoff_t pgoff,
+			       struct mm_struct *target_mm,
+			       unsigned long target_addr, struct page *hpage,
+			       struct collapse_control *cc)
 {
 	struct vm_area_struct *vma;
-	struct mm_struct *mm;
-	unsigned long addr;
-	pmd_t *pmd;
+	int target_result = SCAN_FAIL;
 
 	i_mmap_lock_write(mapping);
 	vma_interval_tree_foreach(vma, &mapping->i_mmap, pgoff, pgoff) {
+		int result = SCAN_FAIL;
+		struct mm_struct *mm = NULL;
+		unsigned long addr = 0;
+		pmd_t *pmd;
+		bool is_target = false;
+
 		/*
 		 * Check vma->anon_vma to exclude MAP_PRIVATE mappings that
 		 * got written to. These VMAs are likely not worth investing
@@ -1525,24 +1602,34 @@ static void retract_page_tables(struct address_space *mapping, pgoff_t pgoff)
 		 * ptl. It has higher chance to recover THP for the VMA, but
 		 * has higher cost too.
 		 */
-		if (vma->anon_vma)
-			continue;
+		if (vma->anon_vma) {
+			result = SCAN_PAGE_ANON;
+			goto next;
+		}
 		addr = vma->vm_start + ((pgoff - vma->vm_pgoff) << PAGE_SHIFT);
-		if (addr & ~HPAGE_PMD_MASK)
-			continue;
-		if (vma->vm_end < addr + HPAGE_PMD_SIZE)
-			continue;
+		if (addr & ~HPAGE_PMD_MASK ||
+		    vma->vm_end < addr + HPAGE_PMD_SIZE) {
+			result = SCAN_VMA_CHECK;
+			goto next;
+		}
 		mm = vma->vm_mm;
-		if (find_pmd_or_thp_or_none(mm, addr, &pmd) != SCAN_SUCCEED)
-			continue;
+		is_target = mm == target_mm && addr == target_addr;
+		result = find_pmd_or_thp_or_none(mm, addr, &pmd);
+		if (result != SCAN_SUCCEED)
+			goto next;
 		/*
 		 * We need exclusive mmap_lock to retract page table.
 		 *
 		 * We use trylock due to lock inversion: we need to acquire
 		 * mmap_lock while holding page lock. Fault path does it in
 		 * reverse order. Trylock is a way to avoid deadlock.
+		 *
+		 * Also, it's not MADV_COLLAPSE's job to collapse other
+		 * mappings - let khugepaged take care of them later.
 		 */
-		if (mmap_write_trylock(mm)) {
+		result = SCAN_PTE_MAPPED_HUGEPAGE;
+		if ((cc->is_khugepaged || is_target) &&
+		    mmap_write_trylock(mm)) {
 			/*
 			 * When a vma is registered with uffd-wp, we can't
 			 * recycle the pmd pgtable because there can be pte
@@ -1551,22 +1638,45 @@ static void retract_page_tables(struct address_space *mapping, pgoff_t pgoff)
 			 * it'll always mapped in small page size for uffd-wp
 			 * registered ranges.
 			 */
-			if (!hpage_collapse_test_exit(mm) &&
-			    !userfaultfd_wp(vma))
-				collapse_and_free_pmd(mm, vma, addr, pmd);
+			if (hpage_collapse_test_exit(mm)) {
+				result = SCAN_ANY_PROCESS;
+				goto unlock_next;
+			}
+			if (userfaultfd_wp(vma)) {
+				result = SCAN_PTE_UFFD_WP;
+				goto unlock_next;
+			}
+			collapse_and_free_pmd(mm, vma, addr, pmd);
+			if (!cc->is_khugepaged && is_target)
+				result = set_huge_pmd(vma, addr, pmd, hpage);
+			else
+				result = SCAN_SUCCEED;
+
+unlock_next:
 			mmap_write_unlock(mm);
-		} else {
-			/* Try again later */
+			goto next;
+		}
+		/*
+		 * Calling context will handle target mm/addr. Otherwise, let
+		 * khugepaged try again later.
+		 */
+		if (!is_target) {
 			khugepaged_add_pte_mapped_thp(mm, addr);
+			continue;
 		}
+next:
+		if (is_target)
+			target_result = result;
 	}
 	i_mmap_unlock_write(mapping);
+	return target_result;
 }
 
 /**
  * collapse_file - collapse filemap/tmpfs/shmem pages into huge one.
  *
  * @mm: process address space where collapse happens
+ * @addr: virtual collapse start address
  * @file: file that collapse on
  * @start: collapse start address
  * @cc: collapse context and scratchpad
@@ -1586,8 +1696,9 @@ static void retract_page_tables(struct address_space *mapping, pgoff_t pgoff)
  *    + restore gaps in the page cache;
  *    + unlock and free huge page;
  */
-static int collapse_file(struct mm_struct *mm, struct file *file,
-			 pgoff_t start, struct collapse_control *cc)
+static int collapse_file(struct mm_struct *mm, unsigned long addr,
+			 struct file *file, pgoff_t start,
+			 struct collapse_control *cc)
 {
 	struct address_space *mapping = file->f_mapping;
 	struct page *hpage;
@@ -1895,7 +2006,8 @@ static int collapse_file(struct mm_struct *mm, struct file *file,
 		/*
 		 * Remove pte page tables, so we can re-fault the page as huge.
 		 */
-		retract_page_tables(mapping, start);
+		result = retract_page_tables(mapping, start, mm, addr, hpage,
+					     cc);
 		unlock_page(hpage);
 		hpage = NULL;
 	} else {
@@ -1951,8 +2063,9 @@ static int collapse_file(struct mm_struct *mm, struct file *file,
 	return result;
 }
 
-static int khugepaged_scan_file(struct mm_struct *mm, struct file *file,
-				pgoff_t start, struct collapse_control *cc)
+static int hpage_collapse_scan_file(struct mm_struct *mm, unsigned long addr,
+				    struct file *file, pgoff_t start,
+				    struct collapse_control *cc)
 {
 	struct page *page = NULL;
 	struct address_space *mapping = file->f_mapping;
@@ -2040,7 +2153,7 @@ static int khugepaged_scan_file(struct mm_struct *mm, struct file *file,
 			result = SCAN_EXCEED_NONE_PTE;
 			count_vm_event(THP_SCAN_EXCEED_NONE_PTE);
 		} else {
-			result = collapse_file(mm, file, start, cc);
+			result = collapse_file(mm, addr, file, start, cc);
 		}
 	}
 
@@ -2048,8 +2161,9 @@ static int khugepaged_scan_file(struct mm_struct *mm, struct file *file,
 	return result;
 }
 #else
-static int khugepaged_scan_file(struct mm_struct *mm, struct file *file,
-				pgoff_t start, struct collapse_control *cc)
+static int hpage_collapse_scan_file(struct mm_struct *mm, unsigned long addr,
+				    struct file *file, pgoff_t start,
+				    struct collapse_control *cc)
 {
 	BUILD_BUG();
 }
@@ -2145,8 +2259,9 @@ static unsigned int khugepaged_scan_mm_slot(unsigned int pages, int *result,
 						khugepaged_scan.address);
 
 				mmap_read_unlock(mm);
-				*result = khugepaged_scan_file(mm, file, pgoff,
-							       cc);
+				*result = hpage_collapse_scan_file(mm,
+								   khugepaged_scan.address,
+								   file, pgoff, cc);
 				mmap_locked = false;
 				fput(file);
 			} else {
@@ -2453,10 +2568,6 @@ int madvise_collapse(struct vm_area_struct *vma, struct vm_area_struct **prev,
 
 	*prev = vma;
 
-	/* TODO: Support file/shmem */
-	if (!vma->anon_vma || !vma_is_anonymous(vma))
-		return -EINVAL;
-
 	if (!hugepage_vma_check(vma, vma->vm_flags, false, false, false))
 		return -EINVAL;
 
@@ -2479,7 +2590,8 @@ int madvise_collapse(struct vm_area_struct *vma, struct vm_area_struct **prev,
 			cond_resched();
 			mmap_read_lock(mm);
 			mmap_locked = true;
-			result = hugepage_vma_revalidate(mm, addr, &vma, cc);
+			result = hugepage_vma_revalidate(mm, addr, false, &vma,
+							 cc);
 			if (result  != SCAN_SUCCEED) {
 				last_fail = result;
 				goto out_nolock;
@@ -2489,16 +2601,35 @@ int madvise_collapse(struct vm_area_struct *vma, struct vm_area_struct **prev,
 		}
 		mmap_assert_locked(mm);
 		memset(cc->node_load, 0, sizeof(cc->node_load));
-		result = hpage_collapse_scan_pmd(mm, vma, addr, &mmap_locked,
-						 cc);
+		if (IS_ENABLED(CONFIG_SHMEM) && vma->vm_file) {
+			struct file *file = get_file(vma->vm_file);
+			pgoff_t pgoff = linear_page_index(vma, addr);
+
+			mmap_read_unlock(mm);
+			mmap_locked = false;
+			result = hpage_collapse_scan_file(mm, addr, file, pgoff,
+							  cc);
+			fput(file);
+		} else {
+			result = hpage_collapse_scan_pmd(mm, vma, addr,
+							 &mmap_locked, cc);
+		}
 		if (!mmap_locked)
 			*prev = NULL;  /* Tell caller we dropped mmap_lock */
 
+handle_result:
 		switch (result) {
 		case SCAN_SUCCEED:
 		case SCAN_PMD_MAPPED:
 			++thps;
 			break;
+		case SCAN_PTE_MAPPED_HUGEPAGE:
+			BUG_ON(mmap_locked);
+			BUG_ON(*prev);
+			mmap_write_lock(mm);
+			result = collapse_pte_mapped_thp(mm, addr, true);
+			mmap_write_unlock(mm);
+			goto handle_result;
 		/* Whitelisted set of results where continuing OK */
 		case SCAN_PMD_NULL:
 		case SCAN_PTE_NON_PRESENT: