diff --git a/arch/x86/include/asm/fpu/internal.h b/arch/x86/include/asm/fpu/internal.h
index 3e0c2c496f2d605d6a6e32eb8bf9beafa05a2138..6eb4a0b1ad0e5c9c1d082c07d0120e9c401e6af3 100644
--- a/arch/x86/include/asm/fpu/internal.h
+++ b/arch/x86/include/asm/fpu/internal.h
@@ -14,6 +14,7 @@
 #include <linux/compat.h>
 #include <linux/sched.h>
 #include <linux/slab.h>
+#include <linux/mm.h>
 
 #include <asm/user.h>
 #include <asm/fpu/api.h>
@@ -534,8 +535,27 @@ switch_fpu_prepare(struct fpu *old_fpu, int cpu)
  */
 static inline void switch_fpu_finish(struct fpu *new_fpu, int cpu)
 {
-	if (static_cpu_has(X86_FEATURE_FPU))
-		__fpregs_load_activate(new_fpu, cpu);
+	u32 pkru_val = init_pkru_value;
+	struct pkru_state *pk;
+
+	if (!static_cpu_has(X86_FEATURE_FPU))
+		return;
+
+	__fpregs_load_activate(new_fpu, cpu);
+
+	if (!cpu_feature_enabled(X86_FEATURE_OSPKE))
+		return;
+
+	/*
+	 * PKRU state is switched eagerly because it needs to be valid before we
+	 * return to userland e.g. for a copy_to_user() operation.
+	 */
+	if (current->mm) {
+		pk = get_xsave_addr(&new_fpu->state.xsave, XFEATURE_PKRU);
+		if (pk)
+			pkru_val = pk->pkru;
+	}
+	__write_pkru(pkru_val);
 }
 
 /*
diff --git a/arch/x86/include/asm/fpu/xstate.h b/arch/x86/include/asm/fpu/xstate.h
index fbe41f808e5d827baaaf78d36ada77424336c2e6..7e42b285c8562009510abd84d0e071e7ff084030 100644
--- a/arch/x86/include/asm/fpu/xstate.h
+++ b/arch/x86/include/asm/fpu/xstate.h
@@ -2,9 +2,11 @@
 #ifndef __ASM_X86_XSAVE_H
 #define __ASM_X86_XSAVE_H
 
+#include <linux/uaccess.h>
 #include <linux/types.h>
+
 #include <asm/processor.h>
-#include <linux/uaccess.h>
+#include <asm/user.h>
 
 /* Bit 63 of XCR0 is reserved for future expansion */
 #define XFEATURE_MASK_EXTEND	(~(XFEATURE_MASK_FPSSE | (1ULL << 63)))
diff --git a/arch/x86/include/asm/pgtable.h b/arch/x86/include/asm/pgtable.h
index e8875ca7562360ced86c7990431b160224eff658..9beb371b1adf5359d73c5369283d41c896aeacb7 100644
--- a/arch/x86/include/asm/pgtable.h
+++ b/arch/x86/include/asm/pgtable.h
@@ -1355,6 +1355,12 @@ static inline pmd_t pmd_swp_clear_soft_dirty(pmd_t pmd)
 #define PKRU_WD_BIT 0x2
 #define PKRU_BITS_PER_PKEY 2
 
+#ifdef CONFIG_X86_INTEL_MEMORY_PROTECTION_KEYS
+extern u32 init_pkru_value;
+#else
+#define init_pkru_value	0
+#endif
+
 static inline bool __pkru_allows_read(u32 pkru, u16 pkey)
 {
 	int pkru_pkey_bits = pkey * PKRU_BITS_PER_PKEY;
diff --git a/arch/x86/mm/pkeys.c b/arch/x86/mm/pkeys.c
index 50f65fc1b9a3f5ce296a1a2730329faaaf626c75..2ecbf4155f98fc434b029cb5feae58569229a138 100644
--- a/arch/x86/mm/pkeys.c
+++ b/arch/x86/mm/pkeys.c
@@ -126,7 +126,6 @@ int __arch_override_mprotect_pkey(struct vm_area_struct *vma, int prot, int pkey
  * in the process's lifetime will not accidentally get access
  * to data which is pkey-protected later on.
  */
-static
 u32 init_pkru_value = PKRU_AD_KEY( 1) | PKRU_AD_KEY( 2) | PKRU_AD_KEY( 3) |
 		      PKRU_AD_KEY( 4) | PKRU_AD_KEY( 5) | PKRU_AD_KEY( 6) |
 		      PKRU_AD_KEY( 7) | PKRU_AD_KEY( 8) | PKRU_AD_KEY( 9) |