virtio: Support modern virtqueue addresses

Modern virtio devices can use separate buffer for descriptors, available
and used rings. They can also use 64-bit addresses instead of 44-bit.
Rework the virtqueue initialization function to support modern virtio.

Signed-off-by: Jean-Philippe Brucker <jean-philippe.brucker@arm.com>
Link: https://lore.kernel.org/r/20220607170239.120084-6-jean-philippe.brucker@arm.com
Signed-off-by: Will Deacon <will@kernel.org>
diff --git a/include/kvm/virtio.h b/include/kvm/virtio.h
index f0b7933..24179ec 100644
--- a/include/kvm/virtio.h
+++ b/include/kvm/virtio.h
@@ -44,9 +44,30 @@
 /* Stop the device */
 #define VIRTIO__STATUS_STOP		(1 << 9)
 
+struct vring_addr {
+	bool			legacy;
+	union {
+		/* Legacy description */
+		struct {
+			u32	pfn;
+			u32	align;
+			u32	pgsize;
+		};
+		/* Modern description */
+		struct {
+			u32	desc_lo;
+			u32	desc_hi;
+			u32	avail_lo;
+			u32	avail_hi;
+			u32	used_lo;
+			u32	used_hi;
+		};
+	};
+};
+
 struct virt_queue {
 	struct vring	vring;
-	u32		pfn;
+	struct vring_addr vring_addr;
 	/* The last_avail_idx field is an index to ->ring of struct vring_avail.
 	   It's where we assume the next request index is at.  */
 	u16		last_avail_idx;
@@ -189,8 +210,7 @@
 	u32 (*get_host_features)(struct kvm *kvm, void *dev);
 	void (*set_guest_features)(struct kvm *kvm, void *dev, u32 features);
 	unsigned int (*get_vq_count)(struct kvm *kvm, void *dev);
-	int (*init_vq)(struct kvm *kvm, void *dev, u32 vq, u32 page_size,
-		       u32 align, u32 pfn);
+	int (*init_vq)(struct kvm *kvm, void *dev, u32 vq);
 	void (*exit_vq)(struct kvm *kvm, void *dev, u32 vq);
 	int (*notify_vq)(struct kvm *kvm, void *dev, u32 vq);
 	struct virt_queue *(*get_vq)(struct kvm *kvm, void *dev, u32 vq);
@@ -213,8 +233,7 @@
 int virtio_compat_add_message(const char *device, const char *config);
 const char* virtio_trans_name(enum virtio_trans trans);
 void virtio_init_device_vq(struct kvm *kvm, struct virtio_device *vdev,
-			   struct virt_queue *vq, size_t nr_descs,
-			   u32 page_size, u32 align, u32 pfn);
+			   struct virt_queue *vq, size_t nr_descs);
 void virtio_exit_vq(struct kvm *kvm, struct virtio_device *vdev, void *dev,
 		    int num);
 void virtio_set_guest_features(struct kvm *kvm, struct virtio_device *vdev,
diff --git a/virtio/9p.c b/virtio/9p.c
index d9a7737..a3f9666 100644
--- a/virtio/9p.c
+++ b/virtio/9p.c
@@ -1408,8 +1408,7 @@
 		close_fid(p9dev, pfid->fid);
 }
 
-static int init_vq(struct kvm *kvm, void *dev, u32 vq, u32 page_size, u32 align,
-		   u32 pfn)
+static int init_vq(struct kvm *kvm, void *dev, u32 vq)
 {
 	struct p9_dev *p9dev = dev;
 	struct p9_dev_job *job;
@@ -1420,8 +1419,7 @@
 	queue		= &p9dev->vqs[vq];
 	job		= &p9dev->jobs[vq];
 
-	virtio_init_device_vq(kvm, &p9dev->vdev, queue, VIRTQUEUE_NUM,
-			      page_size, align, pfn);
+	virtio_init_device_vq(kvm, &p9dev->vdev, queue, VIRTQUEUE_NUM);
 
 	*job		= (struct p9_dev_job) {
 		.vq		= queue,
diff --git a/virtio/balloon.c b/virtio/balloon.c
index 720073d..ffeeb29 100644
--- a/virtio/balloon.c
+++ b/virtio/balloon.c
@@ -130,7 +130,7 @@
 	u64 tmp;
 
 	/* Exit if the queue is not set up. */
-	if (!vq->pfn)
+	if (!vq->enabled)
 		return -ENODEV;
 
 	virt_queue__set_used_elem(vq, bdev.cur_stat_head,
@@ -209,8 +209,7 @@
 {
 }
 
-static int init_vq(struct kvm *kvm, void *dev, u32 vq, u32 page_size, u32 align,
-		   u32 pfn)
+static int init_vq(struct kvm *kvm, void *dev, u32 vq)
 {
 	struct bln_dev *bdev = dev;
 	struct virt_queue *queue;
@@ -219,8 +218,7 @@
 
 	queue		= &bdev->vqs[vq];
 
-	virtio_init_device_vq(kvm, &bdev->vdev, queue, VIRTIO_BLN_QUEUE_SIZE,
-			      page_size, align, pfn);
+	virtio_init_device_vq(kvm, &bdev->vdev, queue, VIRTIO_BLN_QUEUE_SIZE);
 
 	thread_pool__init_job(&bdev->jobs[vq], kvm, virtio_bln_do_io, queue);
 
diff --git a/virtio/blk.c b/virtio/blk.c
index af8c62f..2479e00 100644
--- a/virtio/blk.c
+++ b/virtio/blk.c
@@ -207,8 +207,7 @@
 	return NULL;
 }
 
-static int init_vq(struct kvm *kvm, void *dev, u32 vq, u32 page_size, u32 align,
-		   u32 pfn)
+static int init_vq(struct kvm *kvm, void *dev, u32 vq)
 {
 	unsigned int i;
 	struct blk_dev *bdev = dev;
@@ -216,7 +215,7 @@
 	compat__remove_message(compat_id);
 
 	virtio_init_device_vq(kvm, &bdev->vdev, &bdev->vqs[vq],
-			      VIRTIO_BLK_QUEUE_SIZE, page_size, align, pfn);
+			      VIRTIO_BLK_QUEUE_SIZE);
 
 	if (vq != 0)
 		return 0;
diff --git a/virtio/console.c b/virtio/console.c
index 9fbd101..5263b8e 100644
--- a/virtio/console.c
+++ b/virtio/console.c
@@ -147,8 +147,7 @@
 {
 }
 
-static int init_vq(struct kvm *kvm, void *dev, u32 vq, u32 page_size, u32 align,
-		   u32 pfn)
+static int init_vq(struct kvm *kvm, void *dev, u32 vq)
 {
 	struct virt_queue *queue;
 
@@ -158,8 +157,7 @@
 
 	queue		= &cdev.vqs[vq];
 
-	virtio_init_device_vq(kvm, &cdev.vdev, queue, VIRTIO_CONSOLE_QUEUE_SIZE,
-			      page_size, align, pfn);
+	virtio_init_device_vq(kvm, &cdev.vdev, queue, VIRTIO_CONSOLE_QUEUE_SIZE);
 
 	if (vq == VIRTIO_CONSOLE_TX_QUEUE) {
 		thread_pool__init_job(&cdev.jobs[vq], kvm, virtio_console_handle_callback, queue);
diff --git a/virtio/core.c b/virtio/core.c
index a5125fe..d6f2c68 100644
--- a/virtio/core.c
+++ b/virtio/core.c
@@ -160,17 +160,31 @@
 }
 
 void virtio_init_device_vq(struct kvm *kvm, struct virtio_device *vdev,
-			   struct virt_queue *vq, size_t nr_descs,
-			   u32 page_size, u32 align, u32 pfn)
+			   struct virt_queue *vq, size_t nr_descs)
 {
-	void *p = guest_flat_to_host(kvm, (u64)pfn * page_size);
+	struct vring_addr *addr = &vq->vring_addr;
 
 	vq->endian		= vdev->endian;
-	vq->pfn			= pfn;
 	vq->use_event_idx	= (vdev->features & VIRTIO_RING_F_EVENT_IDX);
 	vq->enabled		= true;
 
-	vring_init(&vq->vring, nr_descs, p, align);
+	if (addr->legacy) {
+		unsigned long base = (u64)addr->pfn * addr->pgsize;
+		void *p = guest_flat_to_host(kvm, base);
+
+		vring_init(&vq->vring, nr_descs, p, addr->align);
+	} else {
+		u64 desc = (u64)addr->desc_hi << 32 | addr->desc_lo;
+		u64 avail = (u64)addr->avail_hi << 32 | addr->avail_lo;
+		u64 used = (u64)addr->used_hi << 32 | addr->used_lo;
+
+		vq->vring = (struct vring) {
+			.desc	= guest_flat_to_host(kvm, desc),
+			.used	= guest_flat_to_host(kvm, used),
+			.avail	= guest_flat_to_host(kvm, avail),
+			.num	= nr_descs,
+		};
+	}
 }
 
 void virtio_exit_vq(struct kvm *kvm, struct virtio_device *vdev,
diff --git a/virtio/mmio.c b/virtio/mmio.c
index 3782d55..77289e2 100644
--- a/virtio/mmio.c
+++ b/virtio/mmio.c
@@ -125,6 +125,9 @@
 	}
 }
 
+#define vmmio_selected_vq(vdev, vmmio) \
+	(vdev)->ops->get_vq((vmmio)->kvm, (vmmio)->dev, (vmmio)->hdr.queue_sel)
+
 static void virtio_mmio_config_in(struct kvm_cpu *vcpu,
 				  u64 addr, void *data, u32 len,
 				  struct virtio_device *vdev)
@@ -149,9 +152,8 @@
 		ioport__write32(data, val);
 		break;
 	case VIRTIO_MMIO_QUEUE_PFN:
-		vq = vdev->ops->get_vq(vmmio->kvm, vmmio->dev,
-				       vmmio->hdr.queue_sel);
-		ioport__write32(data, vq->pfn);
+		vq = vmmio_selected_vq(vdev, vmmio);
+		ioport__write32(data, vq->vring_addr.pfn);
 		break;
 	case VIRTIO_MMIO_QUEUE_NUM_MAX:
 		val = vdev->ops->get_size_vq(vmmio->kvm, vmmio->dev,
@@ -170,6 +172,7 @@
 	struct virtio_mmio *vmmio = vdev->virtio;
 	struct kvm *kvm = vmmio->kvm;
 	unsigned int vq_count = vdev->ops->get_vq_count(kvm, vmmio->dev);
+	struct virt_queue *vq;
 	u32 val = 0;
 
 	switch (addr) {
@@ -217,13 +220,17 @@
 	case VIRTIO_MMIO_QUEUE_PFN:
 		val = ioport__read32(data);
 		if (val) {
+			vq = vmmio_selected_vq(vdev, vmmio);
+			vq->vring_addr = (struct vring_addr) {
+				.legacy	= true,
+				.pfn	= val,
+				.align	= vmmio->hdr.queue_align,
+				.pgsize	= vmmio->hdr.guest_page_size,
+			};
 			virtio_mmio_init_ioeventfd(vmmio->kvm, vdev,
 						   vmmio->hdr.queue_sel);
 			vdev->ops->init_vq(vmmio->kvm, vmmio->dev,
-					   vmmio->hdr.queue_sel,
-					   vmmio->hdr.guest_page_size,
-					   vmmio->hdr.queue_align,
-					   val);
+					   vmmio->hdr.queue_sel);
 		} else {
 			virtio_mmio_exit_vq(kvm, vdev, vmmio->hdr.queue_sel);
 		}
diff --git a/virtio/net.c b/virtio/net.c
index de5ae7b..7c7970a 100644
--- a/virtio/net.c
+++ b/virtio/net.c
@@ -582,8 +582,7 @@
 	return vq == (u32)(ndev->queue_pairs * 2);
 }
 
-static int init_vq(struct kvm *kvm, void *dev, u32 vq, u32 page_size, u32 align,
-		   u32 pfn)
+static int init_vq(struct kvm *kvm, void *dev, u32 vq)
 {
 	struct vhost_vring_state state = { .index = vq };
 	struct net_dev_queue *net_queue;
@@ -598,8 +597,7 @@
 	net_queue->id	= vq;
 	net_queue->ndev	= ndev;
 	queue		= &net_queue->vq;
-	virtio_init_device_vq(kvm, &ndev->vdev, queue, VIRTIO_NET_QUEUE_SIZE,
-			      page_size, align, pfn);
+	virtio_init_device_vq(kvm, &ndev->vdev, queue, VIRTIO_NET_QUEUE_SIZE);
 
 	mutex_init(&net_queue->lock);
 	pthread_cond_init(&net_queue->cond, NULL);
diff --git a/virtio/pci.c b/virtio/pci.c
index 23831d5..20b1622 100644
--- a/virtio/pci.c
+++ b/virtio/pci.c
@@ -178,7 +178,7 @@
 		break;
 	case VIRTIO_PCI_QUEUE_PFN:
 		vq = vdev->ops->get_vq(kvm, vpci->dev, vpci->queue_selector);
-		ioport__write32(data, vq->pfn);
+		ioport__write32(data, vq->vring_addr.pfn);
 		break;
 	case VIRTIO_PCI_QUEUE_NUM:
 		val = vdev->ops->get_size_vq(kvm, vpci->dev, vpci->queue_selector);
@@ -318,6 +318,7 @@
 {
 	bool ret = true;
 	struct virtio_pci *vpci;
+	struct virt_queue *vq;
 	struct kvm *kvm;
 	u32 val;
 	unsigned int vq_count;
@@ -334,11 +335,18 @@
 	case VIRTIO_PCI_QUEUE_PFN:
 		val = ioport__read32(data);
 		if (val) {
+			vq = vdev->ops->get_vq(kvm, vpci->dev,
+					       vpci->queue_selector);
+			vq->vring_addr = (struct vring_addr) {
+				.legacy	= true,
+				.pfn	= val,
+				.align	= VIRTIO_PCI_VRING_ALIGN,
+				.pgsize	= 1 << VIRTIO_PCI_QUEUE_ADDR_SHIFT,
+			};
 			virtio_pci__init_ioeventfd(kvm, vdev,
 						   vpci->queue_selector);
-			vdev->ops->init_vq(kvm, vpci->dev, vpci->queue_selector,
-					   1 << VIRTIO_PCI_QUEUE_ADDR_SHIFT,
-					   VIRTIO_PCI_VRING_ALIGN, val);
+			vdev->ops->init_vq(kvm, vpci->dev,
+					   vpci->queue_selector);
 		} else {
 			virtio_pci_exit_vq(kvm, vdev, vpci->queue_selector);
 		}
diff --git a/virtio/rng.c b/virtio/rng.c
index 5bcd05a..840da0e 100644
--- a/virtio/rng.c
+++ b/virtio/rng.c
@@ -91,8 +91,7 @@
 	rdev->vdev.ops->signal_vq(kvm, &rdev->vdev, vq - rdev->vqs);
 }
 
-static int init_vq(struct kvm *kvm, void *dev, u32 vq, u32 page_size, u32 align,
-		   u32 pfn)
+static int init_vq(struct kvm *kvm, void *dev, u32 vq)
 {
 	struct rng_dev *rdev = dev;
 	struct virt_queue *queue;
@@ -104,8 +103,7 @@
 
 	job = &rdev->jobs[vq];
 
-	virtio_init_device_vq(kvm, &rdev->vdev, queue, VIRTIO_RNG_QUEUE_SIZE,
-			      page_size, align, pfn);
+	virtio_init_device_vq(kvm, &rdev->vdev, queue, VIRTIO_RNG_QUEUE_SIZE);
 
 	*job = (struct rng_dev_job) {
 		.vq	= queue,
diff --git a/virtio/scsi.c b/virtio/scsi.c
index 9dd9e9a..507cf3f 100644
--- a/virtio/scsi.c
+++ b/virtio/scsi.c
@@ -62,8 +62,7 @@
 {
 }
 
-static int init_vq(struct kvm *kvm, void *dev, u32 vq, u32 page_size, u32 align,
-		   u32 pfn)
+static int init_vq(struct kvm *kvm, void *dev, u32 vq)
 {
 	struct vhost_vring_state state = { .index = vq };
 	struct vhost_vring_addr addr;
@@ -75,8 +74,7 @@
 
 	queue		= &sdev->vqs[vq];
 
-	virtio_init_device_vq(kvm, &sdev->vdev, queue, VIRTIO_SCSI_QUEUE_SIZE,
-			      page_size, align, pfn);
+	virtio_init_device_vq(kvm, &sdev->vdev, queue, VIRTIO_SCSI_QUEUE_SIZE);
 
 	if (sdev->vhost_fd == 0)
 		return 0;
diff --git a/virtio/vsock.c b/virtio/vsock.c
index 79a672f..dfd6211 100644
--- a/virtio/vsock.c
+++ b/virtio/vsock.c
@@ -65,8 +65,7 @@
 	return vq == VSOCK_VQ_EVENT;
 }
 
-static int init_vq(struct kvm *kvm, void *dev, u32 vq, u32 page_size, u32 align,
-		   u32 pfn)
+static int init_vq(struct kvm *kvm, void *dev, u32 vq)
 {
 	struct vhost_vring_state state = { .index = vq };
 	struct vhost_vring_addr addr;
@@ -77,8 +76,7 @@
 	compat__remove_message(compat_id);
 
 	queue		= &vdev->vqs[vq];
-	virtio_init_device_vq(kvm, &vdev->vdev, queue, VIRTIO_VSOCK_QUEUE_SIZE,
-			      page_size, align, pfn);
+	virtio_init_device_vq(kvm, &vdev->vdev, queue, VIRTIO_VSOCK_QUEUE_SIZE);
 
 	if (vdev->vhost_fd == -1)
 		return 0;