diff --git a/arch/x86/kvm/mmu.c b/arch/x86/kvm/mmu.c
index bac8d228d82bb2dc0d40d3f6a911182212b2a41d..24c23c66b2263c5b20b86119c192711006e7e189 100644
--- a/arch/x86/kvm/mmu.c
+++ b/arch/x86/kvm/mmu.c
@@ -92,6 +92,7 @@ module_param(dbg, bool, 0644);
 #define SPTE_SPECIAL_MASK (3ULL << 52)
 #define SPTE_AD_ENABLED_MASK (0ULL << 52)
 #define SPTE_AD_DISABLED_MASK (1ULL << 52)
+#define SPTE_AD_WRPROT_ONLY_MASK (2ULL << 52)
 #define SPTE_MMIO_MASK (3ULL << 52)
 
 #define PT64_LEVEL_BITS 9
@@ -328,10 +329,27 @@ static inline bool sp_ad_disabled(struct kvm_mmu_page *sp)
 	return sp->role.ad_disabled;
 }
 
+static inline bool kvm_vcpu_ad_need_write_protect(struct kvm_vcpu *vcpu)
+{
+	/*
+	 * When using the EPT page-modification log, the GPAs in the log
+	 * would come from L2 rather than L1.  Therefore, we need to rely
+	 * on write protection to record dirty pages.  This also bypasses
+	 * PML, since writes now result in a vmexit.
+	 */
+	return vcpu->arch.mmu == &vcpu->arch.guest_mmu;
+}
+
 static inline bool spte_ad_enabled(u64 spte)
 {
 	MMU_WARN_ON(is_mmio_spte(spte));
-	return (spte & SPTE_SPECIAL_MASK) == SPTE_AD_ENABLED_MASK;
+	return (spte & SPTE_SPECIAL_MASK) != SPTE_AD_DISABLED_MASK;
+}
+
+static inline bool spte_ad_need_write_protect(u64 spte)
+{
+	MMU_WARN_ON(is_mmio_spte(spte));
+	return (spte & SPTE_SPECIAL_MASK) != SPTE_AD_ENABLED_MASK;
 }
 
 static inline u64 spte_shadow_accessed_mask(u64 spte)
@@ -1597,16 +1615,16 @@ static bool spte_clear_dirty(u64 *sptep)
 
 	rmap_printk("rmap_clear_dirty: spte %p %llx\n", sptep, *sptep);
 
+	MMU_WARN_ON(!spte_ad_enabled(spte));
 	spte &= ~shadow_dirty_mask;
-
 	return mmu_spte_update(sptep, spte);
 }
 
-static bool wrprot_ad_disabled_spte(u64 *sptep)
+static bool spte_wrprot_for_clear_dirty(u64 *sptep)
 {
 	bool was_writable = test_and_clear_bit(PT_WRITABLE_SHIFT,
 					       (unsigned long *)sptep);
-	if (was_writable)
+	if (was_writable && !spte_ad_enabled(*sptep))
 		kvm_set_pfn_dirty(spte_to_pfn(*sptep));
 
 	return was_writable;
@@ -1625,10 +1643,10 @@ static bool __rmap_clear_dirty(struct kvm *kvm, struct kvm_rmap_head *rmap_head)
 	bool flush = false;
 
 	for_each_rmap_spte(rmap_head, &iter, sptep)
-		if (spte_ad_enabled(*sptep))
-			flush |= spte_clear_dirty(sptep);
+		if (spte_ad_need_write_protect(*sptep))
+			flush |= spte_wrprot_for_clear_dirty(sptep);
 		else
-			flush |= wrprot_ad_disabled_spte(sptep);
+			flush |= spte_clear_dirty(sptep);
 
 	return flush;
 }
@@ -1639,6 +1657,11 @@ static bool spte_set_dirty(u64 *sptep)
 
 	rmap_printk("rmap_set_dirty: spte %p %llx\n", sptep, *sptep);
 
+	/*
+	 * Similar to the !kvm_x86_ops->slot_disable_log_dirty case,
+	 * do not bother adding back write access to pages marked
+	 * SPTE_AD_WRPROT_ONLY_MASK.
+	 */
 	spte |= shadow_dirty_mask;
 
 	return mmu_spte_update(sptep, spte);
@@ -2977,6 +3000,8 @@ static int set_spte(struct kvm_vcpu *vcpu, u64 *sptep,
 	sp = page_header(__pa(sptep));
 	if (sp_ad_disabled(sp))
 		spte |= SPTE_AD_DISABLED_MASK;
+	else if (kvm_vcpu_ad_need_write_protect(vcpu))
+		spte |= SPTE_AD_WRPROT_ONLY_MASK;
 
 	/*
 	 * For the EPT case, shadow_present_mask is 0 if hardware