xsk: simplified umem setup

As suggested by Daniel Borkmann, the umem setup code was a too
defensive and complex. Here, we reduce the number of checks. Also, the
memory pinning is now folded into the umem creation, and we do correct
locking.

Signed-off-by: Björn Töpel <bjorn.topel@intel.com>
Signed-off-by: Daniel Borkmann <daniel@iogearbox.net>
diff --git a/net/xdp/xdp_umem.c b/net/xdp/xdp_umem.c
index c47909c..faa6ffb 100644
--- a/net/xdp/xdp_umem.c
+++ b/net/xdp/xdp_umem.c
@@ -16,39 +16,25 @@
 
 #define XDP_UMEM_MIN_FRAME_SIZE 2048
 
-int xdp_umem_create(struct xdp_umem **umem)
-{
-	*umem = kzalloc(sizeof(**umem), GFP_KERNEL);
-
-	if (!*umem)
-		return -ENOMEM;
-
-	return 0;
-}
-
 static void xdp_umem_unpin_pages(struct xdp_umem *umem)
 {
 	unsigned int i;
 
-	if (umem->pgs) {
-		for (i = 0; i < umem->npgs; i++) {
-			struct page *page = umem->pgs[i];
+	for (i = 0; i < umem->npgs; i++) {
+		struct page *page = umem->pgs[i];
 
-			set_page_dirty_lock(page);
-			put_page(page);
-		}
-
-		kfree(umem->pgs);
-		umem->pgs = NULL;
+		set_page_dirty_lock(page);
+		put_page(page);
 	}
+
+	kfree(umem->pgs);
+	umem->pgs = NULL;
 }
 
 static void xdp_umem_unaccount_pages(struct xdp_umem *umem)
 {
-	if (umem->user) {
-		atomic_long_sub(umem->npgs, &umem->user->locked_vm);
-		free_uid(umem->user);
-	}
+	atomic_long_sub(umem->npgs, &umem->user->locked_vm);
+	free_uid(umem->user);
 }
 
 static void xdp_umem_release(struct xdp_umem *umem)
@@ -66,22 +52,18 @@ static void xdp_umem_release(struct xdp_umem *umem)
 		umem->cq = NULL;
 	}
 
-	if (umem->pgs) {
-		xdp_umem_unpin_pages(umem);
+	xdp_umem_unpin_pages(umem);
 
-		task = get_pid_task(umem->pid, PIDTYPE_PID);
-		put_pid(umem->pid);
-		if (!task)
-			goto out;
-		mm = get_task_mm(task);
-		put_task_struct(task);
-		if (!mm)
-			goto out;
+	task = get_pid_task(umem->pid, PIDTYPE_PID);
+	put_pid(umem->pid);
+	if (!task)
+		goto out;
+	mm = get_task_mm(task);
+	put_task_struct(task);
+	if (!mm)
+		goto out;
 
-		mmput(mm);
-		umem->pgs = NULL;
-	}
-
+	mmput(mm);
 	xdp_umem_unaccount_pages(umem);
 out:
 	kfree(umem);
@@ -167,16 +149,13 @@ static int xdp_umem_account_pages(struct xdp_umem *umem)
 	return 0;
 }
 
-int xdp_umem_reg(struct xdp_umem *umem, struct xdp_umem_reg *mr)
+static int xdp_umem_reg(struct xdp_umem *umem, struct xdp_umem_reg *mr)
 {
 	u32 frame_size = mr->frame_size, frame_headroom = mr->frame_headroom;
 	u64 addr = mr->addr, size = mr->len;
 	unsigned int nframes, nfpp;
 	int size_chk, err;
 
-	if (!umem)
-		return -EINVAL;
-
 	if (frame_size < XDP_UMEM_MIN_FRAME_SIZE || frame_size > PAGE_SIZE) {
 		/* Strictly speaking we could support this, if:
 		 * - huge pages, or*
@@ -245,6 +224,24 @@ int xdp_umem_reg(struct xdp_umem *umem, struct xdp_umem_reg *mr)
 	return err;
 }
 
+struct xdp_umem *xdp_umem_create(struct xdp_umem_reg *mr)
+{
+	struct xdp_umem *umem;
+	int err;
+
+	umem = kzalloc(sizeof(*umem), GFP_KERNEL);
+	if (!umem)
+		return ERR_PTR(-ENOMEM);
+
+	err = xdp_umem_reg(umem, mr);
+	if (err) {
+		kfree(umem);
+		return ERR_PTR(err);
+	}
+
+	return umem;
+}
+
 bool xdp_umem_validate_queues(struct xdp_umem *umem)
 {
 	return umem->fq && umem->cq;
diff --git a/net/xdp/xdp_umem.h b/net/xdp/xdp_umem.h
index 70fe225..9802287 100644
--- a/net/xdp/xdp_umem.h
+++ b/net/xdp/xdp_umem.h
@@ -50,9 +50,8 @@ static inline char *xdp_umem_get_data_with_headroom(struct xdp_umem *umem,
 }
 
 bool xdp_umem_validate_queues(struct xdp_umem *umem);
-int xdp_umem_reg(struct xdp_umem *umem, struct xdp_umem_reg *mr);
 void xdp_get_umem(struct xdp_umem *umem);
 void xdp_put_umem(struct xdp_umem *umem);
-int xdp_umem_create(struct xdp_umem **umem);
+struct xdp_umem *xdp_umem_create(struct xdp_umem_reg *mr);
 
 #endif /* XDP_UMEM_H_ */
diff --git a/net/xdp/xsk.c b/net/xdp/xsk.c
index 01f010e..cce0e4f 100644
--- a/net/xdp/xsk.c
+++ b/net/xdp/xsk.c
@@ -406,25 +406,23 @@ static int xsk_setsockopt(struct socket *sock, int level, int optname,
 		struct xdp_umem_reg mr;
 		struct xdp_umem *umem;
 
-		if (xs->umem)
-			return -EBUSY;
-
 		if (copy_from_user(&mr, optval, sizeof(mr)))
 			return -EFAULT;
 
 		mutex_lock(&xs->mutex);
-		err = xdp_umem_create(&umem);
-
-		err = xdp_umem_reg(umem, &mr);
-		if (err) {
-			kfree(umem);
+		if (xs->umem) {
 			mutex_unlock(&xs->mutex);
-			return err;
+			return -EBUSY;
+		}
+
+		umem = xdp_umem_create(&mr);
+		if (IS_ERR(umem)) {
+			mutex_unlock(&xs->mutex);
+			return PTR_ERR(umem);
 		}
 
 		/* Make sure umem is ready before it can be seen by others */
 		smp_wmb();
-
 		xs->umem = umem;
 		mutex_unlock(&xs->mutex);
 		return 0;
@@ -435,13 +433,15 @@ static int xsk_setsockopt(struct socket *sock, int level, int optname,
 		struct xsk_queue **q;
 		int entries;
 
-		if (!xs->umem)
-			return -EINVAL;
-
 		if (copy_from_user(&entries, optval, sizeof(entries)))
 			return -EFAULT;
 
 		mutex_lock(&xs->mutex);
+		if (!xs->umem) {
+			mutex_unlock(&xs->mutex);
+			return -EINVAL;
+		}
+
 		q = (optname == XDP_UMEM_FILL_RING) ? &xs->umem->fq :
 			&xs->umem->cq;
 		err = xsk_init_queue(entries, q, true);