diff --git a/mm/memory.c b/mm/memory.c
index ad0ea1af1f4497684ba49a39185a10d8ee997bda..7c521a6ec7c685d548b6498edbf6c68eab6c11ca 100644
--- a/mm/memory.c
+++ b/mm/memory.c
@@ -3961,6 +3961,7 @@ static int __handle_mm_fault(struct vm_area_struct *vma, unsigned long address,
 		.pgoff = linear_page_index(vma, address),
 		.gfp_mask = __get_fault_gfp_mask(vma),
 	};
+	unsigned int dirty = flags & FAULT_FLAG_WRITE;
 	struct mm_struct *mm = vma->vm_mm;
 	pgd_t *pgd;
 	p4d_t *p4d;
@@ -3983,7 +3984,6 @@ static int __handle_mm_fault(struct vm_area_struct *vma, unsigned long address,
 
 		barrier();
 		if (pud_trans_huge(orig_pud) || pud_devmap(orig_pud)) {
-			unsigned int dirty = flags & FAULT_FLAG_WRITE;
 
 			/* NUMA case for anonymous PUDs would go here */
 
@@ -4020,8 +4020,7 @@ static int __handle_mm_fault(struct vm_area_struct *vma, unsigned long address,
 			if (pmd_protnone(orig_pmd) && vma_is_accessible(vma))
 				return do_huge_pmd_numa_page(&vmf, orig_pmd);
 
-			if ((vmf.flags & FAULT_FLAG_WRITE) &&
-					!pmd_write(orig_pmd)) {
+			if (dirty && !pmd_write(orig_pmd)) {
 				ret = wp_huge_pmd(&vmf, orig_pmd);
 				if (!(ret & VM_FAULT_FALLBACK))
 					return ret;