KVM: arm64: iommu: Change IDMAPPED donation

Change-Id: I2d996b2c069efb451b0530cfb2f221856d40c980
Signed-off-by: Mostafa Saleh <smostafa@google.com>
diff --git a/arch/arm64/include/asm/kvm_host.h b/arch/arm64/include/asm/kvm_host.h
index 54d1381..b2550e8a 100644
--- a/arch/arm64/include/asm/kvm_host.h
+++ b/arch/arm64/include/asm/kvm_host.h
@@ -1211,6 +1211,7 @@ struct kvm_iommu_ops;
 struct kvm_hyp_iommu_memcache;
 
 int kvm_iommu_init_hyp(struct kvm_iommu_ops *hyp_ops,
+		       struct kvm_hyp_memcache *mc,
 		       unsigned long init_arg);
 
 int kvm_iommu_register_driver(struct kvm_iommu_driver *kern_ops);
diff --git a/arch/arm64/kvm/hyp/include/nvhe/iommu.h b/arch/arm64/kvm/hyp/include/nvhe/iommu.h
index 5ae41a2..2489bf3 100644
--- a/arch/arm64/kvm/hyp/include/nvhe/iommu.h
+++ b/arch/arm64/kvm/hyp/include/nvhe/iommu.h
@@ -28,6 +28,7 @@ size_t kvm_arm_io_pgtable_size(struct io_pgtable *iopt);
 #endif /* CONFIG_ARM_SMMU_V3_PKVM */
 
 int kvm_iommu_init(struct kvm_iommu_ops *ops,
+		   struct kvm_hyp_memcache *mc,
 		   unsigned long init_arg);
 int kvm_iommu_init_device(struct kvm_hyp_iommu *iommu);
 void *kvm_iommu_donate_pages(u8 order, bool fill_req);
diff --git a/arch/arm64/kvm/hyp/nvhe/hyp-main.c b/arch/arm64/kvm/hyp/nvhe/hyp-main.c
index ffa5454..cc6cbff 100644
--- a/arch/arm64/kvm/hyp/nvhe/hyp-main.c
+++ b/arch/arm64/kvm/hyp/nvhe/hyp-main.c
@@ -1227,9 +1227,12 @@ static void handle___pkvm_hyp_alloc_mgt_reclaim(struct kvm_cpu_context *host_ctx
 static void handle___pkvm_iommu_init(struct kvm_cpu_context *host_ctxt)
 {
 	DECLARE_REG(struct kvm_iommu_ops *, ops, host_ctxt, 1);
-	DECLARE_REG(unsigned long, init_arg, host_ctxt, 3);
+	DECLARE_REG(unsigned long, mc_head, host_ctxt, 2);
+	DECLARE_REG(unsigned long, nr_pages, host_ctxt, 3);
+	DECLARE_REG(unsigned long, init_arg, host_ctxt, 4);
+	struct kvm_hyp_memcache mc = {.head = mc_head, .nr_pages = nr_pages};
 
-	cpu_reg(host_ctxt, 1) = kvm_iommu_init(ops, init_arg);
+	cpu_reg(host_ctxt, 1) = kvm_iommu_init(ops, &mc, init_arg);
 }
 
 static void handle___pkvm_devices_init(struct kvm_cpu_context *host_ctxt)
diff --git a/arch/arm64/kvm/hyp/nvhe/iommu/iommu.c b/arch/arm64/kvm/hyp/nvhe/iommu/iommu.c
index c4b3e3b..2c9c8e9 100644
--- a/arch/arm64/kvm/hyp/nvhe/iommu/iommu.c
+++ b/arch/arm64/kvm/hyp/nvhe/iommu/iommu.c
@@ -25,8 +25,6 @@ enum {
 static atomic_t kvm_iommu_initialized;
 
 void **kvm_hyp_iommu_domains;
-phys_addr_t iommu_idmap_mem;
-size_t iommu_idmap_mem_size;
 static struct hyp_pool iommu_idmap_pool;
 static struct hyp_pool iommu_host_pool;
 static int snapshot_host_stage2(void);
@@ -863,11 +861,13 @@ static int snapshot_host_stage2(void)
 	return kvm_pgtable_walk(pgt, 0, BIT(pgt->ia_bits), &walker);
 }
 
-int kvm_iommu_init(struct kvm_iommu_ops *ops, unsigned long init_arg)
+int kvm_iommu_init(struct kvm_iommu_ops *ops, struct kvm_hyp_memcache *mc,
+		   unsigned long init_arg)
 {
 	int ret;
 	void *idmap_pgd;
-	size_t idmap_pgd_sz, idmap_nr_pages;
+	size_t idmap_pgd_sz;
+	void *p;
 
 	BUILD_BUG_ON(sizeof(hyp_spinlock_t) != HYP_SPINLOCK_SIZE);
 
@@ -893,17 +893,20 @@ int kvm_iommu_init(struct kvm_iommu_ops *ops, unsigned long init_arg)
 
 	ret = hyp_pool_init(&iommu_host_pool, 0, 16 /* order = 4*/, 0, true);
 	/* Init IDMAPPED page tables. */
-	if (iommu_idmap_mem) {
-		idmap_nr_pages = PAGE_ALIGN(iommu_idmap_mem_size) >> PAGE_SHIFT;
-
-		ret = __pkvm_host_donate_hyp(iommu_idmap_mem >> PAGE_SHIFT, idmap_nr_pages);
+	if (mc->head) {
+		u8 order;
+		ret = hyp_pool_init(&iommu_idmap_pool, 0,
+				    16 /* order = 4*/, 0, true);
 		if (ret)
 			return ret;
 
-		ret = hyp_pool_init(&iommu_idmap_pool, iommu_idmap_mem >> PAGE_SHIFT,
-				    idmap_nr_pages, 0, false);
-		if (ret)
-			return ret;
+		while (mc->nr_pages) {
+			order = mc->head & (PAGE_SIZE - 1);
+			p = pkvm_admit_host_page(mc, order);
+			hyp_set_page_refcounted(hyp_virt_to_page(p));
+			hyp_virt_to_page(p)->order = order;
+			hyp_put_page(&iommu_idmap_pool, p);
+		}
 
 		idmap_pgd_sz = kvm_iommu_ops->pgd_size(DOMAIN_IDMAPPED_TYPE);
 		idmap_pgd = hyp_alloc_pages(&iommu_idmap_pool, get_order(idmap_pgd_sz));
@@ -914,9 +917,6 @@ int kvm_iommu_init(struct kvm_iommu_ops *ops, unsigned long init_arg)
 		kvm_hyp_iommu_domains[0] = hyp_alloc_pages(&iommu_idmap_pool, 0);
 		ret = kvm_iommu_alloc_domain_nolock(KVM_IOMMU_IDMAPPED_DOMAIN, (u64)idmap_pgd,
 						    idmap_pgd_sz, DOMAIN_IDMAPPED_TYPE);
-		if (ret)
-			return ret;
-		iommu_idmap_mem = (u64)hyp_phys_to_virt(iommu_idmap_mem);
 	}
 
 	return ret;
diff --git a/arch/arm64/kvm/iommu.c b/arch/arm64/kvm/iommu.c
index d8c1bd1..ec72e19 100644
--- a/arch/arm64/kvm/iommu.c
+++ b/arch/arm64/kvm/iommu.c
@@ -24,13 +24,14 @@ int kvm_iommu_register_driver(struct kvm_iommu_driver *kern_ops)
 EXPORT_SYMBOL(kvm_iommu_register_driver);
 
 int kvm_iommu_init_hyp(struct kvm_iommu_ops *hyp_ops,
+		       struct kvm_hyp_memcache *mc,
 		       unsigned long init_arg)
 {
 	int ret = 0;
 
 	BUG_ON(!hyp_ops);
 
-	ret = kvm_call_hyp_nvhe(__pkvm_iommu_init, hyp_ops, init_arg);
+	ret = kvm_call_hyp_nvhe(__pkvm_iommu_init, hyp_ops, mc->head, mc->nr_pages, init_arg);
 	if (ret)
 		return ret;
 
@@ -40,8 +41,6 @@ EXPORT_SYMBOL(kvm_iommu_init_hyp);
 
 int __init kvm_iommu_init_driver(void)
 {
-	struct page *pg;
-
 	if (WARN_ON(!iommu_driver))
 		return -ENODEV;
 	/*
@@ -58,21 +57,6 @@ int __init kvm_iommu_init_driver(void)
 		return -ENOMEM;
 	}
 
-	/* Drivers may not support idmapped domains for memory optimization. */
-	if (iommu_driver->idmapped_pg_size) {
-		/* Extra page size for domain metadata. */
-		iommu_idmap_mem_size = iommu_driver->idmapped_pg_size() + PAGE_SIZE;
-		pg = dma_alloc_from_contiguous(NULL, iommu_idmap_mem_size >> PAGE_SHIFT,
-					       get_order(PAGE_SIZE), false);
-		if (!pg)
-			kvm_err("Couldn't allocate 0x%lx for IOMMU idmap\n",
-				iommu_idmap_mem_size);
-		else
-			iommu_idmap_mem = page_to_phys(pg);
-	}
-	/* Topup hyp alloc so IOMMU driver can allocate domains. */
-	__pkvm_topup_hyp_alloc(1);
-
 	return iommu_driver->init_driver();
 }
 
diff --git a/drivers/iommu/arm/arm-smmu-v3/arm-smmu-v3-kvm.c b/drivers/iommu/arm/arm-smmu-v3/arm-smmu-v3-kvm.c
index b8c06b3..fe0ae21 100644
--- a/drivers/iommu/arm/arm-smmu-v3/arm-smmu-v3-kvm.c
+++ b/drivers/iommu/arm/arm-smmu-v3/arm-smmu-v3-kvm.c
@@ -760,6 +760,27 @@ int smmu_finalise_device(struct device *dev, void *data)
 	return arm_smmu_register_iommu(smmu, &kvm_arm_smmu_ops, mmio_addr);;
 }
 
+static int alloc_idmapped_mc(struct kvm_hyp_memcache *mc)
+{
+	u64 i, total = 0;
+	phys_addr_t start, end;
+	int ret;
+
+	for_each_mem_range(i, &start, &end) {
+		total += __hyp_pgtable_max_pages((end - start) >> PAGE_SHIFT);
+	}
+	/* We don't know how much for MMIO we need, 1GB is very generous. */
+	total += __hyp_pgtable_max_pages(SZ_1G >> PAGE_SHIFT);
+
+	/* For PGD*/
+	ret = topup_hyp_memcache(mc, 1, 0, 3);
+	if (ret)
+		return ret;
+
+	ret = topup_hyp_memcache(mc, total, 0, 0);
+	return ret;
+}
+
 static int kvm_arm_smmu_probe(struct platform_device *pdev)
 {
 	int ret, i;
@@ -774,6 +795,7 @@ static int kvm_arm_smmu_probe(struct platform_device *pdev)
 	struct hyp_arm_smmu_v3_device *hyp_smmu;
 	struct kvm_power_domain power_domain = {};
 	unsigned long ias;
+	struct kvm_hyp_memcache mc = {0, 0};
 
 	if (kvm_arm_smmu_cur >= kvm_arm_smmu_count)
 		return -ENOSPC;
@@ -915,9 +937,15 @@ static int kvm_arm_smmu_probe(struct platform_device *pdev)
 	pm_runtime_resume_and_get(dev);
 
 	if (kvm_arm_smmu_cur == kvm_arm_smmu_count) {
-		/* Go go go. */
-		ret = kvm_iommu_init_hyp(ksym_ref_addr_nvhe(smmu_ops), 0);
+		ret = alloc_idmapped_mc(&mc);
+		if (ret)
+			pr_warn("No SMMUv3 IDMAPPED support err => %d\n", ret);
 
+		/* Topup hyp alloc so IOMMU driver can allocate domains. */
+		__pkvm_topup_hyp_alloc(1);
+
+		/* Go go go. */
+		ret = kvm_iommu_init_hyp(ksym_ref_addr_nvhe(smmu_ops), &mc, 0);
 
 		for (i = 0 ; i < kvm_arm_smmu_cur; ++i)
 			smmu_finalise_device(smmus_arr[i], NULL);
@@ -1062,26 +1090,11 @@ static void kvm_arm_smmu_v3_remove(void)
 	platform_driver_unregister(&kvm_arm_smmu_driver);
 }
 
-size_t kvm_arm_smmu_v3_pg_size(void)
-{
-	u64 i, total = 0;
-	phys_addr_t start, end;
-
-	for_each_mem_range(i, &start, &end) {
-		total += __hyp_pgtable_max_pages((end - start) >> PAGE_SHIFT);
-	}
-	/* We don't know how much for MMIO we need, 1GB is very generous. */
-	total += __hyp_pgtable_max_pages(SZ_1G >> PAGE_SHIFT);
-
-	return total << PAGE_SHIFT;
-}
-
 struct kvm_iommu_driver kvm_smmu_v3_ops = {
 	.init_driver = kvm_arm_smmu_v3_init,
 	.remove_driver = kvm_arm_smmu_v3_remove,
 	.get_iommu_id = kvm_arm_smmu_v3_id,
 	.get_iommu_id_by_of = kvm_arm_v3_id_by_of,
-	.idmapped_pg_size = kvm_arm_smmu_v3_pg_size,
 };
 
 static int kvm_arm_smmu_v3_register(void)