io_uring: modularize io_sqe_buffer_register

Split io_sqe_buffer_register into two routines:

- io_sqe_buffer_register() registers a single buffer
- io_sqe_buffers_register iterates over all user specified buffers

Reviewed-by: Pavel Begunkov <asml.silence@gmail.com>
Signed-off-by: Bijan Mottahedeh <bijan.mottahedeh@oracle.com>
Signed-off-by: Jens Axboe <axboe@kernel.dk>
diff --git a/fs/io_uring.c b/fs/io_uring.c
index fbc6d2f..ec70ba0 100644
--- a/fs/io_uring.c
+++ b/fs/io_uring.c
@@ -8370,7 +8370,7 @@ static unsigned long ring_pages(unsigned sq_entries, unsigned cq_entries)
 	return pages;
 }
 
-static int io_sqe_buffer_unregister(struct io_ring_ctx *ctx)
+static int io_sqe_buffers_unregister(struct io_ring_ctx *ctx)
 {
 	int i, j;
 
@@ -8488,14 +8488,103 @@ static int io_buffer_account_pin(struct io_ring_ctx *ctx, struct page **pages,
 	return ret;
 }
 
-static int io_sqe_buffer_register(struct io_ring_ctx *ctx, void __user *arg,
-				  unsigned nr_args)
+static int io_sqe_buffer_register(struct io_ring_ctx *ctx, struct iovec *iov,
+				  struct io_mapped_ubuf *imu,
+				  struct page **last_hpage)
 {
 	struct vm_area_struct **vmas = NULL;
 	struct page **pages = NULL;
+	unsigned long off, start, end, ubuf;
+	size_t size;
+	int ret, pret, nr_pages, i;
+
+	ubuf = (unsigned long) iov->iov_base;
+	end = (ubuf + iov->iov_len + PAGE_SIZE - 1) >> PAGE_SHIFT;
+	start = ubuf >> PAGE_SHIFT;
+	nr_pages = end - start;
+
+	ret = -ENOMEM;
+
+	pages = kvmalloc_array(nr_pages, sizeof(struct page *), GFP_KERNEL);
+	if (!pages)
+		goto done;
+
+	vmas = kvmalloc_array(nr_pages, sizeof(struct vm_area_struct *),
+			      GFP_KERNEL);
+	if (!vmas)
+		goto done;
+
+	imu->bvec = kvmalloc_array(nr_pages, sizeof(struct bio_vec),
+				   GFP_KERNEL);
+	if (!imu->bvec)
+		goto done;
+
+	ret = 0;
+	mmap_read_lock(current->mm);
+	pret = pin_user_pages(ubuf, nr_pages, FOLL_WRITE | FOLL_LONGTERM,
+			      pages, vmas);
+	if (pret == nr_pages) {
+		/* don't support file backed memory */
+		for (i = 0; i < nr_pages; i++) {
+			struct vm_area_struct *vma = vmas[i];
+
+			if (vma->vm_file &&
+			    !is_file_hugepages(vma->vm_file)) {
+				ret = -EOPNOTSUPP;
+				break;
+			}
+		}
+	} else {
+		ret = pret < 0 ? pret : -EFAULT;
+	}
+	mmap_read_unlock(current->mm);
+	if (ret) {
+		/*
+		 * if we did partial map, or found file backed vmas,
+		 * release any pages we did get
+		 */
+		if (pret > 0)
+			unpin_user_pages(pages, pret);
+		kvfree(imu->bvec);
+		goto done;
+	}
+
+	ret = io_buffer_account_pin(ctx, pages, pret, imu, last_hpage);
+	if (ret) {
+		unpin_user_pages(pages, pret);
+		kvfree(imu->bvec);
+		goto done;
+	}
+
+	off = ubuf & ~PAGE_MASK;
+	size = iov->iov_len;
+	for (i = 0; i < nr_pages; i++) {
+		size_t vec_len;
+
+		vec_len = min_t(size_t, size, PAGE_SIZE - off);
+		imu->bvec[i].bv_page = pages[i];
+		imu->bvec[i].bv_len = vec_len;
+		imu->bvec[i].bv_offset = off;
+		off = 0;
+		size -= vec_len;
+	}
+	/* store original address for later verification */
+	imu->ubuf = ubuf;
+	imu->len = iov->iov_len;
+	imu->nr_bvecs = nr_pages;
+	ret = 0;
+done:
+	kvfree(pages);
+	kvfree(vmas);
+	return ret;
+}
+
+static int io_sqe_buffers_register(struct io_ring_ctx *ctx, void __user *arg,
+				   unsigned int nr_args)
+{
+	int i, ret;
+	struct iovec iov;
 	struct page *last_hpage = NULL;
-	int i, j, got_pages = 0;
-	int ret = -EINVAL;
 
 	if (ctx->user_bufs)
 		return -EBUSY;
@@ -8509,14 +8598,10 @@ static int io_sqe_buffer_register(struct io_ring_ctx *ctx, void __user *arg,
 
 	for (i = 0; i < nr_args; i++) {
 		struct io_mapped_ubuf *imu = &ctx->user_bufs[i];
-		unsigned long off, start, end, ubuf;
-		int pret, nr_pages;
-		struct iovec iov;
-		size_t size;
 
 		ret = io_copy_iov(ctx, &iov, arg, i);
 		if (ret)
-			goto err;
+			break;
 
 		/*
 		 * Don't impose further limits on the size and buffer
@@ -8525,103 +8610,22 @@ static int io_sqe_buffer_register(struct io_ring_ctx *ctx, void __user *arg,
 		 */
 		ret = -EFAULT;
 		if (!iov.iov_base || !iov.iov_len)
-			goto err;
+			break;
 
 		/* arbitrary limit, but we need something */
 		if (iov.iov_len > SZ_1G)
-			goto err;
+			break;
 
-		ubuf = (unsigned long) iov.iov_base;
-		end = (ubuf + iov.iov_len + PAGE_SIZE - 1) >> PAGE_SHIFT;
-		start = ubuf >> PAGE_SHIFT;
-		nr_pages = end - start;
-
-		ret = 0;
-		if (!pages || nr_pages > got_pages) {
-			kvfree(vmas);
-			kvfree(pages);
-			pages = kvmalloc_array(nr_pages, sizeof(struct page *),
-						GFP_KERNEL);
-			vmas = kvmalloc_array(nr_pages,
-					sizeof(struct vm_area_struct *),
-					GFP_KERNEL);
-			if (!pages || !vmas) {
-				ret = -ENOMEM;
-				goto err;
-			}
-			got_pages = nr_pages;
-		}
-
-		imu->bvec = kvmalloc_array(nr_pages, sizeof(struct bio_vec),
-						GFP_KERNEL);
-		ret = -ENOMEM;
-		if (!imu->bvec)
-			goto err;
-
-		ret = 0;
-		mmap_read_lock(current->mm);
-		pret = pin_user_pages(ubuf, nr_pages,
-				      FOLL_WRITE | FOLL_LONGTERM,
-				      pages, vmas);
-		if (pret == nr_pages) {
-			/* don't support file backed memory */
-			for (j = 0; j < nr_pages; j++) {
-				struct vm_area_struct *vma = vmas[j];
-
-				if (vma->vm_file &&
-				    !is_file_hugepages(vma->vm_file)) {
-					ret = -EOPNOTSUPP;
-					break;
-				}
-			}
-		} else {
-			ret = pret < 0 ? pret : -EFAULT;
-		}
-		mmap_read_unlock(current->mm);
-		if (ret) {
-			/*
-			 * if we did partial map, or found file backed vmas,
-			 * release any pages we did get
-			 */
-			if (pret > 0)
-				unpin_user_pages(pages, pret);
-			kvfree(imu->bvec);
-			goto err;
-		}
-
-		ret = io_buffer_account_pin(ctx, pages, pret, imu, &last_hpage);
-		if (ret) {
-			unpin_user_pages(pages, pret);
-			kvfree(imu->bvec);
-			goto err;
-		}
-
-		off = ubuf & ~PAGE_MASK;
-		size = iov.iov_len;
-		for (j = 0; j < nr_pages; j++) {
-			size_t vec_len;
-
-			vec_len = min_t(size_t, size, PAGE_SIZE - off);
-			imu->bvec[j].bv_page = pages[j];
-			imu->bvec[j].bv_len = vec_len;
-			imu->bvec[j].bv_offset = off;
-			off = 0;
-			size -= vec_len;
-		}
-		/* store original address for later verification */
-		imu->ubuf = ubuf;
-		imu->len = iov.iov_len;
-		imu->nr_bvecs = nr_pages;
+		ret = io_sqe_buffer_register(ctx, &iov, imu, &last_hpage);
+		if (ret)
+			break;
 
 		ctx->nr_user_bufs++;
 	}
-	kvfree(pages);
-	kvfree(vmas);
-	return 0;
-err:
-	kvfree(pages);
-	kvfree(vmas);
-	io_sqe_buffer_unregister(ctx);
+
+	if (ret)
+		io_sqe_buffers_unregister(ctx);
+
 	return ret;
 }
 
@@ -8675,7 +8679,7 @@ static void io_destroy_buffers(struct io_ring_ctx *ctx)
 static void io_ring_ctx_free(struct io_ring_ctx *ctx)
 {
 	io_finish_async(ctx);
-	io_sqe_buffer_unregister(ctx);
+	io_sqe_buffers_unregister(ctx);
 
 	if (ctx->sqo_task) {
 		put_task_struct(ctx->sqo_task);
@@ -10057,13 +10061,13 @@ static int __io_uring_register(struct io_ring_ctx *ctx, unsigned opcode,
 
 	switch (opcode) {
 	case IORING_REGISTER_BUFFERS:
-		ret = io_sqe_buffer_register(ctx, arg, nr_args);
+		ret = io_sqe_buffers_register(ctx, arg, nr_args);
 		break;
 	case IORING_UNREGISTER_BUFFERS:
 		ret = -EINVAL;
 		if (arg || nr_args)
 			break;
-		ret = io_sqe_buffer_unregister(ctx);
+		ret = io_sqe_buffers_unregister(ctx);
 		break;
 	case IORING_REGISTER_FILES:
 		ret = io_sqe_files_register(ctx, arg, nr_args);