diff --git a/kernel/bpf/core.c b/kernel/bpf/core.c
index aa1e64196d8d2ffb1ab92a55a78dfee6e8af4bf5..3a283bf97f2f674927d17da17efec2f5906c6580 100644
--- a/kernel/bpf/core.c
+++ b/kernel/bpf/core.c
@@ -2344,6 +2344,10 @@ bool __weak bpf_helper_changes_pkt_data(void *func)
 /* Return TRUE if the JIT backend wants verifier to enable sub-register usage
  * analysis code and wants explicit zero extension inserted by verifier.
  * Otherwise, return FALSE.
+ *
+ * The verifier inserts an explicit zero extension after BPF_CMPXCHGs even if
+ * you don't override this. JITs that don't want these extra insns can detect
+ * them using insn_is_zext.
  */
 bool __weak bpf_jit_needs_zext(void)
 {
diff --git a/kernel/bpf/verifier.c b/kernel/bpf/verifier.c
index bb3eaab934f3975afbd5f3004455bdcf11082cbd..c56e3fcb5f1a07d2e0df00d82f01c4e6a1b12ff9 100644
--- a/kernel/bpf/verifier.c
+++ b/kernel/bpf/verifier.c
@@ -504,6 +504,13 @@ static bool is_ptr_cast_function(enum bpf_func_id func_id)
 		func_id == BPF_FUNC_skc_to_tcp_request_sock;
 }
 
+static bool is_cmpxchg_insn(const struct bpf_insn *insn)
+{
+	return BPF_CLASS(insn->code) == BPF_STX &&
+	       BPF_MODE(insn->code) == BPF_ATOMIC &&
+	       insn->imm == BPF_CMPXCHG;
+}
+
 /* string representation of 'enum bpf_reg_type' */
 static const char * const reg_type_str[] = {
 	[NOT_INIT]		= "?",
@@ -11067,7 +11074,17 @@ static int opt_subreg_zext_lo32_rnd_hi32(struct bpf_verifier_env *env,
 			goto apply_patch_buffer;
 		}
 
-		if (!bpf_jit_needs_zext())
+		/* Add in an zero-extend instruction if a) the JIT has requested
+		 * it or b) it's a CMPXCHG.
+		 *
+		 * The latter is because: BPF_CMPXCHG always loads a value into
+		 * R0, therefore always zero-extends. However some archs'
+		 * equivalent instruction only does this load when the
+		 * comparison is successful. This detail of CMPXCHG is
+		 * orthogonal to the general zero-extension behaviour of the
+		 * CPU, so it's treated independently of bpf_jit_needs_zext.
+		 */
+		if (!bpf_jit_needs_zext() && !is_cmpxchg_insn(&insn))
 			continue;
 
 		if (WARN_ON(load_reg == -1)) {
diff --git a/tools/testing/selftests/bpf/verifier/atomic_cmpxchg.c b/tools/testing/selftests/bpf/verifier/atomic_cmpxchg.c
index 2efd8bcf57a1e4c514afa4aad30b5b95422da30b..6e52dfc644153f08af2587ed27eda74fa66bfa5c 100644
--- a/tools/testing/selftests/bpf/verifier/atomic_cmpxchg.c
+++ b/tools/testing/selftests/bpf/verifier/atomic_cmpxchg.c
@@ -94,3 +94,28 @@
 	.result = REJECT,
 	.errstr = "invalid read from stack",
 },
+{
+	"BPF_W cmpxchg should zero top 32 bits",
+	.insns = {
+		/* r0 = U64_MAX; */
+		BPF_MOV64_IMM(BPF_REG_0, 0),
+		BPF_ALU64_IMM(BPF_SUB, BPF_REG_0, 1),
+		/* u64 val = r0; */
+		BPF_STX_MEM(BPF_DW, BPF_REG_10, BPF_REG_0, -8),
+		/* r0 = (u32)atomic_cmpxchg((u32 *)&val, r0, 1); */
+		BPF_MOV32_IMM(BPF_REG_1, 1),
+		BPF_ATOMIC_OP(BPF_W, BPF_CMPXCHG, BPF_REG_10, BPF_REG_1, -8),
+		/* r1 = 0x00000000FFFFFFFFull; */
+		BPF_MOV64_IMM(BPF_REG_1, 1),
+		BPF_ALU64_IMM(BPF_LSH, BPF_REG_1, 32),
+		BPF_ALU64_IMM(BPF_SUB, BPF_REG_1, 1),
+		/* if (r0 != r1) exit(1); */
+		BPF_JMP_REG(BPF_JEQ, BPF_REG_0, BPF_REG_1, 2),
+		BPF_MOV32_IMM(BPF_REG_0, 1),
+		BPF_EXIT_INSN(),
+		/* exit(0); */
+		BPF_MOV32_IMM(BPF_REG_0, 0),
+		BPF_EXIT_INSN(),
+	},
+	.result = ACCEPT,
+},
diff --git a/tools/testing/selftests/bpf/verifier/atomic_or.c b/tools/testing/selftests/bpf/verifier/atomic_or.c
index 70f982e1f9f05fc41c9da5f887d26f05df6e742b..9d0716ac508080a86db2109a96cbfd9ac2707055 100644
--- a/tools/testing/selftests/bpf/verifier/atomic_or.c
+++ b/tools/testing/selftests/bpf/verifier/atomic_or.c
@@ -75,3 +75,28 @@
 	},
 	.result = ACCEPT,
 },
+{
+	"BPF_W atomic_fetch_or should zero top 32 bits",
+	.insns = {
+		/* r1 = U64_MAX; */
+		BPF_MOV64_IMM(BPF_REG_1, 0),
+		BPF_ALU64_IMM(BPF_SUB, BPF_REG_1, 1),
+		/* u64 val = r1; */
+		BPF_STX_MEM(BPF_DW, BPF_REG_10, BPF_REG_1, -8),
+		/* r1 = (u32)atomic_fetch_or((u32 *)&val, 2); */
+		BPF_MOV32_IMM(BPF_REG_1, 2),
+		BPF_ATOMIC_OP(BPF_W, BPF_OR | BPF_FETCH, BPF_REG_10, BPF_REG_1, -8),
+		/* r2 = 0x00000000FFFFFFFF; */
+		BPF_MOV64_IMM(BPF_REG_2, 1),
+		BPF_ALU64_IMM(BPF_LSH, BPF_REG_2, 32),
+		BPF_ALU64_IMM(BPF_SUB, BPF_REG_2, 1),
+		/* if (r2 != r1) exit(1); */
+		BPF_JMP_REG(BPF_JEQ, BPF_REG_2, BPF_REG_1, 2),
+		BPF_MOV64_REG(BPF_REG_0, BPF_REG_1),
+		BPF_EXIT_INSN(),
+		/* exit(0); */
+		BPF_MOV32_IMM(BPF_REG_0, 0),
+		BPF_EXIT_INSN(),
+	},
+	.result = ACCEPT,
+},