diff --git a/arch/x86/include/asm/fpu/internal.h b/arch/x86/include/asm/fpu/internal.h
index 255645f60ca2b4be333de67a10c6780f364e6fb2..554cdb205d17586887e6b599c991b2fc090ba38a 100644
--- a/arch/x86/include/asm/fpu/internal.h
+++ b/arch/x86/include/asm/fpu/internal.h
@@ -450,10 +450,10 @@ static inline int copy_fpregs_to_fpstate(struct fpu *fpu)
 	return 0;
 }
 
-static inline void __copy_kernel_to_fpregs(union fpregs_state *fpstate)
+static inline void __copy_kernel_to_fpregs(union fpregs_state *fpstate, u64 mask)
 {
 	if (use_xsave()) {
-		copy_kernel_to_xregs(&fpstate->xsave, -1);
+		copy_kernel_to_xregs(&fpstate->xsave, mask);
 	} else {
 		if (use_fxsr())
 			copy_kernel_to_fxregs(&fpstate->fxsave);
@@ -477,7 +477,7 @@ static inline void copy_kernel_to_fpregs(union fpregs_state *fpstate)
 			: : [addr] "m" (fpstate));
 	}
 
-	__copy_kernel_to_fpregs(fpstate);
+	__copy_kernel_to_fpregs(fpstate, -1);
 }
 
 extern int copy_fpstate_to_sigframe(void __user *buf, void __user *fp, int size);
diff --git a/arch/x86/kvm/x86.c b/arch/x86/kvm/x86.c
index d734aa8c5b4f7290e365badd00ea962fd0af9acd..05a5e57c6f39770e81e2af3f760f9ce44783f9b3 100644
--- a/arch/x86/kvm/x86.c
+++ b/arch/x86/kvm/x86.c
@@ -3245,7 +3245,12 @@ static void fill_xsave(u8 *dest, struct kvm_vcpu *vcpu)
 			u32 size, offset, ecx, edx;
 			cpuid_count(XSTATE_CPUID, index,
 				    &size, &offset, &ecx, &edx);
-			memcpy(dest + offset, src, size);
+			if (feature == XFEATURE_MASK_PKRU)
+				memcpy(dest + offset, &vcpu->arch.pkru,
+				       sizeof(vcpu->arch.pkru));
+			else
+				memcpy(dest + offset, src, size);
+
 		}
 
 		valid -= feature;
@@ -3283,7 +3288,11 @@ static void load_xsave(struct kvm_vcpu *vcpu, u8 *src)
 			u32 size, offset, ecx, edx;
 			cpuid_count(XSTATE_CPUID, index,
 				    &size, &offset, &ecx, &edx);
-			memcpy(dest, src + offset, size);
+			if (feature == XFEATURE_MASK_PKRU)
+				memcpy(&vcpu->arch.pkru, src + offset,
+				       sizeof(vcpu->arch.pkru));
+			else
+				memcpy(dest, src + offset, size);
 		}
 
 		valid -= feature;
@@ -7633,7 +7642,9 @@ void kvm_load_guest_fpu(struct kvm_vcpu *vcpu)
 	 */
 	vcpu->guest_fpu_loaded = 1;
 	__kernel_fpu_begin();
-	__copy_kernel_to_fpregs(&vcpu->arch.guest_fpu.state);
+	/* PKRU is separately restored in kvm_x86_ops->run.  */
+	__copy_kernel_to_fpregs(&vcpu->arch.guest_fpu.state,
+				~XFEATURE_MASK_PKRU);
 	trace_kvm_fpu(1);
 }