Merge tag 'io_uring-6.9-20240322' of git://git.kernel.dk/linux

Pull more io_uring updates from Jens Axboe:
 "One patch just missed the initial pull, the rest are either fixes or
  small cleanups that make our life easier for the next kernel:

   - Fix a potential leak in error handling of pinned pages, and clean
     it up (Gabriel, Pavel)

   - Fix an issue with how read multishot returns retry (me)

   - Fix a problem with waitid/futex removals, if we hit the case of
     needing to remove all of them at exit time (me)

   - Fix for a regression introduced in this merge window, where we
     don't always have sr->done_io initialized if the ->prep_async()
     path is used (me)

   - Fix for SQPOLL setup error handling (me)

   - Fix for a poll removal request being delayed (Pavel)

   - Rename of a struct member which had a confusing name (Pavel)"

* tag 'io_uring-6.9-20240322' of git://git.kernel.dk/linux:
  io_uring/sqpoll: early exit thread if task_context wasn't allocated
  io_uring: clear opcode specific data for an early failure
  io_uring/net: ensure async prep handlers always initialize ->done_io
  io_uring/waitid: always remove waitid entry for cancel all
  io_uring/futex: always remove futex entry for cancel all
  io_uring: fix poll_remove stalled req completion
  io_uring: Fix release of pinned pages when __io_uaddr_map fails
  io_uring/kbuf: rename is_mapped
  io_uring: simplify io_pages_free
  io_uring: clean rings on NO_MMAP alloc fail
  io_uring/rw: return IOU_ISSUE_SKIP_COMPLETE for multishot retry
  io_uring: don't save/restore iowait state
diff --git a/io_uring/futex.c b/io_uring/futex.c
index 3c35753..792a03d 100644
--- a/io_uring/futex.c
+++ b/io_uring/futex.c
@@ -159,6 +159,7 @@
 	hlist_for_each_entry_safe(req, tmp, &ctx->futex_list, hash_node) {
 		if (!io_match_task_safe(req, task, cancel_all))
 			continue;
+		hlist_del_init(&req->hash_node);
 		__io_futex_cancel(ctx, req);
 		found = true;
 	}
diff --git a/io_uring/io_uring.c b/io_uring/io_uring.c
index cf348c3..5d4b448 100644
--- a/io_uring/io_uring.c
+++ b/io_uring/io_uring.c
@@ -2181,6 +2181,13 @@
 	}
 }
 
+static __cold int io_init_fail_req(struct io_kiocb *req, int err)
+{
+	/* ensure per-opcode data is cleared if we fail before prep */
+	memset(&req->cmd.data, 0, sizeof(req->cmd.data));
+	return err;
+}
+
 static int io_init_req(struct io_ring_ctx *ctx, struct io_kiocb *req,
 		       const struct io_uring_sqe *sqe)
 	__must_hold(&ctx->uring_lock)
@@ -2202,29 +2209,29 @@
 
 	if (unlikely(opcode >= IORING_OP_LAST)) {
 		req->opcode = 0;
-		return -EINVAL;
+		return io_init_fail_req(req, -EINVAL);
 	}
 	def = &io_issue_defs[opcode];
 	if (unlikely(sqe_flags & ~SQE_COMMON_FLAGS)) {
 		/* enforce forwards compatibility on users */
 		if (sqe_flags & ~SQE_VALID_FLAGS)
-			return -EINVAL;
+			return io_init_fail_req(req, -EINVAL);
 		if (sqe_flags & IOSQE_BUFFER_SELECT) {
 			if (!def->buffer_select)
-				return -EOPNOTSUPP;
+				return io_init_fail_req(req, -EOPNOTSUPP);
 			req->buf_index = READ_ONCE(sqe->buf_group);
 		}
 		if (sqe_flags & IOSQE_CQE_SKIP_SUCCESS)
 			ctx->drain_disabled = true;
 		if (sqe_flags & IOSQE_IO_DRAIN) {
 			if (ctx->drain_disabled)
-				return -EOPNOTSUPP;
+				return io_init_fail_req(req, -EOPNOTSUPP);
 			io_init_req_drain(req);
 		}
 	}
 	if (unlikely(ctx->restricted || ctx->drain_active || ctx->drain_next)) {
 		if (ctx->restricted && !io_check_restriction(ctx, req, sqe_flags))
-			return -EACCES;
+			return io_init_fail_req(req, -EACCES);
 		/* knock it to the slow queue path, will be drained there */
 		if (ctx->drain_active)
 			req->flags |= REQ_F_FORCE_ASYNC;
@@ -2237,9 +2244,9 @@
 	}
 
 	if (!def->ioprio && sqe->ioprio)
-		return -EINVAL;
+		return io_init_fail_req(req, -EINVAL);
 	if (!def->iopoll && (ctx->flags & IORING_SETUP_IOPOLL))
-		return -EINVAL;
+		return io_init_fail_req(req, -EINVAL);
 
 	if (def->needs_file) {
 		struct io_submit_state *state = &ctx->submit_state;
@@ -2263,12 +2270,12 @@
 
 		req->creds = xa_load(&ctx->personalities, personality);
 		if (!req->creds)
-			return -EINVAL;
+			return io_init_fail_req(req, -EINVAL);
 		get_cred(req->creds);
 		ret = security_uring_override_creds(req->creds);
 		if (ret) {
 			put_cred(req->creds);
-			return ret;
+			return io_init_fail_req(req, ret);
 		}
 		req->flags |= REQ_F_CREDS;
 	}
@@ -2539,7 +2546,7 @@
 static inline int io_cqring_wait_schedule(struct io_ring_ctx *ctx,
 					  struct io_wait_queue *iowq)
 {
-	int io_wait, ret;
+	int ret;
 
 	if (unlikely(READ_ONCE(ctx->check_cq)))
 		return 1;
@@ -2557,7 +2564,6 @@
 	 * can take into account that the task is waiting for IO - turns out
 	 * to be important for low QD IO.
 	 */
-	io_wait = current->in_iowait;
 	if (current_pending_io())
 		current->in_iowait = 1;
 	ret = 0;
@@ -2565,7 +2571,7 @@
 		schedule();
 	else if (!schedule_hrtimeout(&iowq->timeout, HRTIMER_MODE_ABS))
 		ret = -ETIME;
-	current->in_iowait = io_wait;
+	current->in_iowait = 0;
 	return ret;
 }
 
@@ -2697,13 +2703,9 @@
 
 static void io_pages_free(struct page ***pages, int npages)
 {
-	struct page **page_array;
+	struct page **page_array = *pages;
 	int i;
 
-	if (!pages)
-		return;
-
-	page_array = *pages;
 	if (!page_array)
 		return;
 
@@ -2719,7 +2721,7 @@
 	struct page **page_array;
 	unsigned int nr_pages;
 	void *page_addr;
-	int ret, i;
+	int ret, i, pinned;
 
 	*npages = 0;
 
@@ -2733,12 +2735,12 @@
 	if (!page_array)
 		return ERR_PTR(-ENOMEM);
 
-	ret = pin_user_pages_fast(uaddr, nr_pages, FOLL_WRITE | FOLL_LONGTERM,
-					page_array);
-	if (ret != nr_pages) {
-err:
-		io_pages_free(&page_array, ret > 0 ? ret : 0);
-		return ret < 0 ? ERR_PTR(ret) : ERR_PTR(-EFAULT);
+
+	pinned = pin_user_pages_fast(uaddr, nr_pages, FOLL_WRITE | FOLL_LONGTERM,
+				     page_array);
+	if (pinned != nr_pages) {
+		ret = (pinned < 0) ? pinned : -EFAULT;
+		goto free_pages;
 	}
 
 	page_addr = page_address(page_array[0]);
@@ -2752,7 +2754,7 @@
 		 * didn't support this feature.
 		 */
 		if (PageHighMem(page_array[i]))
-			goto err;
+			goto free_pages;
 
 		/*
 		 * No support for discontig pages for now, should either be a
@@ -2761,13 +2763,17 @@
 		 * just fail them with EINVAL.
 		 */
 		if (page_address(page_array[i]) != page_addr)
-			goto err;
+			goto free_pages;
 		page_addr += PAGE_SIZE;
 	}
 
 	*pages = page_array;
 	*npages = nr_pages;
 	return page_to_virt(page_array[0]);
+
+free_pages:
+	io_pages_free(&page_array, pinned > 0 ? pinned : 0);
+	return ERR_PTR(ret);
 }
 
 static void *io_rings_map(struct io_ring_ctx *ctx, unsigned long uaddr,
@@ -2789,14 +2795,15 @@
 	if (!(ctx->flags & IORING_SETUP_NO_MMAP)) {
 		io_mem_free(ctx->rings);
 		io_mem_free(ctx->sq_sqes);
-		ctx->rings = NULL;
-		ctx->sq_sqes = NULL;
 	} else {
 		io_pages_free(&ctx->ring_pages, ctx->n_ring_pages);
 		ctx->n_ring_pages = 0;
 		io_pages_free(&ctx->sqe_pages, ctx->n_sqe_pages);
 		ctx->n_sqe_pages = 0;
 	}
+
+	ctx->rings = NULL;
+	ctx->sq_sqes = NULL;
 }
 
 void *io_mem_alloc(size_t size)
diff --git a/io_uring/kbuf.c b/io_uring/kbuf.c
index 9be42bf..693c26d 100644
--- a/io_uring/kbuf.c
+++ b/io_uring/kbuf.c
@@ -199,7 +199,7 @@
 
 	bl = io_buffer_get_list(ctx, req->buf_index);
 	if (likely(bl)) {
-		if (bl->is_mapped)
+		if (bl->is_buf_ring)
 			ret = io_ring_buffer_select(req, len, bl, issue_flags);
 		else
 			ret = io_provided_buffer_select(req, len, bl);
@@ -253,7 +253,7 @@
 	if (!nbufs)
 		return 0;
 
-	if (bl->is_mapped) {
+	if (bl->is_buf_ring) {
 		i = bl->buf_ring->tail - bl->head;
 		if (bl->is_mmap) {
 			/*
@@ -274,7 +274,7 @@
 		}
 		/* make sure it's seen as empty */
 		INIT_LIST_HEAD(&bl->buf_list);
-		bl->is_mapped = 0;
+		bl->is_buf_ring = 0;
 		return i;
 	}
 
@@ -361,7 +361,7 @@
 	if (bl) {
 		ret = -EINVAL;
 		/* can't use provide/remove buffers command on mapped buffers */
-		if (!bl->is_mapped)
+		if (!bl->is_buf_ring)
 			ret = __io_remove_buffers(ctx, bl, p->nbufs);
 	}
 	io_ring_submit_unlock(ctx, issue_flags);
@@ -519,7 +519,7 @@
 		}
 	}
 	/* can't add buffers via this command for a mapped buffer ring */
-	if (bl->is_mapped) {
+	if (bl->is_buf_ring) {
 		ret = -EINVAL;
 		goto err;
 	}
@@ -575,7 +575,7 @@
 	bl->buf_pages = pages;
 	bl->buf_nr_pages = nr_pages;
 	bl->buf_ring = br;
-	bl->is_mapped = 1;
+	bl->is_buf_ring = 1;
 	bl->is_mmap = 0;
 	return 0;
 error_unpin:
@@ -642,7 +642,7 @@
 	}
 	ibf->inuse = 1;
 	bl->buf_ring = ibf->mem;
-	bl->is_mapped = 1;
+	bl->is_buf_ring = 1;
 	bl->is_mmap = 1;
 	return 0;
 }
@@ -688,7 +688,7 @@
 	bl = io_buffer_get_list(ctx, reg.bgid);
 	if (bl) {
 		/* if mapped buffer ring OR classic exists, don't allow */
-		if (bl->is_mapped || !list_empty(&bl->buf_list))
+		if (bl->is_buf_ring || !list_empty(&bl->buf_list))
 			return -EEXIST;
 	} else {
 		free_bl = bl = kzalloc(sizeof(*bl), GFP_KERNEL);
@@ -730,7 +730,7 @@
 	bl = io_buffer_get_list(ctx, reg.bgid);
 	if (!bl)
 		return -ENOENT;
-	if (!bl->is_mapped)
+	if (!bl->is_buf_ring)
 		return -EINVAL;
 
 	__io_remove_buffers(ctx, bl, -1U);
@@ -757,7 +757,7 @@
 	bl = io_buffer_get_list(ctx, buf_status.buf_group);
 	if (!bl)
 		return -ENOENT;
-	if (!bl->is_mapped)
+	if (!bl->is_buf_ring)
 		return -EINVAL;
 
 	buf_status.head = bl->head;
diff --git a/io_uring/kbuf.h b/io_uring/kbuf.h
index 5218bfd..1c7b654 100644
--- a/io_uring/kbuf.h
+++ b/io_uring/kbuf.h
@@ -26,7 +26,7 @@
 	__u16 mask;
 
 	/* ring mapped provided buffers */
-	__u8 is_mapped;
+	__u8 is_buf_ring;
 	/* ring mapped provided buffers, but mmap'ed by application */
 	__u8 is_mmap;
 	/* bl is visible from an RCU point of view for lookup */
diff --git a/io_uring/net.c b/io_uring/net.c
index 19451f0..1e7665f 100644
--- a/io_uring/net.c
+++ b/io_uring/net.c
@@ -326,7 +326,10 @@
 	struct io_async_msghdr *io;
 	int ret;
 
-	if (!zc->addr || req_has_async_data(req))
+	if (req_has_async_data(req))
+		return 0;
+	zc->done_io = 0;
+	if (!zc->addr)
 		return 0;
 	io = io_msg_alloc_async_prep(req);
 	if (!io)
@@ -353,8 +356,10 @@
 
 int io_sendmsg_prep_async(struct io_kiocb *req)
 {
+	struct io_sr_msg *sr = io_kiocb_to_cmd(req, struct io_sr_msg);
 	int ret;
 
+	sr->done_io = 0;
 	if (!io_msg_alloc_async_prep(req))
 		return -ENOMEM;
 	ret = io_sendmsg_copy_hdr(req, req->async_data);
@@ -608,9 +613,11 @@
 
 int io_recvmsg_prep_async(struct io_kiocb *req)
 {
+	struct io_sr_msg *sr = io_kiocb_to_cmd(req, struct io_sr_msg);
 	struct io_async_msghdr *iomsg;
 	int ret;
 
+	sr->done_io = 0;
 	if (!io_msg_alloc_async_prep(req))
 		return -ENOMEM;
 	iomsg = req->async_data;
diff --git a/io_uring/poll.c b/io_uring/poll.c
index 5f77913..6db1dcb 100644
--- a/io_uring/poll.c
+++ b/io_uring/poll.c
@@ -996,7 +996,6 @@
 	struct io_hash_bucket *bucket;
 	struct io_kiocb *preq;
 	int ret2, ret = 0;
-	struct io_tw_state ts = { .locked = true };
 
 	io_ring_submit_lock(ctx, issue_flags);
 	preq = io_poll_find(ctx, true, &cd, &ctx->cancel_table, &bucket);
@@ -1045,7 +1044,8 @@
 
 	req_set_fail(preq);
 	io_req_set_res(preq, -ECANCELED, 0);
-	io_req_task_complete(preq, &ts);
+	preq->io_task_work.func = io_req_task_complete;
+	io_req_task_work_add(preq);
 out:
 	io_ring_submit_unlock(ctx, issue_flags);
 	if (ret < 0) {
diff --git a/io_uring/rw.c b/io_uring/rw.c
index 47e097a..0585ebc 100644
--- a/io_uring/rw.c
+++ b/io_uring/rw.c
@@ -947,6 +947,8 @@
 		 */
 		if (io_kbuf_recycle(req, issue_flags))
 			rw->len = 0;
+		if (issue_flags & IO_URING_F_MULTISHOT)
+			return IOU_ISSUE_SKIP_COMPLETE;
 		return -EAGAIN;
 	}
 
diff --git a/io_uring/sqpoll.c b/io_uring/sqpoll.c
index 363052b..3983708 100644
--- a/io_uring/sqpoll.c
+++ b/io_uring/sqpoll.c
@@ -274,6 +274,10 @@
 	char buf[TASK_COMM_LEN];
 	DEFINE_WAIT(wait);
 
+	/* offload context creation failed, just exit */
+	if (!current->io_uring)
+		goto err_out;
+
 	snprintf(buf, sizeof(buf), "iou-sqp-%d", sqd->task_pid);
 	set_task_comm(current, buf);
 
@@ -371,7 +375,7 @@
 		atomic_or(IORING_SQ_NEED_WAKEUP, &ctx->rings->sq_flags);
 	io_run_task_work();
 	mutex_unlock(&sqd->lock);
-
+err_out:
 	complete(&sqd->exited);
 	do_exit(0);
 }
diff --git a/io_uring/waitid.c b/io_uring/waitid.c
index 6f85197..77d3406 100644
--- a/io_uring/waitid.c
+++ b/io_uring/waitid.c
@@ -125,12 +125,6 @@
 
 	lockdep_assert_held(&req->ctx->uring_lock);
 
-	/*
-	 * Did cancel find it meanwhile?
-	 */
-	if (hlist_unhashed(&req->hash_node))
-		return;
-
 	hlist_del_init(&req->hash_node);
 
 	ret = io_waitid_finish(req, ret);
@@ -202,6 +196,7 @@
 	hlist_for_each_entry_safe(req, tmp, &ctx->waitid_list, hash_node) {
 		if (!io_match_task_safe(req, task, cancel_all))
 			continue;
+		hlist_del_init(&req->hash_node);
 		__io_waitid_cancel(ctx, req);
 		found = true;
 	}