diff --git a/fs/bcachefs/alloc_background.c b/fs/bcachefs/alloc_background.c
index 614da759226b553fb3bdf15769d700b9e6d24293..10704f2d3af5302f71a931e13bd0ba5432d46fe2 100644
--- a/fs/bcachefs/alloc_background.c
+++ b/fs/bcachefs/alloc_background.c
@@ -1604,13 +1604,36 @@ int bch2_check_alloc_to_lru_refs(struct bch_fs *c)
 	return ret;
 }
 
+struct discard_buckets_state {
+	u64		seen;
+	u64		open;
+	u64		need_journal_commit;
+	u64		discarded;
+	struct bch_dev	*ca;
+	u64		need_journal_commit_this_dev;
+};
+
+static void discard_buckets_next_dev(struct bch_fs *c, struct discard_buckets_state *s, struct bch_dev *ca)
+{
+	if (s->ca == ca)
+		return;
+
+	if (s->ca && s->need_journal_commit_this_dev >
+	    bch2_dev_usage_read(s->ca).d[BCH_DATA_free].buckets)
+		bch2_journal_flush_async(&c->journal, NULL);
+
+	if (s->ca)
+		percpu_ref_put(&s->ca->ref);
+	if (ca)
+		percpu_ref_get(&ca->ref);
+	s->ca = ca;
+	s->need_journal_commit_this_dev = 0;
+}
+
 static int bch2_discard_one_bucket(struct btree_trans *trans,
 				   struct btree_iter *need_discard_iter,
 				   struct bpos *discard_pos_done,
-				   u64 *seen,
-				   u64 *open,
-				   u64 *need_journal_commit,
-				   u64 *discarded)
+				   struct discard_buckets_state *s)
 {
 	struct bch_fs *c = trans->c;
 	struct bpos pos = need_discard_iter->pos;
@@ -1622,20 +1645,24 @@ static int bch2_discard_one_bucket(struct btree_trans *trans,
 	int ret = 0;
 
 	ca = bch_dev_bkey_exists(c, pos.inode);
+
 	if (!percpu_ref_tryget(&ca->io_ref)) {
 		bch2_btree_iter_set_pos(need_discard_iter, POS(pos.inode + 1, 0));
 		return 0;
 	}
 
+	discard_buckets_next_dev(c, s, ca);
+
 	if (bch2_bucket_is_open_safe(c, pos.inode, pos.offset)) {
-		(*open)++;
+		s->open++;
 		goto out;
 	}
 
 	if (bch2_bucket_needs_journal_commit(&c->buckets_waiting_for_journal,
 			c->journal.flushed_seq_ondisk,
 			pos.inode, pos.offset)) {
-		(*need_journal_commit)++;
+		s->need_journal_commit++;
+		s->need_journal_commit_this_dev++;
 		goto out;
 	}
 
@@ -1711,9 +1738,9 @@ static int bch2_discard_one_bucket(struct btree_trans *trans,
 		goto out;
 
 	count_event(c, bucket_discard);
-	(*discarded)++;
+	s->discarded++;
 out:
-	(*seen)++;
+	s->seen++;
 	bch2_trans_iter_exit(trans, &iter);
 	percpu_ref_put(&ca->io_ref);
 	printbuf_exit(&buf);
@@ -1723,7 +1750,7 @@ static int bch2_discard_one_bucket(struct btree_trans *trans,
 static void bch2_do_discards_work(struct work_struct *work)
 {
 	struct bch_fs *c = container_of(work, struct bch_fs, discard_work);
-	u64 seen = 0, open = 0, need_journal_commit = 0, discarded = 0;
+	struct discard_buckets_state s = {};
 	struct bpos discard_pos_done = POS_MAX;
 	int ret;
 
@@ -1735,19 +1762,14 @@ static void bch2_do_discards_work(struct work_struct *work)
 	ret = bch2_trans_run(c,
 		for_each_btree_key(trans, iter,
 				   BTREE_ID_need_discard, POS_MIN, 0, k,
-			bch2_discard_one_bucket(trans, &iter, &discard_pos_done,
-						&seen,
-						&open,
-						&need_journal_commit,
-						&discarded)));
-
-	if (need_journal_commit * 2 > seen)
-		bch2_journal_flush_async(&c->journal, NULL);
+			bch2_discard_one_bucket(trans, &iter, &discard_pos_done, &s)));
 
-	bch2_write_ref_put(c, BCH_WRITE_REF_discard);
+	discard_buckets_next_dev(c, &s, NULL);
 
-	trace_discard_buckets(c, seen, open, need_journal_commit, discarded,
+	trace_discard_buckets(c, s.seen, s.open, s.need_journal_commit, s.discarded,
 			      bch2_err_str(ret));
+
+	bch2_write_ref_put(c, BCH_WRITE_REF_discard);
 }
 
 void bch2_do_discards(struct bch_fs *c)