io_uring: allow updating linked timeouts

We allow updating normal timeouts, add support for adjusting timings of
linked timeouts as well.

Reported-by: Victor Stewart <v@nametag.social>
Signed-off-by: Pavel Begunkov <asml.silence@gmail.com>
Signed-off-by: Jens Axboe <axboe@kernel.dk>
diff --git a/fs/io_uring.c b/fs/io_uring.c
index aa97829..7cc458e 100644
--- a/fs/io_uring.c
+++ b/fs/io_uring.c
@@ -552,6 +552,7 @@ struct io_timeout_rem {
 	/* timeout update */
 	struct timespec64		ts;
 	u32				flags;
+	bool				ltimeout;
 };
 
 struct io_rw {
@@ -1069,6 +1070,7 @@ static int io_req_prep_async(struct io_kiocb *req);
 
 static int io_install_fixed_file(struct io_kiocb *req, struct file *file,
 				 unsigned int issue_flags, u32 slot_index);
+static enum hrtimer_restart io_link_timeout_fn(struct hrtimer *timer);
 
 static struct kmem_cache *req_cachep;
 
@@ -5732,6 +5734,31 @@ static clockid_t io_timeout_get_clock(struct io_timeout_data *data)
 	}
 }
 
+static int io_linked_timeout_update(struct io_ring_ctx *ctx, __u64 user_data,
+				    struct timespec64 *ts, enum hrtimer_mode mode)
+	__must_hold(&ctx->timeout_lock)
+{
+	struct io_timeout_data *io;
+	struct io_kiocb *req;
+	bool found = false;
+
+	list_for_each_entry(req, &ctx->ltimeout_list, timeout.list) {
+		found = user_data == req->user_data;
+		if (found)
+			break;
+	}
+	if (!found)
+		return -ENOENT;
+
+	io = req->async_data;
+	if (hrtimer_try_to_cancel(&io->timer) == -1)
+		return -EALREADY;
+	hrtimer_init(&io->timer, io_timeout_get_clock(io), mode);
+	io->timer.function = io_link_timeout_fn;
+	hrtimer_start(&io->timer, timespec64_to_ktime(*ts), mode);
+	return 0;
+}
+
 static int io_timeout_update(struct io_ring_ctx *ctx, __u64 user_data,
 			     struct timespec64 *ts, enum hrtimer_mode mode)
 	__must_hold(&ctx->timeout_lock)
@@ -5763,10 +5790,15 @@ static int io_timeout_remove_prep(struct io_kiocb *req,
 	if (sqe->ioprio || sqe->buf_index || sqe->len || sqe->splice_fd_in)
 		return -EINVAL;
 
+	tr->ltimeout = false;
 	tr->addr = READ_ONCE(sqe->addr);
 	tr->flags = READ_ONCE(sqe->timeout_flags);
-	if (tr->flags & IORING_TIMEOUT_UPDATE) {
-		if (tr->flags & ~(IORING_TIMEOUT_UPDATE|IORING_TIMEOUT_ABS))
+	if (tr->flags & IORING_TIMEOUT_UPDATE_MASK) {
+		if (hweight32(tr->flags & IORING_TIMEOUT_CLOCK_MASK) > 1)
+			return -EINVAL;
+		if (tr->flags & IORING_LINK_TIMEOUT_UPDATE)
+			tr->ltimeout = true;
+		if (tr->flags & ~(IORING_TIMEOUT_UPDATE_MASK|IORING_TIMEOUT_ABS))
 			return -EINVAL;
 		if (get_timespec64(&tr->ts, u64_to_user_ptr(sqe->addr2)))
 			return -EFAULT;
@@ -5800,9 +5832,13 @@ static int io_timeout_remove(struct io_kiocb *req, unsigned int issue_flags)
 		spin_unlock_irq(&ctx->timeout_lock);
 		spin_unlock(&ctx->completion_lock);
 	} else {
+		enum hrtimer_mode mode = io_translate_timeout_mode(tr->flags);
+
 		spin_lock_irq(&ctx->timeout_lock);
-		ret = io_timeout_update(ctx, tr->addr, &tr->ts,
-					io_translate_timeout_mode(tr->flags));
+		if (tr->ltimeout)
+			ret = io_linked_timeout_update(ctx, tr->addr, &tr->ts, mode);
+		else
+			ret = io_timeout_update(ctx, tr->addr, &tr->ts, mode);
 		spin_unlock_irq(&ctx->timeout_lock);
 	}