KVM: arm64: Avoid calling PSCI MEM_PROTECT on every guest page transition

Calling into EL3 every time a page transitions in and out of a guest
is expensive, not just because of the firmware call, but also because
of the global lock protecting it.

Instead, elevate the PSCI MEM_PROTECT count for the lifetime of the VM,
tracking the page count in memory and adjusting the firmware view only
during teardown in the case that not all memory was reclaimed by the
host.

Signed-off-by: Will Deacon <will@kernel.org>
diff --git a/arch/arm64/kvm/hyp/include/nvhe/pkvm.h b/arch/arm64/kvm/hyp/include/nvhe/pkvm.h
index 263d7ac..991b573 100644
--- a/arch/arm64/kvm/hyp/include/nvhe/pkvm.h
+++ b/arch/arm64/kvm/hyp/include/nvhe/pkvm.h
@@ -55,6 +55,10 @@ struct pkvm_hyp_vm {
 	struct hyp_pool pool;
 	hyp_spinlock_t pgtable_lock;
 
+	/* Donated/reclaimed page counters for PSCI MEM_PROTECT */
+	u64 guest_owned_pages;
+	u64 reclaimed_pages;
+
 	/* Primary vCPU pending entry to the pvmfw */
 	struct pkvm_hyp_vcpu *pvmfw_entry_vcpu;
 
diff --git a/arch/arm64/kvm/hyp/nvhe/mem_protect.c b/arch/arm64/kvm/hyp/nvhe/mem_protect.c
index 5922584..4e66429 100644
--- a/arch/arm64/kvm/hyp/nvhe/mem_protect.c
+++ b/arch/arm64/kvm/hyp/nvhe/mem_protect.c
@@ -312,10 +312,8 @@ static int relinquish_walker(const struct kvm_pgtable_visit_ctx *ctx,
 		return -EPERM;
 
 	phys = kvm_pte_to_phys(pte);
-	if (state == PKVM_PAGE_OWNED) {
+	if (state == PKVM_PAGE_OWNED)
 		hyp_poison_page(phys);
-		psci_mem_protect_dec(1);
-	}
 
 	data->pa = phys;
 
@@ -350,6 +348,7 @@ int __pkvm_guest_relinquish_to_host(struct pkvm_hyp_vcpu *vcpu,
 
 	/* Zap the guest stage2 pte and return ownership to the host */
 	if (!ret && data.pa) {
+		vm->guest_owned_pages--;
 		WARN_ON(host_stage2_set_owner_locked(data.pa, PAGE_SIZE, PKVM_ID_HOST));
 		WARN_ON(kvm_pgtable_stage2_unmap(&vm->pgt, ipa, PAGE_SIZE));
 	}
@@ -891,8 +890,12 @@ static int host_complete_share(u64 addr, const struct pkvm_mem_transition *tx,
 	if (err)
 		return err;
 
-	if (tx->initiator.id == PKVM_ID_GUEST)
-		psci_mem_protect_dec(tx->nr_pages);
+	if (tx->initiator.id == PKVM_ID_GUEST) {
+		struct pkvm_hyp_vcpu *vcpu = tx->initiator.guest.hyp_vcpu;
+		struct pkvm_hyp_vm *vm = pkvm_hyp_vcpu_to_hyp_vm(vcpu);
+
+		vm->guest_owned_pages -= tx->nr_pages;
+	}
 
 	return 0;
 }
@@ -902,8 +905,12 @@ static int host_complete_unshare(u64 addr, const struct pkvm_mem_transition *tx)
 	u8 owner_id = tx->initiator.id;
 	u64 size = tx->nr_pages * PAGE_SIZE;
 
-	if (tx->initiator.id == PKVM_ID_GUEST)
-		psci_mem_protect_inc(tx->nr_pages);
+	if (tx->initiator.id == PKVM_ID_GUEST) {
+		struct pkvm_hyp_vcpu *vcpu = tx->initiator.guest.hyp_vcpu;
+		struct pkvm_hyp_vm *vm = pkvm_hyp_vcpu_to_hyp_vm(vcpu);
+
+		vm->guest_owned_pages += tx->nr_pages;
+	}
 
 	return host_stage2_set_owner_locked(addr, size, owner_id);
 }
@@ -1109,7 +1116,7 @@ static int guest_complete_donation(u64 addr, const struct pkvm_mem_transition *t
 	int err;
 
 	if (tx->initiator.id == PKVM_ID_HOST)
-		psci_mem_protect_inc(tx->nr_pages);
+		vm->guest_owned_pages += tx->nr_pages;
 
 	if (pkvm_ipa_range_has_pvmfw(vm, addr, addr + size)) {
 		if (WARN_ON(!pkvm_hyp_vcpu_is_protected(vcpu))) {
@@ -1132,7 +1139,7 @@ static int guest_complete_donation(u64 addr, const struct pkvm_mem_transition *t
 
 err_undo_psci:
 	if (tx->initiator.id == PKVM_ID_HOST)
-		psci_mem_protect_dec(tx->nr_pages);
+		vm->guest_owned_pages -= tx->nr_pages;
 	return err;
 }
 
@@ -2230,7 +2237,7 @@ int __pkvm_host_reclaim_page(struct pkvm_hyp_vm *vm, u64 pfn, u64 ipa)
 	case PKVM_PAGE_OWNED:
 		WARN_ON(__host_check_page_state_range(phys, PAGE_SIZE, PKVM_NOPAGE));
 		hyp_poison_page(phys);
-		psci_mem_protect_dec(1);
+		vm->reclaimed_pages++;
 		break;
 	case PKVM_PAGE_SHARED_BORROWED:
 		WARN_ON(__host_check_page_state_range(phys, PAGE_SIZE, PKVM_PAGE_SHARED_OWNED));
diff --git a/arch/arm64/kvm/hyp/nvhe/pkvm.c b/arch/arm64/kvm/hyp/nvhe/pkvm.c
index e512383..e496146 100644
--- a/arch/arm64/kvm/hyp/nvhe/pkvm.c
+++ b/arch/arm64/kvm/hyp/nvhe/pkvm.c
@@ -890,6 +890,9 @@ int __pkvm_init_vm(struct kvm *host_kvm, unsigned long pgd_hva)
 	ret = kvm_guest_prepare_stage2(hyp_vm, pgd);
 	if (ret)
 		goto err_remove_vm_table_entry;
+
+	if (kvm_vm_is_protected(&hyp_vm->kvm))
+		psci_mem_protect_inc(1);
 	hyp_write_unlock(&vm_table_lock);
 
 	return hyp_vm->kvm.arch.pkvm.handle;
@@ -981,6 +984,12 @@ int __pkvm_start_teardown_vm(pkvm_handle_t handle)
 
 	hyp_vm->is_dying = true;
 
+	/*
+	 * Update the firmware view of the PSCI MEM_PROTECT count so that
+	 * premature teardown of the VM (i.e. before all pages have been
+	 * reclaimed) leaves the count elevated.
+	 */
+	psci_mem_protect_inc(hyp_vm->guest_owned_pages);
 unlock:
 	hyp_write_unlock(&vm_table_lock);
 
@@ -1019,6 +1028,7 @@ int __pkvm_finalize_teardown_vm(pkvm_handle_t handle)
 	 * has a refcount of 0 so we're free to tear it down without
 	 * worrying about anybody else.
 	 */
+	WARN_ON(hyp_vm->guest_owned_pages != hyp_vm->reclaimed_pages);
 	mc = &host_kvm->arch.pkvm.teardown_mc;
 	destroy_hyp_vm_pgt(hyp_vm);
 	drain_hyp_pool(hyp_vm, mc);
@@ -1045,6 +1055,10 @@ int __pkvm_finalize_teardown_vm(pkvm_handle_t handle)
 
 	hyp_free(hyp_vm->kvm.arch.mmu.last_vcpu_ran);
 	vm_size = pkvm_get_hyp_vm_size(hyp_vm->kvm.created_vcpus);
+
+	/* Drop PSCI MEM_PROTECT page references + the hyp_vm reference */
+	if (kvm_vm_is_protected(&hyp_vm->kvm))
+		psci_mem_protect_dec(1 + hyp_vm->reclaimed_pages);
 	hyp_free(hyp_vm);
 	hyp_unpin_shared_mem(host_kvm, host_kvm + 1);
 	return 0;