bcachefs: Fix a use after free in the journal code

Signed-off-by: Kent Overstreet <kent.overstreet@linux.dev>
diff --git a/fs/bcachefs/btree_update_interior.c b/fs/bcachefs/btree_update_interior.c
index 1710efd..cc1f8b9 100644
--- a/fs/bcachefs/btree_update_interior.c
+++ b/fs/bcachefs/btree_update_interior.c
@@ -579,6 +579,8 @@ static void bch2_btree_update_free(struct btree_update *as)
 {
 	struct bch_fs *c = as->c;
 
+	bch2_journal_pin_flush(&c->journal, &as->journal);
+
 	BUG_ON(as->nr_new_nodes);
 	BUG_ON(as->nr_pending);
 
@@ -2151,7 +2153,7 @@ ssize_t bch2_btree_updates_print(struct bch_fs *c, char *buf)
 				 as->mode,
 				 as->nodes_written,
 				 atomic_read(&as->cl.remaining) & CLOSURE_REMAINING_MASK,
-				 bch2_journal_pin_seq(&c->journal, &as->journal));
+				 as->journal.seq);
 	mutex_unlock(&c->btree_interior_update_lock);
 
 	return out - buf;
diff --git a/fs/bcachefs/btree_update_leaf.c b/fs/bcachefs/btree_update_leaf.c
index 6c48518..5cd20b5 100644
--- a/fs/bcachefs/btree_update_leaf.c
+++ b/fs/bcachefs/btree_update_leaf.c
@@ -111,8 +111,7 @@ static void __btree_node_flush(struct journal *j, struct journal_entry_pin *pin,
 
 	btree_node_lock_type(c, b, SIX_LOCK_read);
 	bch2_btree_node_write_cond(c, b,
-			(btree_current_write(b) == w &&
-			 w->journal.pin_list == journal_seq_pin(j, seq)));
+		(btree_current_write(b) == w && w->journal.seq == seq));
 	six_unlock_read(&b->lock);
 }
 
diff --git a/fs/bcachefs/fifo.h b/fs/bcachefs/fifo.h
index bd1534e..00d245e 100644
--- a/fs/bcachefs/fifo.h
+++ b/fs/bcachefs/fifo.h
@@ -109,17 +109,17 @@ do {									\
 #define fifo_peek(fifo)		fifo_peek_front(fifo)
 
 #define fifo_for_each_entry(_entry, _fifo, _iter)			\
-	for (((void) (&(_iter) == &(_fifo)->front)),			\
-	     _iter = (_fifo)->front;					\
+	for (typecheck(typeof((_fifo)->front), _iter),			\
+	     (_iter) = (_fifo)->front;					\
 	     ((_iter != (_fifo)->back) &&				\
 	      (_entry = (_fifo)->data[(_iter) & (_fifo)->mask], true));	\
-	     _iter++)
+	     (_iter)++)
 
 #define fifo_for_each_entry_ptr(_ptr, _fifo, _iter)			\
-	for (((void) (&(_iter) == &(_fifo)->front)),			\
-	     _iter = (_fifo)->front;					\
+	for (typecheck(typeof((_fifo)->front), _iter),			\
+	     (_iter) = (_fifo)->front;					\
 	     ((_iter != (_fifo)->back) &&				\
 	      (_ptr = &(_fifo)->data[(_iter) & (_fifo)->mask], true));	\
-	     _iter++)
+	     (_iter)++)
 
 #endif /* _BCACHEFS_FIFO_H */
diff --git a/fs/bcachefs/journal.c b/fs/bcachefs/journal.c
index fe29260..3878ceb 100644
--- a/fs/bcachefs/journal.c
+++ b/fs/bcachefs/journal.c
@@ -138,8 +138,26 @@ static enum {
 		c->opts.block_size;
 	BUG_ON(j->prev_buf_sectors > j->cur_buf_sectors);
 
+	/*
+	 * We have to set last_seq here, _before_ opening a new journal entry:
+	 *
+	 * A threads may replace an old pin with a new pin on their current
+	 * journal reservation - the expectation being that the journal will
+	 * contain either what the old pin protected or what the new pin
+	 * protects.
+	 *
+	 * After the old pin is dropped journal_last_seq() won't include the old
+	 * pin, so we can only write the updated last_seq on the entry that
+	 * contains whatever the new pin protects.
+	 *
+	 * Restated, we can _not_ update last_seq for a given entry if there
+	 * could be a newer entry open with reservations/pins that have been
+	 * taken against it.
+	 *
+	 * Hence, we want update/set last_seq on the current journal entry right
+	 * before we open a new one:
+	 */
 	bch2_journal_reclaim_fast(j);
-	/* XXX: why set this here, and not in bch2_journal_write()? */
 	buf->data->last_seq	= cpu_to_le64(journal_last_seq(j));
 
 	if (journal_entry_empty(buf->data))
@@ -1022,6 +1040,7 @@ int bch2_fs_journal_init(struct journal *j)
 	init_waitqueue_head(&j->wait);
 	INIT_DELAYED_WORK(&j->write_work, journal_write_work);
 	INIT_DELAYED_WORK(&j->reclaim_work, bch2_journal_reclaim_work);
+	init_waitqueue_head(&j->pin_flush_wait);
 	mutex_init(&j->blacklist_lock);
 	INIT_LIST_HEAD(&j->seq_blacklist);
 	mutex_init(&j->reclaim_lock);
diff --git a/fs/bcachefs/journal_reclaim.c b/fs/bcachefs/journal_reclaim.c
index e5b8666..e1d5d41 100644
--- a/fs/bcachefs/journal_reclaim.c
+++ b/fs/bcachefs/journal_reclaim.c
@@ -11,34 +11,18 @@
  * entry, holding it open to ensure it gets replayed during recovery:
  */
 
-static inline u64 journal_pin_seq(struct journal *j,
-				  struct journal_entry_pin_list *pin_list)
-{
-	return fifo_entry_idx_abs(&j->pin, pin_list);
-}
-
-u64 bch2_journal_pin_seq(struct journal *j, struct journal_entry_pin *pin)
-{
-	u64 ret = 0;
-
-	spin_lock(&j->lock);
-	if (journal_pin_active(pin))
-		ret = journal_pin_seq(j, pin->pin_list);
-	spin_unlock(&j->lock);
-
-	return ret;
-}
-
 static inline void __journal_pin_add(struct journal *j,
-				     struct journal_entry_pin_list *pin_list,
+				     u64 seq,
 				     struct journal_entry_pin *pin,
 				     journal_pin_flush_fn flush_fn)
 {
+	struct journal_entry_pin_list *pin_list = journal_seq_pin(j, seq);
+
 	BUG_ON(journal_pin_active(pin));
 	BUG_ON(!atomic_read(&pin_list->count));
 
 	atomic_inc(&pin_list->count);
-	pin->pin_list	= pin_list;
+	pin->seq	= seq;
 	pin->flush	= flush_fn;
 
 	if (flush_fn)
@@ -58,19 +42,20 @@ void bch2_journal_pin_add(struct journal *j, u64 seq,
 			  journal_pin_flush_fn flush_fn)
 {
 	spin_lock(&j->lock);
-	__journal_pin_add(j, journal_seq_pin(j, seq), pin, flush_fn);
+	__journal_pin_add(j, seq, pin, flush_fn);
 	spin_unlock(&j->lock);
 }
 
 static inline void __journal_pin_drop(struct journal *j,
 				      struct journal_entry_pin *pin)
 {
-	struct journal_entry_pin_list *pin_list = pin->pin_list;
+	struct journal_entry_pin_list *pin_list;
 
 	if (!journal_pin_active(pin))
 		return;
 
-	pin->pin_list = NULL;
+	pin_list = journal_seq_pin(j, pin->seq);
+	pin->seq = 0;
 	list_del_init(&pin->list);
 
 	/*
@@ -83,7 +68,7 @@ static inline void __journal_pin_drop(struct journal *j,
 }
 
 void bch2_journal_pin_drop(struct journal *j,
-			  struct journal_entry_pin *pin)
+			   struct journal_entry_pin *pin)
 {
 	spin_lock(&j->lock);
 	__journal_pin_drop(j, pin);
@@ -99,15 +84,21 @@ void bch2_journal_pin_add_if_older(struct journal *j,
 
 	if (journal_pin_active(src_pin) &&
 	    (!journal_pin_active(pin) ||
-	     journal_pin_seq(j, src_pin->pin_list) <
-	     journal_pin_seq(j, pin->pin_list))) {
+	     src_pin->seq < pin->seq)) {
 		__journal_pin_drop(j, pin);
-		__journal_pin_add(j, src_pin->pin_list, pin, flush_fn);
+		__journal_pin_add(j, src_pin->seq, pin, flush_fn);
 	}
 
 	spin_unlock(&j->lock);
 }
 
+void bch2_journal_pin_flush(struct journal *j, struct journal_entry_pin *pin)
+{
+	BUG_ON(journal_pin_active(pin));
+
+	wait_event(j->pin_flush_wait, j->flush_in_progress != pin);
+}
+
 /*
  * Journal reclaim: flush references to open journal entries to reclaim space in
  * the journal
@@ -145,41 +136,42 @@ void bch2_journal_reclaim_fast(struct journal *j)
 		journal_wake(j);
 }
 
-static struct journal_entry_pin *
-__journal_get_next_pin(struct journal *j, u64 seq_to_flush, u64 *seq)
+static void journal_pin_mark_flushing(struct journal *j,
+				      struct journal_entry_pin *pin,
+				      u64 seq)
 {
-	struct journal_entry_pin_list *pin_list;
-	struct journal_entry_pin *ret;
-	u64 iter;
+	lockdep_assert_held(&j->reclaim_lock);
 
-	/* no need to iterate over empty fifo entries: */
-	bch2_journal_reclaim_fast(j);
+	list_move(&pin->list, &journal_seq_pin(j, seq)->flushed);
+	BUG_ON(j->flush_in_progress);
+	j->flush_in_progress = pin;
+}
 
-	fifo_for_each_entry_ptr(pin_list, &j->pin, iter) {
-		if (iter > seq_to_flush)
-			break;
+static void journal_pin_flush(struct journal *j,
+			      struct journal_entry_pin *pin,
+			      u64 seq)
+{
+	pin->flush(j, pin, seq);
 
-		ret = list_first_entry_or_null(&pin_list->list,
-				struct journal_entry_pin, list);
-		if (ret) {
-			/* must be list_del_init(), see bch2_journal_pin_drop() */
-			list_move(&ret->list, &pin_list->flushed);
-			*seq = iter;
-			return ret;
-		}
-	}
-
-	return NULL;
+	BUG_ON(j->flush_in_progress != pin);
+	j->flush_in_progress = NULL;
+	wake_up(&j->pin_flush_wait);
 }
 
 static struct journal_entry_pin *
 journal_get_next_pin(struct journal *j, u64 seq_to_flush, u64 *seq)
 {
-	struct journal_entry_pin *ret;
+	struct journal_entry_pin_list *pin_list;
+	struct journal_entry_pin *ret = NULL;
 
-	spin_lock(&j->lock);
-	ret = __journal_get_next_pin(j, seq_to_flush, seq);
-	spin_unlock(&j->lock);
+	/* no need to iterate over empty fifo entries: */
+	bch2_journal_reclaim_fast(j);
+
+	fifo_for_each_entry_ptr(pin_list, &j->pin, *seq)
+		if (*seq > seq_to_flush ||
+		    (ret = list_first_entry_or_null(&pin_list->list,
+				struct journal_entry_pin, list)))
+			break;
 
 	return ret;
 }
@@ -279,15 +271,11 @@ void bch2_journal_reclaim_work(struct work_struct *work)
 		spin_unlock(&j->lock);
 	}
 
-	if (reclaim_lock_held)
-		mutex_unlock(&j->reclaim_lock);
-
 	/* Also flush if the pin fifo is more than half full */
 	spin_lock(&j->lock);
 	seq_to_flush = max_t(s64, seq_to_flush,
 			     (s64) journal_cur_seq(j) -
 			     (j->pin.size >> 1));
-	spin_unlock(&j->lock);
 
 	/*
 	 * If it's been longer than j->reclaim_delay_ms since we last flushed,
@@ -299,13 +287,31 @@ void bch2_journal_reclaim_work(struct work_struct *work)
 	while ((pin = journal_get_next_pin(j, need_flush
 					   ? U64_MAX
 					   : seq_to_flush, &seq))) {
-		__set_current_state(TASK_RUNNING);
-		pin->flush(j, pin, seq);
-		need_flush = false;
+		if (!reclaim_lock_held) {
+			spin_unlock(&j->lock);
+			__set_current_state(TASK_RUNNING);
+			mutex_lock(&j->reclaim_lock);
+			reclaim_lock_held = true;
+			spin_lock(&j->lock);
+			continue;
+		}
 
+		journal_pin_mark_flushing(j, pin, seq);
+		spin_unlock(&j->lock);
+
+		journal_pin_flush(j, pin, seq);
+
+		need_flush = false;
 		j->last_flushed = jiffies;
+
+		spin_lock(&j->lock);
 	}
 
+	spin_unlock(&j->lock);
+
+	if (reclaim_lock_held)
+		mutex_unlock(&j->reclaim_lock);
+
 	if (!test_bit(BCH_FS_RO, &c->flags))
 		queue_delayed_work(system_freezable_wq, &j->reclaim_work,
 				   msecs_to_jiffies(j->reclaim_delay_ms));
@@ -328,11 +334,14 @@ static int journal_flush_done(struct journal *j, u64 seq_to_flush,
 	 * If journal replay hasn't completed, the unreplayed journal entries
 	 * hold refs on their corresponding sequence numbers
 	 */
-	ret = (*pin = __journal_get_next_pin(j, seq_to_flush, pin_seq)) != NULL ||
+	ret = (*pin = journal_get_next_pin(j, seq_to_flush, pin_seq)) != NULL ||
 		!test_bit(JOURNAL_REPLAY_DONE, &j->flags) ||
 		journal_last_seq(j) > seq_to_flush ||
 		(fifo_used(&j->pin) == 1 &&
 		 atomic_read(&fifo_peek_front(&j->pin).count) == 1);
+	if (*pin)
+		journal_pin_mark_flushing(j, *pin, *pin_seq);
+
 	spin_unlock(&j->lock);
 
 	return ret;
@@ -346,14 +355,18 @@ void bch2_journal_flush_pins(struct journal *j, u64 seq_to_flush)
 	if (!test_bit(JOURNAL_STARTED, &j->flags))
 		return;
 
+	mutex_lock(&j->reclaim_lock);
+
 	while (1) {
 		wait_event(j->wait, journal_flush_done(j, seq_to_flush,
 						       &pin, &pin_seq));
 		if (!pin)
 			break;
 
-		pin->flush(j, pin, pin_seq);
+		journal_pin_flush(j, pin, pin_seq);
 	}
+
+	mutex_unlock(&j->reclaim_lock);
 }
 
 int bch2_journal_flush_device_pins(struct journal *j, int dev_idx)
diff --git a/fs/bcachefs/journal_reclaim.h b/fs/bcachefs/journal_reclaim.h
index a93ed43..f5af425 100644
--- a/fs/bcachefs/journal_reclaim.h
+++ b/fs/bcachefs/journal_reclaim.h
@@ -6,19 +6,17 @@
 
 static inline bool journal_pin_active(struct journal_entry_pin *pin)
 {
-	return pin->pin_list != NULL;
+	return pin->seq != 0;
 }
 
 static inline struct journal_entry_pin_list *
 journal_seq_pin(struct journal *j, u64 seq)
 {
-	BUG_ON(seq < j->pin.front || seq >= j->pin.back);
+	EBUG_ON(seq < j->pin.front || seq >= j->pin.back);
 
 	return &j->pin.data[seq & j->pin.mask];
 }
 
-u64 bch2_journal_pin_seq(struct journal *, struct journal_entry_pin *);
-
 void bch2_journal_pin_add(struct journal *, u64, struct journal_entry_pin *,
 			  journal_pin_flush_fn);
 void bch2_journal_pin_drop(struct journal *, struct journal_entry_pin *);
@@ -26,6 +24,7 @@ void bch2_journal_pin_add_if_older(struct journal *,
 				  struct journal_entry_pin *,
 				  struct journal_entry_pin *,
 				  journal_pin_flush_fn);
+void bch2_journal_pin_flush(struct journal *, struct journal_entry_pin *);
 
 void bch2_journal_reclaim_fast(struct journal *);
 void bch2_journal_reclaim_work(struct work_struct *);
diff --git a/fs/bcachefs/journal_types.h b/fs/bcachefs/journal_types.h
index cf29122..dae8b8a 100644
--- a/fs/bcachefs/journal_types.h
+++ b/fs/bcachefs/journal_types.h
@@ -48,7 +48,7 @@ typedef void (*journal_pin_flush_fn)(struct journal *j,
 struct journal_entry_pin {
 	struct list_head		list;
 	journal_pin_flush_fn		flush;
-	struct journal_entry_pin_list	*pin_list;
+	u64				seq;
 };
 
 /* corresponds to a btree node with a blacklisted bset: */
@@ -174,6 +174,10 @@ struct journal {
 		u64 front, back, size, mask;
 		struct journal_entry_pin_list *data;
 	}			pin;
+
+	struct journal_entry_pin *flush_in_progress;
+	wait_queue_head_t	pin_flush_wait;
+
 	u64			replay_journal_seq;
 
 	struct mutex		blacklist_lock;