diff --git a/mm/huge_memory.c b/mm/huge_memory.c
index c1411961167e1257c1a56b000512f36d1766471e..36c070167b7109f0639ff4188b6b67772e832591 100644
--- a/mm/huge_memory.c
+++ b/mm/huge_memory.c
@@ -858,7 +858,8 @@ static bool set_huge_zero_page(pgtable_t pgtable, struct mm_struct *mm,
 		return false;
 	entry = mk_pmd(zero_page, vma->vm_page_prot);
 	entry = pmd_mkhuge(entry);
-	pgtable_trans_huge_deposit(mm, pmd, pgtable);
+	if (pgtable)
+		pgtable_trans_huge_deposit(mm, pmd, pgtable);
 	set_pmd_at(mm, haddr, pmd, entry);
 	atomic_long_inc(&mm->nr_ptes);
 	return true;
@@ -1036,13 +1037,15 @@ int copy_huge_pmd(struct mm_struct *dst_mm, struct mm_struct *src_mm,
 	spinlock_t *dst_ptl, *src_ptl;
 	struct page *src_page;
 	pmd_t pmd;
-	pgtable_t pgtable;
+	pgtable_t pgtable = NULL;
 	int ret;
 
-	ret = -ENOMEM;
-	pgtable = pte_alloc_one(dst_mm, addr);
-	if (unlikely(!pgtable))
-		goto out;
+	if (!vma_is_dax(vma)) {
+		ret = -ENOMEM;
+		pgtable = pte_alloc_one(dst_mm, addr);
+		if (unlikely(!pgtable))
+			goto out;
+	}
 
 	dst_ptl = pmd_lock(dst_mm, dst_pmd);
 	src_ptl = pmd_lockptr(src_mm, src_pmd);
@@ -1073,7 +1076,7 @@ int copy_huge_pmd(struct mm_struct *dst_mm, struct mm_struct *src_mm,
 		goto out_unlock;
 	}
 
-	if (pmd_trans_huge(pmd)) {
+	if (!vma_is_dax(vma)) {
 		/* thp accounting separate from pmd_devmap accounting */
 		src_page = pmd_page(pmd);
 		VM_BUG_ON_PAGE(!PageHead(src_page), src_page);