io_uring: base SQPOLL handling off io_sq_data

Remove the SQPOLL thread from the ctx, and use the io_sq_data as the
data structure we pass in. io_sq_data has a list of ctx's that we can
then iterate over and handle.

As of now we're ready to handle multiple ctx's, though we're still just
handling a single one after this patch.

Signed-off-by: Jens Axboe <axboe@kernel.dk>
diff --git a/fs/io_uring.c b/fs/io_uring.c
index 9a7c645..0a9eced 100644
--- a/fs/io_uring.c
+++ b/fs/io_uring.c
@@ -232,6 +232,13 @@ struct io_restriction {
 
 struct io_sq_data {
 	refcount_t		refs;
+	struct mutex		lock;
+
+	/* ctx's that are using this sqd */
+	struct list_head	ctx_list;
+	struct list_head	ctx_new_list;
+	struct mutex		ctx_lock;
+
 	struct task_struct	*thread;
 	struct wait_queue_head	wait;
 };
@@ -295,6 +302,7 @@ struct io_ring_ctx {
 	struct io_sq_data	*sq_data;	/* if using sq thread polling */
 
 	struct wait_queue_entry	sqo_wait_entry;
+	struct list_head	sqd_list;
 
 	/*
 	 * If used, fixed file set. Writers must ensure that ->refs is dead,
@@ -1064,6 +1072,7 @@ static struct io_ring_ctx *io_ring_ctx_alloc(struct io_uring_params *p)
 		goto err;
 
 	ctx->flags = p->flags;
+	INIT_LIST_HEAD(&ctx->sqd_list);
 	init_waitqueue_head(&ctx->cq_wait);
 	INIT_LIST_HEAD(&ctx->cq_overflow_list);
 	init_completion(&ctx->ref_comp);
@@ -6715,8 +6724,6 @@ static enum sq_ret __io_sq_thread(struct io_ring_ctx *ctx,
 			goto again;
 		}
 
-		io_ring_set_wakeup_flag(ctx);
-
 		to_submit = io_sqring_entries(ctx);
 		if (!to_submit || ret == -EBUSY)
 			return SQT_IDLE;
@@ -6732,42 +6739,72 @@ static enum sq_ret __io_sq_thread(struct io_ring_ctx *ctx,
 	return SQT_DID_WORK;
 }
 
+static void io_sqd_init_new(struct io_sq_data *sqd)
+{
+	struct io_ring_ctx *ctx;
+
+	while (!list_empty(&sqd->ctx_new_list)) {
+		ctx = list_first_entry(&sqd->ctx_new_list, struct io_ring_ctx, sqd_list);
+		init_wait(&ctx->sqo_wait_entry);
+		ctx->sqo_wait_entry.func = io_sq_wake_function;
+		list_move_tail(&ctx->sqd_list, &sqd->ctx_list);
+		complete(&ctx->sq_thread_comp);
+	}
+}
+
 static int io_sq_thread(void *data)
 {
-	struct io_ring_ctx *ctx = data;
-	const struct cred *old_cred;
+	const struct cred *old_cred = NULL;
+	struct io_sq_data *sqd = data;
+	struct io_ring_ctx *ctx;
 	unsigned long start_jiffies;
 
-	init_wait(&ctx->sqo_wait_entry);
-	ctx->sqo_wait_entry.func = io_sq_wake_function;
-
-	complete(&ctx->sq_thread_comp);
-
-	old_cred = override_creds(ctx->creds);
-
 	start_jiffies = jiffies;
-	while (!kthread_should_park()) {
-		enum sq_ret ret;
+	while (!kthread_should_stop()) {
+		enum sq_ret ret = 0;
 
-		ret = __io_sq_thread(ctx, start_jiffies);
-		switch (ret) {
-		case SQT_IDLE:
-			schedule();
-			start_jiffies = jiffies;
-			continue;
-		case SQT_SPIN:
+		/*
+		 * Any changes to the sqd lists are synchronized through the
+		 * kthread parking. This synchronizes the thread vs users,
+		 * the users are synchronized on the sqd->ctx_lock.
+		 */
+		if (kthread_should_park())
+			kthread_parkme();
+
+		if (unlikely(!list_empty(&sqd->ctx_new_list)))
+			io_sqd_init_new(sqd);
+
+		list_for_each_entry(ctx, &sqd->ctx_list, sqd_list) {
+			if (current->cred != ctx->creds) {
+				if (old_cred)
+					revert_creds(old_cred);
+				old_cred = override_creds(ctx->creds);
+			}
+
+			ret |= __io_sq_thread(ctx, start_jiffies);
+
+			io_sq_thread_drop_mm();
+		}
+
+		if (ret & SQT_SPIN) {
 			io_run_task_work();
 			cond_resched();
-			fallthrough;
-		case SQT_DID_WORK:
-			continue;
+		} else if (ret == SQT_IDLE) {
+			if (kthread_should_park())
+				continue;
+			list_for_each_entry(ctx, &sqd->ctx_list, sqd_list)
+				io_ring_set_wakeup_flag(ctx);
+			schedule();
+			start_jiffies = jiffies;
+			list_for_each_entry(ctx, &sqd->ctx_list, sqd_list)
+				io_ring_clear_wakeup_flag(ctx);
 		}
 	}
 
 	io_run_task_work();
 
-	io_sq_thread_drop_mm();
-	revert_creds(old_cred);
+	if (old_cred)
+		revert_creds(old_cred);
 
 	kthread_parkme();
 
@@ -6968,10 +7005,32 @@ static struct io_sq_data *io_get_sq_data(struct io_uring_params *p)
 		return ERR_PTR(-ENOMEM);
 
 	refcount_set(&sqd->refs, 1);
+	INIT_LIST_HEAD(&sqd->ctx_list);
+	INIT_LIST_HEAD(&sqd->ctx_new_list);
+	mutex_init(&sqd->ctx_lock);
+	mutex_init(&sqd->lock);
 	init_waitqueue_head(&sqd->wait);
 	return sqd;
 }
 
+static void io_sq_thread_unpark(struct io_sq_data *sqd)
+	__releases(&sqd->lock)
+{
+	if (!sqd->thread)
+		return;
+	kthread_unpark(sqd->thread);
+	mutex_unlock(&sqd->lock);
+}
+
+static void io_sq_thread_park(struct io_sq_data *sqd)
+	__acquires(&sqd->lock)
+{
+	if (!sqd->thread)
+		return;
+	mutex_lock(&sqd->lock);
+	kthread_park(sqd->thread);
+}
+
 static void io_sq_thread_stop(struct io_ring_ctx *ctx)
 {
 	struct io_sq_data *sqd = ctx->sq_data;
@@ -6986,6 +7045,17 @@ static void io_sq_thread_stop(struct io_ring_ctx *ctx)
 			 */
 			wake_up_process(sqd->thread);
 			wait_for_completion(&ctx->sq_thread_comp);
+
+			io_sq_thread_park(sqd);
+		}
+
+		mutex_lock(&sqd->ctx_lock);
+		list_del(&ctx->sqd_list);
+		mutex_unlock(&sqd->ctx_lock);
+
+		if (sqd->thread) {
+			finish_wait(&sqd->wait, &ctx->sqo_wait_entry);
+			io_sq_thread_unpark(sqd);
 		}
 
 		io_put_sq_data(sqd);
@@ -7669,7 +7739,13 @@ static int io_sq_offload_create(struct io_ring_ctx *ctx,
 			ret = PTR_ERR(sqd);
 			goto err;
 		}
+
 		ctx->sq_data = sqd;
+		io_sq_thread_park(sqd);
+		mutex_lock(&sqd->ctx_lock);
+		list_add(&ctx->sqd_list, &sqd->ctx_new_list);
+		mutex_unlock(&sqd->ctx_lock);
+		io_sq_thread_unpark(sqd);
 
 		ctx->sq_thread_idle = msecs_to_jiffies(p->sq_thread_idle);
 		if (!ctx->sq_thread_idle)
@@ -7684,10 +7760,10 @@ static int io_sq_offload_create(struct io_ring_ctx *ctx,
 			if (!cpu_online(cpu))
 				goto err;
 
-			sqd->thread = kthread_create_on_cpu(io_sq_thread, ctx,
+			sqd->thread = kthread_create_on_cpu(io_sq_thread, sqd,
 							cpu, "io_uring-sq");
 		} else {
-			sqd->thread = kthread_create(io_sq_thread, ctx,
+			sqd->thread = kthread_create(io_sq_thread, sqd,
 							"io_uring-sq");
 		}
 		if (IS_ERR(sqd->thread)) {