diff --git a/include/linux/huge_mm.h b/include/linux/huge_mm.h
index 9b9f65d9987393d456911f41eacb4bdfa9fe0284..e35e6de633b9a7bc2a080d6b3596aa16e8c8582f 100644
--- a/include/linux/huge_mm.h
+++ b/include/linux/huge_mm.h
@@ -22,7 +22,7 @@ extern int mincore_huge_pmd(struct vm_area_struct *vma, pmd_t *pmd,
 			unsigned char *vec);
 extern bool move_huge_pmd(struct vm_area_struct *vma, unsigned long old_addr,
 			 unsigned long new_addr, unsigned long old_end,
-			 pmd_t *old_pmd, pmd_t *new_pmd);
+			 pmd_t *old_pmd, pmd_t *new_pmd, bool *need_flush);
 extern int change_huge_pmd(struct vm_area_struct *vma, pmd_t *pmd,
 			unsigned long addr, pgprot_t newprot,
 			int prot_numa);
diff --git a/mm/huge_memory.c b/mm/huge_memory.c
index cdcd25cb30fea3f2ad2c660e547d014b7378b3dd..eff3de359d50a30588abf70676dc6993c751471d 100644
--- a/mm/huge_memory.c
+++ b/mm/huge_memory.c
@@ -1426,11 +1426,12 @@ int zap_huge_pmd(struct mmu_gather *tlb, struct vm_area_struct *vma,
 
 bool move_huge_pmd(struct vm_area_struct *vma, unsigned long old_addr,
 		  unsigned long new_addr, unsigned long old_end,
-		  pmd_t *old_pmd, pmd_t *new_pmd)
+		  pmd_t *old_pmd, pmd_t *new_pmd, bool *need_flush)
 {
 	spinlock_t *old_ptl, *new_ptl;
 	pmd_t pmd;
 	struct mm_struct *mm = vma->vm_mm;
+	bool force_flush = false;
 
 	if ((old_addr & ~HPAGE_PMD_MASK) ||
 	    (new_addr & ~HPAGE_PMD_MASK) ||
@@ -1455,6 +1456,8 @@ bool move_huge_pmd(struct vm_area_struct *vma, unsigned long old_addr,
 		new_ptl = pmd_lockptr(mm, new_pmd);
 		if (new_ptl != old_ptl)
 			spin_lock_nested(new_ptl, SINGLE_DEPTH_NESTING);
+		if (pmd_present(*old_pmd) && pmd_dirty(*old_pmd))
+			force_flush = true;
 		pmd = pmdp_huge_get_and_clear(mm, old_addr, old_pmd);
 		VM_BUG_ON(!pmd_none(*new_pmd));
 
@@ -1467,6 +1470,10 @@ bool move_huge_pmd(struct vm_area_struct *vma, unsigned long old_addr,
 		set_pmd_at(mm, new_addr, new_pmd, pmd_mksoft_dirty(pmd));
 		if (new_ptl != old_ptl)
 			spin_unlock(new_ptl);
+		if (force_flush)
+			flush_tlb_range(vma, old_addr, old_addr + PMD_SIZE);
+		else
+			*need_flush = true;
 		spin_unlock(old_ptl);
 		return true;
 	}
diff --git a/mm/mremap.c b/mm/mremap.c
index da22ad2a5678265ea9f2d0aa5ece9e14c519a494..6ccecc03f56ad05940484a481c5a092f1ad245ef 100644
--- a/mm/mremap.c
+++ b/mm/mremap.c
@@ -104,11 +104,13 @@ static pte_t move_soft_dirty_pte(pte_t pte)
 static void move_ptes(struct vm_area_struct *vma, pmd_t *old_pmd,
 		unsigned long old_addr, unsigned long old_end,
 		struct vm_area_struct *new_vma, pmd_t *new_pmd,
-		unsigned long new_addr, bool need_rmap_locks)
+		unsigned long new_addr, bool need_rmap_locks, bool *need_flush)
 {
 	struct mm_struct *mm = vma->vm_mm;
 	pte_t *old_pte, *new_pte, pte;
 	spinlock_t *old_ptl, *new_ptl;
+	bool force_flush = false;
+	unsigned long len = old_end - old_addr;
 
 	/*
 	 * When need_rmap_locks is true, we take the i_mmap_rwsem and anon_vma
@@ -146,6 +148,14 @@ static void move_ptes(struct vm_area_struct *vma, pmd_t *old_pmd,
 				   new_pte++, new_addr += PAGE_SIZE) {
 		if (pte_none(*old_pte))
 			continue;
+
+		/*
+		 * We are remapping a dirty PTE, make sure to
+		 * flush TLB before we drop the PTL for the
+		 * old PTE or we may race with page_mkclean().
+		 */
+		if (pte_present(*old_pte) && pte_dirty(*old_pte))
+			force_flush = true;
 		pte = ptep_get_and_clear(mm, old_addr, old_pte);
 		pte = move_pte(pte, new_vma->vm_page_prot, old_addr, new_addr);
 		pte = move_soft_dirty_pte(pte);
@@ -156,6 +166,10 @@ static void move_ptes(struct vm_area_struct *vma, pmd_t *old_pmd,
 	if (new_ptl != old_ptl)
 		spin_unlock(new_ptl);
 	pte_unmap(new_pte - 1);
+	if (force_flush)
+		flush_tlb_range(vma, old_end - len, old_end);
+	else
+		*need_flush = true;
 	pte_unmap_unlock(old_pte - 1, old_ptl);
 	if (need_rmap_locks)
 		drop_rmap_locks(vma);
@@ -201,13 +215,12 @@ unsigned long move_page_tables(struct vm_area_struct *vma,
 				if (need_rmap_locks)
 					take_rmap_locks(vma);
 				moved = move_huge_pmd(vma, old_addr, new_addr,
-						    old_end, old_pmd, new_pmd);
+						    old_end, old_pmd, new_pmd,
+						    &need_flush);
 				if (need_rmap_locks)
 					drop_rmap_locks(vma);
-				if (moved) {
-					need_flush = true;
+				if (moved)
 					continue;
-				}
 			}
 			split_huge_pmd(vma, old_pmd, old_addr);
 			if (pmd_trans_unstable(old_pmd))
@@ -220,11 +233,10 @@ unsigned long move_page_tables(struct vm_area_struct *vma,
 			extent = next - new_addr;
 		if (extent > LATENCY_LIMIT)
 			extent = LATENCY_LIMIT;
-		move_ptes(vma, old_pmd, old_addr, old_addr + extent,
-			  new_vma, new_pmd, new_addr, need_rmap_locks);
-		need_flush = true;
+		move_ptes(vma, old_pmd, old_addr, old_addr + extent, new_vma,
+			  new_pmd, new_addr, need_rmap_locks, &need_flush);
 	}
-	if (likely(need_flush))
+	if (need_flush)
 		flush_tlb_range(vma, old_end-len, old_addr);
 
 	mmu_notifier_invalidate_range_end(vma->vm_mm, mmun_start, mmun_end);