diff --git a/arch/x86/include/asm/kvm_host.h b/arch/x86/include/asm/kvm_host.h
index 87ac4fba6d8e12f07e8a9f191bdb028a1c3e6234..f4d120a3e22e8aee8326bebeedeaa506c2291b47 100644
--- a/arch/x86/include/asm/kvm_host.h
+++ b/arch/x86/include/asm/kvm_host.h
@@ -492,6 +492,7 @@ struct kvm_vcpu_arch {
 	unsigned long cr4;
 	unsigned long cr4_guest_owned_bits;
 	unsigned long cr8;
+	u32 pkru;
 	u32 hflags;
 	u64 efer;
 	u64 apic_base;
diff --git a/arch/x86/kvm/kvm_cache_regs.h b/arch/x86/kvm/kvm_cache_regs.h
index 762cdf2595f992fd4ac8bb1e4c2c8914b344db04..e1e89ee4af750dc51f78cfbf8aa71e22d77b4cd1 100644
--- a/arch/x86/kvm/kvm_cache_regs.h
+++ b/arch/x86/kvm/kvm_cache_regs.h
@@ -84,11 +84,6 @@ static inline u64 kvm_read_edx_eax(struct kvm_vcpu *vcpu)
 		| ((u64)(kvm_register_read(vcpu, VCPU_REGS_RDX) & -1u) << 32);
 }
 
-static inline u32 kvm_read_pkru(struct kvm_vcpu *vcpu)
-{
-	return kvm_x86_ops->get_pkru(vcpu);
-}
-
 static inline void enter_guest_mode(struct kvm_vcpu *vcpu)
 {
 	vcpu->arch.hflags |= HF_GUEST_MASK;
diff --git a/arch/x86/kvm/mmu.h b/arch/x86/kvm/mmu.h
index d7d248a000dd6772681f3f5541e344f9677a2d1d..4b9a3ae6b725d37bdeeb5aeb24c0a9c2716228f4 100644
--- a/arch/x86/kvm/mmu.h
+++ b/arch/x86/kvm/mmu.h
@@ -185,7 +185,7 @@ static inline u8 permission_fault(struct kvm_vcpu *vcpu, struct kvm_mmu *mmu,
 		* index of the protection domain, so pte_pkey * 2 is
 		* is the index of the first bit for the domain.
 		*/
-		pkru_bits = (kvm_read_pkru(vcpu) >> (pte_pkey * 2)) & 3;
+		pkru_bits = (vcpu->arch.pkru >> (pte_pkey * 2)) & 3;
 
 		/* clear present bit, replace PFEC.RSVD with ACC_USER_MASK. */
 		offset = (pfec & ~1) +
diff --git a/arch/x86/kvm/svm.c b/arch/x86/kvm/svm.c
index 56ba05312759d3ed4568e546492d9d8bfad05b71..af256b786a70cccd14906fe926e2b790156967a4 100644
--- a/arch/x86/kvm/svm.c
+++ b/arch/x86/kvm/svm.c
@@ -1777,11 +1777,6 @@ static void svm_set_rflags(struct kvm_vcpu *vcpu, unsigned long rflags)
 	to_svm(vcpu)->vmcb->save.rflags = rflags;
 }
 
-static u32 svm_get_pkru(struct kvm_vcpu *vcpu)
-{
-	return 0;
-}
-
 static void svm_cache_reg(struct kvm_vcpu *vcpu, enum kvm_reg reg)
 {
 	switch (reg) {
@@ -5413,8 +5408,6 @@ static struct kvm_x86_ops svm_x86_ops __ro_after_init = {
 	.get_rflags = svm_get_rflags,
 	.set_rflags = svm_set_rflags,
 
-	.get_pkru = svm_get_pkru,
-
 	.tlb_flush = svm_flush_tlb,
 
 	.run = svm_vcpu_run,
diff --git a/arch/x86/kvm/vmx.c b/arch/x86/kvm/vmx.c
index 9b21b12230354e334900e6536b7612285f75b7e3..c6ef2940119bfdfb00f0547c321697280ebae0fa 100644
--- a/arch/x86/kvm/vmx.c
+++ b/arch/x86/kvm/vmx.c
@@ -636,8 +636,6 @@ struct vcpu_vmx {
 
 	u64 current_tsc_ratio;
 
-	bool guest_pkru_valid;
-	u32 guest_pkru;
 	u32 host_pkru;
 
 	/*
@@ -2383,11 +2381,6 @@ static void vmx_set_rflags(struct kvm_vcpu *vcpu, unsigned long rflags)
 		to_vmx(vcpu)->emulation_required = emulation_required(vcpu);
 }
 
-static u32 vmx_get_pkru(struct kvm_vcpu *vcpu)
-{
-	return to_vmx(vcpu)->guest_pkru;
-}
-
 static u32 vmx_get_interrupt_shadow(struct kvm_vcpu *vcpu)
 {
 	u32 interruptibility = vmcs_read32(GUEST_INTERRUPTIBILITY_INFO);
@@ -9020,8 +9013,10 @@ static void __noclone vmx_vcpu_run(struct kvm_vcpu *vcpu)
 	if (vcpu->guest_debug & KVM_GUESTDBG_SINGLESTEP)
 		vmx_set_interrupt_shadow(vcpu, 0);
 
-	if (vmx->guest_pkru_valid)
-		__write_pkru(vmx->guest_pkru);
+	if (static_cpu_has(X86_FEATURE_PKU) &&
+	    kvm_read_cr4_bits(vcpu, X86_CR4_PKE) &&
+	    vcpu->arch.pkru != vmx->host_pkru)
+		__write_pkru(vcpu->arch.pkru);
 
 	atomic_switch_perf_msrs(vmx);
 	debugctlmsr = get_debugctlmsr();
@@ -9169,13 +9164,11 @@ static void __noclone vmx_vcpu_run(struct kvm_vcpu *vcpu)
 	 * back on host, so it is safe to read guest PKRU from current
 	 * XSAVE.
 	 */
-	if (boot_cpu_has(X86_FEATURE_OSPKE)) {
-		vmx->guest_pkru = __read_pkru();
-		if (vmx->guest_pkru != vmx->host_pkru) {
-			vmx->guest_pkru_valid = true;
+	if (static_cpu_has(X86_FEATURE_PKU) &&
+	    kvm_read_cr4_bits(vcpu, X86_CR4_PKE)) {
+		vcpu->arch.pkru = __read_pkru();
+		if (vcpu->arch.pkru != vmx->host_pkru)
 			__write_pkru(vmx->host_pkru);
-		} else
-			vmx->guest_pkru_valid = false;
 	}
 
 	/*
@@ -11682,8 +11675,6 @@ static struct kvm_x86_ops vmx_x86_ops __ro_after_init = {
 	.get_rflags = vmx_get_rflags,
 	.set_rflags = vmx_set_rflags,
 
-	.get_pkru = vmx_get_pkru,
-
 	.tlb_flush = vmx_flush_tlb,
 
 	.run = vmx_vcpu_run,