ANDROID: KVM: arm64: Skip prefaulting ptes which will be modified later

Block mappings can be split as part of a page table update. When
prefaulting entries during the split, it is pointless to install
valid ptes which will later be modified by the same walk.

At the same time, push the check for pte_is_counted into the
prefault handler, where it logically belongs.

Bug: 278749606
Bug: 308373293
Bug: 357781595
Change-Id: If4599b2860aa62d82ce8db019a8410c2d883de71
Signed-off-by: Keir Fraser <keirf@google.com>
diff --git a/arch/arm64/kvm/hyp/pgtable.c b/arch/arm64/kvm/hyp/pgtable.c
index 1d807a6..5fb1167 100644
--- a/arch/arm64/kvm/hyp/pgtable.c
+++ b/arch/arm64/kvm/hyp/pgtable.c
@@ -1011,19 +1011,29 @@ static int stage2_map_walk_table_pre(const struct kvm_pgtable_visit_ctx *ctx,
 	return 0;
 }
 
-static void stage2_map_prefault_idmap(u64 addr, u32 level, kvm_pte_t *ptep,
-				      kvm_pte_t attr)
+static void stage2_map_prefault_idmap(struct kvm_pgtable_pte_ops *pte_ops,
+				      const struct kvm_pgtable_visit_ctx *ctx,
+				      kvm_pte_t *ptep)
 {
-	u64 granule = kvm_granule_size(level);
+	kvm_pte_t block_pte = ctx->old;
+	u64 pa, granule;
 	int i;
 
-	if (!kvm_pte_valid(attr))
+	WARN_ON(pte_ops->pte_is_counted_cb(block_pte, ctx->level));
+
+	if (!kvm_pte_valid(block_pte))
 		return;
 
-	for (i = 0; i < PTRS_PER_PTE; ++i, ++ptep, addr += granule) {
-		kvm_pte_t pte = kvm_init_valid_leaf_pte(addr, attr, level);
-		/* We can write non-atomically: ptep isn't yet live. */
-		*ptep = pte;
+	pa = ALIGN_DOWN(ctx->addr, kvm_granule_size(ctx->level));
+	granule = kvm_granule_size(ctx->level + 1);
+	for (i = 0; i < PTRS_PER_PTE; ++i, ++ptep, pa += granule) {
+		kvm_pte_t pte = kvm_init_valid_leaf_pte(
+			pa, block_pte, ctx->level + 1);
+		/* Skip ptes in the range being modified by the caller. */
+		if ((pa < ctx->addr) || (pa >= ctx->end)) {
+			/* We can write non-atomically: ptep isn't yet live. */
+			*ptep = pte;
+		}
 	}
 }
 
@@ -1051,10 +1061,7 @@ static int stage2_map_walk_leaf(const struct kvm_pgtable_visit_ctx *ctx,
 		return -ENOMEM;
 
 	if (pgt->flags & KVM_PGTABLE_S2_IDMAP) {
-		WARN_ON(pte_ops->pte_is_counted_cb(ctx->old, ctx->level));
-		stage2_map_prefault_idmap(
-			ALIGN_DOWN(ctx->addr, kvm_granule_size(ctx->level)),
-			ctx->level + 1, childp, ctx->old);
+		stage2_map_prefault_idmap(pte_ops, ctx, childp);
 	}
 
 	if (!stage2_try_break_pte(ctx, data->mmu)) {