diff --git a/include/linux/mm.h b/include/linux/mm.h
index ad06d42adb1a2602094f894b69ed182cbac978fe..ae806dbc63eef412b3999b42e57fe19b8bd2e2bf 100644
--- a/include/linux/mm.h
+++ b/include/linux/mm.h
@@ -287,6 +287,12 @@ extern unsigned int kobjsize(const void *objp);
 /* This mask is used to clear all the VMA flags used by mlock */
 #define VM_LOCKED_CLEAR_MASK	(~(VM_LOCKED | VM_LOCKONFAULT))
 
+/* Arch-specific flags to clear when updating VM flags on protection change */
+#ifndef VM_ARCH_CLEAR
+# define VM_ARCH_CLEAR	VM_NONE
+#endif
+#define VM_FLAGS_CLEAR	(ARCH_VM_PKEY_FLAGS | VM_ARCH_CLEAR)
+
 /*
  * mapping from the currently active vm_flags protection bits (the
  * low four bits) to a page protection mask..
diff --git a/mm/mprotect.c b/mm/mprotect.c
index 088ea9c08678677fbd7d2318c22e67129b80e25a..c1d6af7455da542b9462cfc5ef8d64fa1d386a4f 100644
--- a/mm/mprotect.c
+++ b/mm/mprotect.c
@@ -475,7 +475,7 @@ static int do_mprotect_pkey(unsigned long start, size_t len,
 		 * cleared from the VMA.
 		 */
 		mask_off_old_flags = VM_READ | VM_WRITE | VM_EXEC |
-					ARCH_VM_PKEY_FLAGS;
+					VM_FLAGS_CLEAR;
 
 		new_vma_pkey = arch_override_mprotect_pkey(vma, prot, pkey);
 		newflags = calc_vm_prot_bits(prot, new_vma_pkey);