Merge branches 'v5.20/vfio/migration-enhancements-v3', 'v5.20/vfio/simplify-bus_type-determination-v3', 'v5.20/vfio/check-vfio_register_iommu_driver-return-v2', 'v5.20/vfio/check-iommu_group_set_name_return-v1', 'v5.20/vfio/clear-caps-buf-v3', 'v5.20/vfio/remove-useless-judgement-v1' and 'v5.20/vfio/move-device_open-count-v2' into v5.20/vfio/next
diff --git a/drivers/vfio/vfio.c b/drivers/vfio/vfio.c
index aac9213..bd84ca7 100644
--- a/drivers/vfio/vfio.c
+++ b/drivers/vfio/vfio.c
@@ -504,7 +504,9 @@ static struct vfio_group *vfio_noiommu_group_alloc(struct device *dev,
 	if (IS_ERR(iommu_group))
 		return ERR_CAST(iommu_group);
 
-	iommu_group_set_name(iommu_group, "vfio-noiommu");
+	ret = iommu_group_set_name(iommu_group, "vfio-noiommu");
+	if (ret)
+		goto out_put_group;
 	ret = iommu_group_add_device(iommu_group, dev);
 	if (ret)
 		goto out_put_group;
@@ -605,7 +607,7 @@ int vfio_register_group_dev(struct vfio_device *device)
 	 * VFIO always sets IOMMU_CACHE because we offer no way for userspace to
 	 * restore cache coherency.
 	 */
-	if (!iommu_capable(device->dev->bus, IOMMU_CAP_CACHE_COHERENCY))
+	if (!device_iommu_capable(device->dev, IOMMU_CAP_CACHE_COHERENCY))
 		return -EINVAL;
 
 	return __vfio_register_dev(device,
@@ -1146,10 +1148,10 @@ static struct file *vfio_device_open(struct vfio_device *device)
 	if (device->open_count == 1 && device->ops->close_device)
 		device->ops->close_device(device);
 err_undo_count:
+	up_read(&device->group->group_rwsem);
 	device->open_count--;
 	if (device->open_count == 0 && device->kvm)
 		device->kvm = NULL;
-	up_read(&device->group->group_rwsem);
 	mutex_unlock(&device->dev_set->lock);
 	module_put(device->dev->driver->owner);
 err_unassign_container:
@@ -1811,6 +1813,7 @@ struct vfio_info_cap_header *vfio_info_cap_add(struct vfio_info_cap *caps,
 	buf = krealloc(caps->buf, caps->size + size, GFP_KERNEL);
 	if (!buf) {
 		kfree(caps->buf);
+		caps->buf = NULL;
 		caps->size = 0;
 		return ERR_PTR(-ENOMEM);
 	}
@@ -2155,13 +2158,17 @@ static int __init vfio_init(void)
 	if (ret)
 		goto err_alloc_chrdev;
 
-	pr_info(DRIVER_DESC " version: " DRIVER_VERSION "\n");
-
 #ifdef CONFIG_VFIO_NOIOMMU
-	vfio_register_iommu_driver(&vfio_noiommu_ops);
+	ret = vfio_register_iommu_driver(&vfio_noiommu_ops);
 #endif
+	if (ret)
+		goto err_driver_register;
+
+	pr_info(DRIVER_DESC " version: " DRIVER_VERSION "\n");
 	return 0;
 
+err_driver_register:
+	unregister_chrdev_region(vfio.group_devt, MINORMASK + 1);
 err_alloc_chrdev:
 	class_destroy(vfio.class);
 	vfio.class = NULL;
diff --git a/drivers/vfio/vfio_iommu_type1.c b/drivers/vfio/vfio_iommu_type1.c
index c13b929..db24062 100644
--- a/drivers/vfio/vfio_iommu_type1.c
+++ b/drivers/vfio/vfio_iommu_type1.c
@@ -1377,12 +1377,6 @@ static int vfio_dma_do_unmap(struct vfio_iommu *iommu,
 
 		if (!iommu->v2 && iova > dma->iova)
 			break;
-		/*
-		 * Task with same address space who mapped this iova range is
-		 * allowed to unmap the iova range.
-		 */
-		if (dma->task->mm != current->mm)
-			break;
 
 		if (invalidate_vaddr) {
 			if (dma->vaddr_invalid) {
@@ -1679,18 +1673,6 @@ static int vfio_dma_do_map(struct vfio_iommu *iommu,
 	return ret;
 }
 
-static int vfio_bus_type(struct device *dev, void *data)
-{
-	struct bus_type **bus = data;
-
-	if (*bus && *bus != dev->bus)
-		return -EINVAL;
-
-	*bus = dev->bus;
-
-	return 0;
-}
-
 static int vfio_iommu_replay(struct vfio_iommu *iommu,
 			     struct vfio_domain *domain)
 {
@@ -2153,13 +2135,26 @@ static void vfio_iommu_iova_insert_copy(struct vfio_iommu *iommu,
 	list_splice_tail(iova_copy, iova);
 }
 
+/* Redundantly walks non-present capabilities to simplify caller */
+static int vfio_iommu_device_capable(struct device *dev, void *data)
+{
+	return device_iommu_capable(dev, (enum iommu_cap)data);
+}
+
+static int vfio_iommu_domain_alloc(struct device *dev, void *data)
+{
+	struct iommu_domain **domain = data;
+
+	*domain = iommu_domain_alloc(dev->bus);
+	return 1; /* Don't iterate */
+}
+
 static int vfio_iommu_type1_attach_group(void *iommu_data,
 		struct iommu_group *iommu_group, enum vfio_group_type type)
 {
 	struct vfio_iommu *iommu = iommu_data;
 	struct vfio_iommu_group *group;
 	struct vfio_domain *domain, *d;
-	struct bus_type *bus = NULL;
 	bool resv_msi, msi_remap;
 	phys_addr_t resv_msi_base = 0;
 	struct iommu_domain_geometry *geo;
@@ -2192,18 +2187,19 @@ static int vfio_iommu_type1_attach_group(void *iommu_data,
 		goto out_unlock;
 	}
 
-	/* Determine bus_type in order to allocate a domain */
-	ret = iommu_group_for_each_dev(iommu_group, &bus, vfio_bus_type);
-	if (ret)
-		goto out_free_group;
-
 	ret = -ENOMEM;
 	domain = kzalloc(sizeof(*domain), GFP_KERNEL);
 	if (!domain)
 		goto out_free_group;
 
+	/*
+	 * Going via the iommu_group iterator avoids races, and trivially gives
+	 * us a representative device for the IOMMU API call. We don't actually
+	 * want to iterate beyond the first device (if any).
+	 */
 	ret = -EIO;
-	domain->domain = iommu_domain_alloc(bus);
+	iommu_group_for_each_dev(iommu_group, &domain->domain,
+				 vfio_iommu_domain_alloc);
 	if (!domain->domain)
 		goto out_free_domain;
 
@@ -2258,7 +2254,8 @@ static int vfio_iommu_type1_attach_group(void *iommu_data,
 	list_add(&group->next, &domain->group_list);
 
 	msi_remap = irq_domain_check_msi_remap() ||
-		    iommu_capable(bus, IOMMU_CAP_INTR_REMAP);
+		    iommu_group_for_each_dev(iommu_group, (void *)IOMMU_CAP_INTR_REMAP,
+					     vfio_iommu_device_capable);
 
 	if (!allow_unsafe_interrupts && !msi_remap) {
 		pr_warn("%s: No interrupt remapping support.  Use the module param \"allow_unsafe_interrupts\" to enable VFIO IOMMU support on this platform\n",