io_uring: fix multi ctx cancellation

io_uring_try_cancel_requests() loops until there is nothing left to do
with the ring, however there might be several rings and they might have
dependencies between them, e.g. via poll requests.

Instead of cancelling rings one by one, try to cancel them all and only
then loop over if we still potenially some work to do.

Signed-off-by: Pavel Begunkov <asml.silence@gmail.com>
Link: https://lore.kernel.org/r/8d491fe02d8ac4c77ff38061cf86b9a827e8845c.1655684496.git.asml.silence@gmail.com
Signed-off-by: Jens Axboe <axboe@kernel.dk>
diff --git a/io_uring/io_uring.c b/io_uring/io_uring.c
index 16a625e..707b599 100644
--- a/io_uring/io_uring.c
+++ b/io_uring/io_uring.c
@@ -132,7 +132,7 @@ struct io_defer_entry {
 #define IO_DISARM_MASK (REQ_F_ARM_LTIMEOUT | REQ_F_LINK_TIMEOUT | REQ_F_FAIL)
 #define IO_REQ_LINK_FLAGS (REQ_F_LINK | REQ_F_HARDLINK)
 
-static void io_uring_try_cancel_requests(struct io_ring_ctx *ctx,
+static bool io_uring_try_cancel_requests(struct io_ring_ctx *ctx,
 					 struct task_struct *task,
 					 bool cancel_all);
 
@@ -2648,7 +2648,9 @@ static __cold void io_ring_exit_work(struct work_struct *work)
 	 * as nobody else will be looking for them.
 	 */
 	do {
-		io_uring_try_cancel_requests(ctx, NULL, true);
+		while (io_uring_try_cancel_requests(ctx, NULL, true))
+			cond_resched();
+
 		if (ctx->sq_data) {
 			struct io_sq_data *sqd = ctx->sq_data;
 			struct task_struct *tsk;
@@ -2806,53 +2808,48 @@ static __cold bool io_uring_try_cancel_iowq(struct io_ring_ctx *ctx)
 	return ret;
 }
 
-static __cold void io_uring_try_cancel_requests(struct io_ring_ctx *ctx,
+static __cold bool io_uring_try_cancel_requests(struct io_ring_ctx *ctx,
 						struct task_struct *task,
 						bool cancel_all)
 {
 	struct io_task_cancel cancel = { .task = task, .all = cancel_all, };
 	struct io_uring_task *tctx = task ? task->io_uring : NULL;
+	enum io_wq_cancel cret;
+	bool ret = false;
 
 	/* failed during ring init, it couldn't have issued any requests */
 	if (!ctx->rings)
-		return;
+		return false;
 
-	while (1) {
-		enum io_wq_cancel cret;
-		bool ret = false;
-
-		if (!task) {
-			ret |= io_uring_try_cancel_iowq(ctx);
-		} else if (tctx && tctx->io_wq) {
-			/*
-			 * Cancels requests of all rings, not only @ctx, but
-			 * it's fine as the task is in exit/exec.
-			 */
-			cret = io_wq_cancel_cb(tctx->io_wq, io_cancel_task_cb,
-					       &cancel, true);
-			ret |= (cret != IO_WQ_CANCEL_NOTFOUND);
-		}
-
-		/* SQPOLL thread does its own polling */
-		if ((!(ctx->flags & IORING_SETUP_SQPOLL) && cancel_all) ||
-		    (ctx->sq_data && ctx->sq_data->thread == current)) {
-			while (!wq_list_empty(&ctx->iopoll_list)) {
-				io_iopoll_try_reap_events(ctx);
-				ret = true;
-			}
-		}
-
-		ret |= io_cancel_defer_files(ctx, task, cancel_all);
-		mutex_lock(&ctx->uring_lock);
-		ret |= io_poll_remove_all(ctx, task, cancel_all);
-		mutex_unlock(&ctx->uring_lock);
-		ret |= io_kill_timeouts(ctx, task, cancel_all);
-		if (task)
-			ret |= io_run_task_work();
-		if (!ret)
-			break;
-		cond_resched();
+	if (!task) {
+		ret |= io_uring_try_cancel_iowq(ctx);
+	} else if (tctx && tctx->io_wq) {
+		/*
+		 * Cancels requests of all rings, not only @ctx, but
+		 * it's fine as the task is in exit/exec.
+		 */
+		cret = io_wq_cancel_cb(tctx->io_wq, io_cancel_task_cb,
+				       &cancel, true);
+		ret |= (cret != IO_WQ_CANCEL_NOTFOUND);
 	}
+
+	/* SQPOLL thread does its own polling */
+	if ((!(ctx->flags & IORING_SETUP_SQPOLL) && cancel_all) ||
+	    (ctx->sq_data && ctx->sq_data->thread == current)) {
+		while (!wq_list_empty(&ctx->iopoll_list)) {
+			io_iopoll_try_reap_events(ctx);
+			ret = true;
+		}
+	}
+
+	ret |= io_cancel_defer_files(ctx, task, cancel_all);
+	mutex_lock(&ctx->uring_lock);
+	ret |= io_poll_remove_all(ctx, task, cancel_all);
+	mutex_unlock(&ctx->uring_lock);
+	ret |= io_kill_timeouts(ctx, task, cancel_all);
+	if (task)
+		ret |= io_run_task_work();
+	return ret;
 }
 
 static s64 tctx_inflight(struct io_uring_task *tctx, bool tracked)
@@ -2882,6 +2879,8 @@ __cold void io_uring_cancel_generic(bool cancel_all, struct io_sq_data *sqd)
 
 	atomic_inc(&tctx->in_idle);
 	do {
+		bool loop = false;
+
 		io_uring_drop_tctx_refs(current);
 		/* read completions before cancelations */
 		inflight = tctx_inflight(tctx, !cancel_all);
@@ -2896,13 +2895,19 @@ __cold void io_uring_cancel_generic(bool cancel_all, struct io_sq_data *sqd)
 				/* sqpoll task will cancel all its requests */
 				if (node->ctx->sq_data)
 					continue;
-				io_uring_try_cancel_requests(node->ctx, current,
-							     cancel_all);
+				loop |= io_uring_try_cancel_requests(node->ctx,
+							current, cancel_all);
 			}
 		} else {
 			list_for_each_entry(ctx, &sqd->ctx_list, sqd_list)
-				io_uring_try_cancel_requests(ctx, current,
-							     cancel_all);
+				loop |= io_uring_try_cancel_requests(ctx,
+								     current,
+								     cancel_all);
+		}
+
+		if (loop) {
+			cond_resched();
+			continue;
 		}
 
 		prepare_to_wait(&tctx->wait, &wait, TASK_INTERRUPTIBLE);