io-wq: make worker creation resilient against signals

If a task is queueing async work and also handling signals, then we can
run into the case where create_io_thread() is interrupted and returns
failure because of that. If this happens for creating the first worker
in a group, then that worker will never get created and we can hang the
ring.

If we do get a fork failure, retry from task_work. With signals we have
to be a bit careful as we cannot simply queue as task_work, as we'll
still have signals pending at that point. Punt over a normal workqueue
first and then create from task_work after that.

Lastly, ensure that we handle fatal worker creations. Worker creation
failures are normally not fatal, only if we fail to create one in an empty
worker group can we not make progress. Right now that is ignored, ensure
that we handle that and run cancel on the work item.

There are two paths that create new workers - one is the "existing worker
going to sleep", and the other is "no workers found for this work, create
one". The former is never fatal, as workers do exist in the group. Only
the latter needs to be carefully handled.

Signed-off-by: Jens Axboe <axboe@kernel.dk>
diff --git a/fs/io-wq.c b/fs/io-wq.c
index 50ea077..d80e4a7 100644
--- a/fs/io-wq.c
+++ b/fs/io-wq.c
@@ -54,7 +54,10 @@ struct io_worker {
 	struct callback_head create_work;
 	int create_index;
 
-	struct rcu_head rcu;
+	union {
+		struct rcu_head rcu;
+		struct work_struct work;
+	};
 };
 
 #if BITS_PER_LONG == 64
@@ -131,8 +134,11 @@ struct io_cb_cancel_data {
 	bool cancel_all;
 };
 
-static void create_io_worker(struct io_wq *wq, struct io_wqe *wqe, int index);
+static bool create_io_worker(struct io_wq *wq, struct io_wqe *wqe, int index);
 static void io_wqe_dec_running(struct io_worker *worker);
+static bool io_acct_cancel_pending_work(struct io_wqe *wqe,
+					struct io_wqe_acct *acct,
+					struct io_cb_cancel_data *match);
 
 static bool io_worker_get(struct io_worker *worker)
 {
@@ -238,7 +244,7 @@ static bool io_wqe_activate_free_worker(struct io_wqe *wqe,
  * We need a worker. If we find a free one, we're good. If not, and we're
  * below the max number of workers, create one.
  */
-static void io_wqe_create_worker(struct io_wqe *wqe, struct io_wqe_acct *acct)
+static bool io_wqe_create_worker(struct io_wqe *wqe, struct io_wqe_acct *acct)
 {
 	bool do_create = false;
 
@@ -258,8 +264,10 @@ static void io_wqe_create_worker(struct io_wqe *wqe, struct io_wqe_acct *acct)
 	if (do_create) {
 		atomic_inc(&acct->nr_running);
 		atomic_inc(&wqe->wq->worker_refs);
-		create_io_worker(wqe->wq, wqe, acct->index);
+		return create_io_worker(wqe->wq, wqe, acct->index);
 	}
+
+	return true;
 }
 
 static void io_wqe_inc_running(struct io_worker *worker)
@@ -297,9 +305,11 @@ static void create_worker_cb(struct callback_head *cb)
 	io_worker_release(worker);
 }
 
-static void io_queue_worker_create(struct io_wqe *wqe, struct io_worker *worker,
-				   struct io_wqe_acct *acct)
+static bool io_queue_worker_create(struct io_worker *worker,
+				   struct io_wqe_acct *acct,
+				   task_work_func_t func)
 {
+	struct io_wqe *wqe = worker->wqe;
 	struct io_wq *wq = wqe->wq;
 
 	/* raced with exit, just ignore create call */
@@ -317,16 +327,17 @@ static void io_queue_worker_create(struct io_wqe *wqe, struct io_worker *worker,
 	    test_and_set_bit_lock(0, &worker->create_state))
 		goto fail_release;
 
-	init_task_work(&worker->create_work, create_worker_cb);
+	init_task_work(&worker->create_work, func);
 	worker->create_index = acct->index;
 	if (!task_work_add(wq->task, &worker->create_work, TWA_SIGNAL))
-		return;
+		return true;
 	clear_bit_unlock(0, &worker->create_state);
 fail_release:
 	io_worker_release(worker);
 fail:
 	atomic_dec(&acct->nr_running);
 	io_worker_ref_put(wq);
+	return false;
 }
 
 static void io_wqe_dec_running(struct io_worker *worker)
@@ -341,7 +352,7 @@ static void io_wqe_dec_running(struct io_worker *worker)
 	if (atomic_dec_and_test(&acct->nr_running) && io_acct_run_queue(acct)) {
 		atomic_inc(&acct->nr_running);
 		atomic_inc(&wqe->wq->worker_refs);
-		io_queue_worker_create(wqe, worker, acct);
+		io_queue_worker_create(worker, acct, create_worker_cb);
 	}
 }
 
@@ -633,36 +644,9 @@ void io_wq_worker_sleeping(struct task_struct *tsk)
 	raw_spin_unlock(&worker->wqe->lock);
 }
 
-static void create_io_worker(struct io_wq *wq, struct io_wqe *wqe, int index)
+static void io_init_new_worker(struct io_wqe *wqe, struct io_worker *worker,
+			       struct task_struct *tsk)
 {
-	struct io_wqe_acct *acct = &wqe->acct[index];
-	struct io_worker *worker;
-	struct task_struct *tsk;
-
-	__set_current_state(TASK_RUNNING);
-
-	worker = kzalloc_node(sizeof(*worker), GFP_KERNEL, wqe->node);
-	if (!worker)
-		goto fail;
-
-	refcount_set(&worker->ref, 1);
-	worker->nulls_node.pprev = NULL;
-	worker->wqe = wqe;
-	spin_lock_init(&worker->lock);
-	init_completion(&worker->ref_done);
-
-	tsk = create_io_thread(io_wqe_worker, worker, wqe->node);
-	if (IS_ERR(tsk)) {
-		kfree(worker);
-fail:
-		atomic_dec(&acct->nr_running);
-		raw_spin_lock(&wqe->lock);
-		acct->nr_workers--;
-		raw_spin_unlock(&wqe->lock);
-		io_worker_ref_put(wq);
-		return;
-	}
-
 	tsk->pf_io_worker = worker;
 	worker->task = tsk;
 	set_cpus_allowed_ptr(tsk, wqe->cpu_mask);
@@ -672,12 +656,118 @@ static void create_io_worker(struct io_wq *wq, struct io_wqe *wqe, int index)
 	hlist_nulls_add_head_rcu(&worker->nulls_node, &wqe->free_list);
 	list_add_tail_rcu(&worker->all_list, &wqe->all_list);
 	worker->flags |= IO_WORKER_F_FREE;
-	if (index == IO_WQ_ACCT_BOUND)
-		worker->flags |= IO_WORKER_F_BOUND;
 	raw_spin_unlock(&wqe->lock);
 	wake_up_new_task(tsk);
 }
 
+static bool io_wq_work_match_all(struct io_wq_work *work, void *data)
+{
+	return true;
+}
+
+static inline bool io_should_retry_thread(long err)
+{
+	switch (err) {
+	case -EAGAIN:
+	case -ERESTARTSYS:
+	case -ERESTARTNOINTR:
+	case -ERESTARTNOHAND:
+		return true;
+	default:
+		return false;
+	}
+}
+
+static void create_worker_cont(struct callback_head *cb)
+{
+	struct io_worker *worker;
+	struct task_struct *tsk;
+	struct io_wqe *wqe;
+
+	worker = container_of(cb, struct io_worker, create_work);
+	clear_bit_unlock(0, &worker->create_state);
+	wqe = worker->wqe;
+	tsk = create_io_thread(io_wqe_worker, worker, wqe->node);
+	if (!IS_ERR(tsk)) {
+		io_init_new_worker(wqe, worker, tsk);
+		io_worker_release(worker);
+		return;
+	} else if (!io_should_retry_thread(PTR_ERR(tsk))) {
+		struct io_wqe_acct *acct = io_wqe_get_acct(worker);
+
+		atomic_dec(&acct->nr_running);
+		raw_spin_lock(&wqe->lock);
+		acct->nr_workers--;
+		if (!acct->nr_workers) {
+			struct io_cb_cancel_data match = {
+				.fn		= io_wq_work_match_all,
+				.cancel_all	= true,
+			};
+
+			while (io_acct_cancel_pending_work(wqe, acct, &match))
+				raw_spin_lock(&wqe->lock);
+		}
+		raw_spin_unlock(&wqe->lock);
+		io_worker_ref_put(wqe->wq);
+		return;
+	}
+
+	/* re-create attempts grab a new worker ref, drop the existing one */
+	io_worker_release(worker);
+	schedule_work(&worker->work);
+}
+
+static void io_workqueue_create(struct work_struct *work)
+{
+	struct io_worker *worker = container_of(work, struct io_worker, work);
+	struct io_wqe_acct *acct = io_wqe_get_acct(worker);
+
+	if (!io_queue_worker_create(worker, acct, create_worker_cont)) {
+		clear_bit_unlock(0, &worker->create_state);
+		io_worker_release(worker);
+	}
+}
+
+static bool create_io_worker(struct io_wq *wq, struct io_wqe *wqe, int index)
+{
+	struct io_wqe_acct *acct = &wqe->acct[index];
+	struct io_worker *worker;
+	struct task_struct *tsk;
+
+	__set_current_state(TASK_RUNNING);
+
+	worker = kzalloc_node(sizeof(*worker), GFP_KERNEL, wqe->node);
+	if (!worker) {
+fail:
+		atomic_dec(&acct->nr_running);
+		raw_spin_lock(&wqe->lock);
+		acct->nr_workers--;
+		raw_spin_unlock(&wqe->lock);
+		io_worker_ref_put(wq);
+		return false;
+	}
+
+	refcount_set(&worker->ref, 1);
+	worker->wqe = wqe;
+	spin_lock_init(&worker->lock);
+	init_completion(&worker->ref_done);
+
+	if (index == IO_WQ_ACCT_BOUND)
+		worker->flags |= IO_WORKER_F_BOUND;
+
+	tsk = create_io_thread(io_wqe_worker, worker, wqe->node);
+	if (!IS_ERR(tsk)) {
+		io_init_new_worker(wqe, worker, tsk);
+	} else if (!io_should_retry_thread(PTR_ERR(tsk))) {
+		goto fail;
+	} else {
+		INIT_WORK(&worker->work, io_workqueue_create);
+		schedule_work(&worker->work);
+	}
+
+	return true;
+}
+
 /*
  * Iterate the passed in list and call the specific function for each
  * worker that isn't exiting
@@ -710,11 +800,6 @@ static bool io_wq_worker_wake(struct io_worker *worker, void *data)
 	return false;
 }
 
-static bool io_wq_work_match_all(struct io_wq_work *work, void *data)
-{
-	return true;
-}
-
 static void io_run_cancel(struct io_wq_work *work, struct io_wqe *wqe)
 {
 	struct io_wq *wq = wqe->wq;
@@ -759,6 +844,7 @@ static void io_wqe_enqueue(struct io_wqe *wqe, struct io_wq_work *work)
 	 */
 	if (test_bit(IO_WQ_BIT_EXIT, &wqe->wq->state) ||
 	    (work->flags & IO_WQ_WORK_CANCEL)) {
+run_cancel:
 		io_run_cancel(work, wqe);
 		return;
 	}
@@ -774,8 +860,20 @@ static void io_wqe_enqueue(struct io_wqe *wqe, struct io_wq_work *work)
 	raw_spin_unlock(&wqe->lock);
 
 	if (do_create && ((work_flags & IO_WQ_WORK_CONCURRENT) ||
-	    !atomic_read(&acct->nr_running)))
-		io_wqe_create_worker(wqe, acct);
+	    !atomic_read(&acct->nr_running))) {
+		bool did_create;
+
+		did_create = io_wqe_create_worker(wqe, acct);
+		if (unlikely(!did_create)) {
+			raw_spin_lock(&wqe->lock);
+			/* fatal condition, failed to create the first worker */
+			if (!acct->nr_workers) {
+				raw_spin_unlock(&wqe->lock);
+				goto run_cancel;
+			}
+			raw_spin_unlock(&wqe->lock);
+		}
+	}
 }
 
 void io_wq_enqueue(struct io_wq *wq, struct io_wq_work *work)
@@ -835,31 +933,42 @@ static inline void io_wqe_remove_pending(struct io_wqe *wqe,
 	wq_list_del(&acct->work_list, &work->list, prev);
 }
 
-static void io_wqe_cancel_pending_work(struct io_wqe *wqe,
-				       struct io_cb_cancel_data *match)
+static bool io_acct_cancel_pending_work(struct io_wqe *wqe,
+					struct io_wqe_acct *acct,
+					struct io_cb_cancel_data *match)
+	__releases(wqe->lock)
 {
 	struct io_wq_work_node *node, *prev;
 	struct io_wq_work *work;
-	int i;
 
+	wq_list_for_each(node, prev, &acct->work_list) {
+		work = container_of(node, struct io_wq_work, list);
+		if (!match->fn(work, match->data))
+			continue;
+		io_wqe_remove_pending(wqe, work, prev);
+		raw_spin_unlock(&wqe->lock);
+		io_run_cancel(work, wqe);
+		match->nr_pending++;
+		/* not safe to continue after unlock */
+		return true;
+	}
+
+	return false;
+}
+
+static void io_wqe_cancel_pending_work(struct io_wqe *wqe,
+				       struct io_cb_cancel_data *match)
+{
+	int i;
 retry:
 	raw_spin_lock(&wqe->lock);
 	for (i = 0; i < IO_WQ_ACCT_NR; i++) {
 		struct io_wqe_acct *acct = io_get_acct(wqe, i == 0);
 
-		wq_list_for_each(node, prev, &acct->work_list) {
-			work = container_of(node, struct io_wq_work, list);
-			if (!match->fn(work, match->data))
-				continue;
-			io_wqe_remove_pending(wqe, work, prev);
-			raw_spin_unlock(&wqe->lock);
-			io_run_cancel(work, wqe);
-			match->nr_pending++;
-			if (!match->cancel_all)
-				return;
-
-			/* not safe to continue after unlock */
-			goto retry;
+		if (io_acct_cancel_pending_work(wqe, acct, match)) {
+			if (match->cancel_all)
+				goto retry;
+			return;
 		}
 	}
 	raw_spin_unlock(&wqe->lock);
@@ -1013,7 +1122,7 @@ static bool io_task_work_match(struct callback_head *cb, void *data)
 {
 	struct io_worker *worker;
 
-	if (cb->func != create_worker_cb)
+	if (cb->func != create_worker_cb || cb->func != create_worker_cont)
 		return false;
 	worker = container_of(cb, struct io_worker, create_work);
 	return worker->wqe->wq == data;