virtio: Check for overflows in QUEUE_NOTIFY and QUEUE_SEL
This patch checks for overflows in QUEUE_NOTIFY and QUEUE_SEL in
the PCI and MMIO operation handling paths. Further, the return
value type of get_vq_count is changed from int to uint since negative
doesn't carry any semantic meaning.
Reviewed-by: Alexandru Elisei <alexandru.elisei@arm.com>
Signed-off-by: Martin Radev <martin.b.radev@gmail.com>
Link: https://lore.kernel.org/r/20220509203940.754644-6-martin.b.radev@gmail.com
Signed-off-by: Will Deacon <will@kernel.org>
diff --git a/include/kvm/virtio.h b/include/kvm/virtio.h
index 3880e74..ad274ac 100644
--- a/include/kvm/virtio.h
+++ b/include/kvm/virtio.h
@@ -187,7 +187,7 @@
size_t (*get_config_size)(struct kvm *kvm, void *dev);
u32 (*get_host_features)(struct kvm *kvm, void *dev);
void (*set_guest_features)(struct kvm *kvm, void *dev, u32 features);
- int (*get_vq_count)(struct kvm *kvm, void *dev);
+ unsigned int (*get_vq_count)(struct kvm *kvm, void *dev);
int (*init_vq)(struct kvm *kvm, void *dev, u32 vq, u32 page_size,
u32 align, u32 pfn);
void (*exit_vq)(struct kvm *kvm, void *dev, u32 vq);
diff --git a/virtio/9p.c b/virtio/9p.c
index 57cd6d0..7c9d792 100644
--- a/virtio/9p.c
+++ b/virtio/9p.c
@@ -1469,7 +1469,7 @@
return size;
}
-static int get_vq_count(struct kvm *kvm, void *dev)
+static unsigned int get_vq_count(struct kvm *kvm, void *dev)
{
return NUM_VIRT_QUEUES;
}
diff --git a/virtio/balloon.c b/virtio/balloon.c
index 655a661..f398ce4 100644
--- a/virtio/balloon.c
+++ b/virtio/balloon.c
@@ -256,7 +256,7 @@
return size;
}
-static int get_vq_count(struct kvm *kvm, void *dev)
+static unsigned int get_vq_count(struct kvm *kvm, void *dev)
{
return NUM_VIRT_QUEUES;
}
diff --git a/virtio/blk.c b/virtio/blk.c
index af71c0c..46ee028 100644
--- a/virtio/blk.c
+++ b/virtio/blk.c
@@ -291,7 +291,7 @@
return size;
}
-static int get_vq_count(struct kvm *kvm, void *dev)
+static unsigned int get_vq_count(struct kvm *kvm, void *dev)
{
return NUM_VIRT_QUEUES;
}
diff --git a/virtio/console.c b/virtio/console.c
index dae6034..8315808 100644
--- a/virtio/console.c
+++ b/virtio/console.c
@@ -216,7 +216,7 @@
return size;
}
-static int get_vq_count(struct kvm *kvm, void *dev)
+static unsigned int get_vq_count(struct kvm *kvm, void *dev)
{
return VIRTIO_CONSOLE_NUM_QUEUES;
}
diff --git a/virtio/mmio.c b/virtio/mmio.c
index 53519c1..d08da1e 100644
--- a/virtio/mmio.c
+++ b/virtio/mmio.c
@@ -169,13 +169,22 @@
{
struct virtio_mmio *vmmio = vdev->virtio;
struct kvm *kvm = vmmio->kvm;
+ unsigned int vq_count = vdev->ops->get_vq_count(kvm, vmmio->dev);
u32 val = 0;
switch (addr) {
case VIRTIO_MMIO_HOST_FEATURES_SEL:
case VIRTIO_MMIO_GUEST_FEATURES_SEL:
+ val = ioport__read32(data);
+ *(u32 *)(((void *)&vmmio->hdr) + addr) = val;
+ break;
case VIRTIO_MMIO_QUEUE_SEL:
val = ioport__read32(data);
+ if (val >= vq_count) {
+ WARN_ONCE(1, "QUEUE_SEL value (%u) is larger than VQ count (%u)\n",
+ val, vq_count);
+ break;
+ }
*(u32 *)(((void *)&vmmio->hdr) + addr) = val;
break;
case VIRTIO_MMIO_STATUS:
@@ -221,6 +230,11 @@
break;
case VIRTIO_MMIO_QUEUE_NOTIFY:
val = ioport__read32(data);
+ if (val >= vq_count) {
+ WARN_ONCE(1, "QUEUE_NOTIFY value (%u) is larger than VQ count (%u)\n",
+ val, vq_count);
+ break;
+ }
vdev->ops->notify_vq(vmmio->kvm, vmmio->dev, val);
break;
case VIRTIO_MMIO_INTERRUPT_ACK:
@@ -340,7 +354,7 @@
int virtio_mmio_reset(struct kvm *kvm, struct virtio_device *vdev)
{
- int vq;
+ unsigned int vq;
struct virtio_mmio *vmmio = vdev->virtio;
for (vq = 0; vq < vdev->ops->get_vq_count(kvm, vmmio->dev); vq++)
diff --git a/virtio/net.c b/virtio/net.c
index ec5dc1f..67070d6 100644
--- a/virtio/net.c
+++ b/virtio/net.c
@@ -755,7 +755,7 @@
return size;
}
-static int get_vq_count(struct kvm *kvm, void *dev)
+static unsigned int get_vq_count(struct kvm *kvm, void *dev)
{
struct net_dev *ndev = dev;
diff --git a/virtio/pci.c b/virtio/pci.c
index 050cfea..23831d5 100644
--- a/virtio/pci.c
+++ b/virtio/pci.c
@@ -320,9 +320,11 @@
struct virtio_pci *vpci;
struct kvm *kvm;
u32 val;
+ unsigned int vq_count;
kvm = vcpu->kvm;
vpci = vdev->virtio;
+ vq_count = vdev->ops->get_vq_count(kvm, vpci->dev);
switch (offset) {
case VIRTIO_PCI_GUEST_FEATURES:
@@ -342,10 +344,21 @@
}
break;
case VIRTIO_PCI_QUEUE_SEL:
- vpci->queue_selector = ioport__read16(data);
+ val = ioport__read16(data);
+ if (val >= vq_count) {
+ WARN_ONCE(1, "QUEUE_SEL value (%u) is larger than VQ count (%u)\n",
+ val, vq_count);
+ return false;
+ }
+ vpci->queue_selector = val;
break;
case VIRTIO_PCI_QUEUE_NOTIFY:
val = ioport__read16(data);
+ if (val >= vq_count) {
+ WARN_ONCE(1, "QUEUE_SEL value (%u) is larger than VQ count (%u)\n",
+ val, vq_count);
+ return false;
+ }
vdev->ops->notify_vq(kvm, vpci->dev, val);
break;
case VIRTIO_PCI_STATUS:
@@ -638,7 +651,7 @@
int virtio_pci__reset(struct kvm *kvm, struct virtio_device *vdev)
{
- int vq;
+ unsigned int vq;
struct virtio_pci *vpci = vdev->virtio;
for (vq = 0; vq < vdev->ops->get_vq_count(kvm, vpci->dev); vq++)
diff --git a/virtio/rng.c b/virtio/rng.c
index c7835a0..75b682e 100644
--- a/virtio/rng.c
+++ b/virtio/rng.c
@@ -147,7 +147,7 @@
return size;
}
-static int get_vq_count(struct kvm *kvm, void *dev)
+static unsigned int get_vq_count(struct kvm *kvm, void *dev)
{
return NUM_VIRT_QUEUES;
}
diff --git a/virtio/scsi.c b/virtio/scsi.c
index 8f1c348..60432cc 100644
--- a/virtio/scsi.c
+++ b/virtio/scsi.c
@@ -176,7 +176,7 @@
return size;
}
-static int get_vq_count(struct kvm *kvm, void *dev)
+static unsigned int get_vq_count(struct kvm *kvm, void *dev)
{
return NUM_VIRT_QUEUES;
}
diff --git a/virtio/vsock.c b/virtio/vsock.c
index 34397b6..64b4e95 100644
--- a/virtio/vsock.c
+++ b/virtio/vsock.c
@@ -204,7 +204,7 @@
die_perror("VHOST_SET_VRING_CALL failed");
}
-static int get_vq_count(struct kvm *kvm, void *dev)
+static unsigned int get_vq_count(struct kvm *kvm, void *dev)
{
return VSOCK_VQ_MAX;
}