diff --git a/arch/Kconfig b/arch/Kconfig
index 4f654c149709ffe6f26dbfb4afe71179b0c841f6..54a240b61f560103e42f14b2634fb8c02a1f2b73 100644
--- a/arch/Kconfig
+++ b/arch/Kconfig
@@ -648,6 +648,13 @@ config HAVE_IRQ_TIME_ACCOUNTING
 	  Archs need to ensure they use a high enough resolution clock to
 	  support irq time accounting and then call enable_sched_clock_irqtime().
 
+config HAVE_MOVE_PUD
+	bool
+	help
+	  Architectures that select this are able to move page tables at the
+	  PUD level. If there are only 3 page table levels, the move effectively
+	  happens at the PGD level.
+
 config HAVE_MOVE_PMD
 	bool
 	help
diff --git a/mm/mremap.c b/mm/mremap.c
index 138abbae4f758763c5be19578b2909d5d003db28..078f731277b6e8059a3472f1038eeec4ee1cd902 100644
--- a/mm/mremap.c
+++ b/mm/mremap.c
@@ -30,12 +30,11 @@
 
 #include "internal.h"
 
-static pmd_t *get_old_pmd(struct mm_struct *mm, unsigned long addr)
+static pud_t *get_old_pud(struct mm_struct *mm, unsigned long addr)
 {
 	pgd_t *pgd;
 	p4d_t *p4d;
 	pud_t *pud;
-	pmd_t *pmd;
 
 	pgd = pgd_offset(mm, addr);
 	if (pgd_none_or_clear_bad(pgd))
@@ -49,6 +48,18 @@ static pmd_t *get_old_pmd(struct mm_struct *mm, unsigned long addr)
 	if (pud_none_or_clear_bad(pud))
 		return NULL;
 
+	return pud;
+}
+
+static pmd_t *get_old_pmd(struct mm_struct *mm, unsigned long addr)
+{
+	pud_t *pud;
+	pmd_t *pmd;
+
+	pud = get_old_pud(mm, addr);
+	if (!pud)
+		return NULL;
+
 	pmd = pmd_offset(pud, addr);
 	if (pmd_none(*pmd))
 		return NULL;
@@ -56,19 +67,27 @@ static pmd_t *get_old_pmd(struct mm_struct *mm, unsigned long addr)
 	return pmd;
 }
 
-static pmd_t *alloc_new_pmd(struct mm_struct *mm, struct vm_area_struct *vma,
+static pud_t *alloc_new_pud(struct mm_struct *mm, struct vm_area_struct *vma,
 			    unsigned long addr)
 {
 	pgd_t *pgd;
 	p4d_t *p4d;
-	pud_t *pud;
-	pmd_t *pmd;
 
 	pgd = pgd_offset(mm, addr);
 	p4d = p4d_alloc(mm, pgd, addr);
 	if (!p4d)
 		return NULL;
-	pud = pud_alloc(mm, p4d, addr);
+
+	return pud_alloc(mm, p4d, addr);
+}
+
+static pmd_t *alloc_new_pmd(struct mm_struct *mm, struct vm_area_struct *vma,
+			    unsigned long addr)
+{
+	pud_t *pud;
+	pmd_t *pmd;
+
+	pud = alloc_new_pud(mm, vma, addr);
 	if (!pud)
 		return NULL;
 
@@ -249,14 +268,148 @@ static bool move_normal_pmd(struct vm_area_struct *vma, unsigned long old_addr,
 
 	return true;
 }
+#else
+static inline bool move_normal_pmd(struct vm_area_struct *vma,
+		unsigned long old_addr, unsigned long new_addr, pmd_t *old_pmd,
+		pmd_t *new_pmd)
+{
+	return false;
+}
 #endif
 
+#ifdef CONFIG_HAVE_MOVE_PUD
+static bool move_normal_pud(struct vm_area_struct *vma, unsigned long old_addr,
+		  unsigned long new_addr, pud_t *old_pud, pud_t *new_pud)
+{
+	spinlock_t *old_ptl, *new_ptl;
+	struct mm_struct *mm = vma->vm_mm;
+	pud_t pud;
+
+	/*
+	 * The destination pud shouldn't be established, free_pgtables()
+	 * should have released it.
+	 */
+	if (WARN_ON_ONCE(!pud_none(*new_pud)))
+		return false;
+
+	/*
+	 * We don't have to worry about the ordering of src and dst
+	 * ptlocks because exclusive mmap_lock prevents deadlock.
+	 */
+	old_ptl = pud_lock(vma->vm_mm, old_pud);
+	new_ptl = pud_lockptr(mm, new_pud);
+	if (new_ptl != old_ptl)
+		spin_lock_nested(new_ptl, SINGLE_DEPTH_NESTING);
+
+	/* Clear the pud */
+	pud = *old_pud;
+	pud_clear(old_pud);
+
+	VM_BUG_ON(!pud_none(*new_pud));
+
+	/* Set the new pud */
+	set_pud_at(mm, new_addr, new_pud, pud);
+	flush_tlb_range(vma, old_addr, old_addr + PUD_SIZE);
+	if (new_ptl != old_ptl)
+		spin_unlock(new_ptl);
+	spin_unlock(old_ptl);
+
+	return true;
+}
+#else
+static inline bool move_normal_pud(struct vm_area_struct *vma,
+		unsigned long old_addr, unsigned long new_addr, pud_t *old_pud,
+		pud_t *new_pud)
+{
+	return false;
+}
+#endif
+
+enum pgt_entry {
+	NORMAL_PMD,
+	HPAGE_PMD,
+	NORMAL_PUD,
+};
+
+/*
+ * Returns an extent of the corresponding size for the pgt_entry specified if
+ * valid. Else returns a smaller extent bounded by the end of the source and
+ * destination pgt_entry.
+ */
+static unsigned long get_extent(enum pgt_entry entry, unsigned long old_addr,
+			unsigned long old_end, unsigned long new_addr)
+{
+	unsigned long next, extent, mask, size;
+
+	switch (entry) {
+	case HPAGE_PMD:
+	case NORMAL_PMD:
+		mask = PMD_MASK;
+		size = PMD_SIZE;
+		break;
+	case NORMAL_PUD:
+		mask = PUD_MASK;
+		size = PUD_SIZE;
+		break;
+	default:
+		BUILD_BUG();
+		break;
+	}
+
+	next = (old_addr + size) & mask;
+	/* even if next overflowed, extent below will be ok */
+	extent = (next > old_end) ? old_end - old_addr : next - old_addr;
+	next = (new_addr + size) & mask;
+	if (extent > next - new_addr)
+		extent = next - new_addr;
+	return extent;
+}
+
+/*
+ * Attempts to speedup the move by moving entry at the level corresponding to
+ * pgt_entry. Returns true if the move was successful, else false.
+ */
+static bool move_pgt_entry(enum pgt_entry entry, struct vm_area_struct *vma,
+			unsigned long old_addr, unsigned long new_addr,
+			void *old_entry, void *new_entry, bool need_rmap_locks)
+{
+	bool moved = false;
+
+	/* See comment in move_ptes() */
+	if (need_rmap_locks)
+		take_rmap_locks(vma);
+
+	switch (entry) {
+	case NORMAL_PMD:
+		moved = move_normal_pmd(vma, old_addr, new_addr, old_entry,
+					new_entry);
+		break;
+	case NORMAL_PUD:
+		moved = move_normal_pud(vma, old_addr, new_addr, old_entry,
+					new_entry);
+		break;
+	case HPAGE_PMD:
+		moved = IS_ENABLED(CONFIG_TRANSPARENT_HUGEPAGE) &&
+			move_huge_pmd(vma, old_addr, new_addr, old_entry,
+				      new_entry);
+		break;
+	default:
+		WARN_ON_ONCE(1);
+		break;
+	}
+
+	if (need_rmap_locks)
+		drop_rmap_locks(vma);
+
+	return moved;
+}
+
 unsigned long move_page_tables(struct vm_area_struct *vma,
 		unsigned long old_addr, struct vm_area_struct *new_vma,
 		unsigned long new_addr, unsigned long len,
 		bool need_rmap_locks)
 {
-	unsigned long extent, next, old_end;
+	unsigned long extent, old_end;
 	struct mmu_notifier_range range;
 	pmd_t *old_pmd, *new_pmd;
 
@@ -269,53 +422,50 @@ unsigned long move_page_tables(struct vm_area_struct *vma,
 
 	for (; old_addr < old_end; old_addr += extent, new_addr += extent) {
 		cond_resched();
-		next = (old_addr + PMD_SIZE) & PMD_MASK;
-		/* even if next overflowed, extent below will be ok */
-		extent = next - old_addr;
-		if (extent > old_end - old_addr)
-			extent = old_end - old_addr;
-		next = (new_addr + PMD_SIZE) & PMD_MASK;
-		if (extent > next - new_addr)
-			extent = next - new_addr;
+		/*
+		 * If extent is PUD-sized try to speed up the move by moving at the
+		 * PUD level if possible.
+		 */
+		extent = get_extent(NORMAL_PUD, old_addr, old_end, new_addr);
+		if (IS_ENABLED(CONFIG_HAVE_MOVE_PUD) && extent == PUD_SIZE) {
+			pud_t *old_pud, *new_pud;
+
+			old_pud = get_old_pud(vma->vm_mm, old_addr);
+			if (!old_pud)
+				continue;
+			new_pud = alloc_new_pud(vma->vm_mm, vma, new_addr);
+			if (!new_pud)
+				break;
+			if (move_pgt_entry(NORMAL_PUD, vma, old_addr, new_addr,
+					   old_pud, new_pud, need_rmap_locks))
+				continue;
+		}
+
+		extent = get_extent(NORMAL_PMD, old_addr, old_end, new_addr);
 		old_pmd = get_old_pmd(vma->vm_mm, old_addr);
 		if (!old_pmd)
 			continue;
 		new_pmd = alloc_new_pmd(vma->vm_mm, vma, new_addr);
 		if (!new_pmd)
 			break;
-		if (is_swap_pmd(*old_pmd) || pmd_trans_huge(*old_pmd) || pmd_devmap(*old_pmd)) {
-			if (extent == HPAGE_PMD_SIZE) {
-				bool moved;
-				/* See comment in move_ptes() */
-				if (need_rmap_locks)
-					take_rmap_locks(vma);
-				moved = move_huge_pmd(vma, old_addr, new_addr,
-						      old_pmd, new_pmd);
-				if (need_rmap_locks)
-					drop_rmap_locks(vma);
-				if (moved)
-					continue;
-			}
+		if (is_swap_pmd(*old_pmd) || pmd_trans_huge(*old_pmd) ||
+		    pmd_devmap(*old_pmd)) {
+			if (extent == HPAGE_PMD_SIZE &&
+			    move_pgt_entry(HPAGE_PMD, vma, old_addr, new_addr,
+					   old_pmd, new_pmd, need_rmap_locks))
+				continue;
 			split_huge_pmd(vma, old_pmd, old_addr);
 			if (pmd_trans_unstable(old_pmd))
 				continue;
-		} else if (extent == PMD_SIZE) {
-#ifdef CONFIG_HAVE_MOVE_PMD
+		} else if (IS_ENABLED(CONFIG_HAVE_MOVE_PMD) &&
+			   extent == PMD_SIZE) {
 			/*
 			 * If the extent is PMD-sized, try to speed the move by
 			 * moving at the PMD level if possible.
 			 */
-			bool moved;
-
-			if (need_rmap_locks)
-				take_rmap_locks(vma);
-			moved = move_normal_pmd(vma, old_addr, new_addr,
-						old_pmd, new_pmd);
-			if (need_rmap_locks)
-				drop_rmap_locks(vma);
-			if (moved)
+			if (move_pgt_entry(NORMAL_PMD, vma, old_addr, new_addr,
+					   old_pmd, new_pmd, need_rmap_locks))
 				continue;
-#endif
 		}
 
 		if (pte_alloc(new_vma->vm_mm, new_pmd))