io_uring: properly handle SQPOLL request cancelations
Track if a given task io_uring context contains SQPOLL instances, so we
can iterate those for cancelation (and request counts). This ensures that
we properly wait on SQPOLL contexts, and find everything that needs
canceling.
Signed-off-by: Jens Axboe <axboe@kernel.dk>
diff --git a/fs/io_uring.c b/fs/io_uring.c
index a7429c9..b398394 100644
--- a/fs/io_uring.c
+++ b/fs/io_uring.c
@@ -1668,7 +1668,8 @@ static void __io_cqring_fill_event(struct io_kiocb *req, long res, long cflags)
WRITE_ONCE(cqe->user_data, req->user_data);
WRITE_ONCE(cqe->res, res);
WRITE_ONCE(cqe->flags, cflags);
- } else if (ctx->cq_overflow_flushed || req->task->io_uring->in_idle) {
+ } else if (ctx->cq_overflow_flushed ||
+ atomic_read(&req->task->io_uring->in_idle)) {
/*
* If we're in ring overflow flush mode, or in task cancel mode,
* then we cannot store the request for later flushing, we need
@@ -1838,7 +1839,7 @@ static void __io_free_req(struct io_kiocb *req)
io_dismantle_req(req);
percpu_counter_dec(&tctx->inflight);
- if (tctx->in_idle)
+ if (atomic_read(&tctx->in_idle))
wake_up(&tctx->wait);
put_task_struct(req->task);
@@ -7695,7 +7696,8 @@ static int io_uring_alloc_task_context(struct task_struct *task)
xa_init(&tctx->xa);
init_waitqueue_head(&tctx->wait);
tctx->last = NULL;
- tctx->in_idle = 0;
+ atomic_set(&tctx->in_idle, 0);
+ tctx->sqpoll = false;
io_init_identity(&tctx->__identity);
tctx->identity = &tctx->__identity;
task->io_uring = tctx;
@@ -8598,8 +8600,11 @@ static void io_uring_cancel_task_requests(struct io_ring_ctx *ctx,
{
struct task_struct *task = current;
- if ((ctx->flags & IORING_SETUP_SQPOLL) && ctx->sq_data)
+ if ((ctx->flags & IORING_SETUP_SQPOLL) && ctx->sq_data) {
task = ctx->sq_data->thread;
+ atomic_inc(&task->io_uring->in_idle);
+ io_sq_thread_park(ctx->sq_data);
+ }
io_cqring_overflow_flush(ctx, true, task, files);
@@ -8607,12 +8612,23 @@ static void io_uring_cancel_task_requests(struct io_ring_ctx *ctx,
io_run_task_work();
cond_resched();
}
+
+ if ((ctx->flags & IORING_SETUP_SQPOLL) && ctx->sq_data) {
+ atomic_dec(&task->io_uring->in_idle);
+ /*
+ * If the files that are going away are the ones in the thread
+ * identity, clear them out.
+ */
+ if (task->io_uring->identity->files == files)
+ task->io_uring->identity->files = NULL;
+ io_sq_thread_unpark(ctx->sq_data);
+ }
}
/*
* Note that this task has used io_uring. We use it for cancelation purposes.
*/
-static int io_uring_add_task_file(struct file *file)
+static int io_uring_add_task_file(struct io_ring_ctx *ctx, struct file *file)
{
struct io_uring_task *tctx = current->io_uring;
@@ -8634,6 +8650,14 @@ static int io_uring_add_task_file(struct file *file)
tctx->last = file;
}
+ /*
+ * This is race safe in that the task itself is doing this, hence it
+ * cannot be going through the exit/cancel paths at the same time.
+ * This cannot be modified while exit/cancel is running.
+ */
+ if (!tctx->sqpoll && (ctx->flags & IORING_SETUP_SQPOLL))
+ tctx->sqpoll = true;
+
return 0;
}
@@ -8675,7 +8699,7 @@ void __io_uring_files_cancel(struct files_struct *files)
unsigned long index;
/* make sure overflow events are dropped */
- tctx->in_idle = true;
+ atomic_inc(&tctx->in_idle);
xa_for_each(&tctx->xa, index, file) {
struct io_ring_ctx *ctx = file->private_data;
@@ -8684,6 +8708,35 @@ void __io_uring_files_cancel(struct files_struct *files)
if (files)
io_uring_del_task_file(file);
}
+
+ atomic_dec(&tctx->in_idle);
+}
+
+static s64 tctx_inflight(struct io_uring_task *tctx)
+{
+ unsigned long index;
+ struct file *file;
+ s64 inflight;
+
+ inflight = percpu_counter_sum(&tctx->inflight);
+ if (!tctx->sqpoll)
+ return inflight;
+
+ /*
+ * If we have SQPOLL rings, then we need to iterate and find them, and
+ * add the pending count for those.
+ */
+ xa_for_each(&tctx->xa, index, file) {
+ struct io_ring_ctx *ctx = file->private_data;
+
+ if (ctx->flags & IORING_SETUP_SQPOLL) {
+ struct io_uring_task *__tctx = ctx->sqo_task->io_uring;
+
+ inflight += percpu_counter_sum(&__tctx->inflight);
+ }
+ }
+
+ return inflight;
}
/*
@@ -8697,11 +8750,11 @@ void __io_uring_task_cancel(void)
s64 inflight;
/* make sure overflow events are dropped */
- tctx->in_idle = true;
+ atomic_inc(&tctx->in_idle);
do {
/* read completions before cancelations */
- inflight = percpu_counter_sum(&tctx->inflight);
+ inflight = tctx_inflight(tctx);
if (!inflight)
break;
__io_uring_files_cancel(NULL);
@@ -8712,13 +8765,13 @@ void __io_uring_task_cancel(void)
* If we've seen completions, retry. This avoids a race where
* a completion comes in before we did prepare_to_wait().
*/
- if (inflight != percpu_counter_sum(&tctx->inflight))
+ if (inflight != tctx_inflight(tctx))
continue;
schedule();
} while (1);
finish_wait(&tctx->wait, &wait);
- tctx->in_idle = false;
+ atomic_dec(&tctx->in_idle);
}
static int io_uring_flush(struct file *file, void *data)
@@ -8863,7 +8916,7 @@ SYSCALL_DEFINE6(io_uring_enter, unsigned int, fd, u32, to_submit,
io_sqpoll_wait_sq(ctx);
submitted = to_submit;
} else if (to_submit) {
- ret = io_uring_add_task_file(f.file);
+ ret = io_uring_add_task_file(ctx, f.file);
if (unlikely(ret))
goto out;
mutex_lock(&ctx->uring_lock);
@@ -9092,7 +9145,7 @@ static int io_uring_get_fd(struct io_ring_ctx *ctx)
#if defined(CONFIG_UNIX)
ctx->ring_sock->file = file;
#endif
- if (unlikely(io_uring_add_task_file(file))) {
+ if (unlikely(io_uring_add_task_file(ctx, file))) {
file = ERR_PTR(-ENOMEM);
goto err_fd;
}
diff --git a/include/linux/io_uring.h b/include/linux/io_uring.h
index 868364c..35b2d84 100644
--- a/include/linux/io_uring.h
+++ b/include/linux/io_uring.h
@@ -30,7 +30,8 @@ struct io_uring_task {
struct percpu_counter inflight;
struct io_identity __identity;
struct io_identity *identity;
- bool in_idle;
+ atomic_t in_idle;
+ bool sqpoll;
};
#if defined(CONFIG_IO_URING)