vhost: Allow device specific fields per vq

This is useful for any device who wants device specific fields per vq.
For example, tcm_vhost wants a per vq field to track requests which are
in flight on the vq. Also, on top of this we can add patches to move
things like ubufs from vhost.h out to net.c.

Signed-off-by: Michael S. Tsirkin <mst@redhat.com>
Signed-off-by: Asias He <asias@redhat.com>
Signed-off-by: Michael S. Tsirkin <mst@redhat.com>
diff --git a/drivers/vhost/net.c b/drivers/vhost/net.c
index 87c216c..176aa03 100644
--- a/drivers/vhost/net.c
+++ b/drivers/vhost/net.c
@@ -64,9 +64,13 @@
 	VHOST_NET_VQ_MAX = 2,
 };
 
+struct vhost_net_virtqueue {
+	struct vhost_virtqueue vq;
+};
+
 struct vhost_net {
 	struct vhost_dev dev;
-	struct vhost_virtqueue vqs[VHOST_NET_VQ_MAX];
+	struct vhost_net_virtqueue vqs[VHOST_NET_VQ_MAX];
 	struct vhost_poll poll[VHOST_NET_VQ_MAX];
 	/* Number of TX recently submitted.
 	 * Protected by tx vq lock. */
@@ -198,7 +202,7 @@
  * read-size critical section for our kind of RCU. */
 static void handle_tx(struct vhost_net *net)
 {
-	struct vhost_virtqueue *vq = &net->dev.vqs[VHOST_NET_VQ_TX];
+	struct vhost_virtqueue *vq = &net->vqs[VHOST_NET_VQ_TX].vq;
 	unsigned out, in, s;
 	int head;
 	struct msghdr msg = {
@@ -417,7 +421,7 @@
  * read-size critical section for our kind of RCU. */
 static void handle_rx(struct vhost_net *net)
 {
-	struct vhost_virtqueue *vq = &net->dev.vqs[VHOST_NET_VQ_RX];
+	struct vhost_virtqueue *vq = &net->vqs[VHOST_NET_VQ_RX].vq;
 	unsigned uninitialized_var(in), log;
 	struct vhost_log *vq_log;
 	struct msghdr msg = {
@@ -559,17 +563,26 @@
 {
 	struct vhost_net *n = kmalloc(sizeof *n, GFP_KERNEL);
 	struct vhost_dev *dev;
+	struct vhost_virtqueue **vqs;
 	int r;
 
 	if (!n)
 		return -ENOMEM;
+	vqs = kmalloc(VHOST_NET_VQ_MAX * sizeof(*vqs), GFP_KERNEL);
+	if (!vqs) {
+		kfree(n);
+		return -ENOMEM;
+	}
 
 	dev = &n->dev;
-	n->vqs[VHOST_NET_VQ_TX].handle_kick = handle_tx_kick;
-	n->vqs[VHOST_NET_VQ_RX].handle_kick = handle_rx_kick;
-	r = vhost_dev_init(dev, n->vqs, VHOST_NET_VQ_MAX);
+	vqs[VHOST_NET_VQ_TX] = &n->vqs[VHOST_NET_VQ_TX].vq;
+	vqs[VHOST_NET_VQ_RX] = &n->vqs[VHOST_NET_VQ_RX].vq;
+	n->vqs[VHOST_NET_VQ_TX].vq.handle_kick = handle_tx_kick;
+	n->vqs[VHOST_NET_VQ_RX].vq.handle_kick = handle_rx_kick;
+	r = vhost_dev_init(dev, vqs, VHOST_NET_VQ_MAX);
 	if (r < 0) {
 		kfree(n);
+		kfree(vqs);
 		return r;
 	}
 
@@ -584,7 +597,9 @@
 static void vhost_net_disable_vq(struct vhost_net *n,
 				 struct vhost_virtqueue *vq)
 {
-	struct vhost_poll *poll = n->poll + (vq - n->vqs);
+	struct vhost_net_virtqueue *nvq =
+		container_of(vq, struct vhost_net_virtqueue, vq);
+	struct vhost_poll *poll = n->poll + (nvq - n->vqs);
 	if (!vq->private_data)
 		return;
 	vhost_poll_stop(poll);
@@ -593,7 +608,9 @@
 static int vhost_net_enable_vq(struct vhost_net *n,
 				struct vhost_virtqueue *vq)
 {
-	struct vhost_poll *poll = n->poll + (vq - n->vqs);
+	struct vhost_net_virtqueue *nvq =
+		container_of(vq, struct vhost_net_virtqueue, vq);
+	struct vhost_poll *poll = n->poll + (nvq - n->vqs);
 	struct socket *sock;
 
 	sock = rcu_dereference_protected(vq->private_data,
@@ -621,30 +638,30 @@
 static void vhost_net_stop(struct vhost_net *n, struct socket **tx_sock,
 			   struct socket **rx_sock)
 {
-	*tx_sock = vhost_net_stop_vq(n, n->vqs + VHOST_NET_VQ_TX);
-	*rx_sock = vhost_net_stop_vq(n, n->vqs + VHOST_NET_VQ_RX);
+	*tx_sock = vhost_net_stop_vq(n, &n->vqs[VHOST_NET_VQ_TX].vq);
+	*rx_sock = vhost_net_stop_vq(n, &n->vqs[VHOST_NET_VQ_RX].vq);
 }
 
 static void vhost_net_flush_vq(struct vhost_net *n, int index)
 {
 	vhost_poll_flush(n->poll + index);
-	vhost_poll_flush(&n->dev.vqs[index].poll);
+	vhost_poll_flush(&n->vqs[index].vq.poll);
 }
 
 static void vhost_net_flush(struct vhost_net *n)
 {
 	vhost_net_flush_vq(n, VHOST_NET_VQ_TX);
 	vhost_net_flush_vq(n, VHOST_NET_VQ_RX);
-	if (n->dev.vqs[VHOST_NET_VQ_TX].ubufs) {
-		mutex_lock(&n->dev.vqs[VHOST_NET_VQ_TX].mutex);
+	if (n->vqs[VHOST_NET_VQ_TX].vq.ubufs) {
+		mutex_lock(&n->vqs[VHOST_NET_VQ_TX].vq.mutex);
 		n->tx_flush = true;
-		mutex_unlock(&n->dev.vqs[VHOST_NET_VQ_TX].mutex);
+		mutex_unlock(&n->vqs[VHOST_NET_VQ_TX].vq.mutex);
 		/* Wait for all lower device DMAs done. */
-		vhost_ubuf_put_and_wait(n->dev.vqs[VHOST_NET_VQ_TX].ubufs);
-		mutex_lock(&n->dev.vqs[VHOST_NET_VQ_TX].mutex);
+		vhost_ubuf_put_and_wait(n->vqs[VHOST_NET_VQ_TX].vq.ubufs);
+		mutex_lock(&n->vqs[VHOST_NET_VQ_TX].vq.mutex);
 		n->tx_flush = false;
-		kref_init(&n->dev.vqs[VHOST_NET_VQ_TX].ubufs->kref);
-		mutex_unlock(&n->dev.vqs[VHOST_NET_VQ_TX].mutex);
+		kref_init(&n->vqs[VHOST_NET_VQ_TX].vq.ubufs->kref);
+		mutex_unlock(&n->vqs[VHOST_NET_VQ_TX].vq.mutex);
 	}
 }
 
@@ -665,6 +682,7 @@
 	/* We do an extra flush before freeing memory,
 	 * since jobs can re-queue themselves. */
 	vhost_net_flush(n);
+	kfree(n->dev.vqs);
 	kfree(n);
 	return 0;
 }
@@ -750,7 +768,7 @@
 		r = -ENOBUFS;
 		goto err;
 	}
-	vq = n->vqs + index;
+	vq = &n->vqs[index].vq;
 	mutex_lock(&vq->mutex);
 
 	/* Verify that ring has been setup correctly. */
@@ -870,10 +888,10 @@
 	n->dev.acked_features = features;
 	smp_wmb();
 	for (i = 0; i < VHOST_NET_VQ_MAX; ++i) {
-		mutex_lock(&n->vqs[i].mutex);
-		n->vqs[i].vhost_hlen = vhost_hlen;
-		n->vqs[i].sock_hlen = sock_hlen;
-		mutex_unlock(&n->vqs[i].mutex);
+		mutex_lock(&n->vqs[i].vq.mutex);
+		n->vqs[i].vq.vhost_hlen = vhost_hlen;
+		n->vqs[i].vq.sock_hlen = sock_hlen;
+		mutex_unlock(&n->vqs[i].vq.mutex);
 	}
 	vhost_net_flush(n);
 	mutex_unlock(&n->dev.mutex);