KVM: arm64: iommu: specify domain type

Add a type argument to alloc_domain, so smmu driver can use it for
stage-1/stage-2 configuration.

This allows the kernel to choose which domain type does it need.

Signed-off-by: Mostafa Saleh <smostafa@google.com>
diff --git a/arch/arm64/kvm/hyp/include/nvhe/iommu.h b/arch/arm64/kvm/hyp/include/nvhe/iommu.h
index e913b80..4cc97a2 100644
--- a/arch/arm64/kvm/hyp/include/nvhe/iommu.h
+++ b/arch/arm64/kvm/hyp/include/nvhe/iommu.h
@@ -9,8 +9,10 @@
 #if IS_ENABLED(CONFIG_ARM_SMMU_V3_PKVM)
 #include <linux/io-pgtable-arm.h>
 
-int kvm_arm_io_pgtable_init(struct io_pgtable_cfg *cfg,
-			    struct arm_lpae_io_pgtable *data);
+int kvm_arm_io_pgtable_init_s1(struct io_pgtable_cfg *cfg,
+			       struct arm_lpae_io_pgtable *data);
+int kvm_arm_io_pgtable_init_s2(struct io_pgtable_cfg *cfg,
+			       struct arm_lpae_io_pgtable *data);
 int kvm_arm_io_pgtable_alloc(struct io_pgtable *iop, unsigned long pgd_hva);
 int kvm_arm_io_pgtable_free(struct io_pgtable *iop);
 size_t kvm_arm_io_pgtable_size(struct io_pgtable *iopt);
@@ -24,7 +26,8 @@ void kvm_iommu_reclaim_page(void *p);
 
 /* Hypercall handlers */
 int kvm_iommu_alloc_domain(pkvm_handle_t iommu_id, pkvm_handle_t domain_id,
-			   unsigned long pgd_hva, struct pkvm_hyp_vcpu *ctxt);
+			   unsigned long pgd_hva, int type,
+			   struct pkvm_hyp_vcpu *ctxt);
 int kvm_iommu_free_domain(pkvm_handle_t iommu_id, pkvm_handle_t domain_id,
 			  struct pkvm_hyp_vcpu *ctxt);
 int kvm_iommu_attach_dev(pkvm_handle_t iommu_id, pkvm_handle_t domain_id,
@@ -47,7 +50,8 @@ int kvm_iommu_force_detach_dev(pkvm_handle_t iommu_id, u32 endpoint_id, u32 pasi
 #else /* !CONFIG_KVM_IOMMU */
 static inline int kvm_iommu_alloc_domain(pkvm_handle_t iommu_id,
 					 pkvm_handle_t domain_id,
-					 unsigned long pgd_hva, struct pkvm_hyp_vcpu *ctxt)
+					 unsigned long pgd_hva, int type,
+					 struct pkvm_hyp_vcpu *ctxt)
 {
 	return -ENODEV;
 }
@@ -112,6 +116,7 @@ int kvm_iommu_force_detach_dev(pkvm_handle_t iommu_id, u32 endpoint_id)
 struct kvm_iommu_tlb_cookie {
 	struct kvm_hyp_iommu	*iommu;
 	pkvm_handle_t		domain_id;
+	struct kvm_hyp_iommu_domain *domain;
 };
 
 struct kvm_iommu_ops {
@@ -125,7 +130,7 @@ struct kvm_iommu_ops {
 			  u32 pasid_bits);
 	int (*detach_dev)(struct kvm_hyp_iommu *iommu, pkvm_handle_t domain_id,
 			  struct kvm_hyp_iommu_domain *domain, u32 endpoint_id, u32 pasid);
-	int (*alloc_domain)(pkvm_handle_t iommu_id, struct io_pgtable_params **pgtable);
+	int (*alloc_domain)(pkvm_handle_t iommu_id, struct io_pgtable_params **pgtable, int type);
 };
 
 extern struct kvm_iommu_ops *kvm_iommu_ops;
diff --git a/arch/arm64/kvm/hyp/nvhe/hyp-main.c b/arch/arm64/kvm/hyp/nvhe/hyp-main.c
index 02a083a..b5d0c5b 100644
--- a/arch/arm64/kvm/hyp/nvhe/hyp-main.c
+++ b/arch/arm64/kvm/hyp/nvhe/hyp-main.c
@@ -1191,8 +1191,9 @@ static void handle___pkvm_host_iommu_alloc_domain(struct kvm_cpu_context *host_c
 	DECLARE_REG(pkvm_handle_t, iommu, host_ctxt, 1);
 	DECLARE_REG(pkvm_handle_t, domain, host_ctxt, 2);
 	DECLARE_REG(unsigned long, pgd_hva, host_ctxt, 3);
+	DECLARE_REG(int, type, host_ctxt, 4);
 
-	cpu_reg(host_ctxt, 1) = kvm_iommu_alloc_domain(iommu, domain, pgd_hva, NULL);
+	cpu_reg(host_ctxt, 1) = kvm_iommu_alloc_domain(iommu, domain, pgd_hva, type, NULL);
 }
 
 static void handle___pkvm_host_iommu_free_domain(struct kvm_cpu_context *host_ctxt)
diff --git a/arch/arm64/kvm/hyp/nvhe/iommu/arm-smmu-v3.c b/arch/arm64/kvm/hyp/nvhe/iommu/arm-smmu-v3.c
index a039cb0..21a3a5b 100644
--- a/arch/arm64/kvm/hyp/nvhe/iommu/arm-smmu-v3.c
+++ b/arch/arm64/kvm/hyp/nvhe/iommu/arm-smmu-v3.c
@@ -449,8 +449,11 @@ static void smmu_tlb_flush_all(void *cookie)
 	struct kvm_iommu_tlb_cookie *data = cookie;
 	struct hyp_arm_smmu_v3_device *smmu = to_smmu(data->iommu);
 	struct arm_smmu_cmdq_ent cmd;
+	struct kvm_hyp_iommu_domain *domain = data->domain;
+	struct arm_lpae_io_pgtable *pgtable = container_of(domain->pgtable,
+							   struct arm_lpae_io_pgtable, iop);
 
-	if (smmu->pgtable.iop.cfg.fmt == ARM_64_LPAE_S2) {
+	if (pgtable->iop.cfg.fmt == ARM_64_LPAE_S2) {
 		cmd.opcode = CMDQ_OP_TLBI_S12_VMALL;
 		cmd.tlbi.vmid = data->domain_id;
 	} else {
@@ -477,11 +480,14 @@ static void smmu_tlb_inv_range(struct kvm_iommu_tlb_cookie *data,
 		.tlbi.vmid = data->domain_id,
 		.tlbi.leaf = leaf,
 	};
+	struct kvm_hyp_iommu_domain *domain = data->domain;
+	struct arm_lpae_io_pgtable *pgtable = container_of(domain->pgtable,
+							   struct arm_lpae_io_pgtable, iop);
 
 	if (smmu->iommu.power_is_off && smmu->caches_clean_on_power_on)
 		return;
 
-	if (smmu->pgtable.iop.cfg.fmt == ARM_64_LPAE_S1) {
+	if (pgtable->iop.cfg.fmt == ARM_64_LPAE_S1) {
 		cmd.opcode = CMDQ_OP_TLBI_NH_VA;
 		cmd.tlbi.asid = data->domain_id;
 		cmd.tlbi.vmid = 0;
@@ -616,7 +622,7 @@ static int smmu_finalise_s1(u64 *ent, struct hyp_arm_smmu_v3_device *smmu,
 	u64 *cd_entry;
 	struct io_pgtable_cfg *cfg;
 
-	cfg = &smmu->pgtable.iop.cfg;
+	cfg = &smmu->pgtable_s1.iop.cfg;
 	/* Check if we already have CD for this SID. */
 	ste = smmu_get_ste_ptr(smmu, sid);
 	val = le64_to_cpu(ste[0]);
@@ -685,7 +691,7 @@ static int smmu_finalise_s2(u64 *ent, struct hyp_arm_smmu_v3_device *smmu,
 	struct io_pgtable_cfg *cfg;
 	u64 ts, sl, ic, oc, sh, tg, ps;
 
-	cfg = &smmu->pgtable.iop.cfg;
+	cfg = &smmu->pgtable_s2.iop.cfg;
 	ps = cfg->arm_lpae_s2_cfg.vtcr.ps;
 	tg = cfg->arm_lpae_s2_cfg.vtcr.tg;
 	sh = cfg->arm_lpae_s2_cfg.vtcr.sh;
@@ -784,26 +790,44 @@ static int smmu_detach_dev(struct kvm_hyp_iommu *iommu, pkvm_handle_t domain_id,
 	return smmu_sync_ste(smmu, sid);
 }
 
-int smmu_alloc_domain(pkvm_handle_t iommu_id, struct io_pgtable_params **pgtable)
+int smmu_alloc_domain(pkvm_handle_t iommu_id, struct io_pgtable_params **pgtable, int type)
 {
 	struct kvm_hyp_iommu *iommu = smmu_id_to_iommu(iommu_id);
 	struct hyp_arm_smmu_v3_device *smmu = to_smmu(iommu);
 	int ret;
-	struct io_pgtable_cfg pgtable_cfg  = (struct io_pgtable_cfg) {
-		.fmt = ARM_64_LPAE_S2,
-		.pgsize_bitmap = smmu->pgsize_bitmap,
-		.ias = smmu->ias,
-		.oas = smmu->oas,
-		.coherent_walk = smmu->features & ARM_SMMU_FEAT_COHERENCY,
-		.tlb = &smmu_tlb_ops
-	};
+	unsigned long ias = (smmu->features & ARM_SMMU_FEAT_VAX) ? 52 : 48;
+	struct io_pgtable_cfg pgtable_cfg;
 
-	ret = kvm_arm_io_pgtable_init(&pgtable_cfg, &smmu->pgtable);
+	if (type == ARM_64_LPAE_S1) {
+		pgtable_cfg = (struct io_pgtable_cfg) {
+			      .fmt = ARM_64_LPAE_S1,
+			      .pgsize_bitmap = smmu->pgsize_bitmap,
+			      .ias = min_t(unsigned long, ias, VA_BITS),
+			      .oas = smmu->ias,
+			      .coherent_walk = smmu->features & ARM_SMMU_FEAT_COHERENCY,
+			      .tlb = &smmu_tlb_ops
+		};
+		ret = kvm_arm_io_pgtable_init_s1(&pgtable_cfg, &smmu->pgtable_s1);
+		*pgtable = &smmu->pgtable_s1.iop;
+	} else if (type == ARM_64_LPAE_S2) {
+		pgtable_cfg = (struct io_pgtable_cfg) {
+			      .fmt = ARM_64_LPAE_S1,
+			      .pgsize_bitmap = smmu->pgsize_bitmap,
+			      .ias = smmu->ias,
+			      .oas = smmu->oas,
+			      .coherent_walk = smmu->features & ARM_SMMU_FEAT_COHERENCY,
+			      .tlb = &smmu_tlb_ops
+		};
+		ret = kvm_arm_io_pgtable_init_s2(&pgtable_cfg, &smmu->pgtable_s2);
+		*pgtable = &smmu->pgtable_s2.iop;
+	} else {
+		BUG();
+	}
+
+
 	if (ret)
 		return ret;
 
-	*pgtable = &smmu->pgtable.iop;
-
 	return 0;
 }
 
diff --git a/arch/arm64/kvm/hyp/nvhe/iommu/io-pgtable-arm.c b/arch/arm64/kvm/hyp/nvhe/iommu/io-pgtable-arm.c
index c9b8646..410df04 100644
--- a/arch/arm64/kvm/hyp/nvhe/iommu/io-pgtable-arm.c
+++ b/arch/arm64/kvm/hyp/nvhe/iommu/io-pgtable-arm.c
@@ -43,8 +43,26 @@ void __arm_lpae_sync_pte(arm_lpae_iopte *ptep, int num_entries,
 		kvm_flush_dcache_to_poc(ptep, sizeof(*ptep) * num_entries);
 }
 
-int kvm_arm_io_pgtable_init(struct io_pgtable_cfg *cfg,
-			    struct arm_lpae_io_pgtable *data)
+int kvm_arm_io_pgtable_init_s1(struct io_pgtable_cfg *cfg,
+			       struct arm_lpae_io_pgtable *data)
+{
+	size_t pgd_size;
+	int ret = arm_64_lpae_configure_s1(cfg, &pgd_size);
+
+	if (ret)
+		return ret;
+
+	ret = arm_lpae_init_pgtable_s1(cfg, data);
+
+	if (ret)
+		return ret;
+
+	data->iop.cfg = *cfg;
+	return 0;
+}
+
+int kvm_arm_io_pgtable_init_s2(struct io_pgtable_cfg *cfg,
+			       struct arm_lpae_io_pgtable *data)
 {
 	size_t pgd_size;
 	int ret = arm_64_lpae_configure_s2(cfg, &pgd_size);
diff --git a/arch/arm64/kvm/hyp/nvhe/iommu/iommu.c b/arch/arm64/kvm/hyp/nvhe/iommu/iommu.c
index 78ebee0..ee9ac0c 100644
--- a/arch/arm64/kvm/hyp/nvhe/iommu/iommu.c
+++ b/arch/arm64/kvm/hyp/nvhe/iommu/iommu.c
@@ -25,6 +25,7 @@ struct kvm_hyp_iommu_memcache __ro_after_init *kvm_hyp_iommu_memcaches;
 		.cookie = &(struct kvm_iommu_tlb_cookie) {	\
 			.iommu		= (_iommu),		\
 			.domain_id	= (_domain_id),		\
+			.domain		= (_domain),		\
 		},						\
 	}
 
@@ -113,7 +114,7 @@ handle_to_domain(pkvm_handle_t iommu_id, pkvm_handle_t domain_id,
 }
 
 int kvm_iommu_alloc_domain_nolock(pkvm_handle_t iommu_id, pkvm_handle_t domain_id,
-				  unsigned long pgd_hva, struct pkvm_hyp_vcpu *ctxt)
+				  unsigned long pgd_hva, int type, struct pkvm_hyp_vcpu *ctxt)
 {
 	int ret = -EINVAL;
 	struct io_pgtable iopt;
@@ -128,7 +129,7 @@ int kvm_iommu_alloc_domain_nolock(pkvm_handle_t iommu_id, pkvm_handle_t domain_i
 	if (domain->refs)
 		return ret;
 
-	ret = kvm_iommu_ops->alloc_domain(iommu_id, &domain->pgtable);
+	ret = kvm_iommu_ops->alloc_domain(iommu_id, &domain->pgtable, type);
 	if (ret)
 		return ret;
 
@@ -157,13 +158,13 @@ int kvm_iommu_alloc_domain_nolock(pkvm_handle_t iommu_id, pkvm_handle_t domain_i
 }
 
 int kvm_iommu_alloc_domain(pkvm_handle_t iommu_id, pkvm_handle_t domain_id,
-			   unsigned long pgd_hva, struct pkvm_hyp_vcpu *ctxt)
+			   unsigned long pgd_hva, int type, struct pkvm_hyp_vcpu *ctxt)
 {
 	int ret;
 	struct kvm_hyp_iommu *iommu = kvm_iommu_ops->get_iommu_by_id(iommu_id);
 
 	hyp_spin_lock(&iommu->iommu_lock);
-	ret = kvm_iommu_alloc_domain_nolock(iommu_id, domain_id, pgd_hva, ctxt);
+	ret = kvm_iommu_alloc_domain_nolock(iommu_id, domain_id, pgd_hva, type, ctxt);
 	hyp_spin_unlock(&iommu->iommu_lock);
 
 	return ret;
@@ -513,7 +514,7 @@ int kvm_iommu_alloc_guest_domain(pkvm_handle_t iommu_id, struct pkvm_hyp_vcpu *c
 		goto out_unlock;
 	}
 
-	ret = kvm_iommu_alloc_domain_nolock(iommu_id, domain_id, pgd_hva, ctxt);
+	ret = kvm_iommu_alloc_domain_nolock(iommu_id, domain_id, pgd_hva, 2, ctxt);
 	*ret_domain = domain_id;
 out_unlock:
 	cur_context = ctxt;
diff --git a/drivers/iommu/arm/arm-smmu-v3/arm-smmu-v3-kvm.c b/drivers/iommu/arm/arm-smmu-v3/arm-smmu-v3-kvm.c
index eb2a847..c47ed8b 100644
--- a/drivers/iommu/arm/arm-smmu-v3/arm-smmu-v3-kvm.c
+++ b/drivers/iommu/arm/arm-smmu-v3/arm-smmu-v3-kvm.c
@@ -18,7 +18,8 @@ struct host_arm_smmu_device {
 	struct arm_smmu_device		smmu;
 	pkvm_handle_t			id;
 	u32				boot_gbpa;
-	unsigned int			pgd_order;
+	unsigned int			pgd_order_s1;
+	unsigned int			pgd_order_s2;
 	atomic_t			initialized;
 };
 
@@ -256,15 +257,16 @@ static int kvm_arm_smmu_domain_finalize(struct kvm_arm_smmu_domain *kvm_smmu_dom
 	 * order when concatenated.
 	 */
 	p = alloc_pages_node(dev_to_node(smmu->dev), GFP_KERNEL | __GFP_ZERO,
-			     host_smmu->pgd_order);
+			     host_smmu->pgd_order_s1);
 	if (!p)
 		return -ENOMEM;
 
 	pgd = (unsigned long)page_to_virt(p);
 
+	/* TODO: choose s1 or s2 based on master->pasid_bits? */
 	local_lock_irq(&memcache_lock);
 	ret = kvm_call_hyp_nvhe_mc(smmu, __pkvm_host_iommu_alloc_domain,
-				   host_smmu->id, kvm_smmu_domain->id, pgd);
+				   host_smmu->id, kvm_smmu_domain->id, pgd, ARM_64_LPAE_S1);
 	local_unlock_irq(&memcache_lock);
 	if (ret)
 		goto err_free;
@@ -278,7 +280,7 @@ static int kvm_arm_smmu_domain_finalize(struct kvm_arm_smmu_domain *kvm_smmu_dom
 	return 0;
 
 err_free:
-	free_pages(pgd, host_smmu->pgd_order);
+	free_pages(pgd, host_smmu->pgd_order_s1);
 	ida_free(&kvm_arm_smmu_domain_ida, kvm_smmu_domain->id);
 	return ret;
 }
@@ -299,7 +301,7 @@ static void kvm_arm_smmu_domain_free(struct iommu_domain *domain)
 		 * reclaimed by the host.
 		 */
 		if (!WARN_ON(ret))
-			free_pages(kvm_smmu_domain->pgd, host_smmu->pgd_order);
+			free_pages(kvm_smmu_domain->pgd, host_smmu->pgd_order_s1);
 		ida_free(&kvm_arm_smmu_domain_ida, kvm_smmu_domain->id);
 	}
 	kfree(kvm_smmu_domain);
@@ -454,7 +456,6 @@ static bool kvm_arm_smmu_validate_features(struct arm_smmu_device *smmu)
 {
 	unsigned long oas;
 	unsigned int required_features =
-		ARM_SMMU_FEAT_TRANS_S2 |
 		ARM_SMMU_FEAT_TT_LE;
 	unsigned int forbidden_features =
 		ARM_SMMU_FEAT_STALL_FORCE;
@@ -616,13 +617,14 @@ static int kvm_arm_smmu_probe(struct platform_device *pdev)
 	bool bypass;
 	struct resource *res;
 	phys_addr_t mmio_addr;
-	struct io_pgtable_cfg cfg;
+	struct io_pgtable_cfg cfg_s1, cfg_s2;
 	size_t mmio_size, pgd_size;
 	struct arm_smmu_device *smmu;
 	struct device *dev = &pdev->dev;
 	struct host_arm_smmu_device *host_smmu;
 	struct hyp_arm_smmu_v3_device *hyp_smmu;
 	struct kvm_power_domain power_domain = {};
+	unsigned long ias;
 
 	if (kvm_arm_smmu_cur >= kvm_arm_smmu_count)
 		return -ENOSPC;
@@ -664,11 +666,21 @@ static int kvm_arm_smmu_probe(struct platform_device *pdev)
 	if (!kvm_arm_smmu_validate_features(smmu))
 		return -ENODEV;
 
+	ias = (smmu->features & ARM_SMMU_FEAT_VAX) ? 52 : 48;
+
 	/*
-	 * Stage-1 should be easy to support, though we do need to allocate a
-	 * context descriptor table.
+	 * SMMU will hold possible configuration for both S1 and S2 as any of them can be chosen
+	 * when a device is attached.
 	 */
-	cfg = (struct io_pgtable_cfg) {
+	cfg_s1 = (struct io_pgtable_cfg) {
+		.fmt = ARM_64_LPAE_S1,
+		.pgsize_bitmap = smmu->pgsize_bitmap,
+		.ias = min_t(unsigned long, ias, VA_BITS),
+		.oas = smmu->ias,
+		.coherent_walk = smmu->features & ARM_SMMU_FEAT_COHERENCY,
+	};
+
+	cfg_s2 = (struct io_pgtable_cfg) {
 		.fmt = ARM_64_LPAE_S2,
 		.pgsize_bitmap = smmu->pgsize_bitmap,
 		.ias = smmu->ias,
@@ -680,14 +692,17 @@ static int kvm_arm_smmu_probe(struct platform_device *pdev)
 	 * Choose the page and address size. Compute the PGD size as well, so we
 	 * know how much memory to pre-allocate.
 	 */
-	ret = io_pgtable_configure(&cfg, &pgd_size);
+	ret = io_pgtable_configure(&cfg_s1, &pgd_size);
 	if (ret)
 		return ret;
+	host_smmu->pgd_order_s1 = get_order(pgd_size);
 
-	host_smmu->pgd_order = get_order(pgd_size);
-	smmu->pgsize_bitmap = cfg.pgsize_bitmap;
-	smmu->ias = cfg.ias;
-	smmu->oas = cfg.oas;
+	ret = io_pgtable_configure(&cfg_s2, &pgd_size);
+	if (ret)
+		return ret;
+	host_smmu->pgd_order_s2 = get_order(pgd_size);
+
+	smmu->pgsize_bitmap = cfg_s1.pgsize_bitmap;
 
 	ret = arm_smmu_init_one_queue(smmu, &smmu->cmdq.q, smmu->base,
 				      ARM_SMMU_CMDQ_PROD, ARM_SMMU_CMDQ_CONS,
diff --git a/include/kvm/arm_smmu_v3.h b/include/kvm/arm_smmu_v3.h
index 5ed70ba..3fee2a8 100644
--- a/include/kvm/arm_smmu_v3.h
+++ b/include/kvm/arm_smmu_v3.h
@@ -33,7 +33,8 @@ struct hyp_arm_smmu_v3_device {
 	size_t			strtab_num_entries;
 	size_t			strtab_num_l1_entries;
 	u8			strtab_split;
-	struct arm_lpae_io_pgtable pgtable;
+	struct arm_lpae_io_pgtable pgtable_s1;
+	struct arm_lpae_io_pgtable pgtable_s2;
 	unsigned long		ias; /* IPA */
 	unsigned long		oas; /* PA */
 	unsigned long		pgsize_bitmap;