io_uring: split SQPOLL data into separate structure

Move all the necessary state out of io_ring_ctx, and into a new
structure, io_sq_data. The latter now deals with any state or
variables associated with the SQPOLL thread itself.

In preparation for supporting more than one io_ring_ctx per SQPOLL
thread.

Signed-off-by: Jens Axboe <axboe@kernel.dk>
diff --git a/fs/io_uring.c b/fs/io_uring.c
index 4958e78..9a7c645 100644
--- a/fs/io_uring.c
+++ b/fs/io_uring.c
@@ -230,6 +230,12 @@ struct io_restriction {
 	bool registered;
 };
 
+struct io_sq_data {
+	refcount_t		refs;
+	struct task_struct	*thread;
+	struct wait_queue_head	wait;
+};
+
 struct io_ring_ctx {
 	struct {
 		struct percpu_ref	refs;
@@ -276,7 +282,6 @@ struct io_ring_ctx {
 
 	/* IO offload */
 	struct io_wq		*io_wq;
-	struct task_struct	*sqo_thread;	/* if using sq thread polling */
 
 	/*
 	 * For SQPOLL usage - we hold a reference to the parent task, so we
@@ -287,8 +292,8 @@ struct io_ring_ctx {
 	/* Only used for accounting purposes */
 	struct mm_struct	*mm_account;
 
-	struct wait_queue_head	*sqo_wait;
-	struct wait_queue_head	__sqo_wait;
+	struct io_sq_data	*sq_data;	/* if using sq thread polling */
+
 	struct wait_queue_entry	sqo_wait_entry;
 
 	/*
@@ -1059,8 +1064,6 @@ static struct io_ring_ctx *io_ring_ctx_alloc(struct io_uring_params *p)
 		goto err;
 
 	ctx->flags = p->flags;
-	init_waitqueue_head(&ctx->__sqo_wait);
-	ctx->sqo_wait = &ctx->__sqo_wait;
 	init_waitqueue_head(&ctx->cq_wait);
 	INIT_LIST_HEAD(&ctx->cq_overflow_list);
 	init_completion(&ctx->ref_comp);
@@ -1238,8 +1241,10 @@ static bool io_task_match(struct io_kiocb *req, struct task_struct *tsk)
 
 	if (!tsk || req->task == tsk)
 		return true;
-	if ((ctx->flags & IORING_SETUP_SQPOLL) && req->task == ctx->sqo_thread)
-		return true;
+	if (ctx->flags & IORING_SETUP_SQPOLL) {
+		if (ctx->sq_data && req->task == ctx->sq_data->thread)
+			return true;
+	}
 	return false;
 }
 
@@ -1343,8 +1348,8 @@ static void io_cqring_ev_posted(struct io_ring_ctx *ctx)
 {
 	if (waitqueue_active(&ctx->wait))
 		wake_up(&ctx->wait);
-	if (waitqueue_active(ctx->sqo_wait))
-		wake_up(ctx->sqo_wait);
+	if (ctx->sq_data && waitqueue_active(&ctx->sq_data->wait))
+		wake_up(&ctx->sq_data->wait);
 	if (io_should_trigger_evfd(ctx))
 		eventfd_signal(ctx->cq_ev_fd, 1);
 }
@@ -2451,8 +2456,9 @@ static void io_iopoll_req_issued(struct io_kiocb *req)
 	else
 		list_add_tail(&req->inflight_entry, &ctx->iopoll_list);
 
-	if ((ctx->flags & IORING_SETUP_SQPOLL) && wq_has_sleeper(ctx->sqo_wait))
-		wake_up(ctx->sqo_wait);
+	if ((ctx->flags & IORING_SETUP_SQPOLL) &&
+	    wq_has_sleeper(&ctx->sq_data->wait))
+		wake_up(&ctx->sq_data->wait);
 }
 
 static void __io_state_file_put(struct io_submit_state *state)
@@ -6652,6 +6658,7 @@ static enum sq_ret __io_sq_thread(struct io_ring_ctx *ctx,
 				  unsigned long start_jiffies)
 {
 	unsigned long timeout = start_jiffies + ctx->sq_thread_idle;
+	struct io_sq_data *sqd = ctx->sq_data;
 	unsigned int to_submit;
 	int ret = 0;
 
@@ -6692,7 +6699,7 @@ static enum sq_ret __io_sq_thread(struct io_ring_ctx *ctx,
 		    !percpu_ref_is_dying(&ctx->refs)))
 			return SQT_SPIN;
 
-		prepare_to_wait(ctx->sqo_wait, &ctx->sqo_wait_entry,
+		prepare_to_wait(&sqd->wait, &ctx->sqo_wait_entry,
 					TASK_INTERRUPTIBLE);
 
 		/*
@@ -6704,7 +6711,7 @@ static enum sq_ret __io_sq_thread(struct io_ring_ctx *ctx,
 		 */
 		if ((ctx->flags & IORING_SETUP_IOPOLL) &&
 		    !list_empty_careful(&ctx->iopoll_list)) {
-			finish_wait(ctx->sqo_wait, &ctx->sqo_wait_entry);
+			finish_wait(&sqd->wait, &ctx->sqo_wait_entry);
 			goto again;
 		}
 
@@ -6715,7 +6722,7 @@ static enum sq_ret __io_sq_thread(struct io_ring_ctx *ctx,
 			return SQT_IDLE;
 	}
 
-	finish_wait(ctx->sqo_wait, &ctx->sqo_wait_entry);
+	finish_wait(&sqd->wait, &ctx->sqo_wait_entry);
 	io_ring_clear_wakeup_flag(ctx);
 
 	mutex_lock(&ctx->uring_lock);
@@ -6935,26 +6942,54 @@ static int io_sqe_files_unregister(struct io_ring_ctx *ctx)
 	return 0;
 }
 
-static void io_sq_thread_stop(struct io_ring_ctx *ctx)
+static void io_put_sq_data(struct io_sq_data *sqd)
 {
-	if (ctx->sqo_thread) {
-		/*
-		 * We may arrive here from the error branch in
-		 * io_sq_offload_create() where the kthread is created
-		 * without being waked up, thus wake it up now to make
-		 * sure the wait will complete.
-		 */
-		wake_up_process(ctx->sqo_thread);
-
-		wait_for_completion(&ctx->sq_thread_comp);
+	if (refcount_dec_and_test(&sqd->refs)) {
 		/*
 		 * The park is a bit of a work-around, without it we get
 		 * warning spews on shutdown with SQPOLL set and affinity
 		 * set to a single CPU.
 		 */
-		kthread_park(ctx->sqo_thread);
-		kthread_stop(ctx->sqo_thread);
-		ctx->sqo_thread = NULL;
+		if (sqd->thread) {
+			kthread_park(sqd->thread);
+			kthread_stop(sqd->thread);
+		}
+
+		kfree(sqd);
+	}
+}
+
+static struct io_sq_data *io_get_sq_data(struct io_uring_params *p)
+{
+	struct io_sq_data *sqd;
+
+	sqd = kzalloc(sizeof(*sqd), GFP_KERNEL);
+	if (!sqd)
+		return ERR_PTR(-ENOMEM);
+
+	refcount_set(&sqd->refs, 1);
+	init_waitqueue_head(&sqd->wait);
+	return sqd;
+}
+
+static void io_sq_thread_stop(struct io_ring_ctx *ctx)
+{
+	struct io_sq_data *sqd = ctx->sq_data;
+
+	if (sqd) {
+		if (sqd->thread) {
+			/*
+			 * We may arrive here from the error branch in
+			 * io_sq_offload_create() where the kthread is created
+			 * without being waked up, thus wake it up now to make
+			 * sure the wait will complete.
+			 */
+			wake_up_process(sqd->thread);
+			wait_for_completion(&ctx->sq_thread_comp);
+		}
+
+		io_put_sq_data(sqd);
+		ctx->sq_data = NULL;
 	}
 }
 
@@ -7623,10 +7658,19 @@ static int io_sq_offload_create(struct io_ring_ctx *ctx,
 	int ret;
 
 	if (ctx->flags & IORING_SETUP_SQPOLL) {
+		struct io_sq_data *sqd;
+
 		ret = -EPERM;
 		if (!capable(CAP_SYS_ADMIN))
 			goto err;
 
+		sqd = io_get_sq_data(p);
+		if (IS_ERR(sqd)) {
+			ret = PTR_ERR(sqd);
+			goto err;
+		}
+		ctx->sq_data = sqd;
+
 		ctx->sq_thread_idle = msecs_to_jiffies(p->sq_thread_idle);
 		if (!ctx->sq_thread_idle)
 			ctx->sq_thread_idle = HZ;
@@ -7640,19 +7684,18 @@ static int io_sq_offload_create(struct io_ring_ctx *ctx,
 			if (!cpu_online(cpu))
 				goto err;
 
-			ctx->sqo_thread = kthread_create_on_cpu(io_sq_thread,
-							ctx, cpu,
-							"io_uring-sq");
+			sqd->thread = kthread_create_on_cpu(io_sq_thread, ctx,
+							cpu, "io_uring-sq");
 		} else {
-			ctx->sqo_thread = kthread_create(io_sq_thread, ctx,
+			sqd->thread = kthread_create(io_sq_thread, ctx,
 							"io_uring-sq");
 		}
-		if (IS_ERR(ctx->sqo_thread)) {
-			ret = PTR_ERR(ctx->sqo_thread);
-			ctx->sqo_thread = NULL;
+		if (IS_ERR(sqd->thread)) {
+			ret = PTR_ERR(sqd->thread);
+			sqd->thread = NULL;
 			goto err;
 		}
-		ret = io_uring_alloc_task_context(ctx->sqo_thread);
+		ret = io_uring_alloc_task_context(sqd->thread);
 		if (ret)
 			goto err;
 	} else if (p->flags & IORING_SETUP_SQ_AFF) {
@@ -7673,8 +7716,10 @@ static int io_sq_offload_create(struct io_ring_ctx *ctx,
 
 static void io_sq_offload_start(struct io_ring_ctx *ctx)
 {
-	if ((ctx->flags & IORING_SETUP_SQPOLL) && ctx->sqo_thread)
-		wake_up_process(ctx->sqo_thread);
+	struct io_sq_data *sqd = ctx->sq_data;
+
+	if ((ctx->flags & IORING_SETUP_SQPOLL) && sqd->thread)
+		wake_up_process(sqd->thread);
 }
 
 static inline void __io_unaccount_mem(struct user_struct *user,
@@ -8402,8 +8447,8 @@ static void io_uring_cancel_task_requests(struct io_ring_ctx *ctx,
 {
 	struct task_struct *task = current;
 
-	if (ctx->flags & IORING_SETUP_SQPOLL)
-		task = ctx->sqo_thread;
+	if ((ctx->flags & IORING_SETUP_SQPOLL) && ctx->sq_data)
+		task = ctx->sq_data->thread;
 
 	io_cqring_overflow_flush(ctx, true, task, files);
 
@@ -8688,7 +8733,7 @@ SYSCALL_DEFINE6(io_uring_enter, unsigned int, fd, u32, to_submit,
 		if (!list_empty_careful(&ctx->cq_overflow_list))
 			io_cqring_overflow_flush(ctx, false, NULL, NULL);
 		if (flags & IORING_ENTER_SQ_WAKEUP)
-			wake_up(ctx->sqo_wait);
+			wake_up(&ctx->sq_data->wait);
 		submitted = to_submit;
 	} else if (to_submit) {
 		ret = io_uring_add_task_file(f.file);