diff --git a/block/blk-wbt.c b/block/blk-wbt.c
index 4575b4650370131e795cc05678f851cc4b7ff574..bfb0d21d19ce3799550764b19d68a9f66df7328d 100644
--- a/block/blk-wbt.c
+++ b/block/blk-wbt.c
@@ -161,7 +161,7 @@ static void wbt_rqw_done(struct rq_wb *rwb, struct rq_wait *rqw,
 		int diff = limit - inflight;
 
 		if (!inflight || diff >= rwb->wb_background / 2)
-			wake_up(&rqw->wait);
+			wake_up_all(&rqw->wait);
 	}
 }
 
@@ -488,6 +488,34 @@ static inline unsigned int get_limit(struct rq_wb *rwb, unsigned long rw)
 	return limit;
 }
 
+struct wbt_wait_data {
+	struct wait_queue_entry wq;
+	struct task_struct *task;
+	struct rq_wb *rwb;
+	struct rq_wait *rqw;
+	unsigned long rw;
+	bool got_token;
+};
+
+static int wbt_wake_function(struct wait_queue_entry *curr, unsigned int mode,
+			     int wake_flags, void *key)
+{
+	struct wbt_wait_data *data = container_of(curr, struct wbt_wait_data,
+							wq);
+
+	/*
+	 * If we fail to get a budget, return -1 to interrupt the wake up
+	 * loop in __wake_up_common.
+	 */
+	if (!rq_wait_inc_below(data->rqw, get_limit(data->rwb, data->rw)))
+		return -1;
+
+	data->got_token = true;
+	list_del_init(&curr->entry);
+	wake_up_process(data->task);
+	return 1;
+}
+
 /*
  * Block if we will exceed our limit, or if we are currently waiting for
  * the timer to kick off queuing again.
@@ -498,19 +526,40 @@ static void __wbt_wait(struct rq_wb *rwb, enum wbt_flags wb_acct,
 	__acquires(lock)
 {
 	struct rq_wait *rqw = get_rq_wait(rwb, wb_acct);
-	DECLARE_WAITQUEUE(wait, current);
+	struct wbt_wait_data data = {
+		.wq = {
+			.func	= wbt_wake_function,
+			.entry	= LIST_HEAD_INIT(data.wq.entry),
+		},
+		.task = current,
+		.rwb = rwb,
+		.rqw = rqw,
+		.rw = rw,
+	};
 	bool has_sleeper;
 
 	has_sleeper = wq_has_sleeper(&rqw->wait);
 	if (!has_sleeper && rq_wait_inc_below(rqw, get_limit(rwb, rw)))
 		return;
 
-	add_wait_queue_exclusive(&rqw->wait, &wait);
+	prepare_to_wait_exclusive(&rqw->wait, &data.wq, TASK_UNINTERRUPTIBLE);
 	do {
-		set_current_state(TASK_UNINTERRUPTIBLE);
+		if (data.got_token)
+			break;
 
-		if (!has_sleeper && rq_wait_inc_below(rqw, get_limit(rwb, rw)))
+		if (!has_sleeper &&
+		    rq_wait_inc_below(rqw, get_limit(rwb, rw))) {
+			finish_wait(&rqw->wait, &data.wq);
+
+			/*
+			 * We raced with wbt_wake_function() getting a token,
+			 * which means we now have two. Put our local token
+			 * and wake anyone else potentially waiting for one.
+			 */
+			if (data.got_token)
+				wbt_rqw_done(rwb, rqw, wb_acct);
 			break;
+		}
 
 		if (lock) {
 			spin_unlock_irq(lock);
@@ -518,11 +567,11 @@ static void __wbt_wait(struct rq_wb *rwb, enum wbt_flags wb_acct,
 			spin_lock_irq(lock);
 		} else
 			io_schedule();
+
 		has_sleeper = false;
 	} while (1);
 
-	__set_current_state(TASK_RUNNING);
-	remove_wait_queue(&rqw->wait, &wait);
+	finish_wait(&rqw->wait, &data.wq);
 }
 
 static inline bool wbt_should_throttle(struct rq_wb *rwb, struct bio *bio)