diff --git a/fs/dax.c b/fs/dax.c
index 306c2b603fb8aa8845558a2bf8226523e5282cdb..865d42c63e23e4746c6658fbe2746fd17ade11f9 100644
--- a/fs/dax.c
+++ b/fs/dax.c
@@ -1383,6 +1383,16 @@ static int dax_iomap_pmd_fault(struct vm_fault *vmf,
 
 	trace_dax_pmd_fault(inode, vmf, max_pgoff, 0);
 
+	/*
+	 * Make sure that the faulting address's PMD offset (color) matches
+	 * the PMD offset from the start of the file.  This is necessary so
+	 * that a PMD range in the page table overlaps exactly with a PMD
+	 * range in the radix tree.
+	 */
+	if ((vmf->pgoff & PG_PMD_COLOUR) !=
+	    ((vmf->address >> PAGE_SHIFT) & PG_PMD_COLOUR))
+		goto fallback;
+
 	/* Fall back to PTEs if we're going to COW */
 	if (write && !(vma->vm_flags & VM_SHARED))
 		goto fallback;