ANDROID: drivers/arm-smmu-v3-kvm: Save last bad index during coalescing

Avoid multiple coalescing scans of page-table pages by keeping track of
the last "bad" entry found in the table.

Signed-off-by: Will Deacon <willdeacon@google.com>
diff --git a/drivers/iommu/arm/arm-smmu-v3/pkvm/io-pgtable-arm.c b/drivers/iommu/arm/arm-smmu-v3/pkvm/io-pgtable-arm.c
index 6a23bf7..5f49bb8 100644
--- a/drivers/iommu/arm/arm-smmu-v3/pkvm/io-pgtable-arm.c
+++ b/drivers/iommu/arm/arm-smmu-v3/pkvm/io-pgtable-arm.c
@@ -166,32 +166,54 @@ static bool arm_lpae_iopte_is_mmio(struct arm_lpae_io_pgtable *data,
 	return (pte & (0xf << 2)) == ARM_LPAE_PTE_MEMATTR_DEV;
 }
 
+#define ARM_LPAE_TABLE_LAST_IDX	GENMASK(7, 2)
+static u32 arm_lpae_table_get_last_idx(struct arm_lpae_io_pgtable *data,
+				       arm_lpae_iopte table)
+{
+	u16 val = FIELD_GET(ARM_LPAE_TABLE_LAST_IDX, table);
+	return val << (data->bits_per_level - 6);
+}
+
+static void arm_lpae_table_set_last_idx(struct arm_lpae_io_pgtable *data,
+					arm_lpae_iopte *tablep, u32 idx)
+{
+	u16 val = idx >> (data->bits_per_level - 6);
+	u64p_replace_bits(tablep, val, ARM_LPAE_TABLE_LAST_IDX);
+}
+
 static bool arm_lpae_scan_last_level(struct arm_lpae_io_pgtable *data,
 				     unsigned long iova, size_t size,
 				     arm_lpae_iopte *tablep)
 {
-	u32 idx, nentries, map_idx_start, map_idx_end;
+	u32 n, idx, start, nentries, map_idx_start, map_idx_end;
 	arm_lpae_iopte table = *tablep, *cptep = iopte_deref(table, data);
 
 	nentries = ARM_LPAE_PTES_PER_TABLE(data);
+	idx = start = arm_lpae_table_get_last_idx(data, table);
 	map_idx_start = ARM_LPAE_LVL_IDX(iova, ARM_LPAE_MAX_LEVELS - 1, data);
 	map_idx_end = min_t(u32,
 			    map_idx_start + (size / ARM_LPAE_GRANULE(data)),
 			    nentries) - 1;
 
-	for (idx = 0; idx < nentries; ++idx) {
+	for (n = 0; n < nentries; ++n) {
 		arm_lpae_iopte pte = cptep[idx];
 
 		if (idx >= map_idx_start && idx <= map_idx_end) {
-			idx = map_idx_end;
+			n += map_idx_end - map_idx_start;
+			idx = (map_idx_end + 1) % nentries;
 			continue;
 		}
 
 		if (!pte || arm_lpae_iopte_is_mmio(data, pte))
 			break;
+
+		idx = (idx + 1) % nentries;
 	}
 
-	return idx == nentries;
+	if (n != nentries && idx != start)
+		arm_lpae_table_set_last_idx(data, tablep, idx);
+
+	return n == nentries;
 }
 
 bool arm_lpae_use_block_mapping(struct arm_lpae_io_pgtable *data,