diff --git a/io_uring/io_uring.c b/io_uring/io_uring.c
index 6cc16e39b27f001018abcc43b4ade8107e42d893..ac8c488e3077d9e1fb88cb9fd91063a4dc1c4b52 100644
--- a/io_uring/io_uring.c
+++ b/io_uring/io_uring.c
@@ -1173,7 +1173,7 @@ static void __cold io_move_task_work_from_local(struct io_ring_ctx *ctx)
 	}
 }
 
-int __io_run_local_work(struct io_ring_ctx *ctx, bool locked)
+int __io_run_local_work(struct io_ring_ctx *ctx, bool *locked)
 {
 	struct llist_node *node;
 	struct llist_node fake;
@@ -1192,7 +1192,7 @@ int __io_run_local_work(struct io_ring_ctx *ctx, bool locked)
 		struct io_kiocb *req = container_of(node, struct io_kiocb,
 						    io_task_work.node);
 		prefetch(container_of(next, struct io_kiocb, io_task_work.node));
-		req->io_task_work.func(req, &locked);
+		req->io_task_work.func(req, locked);
 		ret++;
 		node = next;
 	}
@@ -1208,7 +1208,7 @@ int __io_run_local_work(struct io_ring_ctx *ctx, bool locked)
 		goto again;
 	}
 
-	if (locked)
+	if (*locked)
 		io_submit_flush_completions(ctx);
 	trace_io_uring_local_work_run(ctx, ret, loops);
 	return ret;
@@ -1225,7 +1225,7 @@ int io_run_local_work(struct io_ring_ctx *ctx)
 
 	__set_current_state(TASK_RUNNING);
 	locked = mutex_trylock(&ctx->uring_lock);
-	ret = __io_run_local_work(ctx, locked);
+	ret = __io_run_local_work(ctx, &locked);
 	if (locked)
 		mutex_unlock(&ctx->uring_lock);
 
@@ -1446,8 +1446,7 @@ static int io_iopoll_check(struct io_ring_ctx *ctx, long min)
 		    io_task_work_pending(ctx)) {
 			u32 tail = ctx->cached_cq_tail;
 
-			if (!llist_empty(&ctx->work_llist))
-				__io_run_local_work(ctx, true);
+			(void) io_run_local_work_locked(ctx);
 
 			if (task_work_pending(current) ||
 			    wq_list_empty(&ctx->iopoll_list)) {
diff --git a/io_uring/io_uring.h b/io_uring/io_uring.h
index ef77d2aa3172ca470a7364af92abf57232ddbc8b..e99a79f2df9b18365559d6623133f77fdb60e3a6 100644
--- a/io_uring/io_uring.h
+++ b/io_uring/io_uring.h
@@ -27,7 +27,7 @@ enum {
 struct io_uring_cqe *__io_get_cqe(struct io_ring_ctx *ctx, bool overflow);
 bool io_req_cqe_overflow(struct io_kiocb *req);
 int io_run_task_work_sig(struct io_ring_ctx *ctx);
-int __io_run_local_work(struct io_ring_ctx *ctx, bool locked);
+int __io_run_local_work(struct io_ring_ctx *ctx, bool *locked);
 int io_run_local_work(struct io_ring_ctx *ctx);
 void io_req_complete_failed(struct io_kiocb *req, s32 res);
 void __io_req_complete(struct io_kiocb *req, unsigned issue_flags);
@@ -277,9 +277,18 @@ static inline int io_run_task_work_ctx(struct io_ring_ctx *ctx)
 
 static inline int io_run_local_work_locked(struct io_ring_ctx *ctx)
 {
+	bool locked;
+	int ret;
+
 	if (llist_empty(&ctx->work_llist))
 		return 0;
-	return __io_run_local_work(ctx, true);
+
+	locked = true;
+	ret = __io_run_local_work(ctx, &locked);
+	/* shouldn't happen! */
+	if (WARN_ON_ONCE(!locked))
+		mutex_lock(&ctx->uring_lock);
+	return ret;
 }
 
 static inline void io_tw_lock(struct io_ring_ctx *ctx, bool *locked)