KVM: arm64: Use PSCI MEM_PROTECT to zap guest pages on reset

If a malicious/compromised host issues a PSCI SYSTEM_RESET call in the
presence of guest-owned pages then the contents of those pages may be
susceptible to cold-reboot attacks.

Use the PSCI MEM_PROTECT call to ensure that volatile memory is wiped by
the firmware if a SYSTEM_RESET occurs while unpoisoned guest pages exist
in the system. Since this call does not offer protection for a "warm"
reset initiated by SYSTEM_RESET2, detect this case in the PSCI relay and
repaint the call to a standard SYSTEM_RESET instead.

Signed-off-by: Will Deacon <will@kernel.org>
diff --git a/arch/arm64/kvm/hyp/include/nvhe/mem_protect.h b/arch/arm64/kvm/hyp/include/nvhe/mem_protect.h
index 1be1d92..a95f113 100644
--- a/arch/arm64/kvm/hyp/include/nvhe/mem_protect.h
+++ b/arch/arm64/kvm/hyp/include/nvhe/mem_protect.h
@@ -94,6 +94,9 @@ void reclaim_guest_pages(struct pkvm_hyp_vm *vm, struct kvm_hyp_memcache *mc);
 int refill_memcache(struct kvm_hyp_memcache *mc, unsigned long min_pages,
 		    struct kvm_hyp_memcache *host_mc);
 
+void psci_mem_protect_inc(u64 n);
+void psci_mem_protect_dec(u64 n);
+
 static __always_inline void __load_host_stage2(void)
 {
 	if (static_branch_likely(&kvm_protected_mode_initialized))
diff --git a/arch/arm64/kvm/hyp/nvhe/mem_protect.c b/arch/arm64/kvm/hyp/nvhe/mem_protect.c
index b02f226..287560c 100644
--- a/arch/arm64/kvm/hyp/nvhe/mem_protect.c
+++ b/arch/arm64/kvm/hyp/nvhe/mem_protect.c
@@ -901,8 +901,16 @@ static int host_complete_share(u64 addr, const struct pkvm_mem_transition *tx,
 			       enum kvm_pgtable_prot perms)
 {
 	u64 size = tx->nr_pages * PAGE_SIZE;
+	int err;
 
-	return __host_set_page_state_range(addr, size, PKVM_PAGE_SHARED_BORROWED);
+	err = __host_set_page_state_range(addr, size, PKVM_PAGE_SHARED_BORROWED);
+	if (err)
+		return err;
+
+	if (tx->initiator.id == PKVM_ID_GUEST)
+		psci_mem_protect_dec(tx->nr_pages);
+
+	return 0;
 }
 
 static int host_complete_unshare(u64 addr, const struct pkvm_mem_transition *tx)
@@ -910,6 +918,9 @@ static int host_complete_unshare(u64 addr, const struct pkvm_mem_transition *tx)
 	u8 owner_id = tx->initiator.id;
 	u64 size = tx->nr_pages * PAGE_SIZE;
 
+	if (tx->initiator.id == PKVM_ID_GUEST)
+		psci_mem_protect_inc(tx->nr_pages);
+
 	return host_stage2_set_owner_locked(addr, size, owner_id);
 }
 
@@ -1096,18 +1107,32 @@ static int guest_complete_donation(u64 addr, const struct pkvm_mem_transition *t
 	u64 size = tx->nr_pages * PAGE_SIZE;
 	int err;
 
+	if (tx->initiator.id == PKVM_ID_HOST)
+		psci_mem_protect_inc(tx->nr_pages);
+
 	if (pkvm_ipa_range_has_pvmfw(vm, addr, addr + size)) {
-		if (WARN_ON(!pkvm_hyp_vcpu_is_protected(vcpu)))
-			return -EPERM;
+		if (WARN_ON(!pkvm_hyp_vcpu_is_protected(vcpu))) {
+			err = -EPERM;
+			goto err_undo_psci;
+		}
 
 		WARN_ON(tx->initiator.id != PKVM_ID_HOST);
 		err = pkvm_load_pvmfw_pages(vm, addr, phys, size);
 		if (err)
-			return err;
+			goto err_undo_psci;
 	}
 
+	/*
+	 * If this fails, we effectively leak the pages since they're now
+	 * owned by the guest but not mapped into its stage-2 page-table.
+	 */
 	return kvm_pgtable_stage2_map(&vm->pgt, addr, size, phys, prot,
 				      &vcpu->vcpu.arch.pkvm_memcache, 0);
+
+err_undo_psci:
+	if (tx->initiator.id == PKVM_ID_HOST)
+		psci_mem_protect_dec(tx->nr_pages);
+	return err;
 }
 
 static int __guest_get_completer_addr(u64 *completer_addr, phys_addr_t phys,
@@ -1923,6 +1948,7 @@ int __pkvm_host_reclaim_page(u64 pfn)
 	if (page->flags & HOST_PAGE_NEED_POISONING) {
 		hyp_poison_page(addr);
 		page->flags &= ~HOST_PAGE_NEED_POISONING;
+		psci_mem_protect_dec(1);
 	}
 
 	ret = host_stage2_set_owner_locked(addr, PAGE_SIZE, PKVM_ID_HOST);
diff --git a/arch/arm64/kvm/hyp/nvhe/psci-relay.c b/arch/arm64/kvm/hyp/nvhe/psci-relay.c
index 93de8af..b62ca51 100644
--- a/arch/arm64/kvm/hyp/nvhe/psci-relay.c
+++ b/arch/arm64/kvm/hyp/nvhe/psci-relay.c
@@ -11,6 +11,7 @@
 #include <linux/kvm_host.h>
 #include <uapi/linux/psci.h>
 
+#include <nvhe/mem_protect.h>
 #include <nvhe/memory.h>
 #include <nvhe/pkvm.h>
 #include <nvhe/trap_handler.h>
@@ -222,6 +223,44 @@ asmlinkage void __noreturn __kvm_host_psci_cpu_entry(bool is_cpu_on)
 	__host_enter(host_ctxt);
 }
 
+static DEFINE_HYP_SPINLOCK(mem_protect_lock);
+
+static u64 psci_mem_protect(s64 offset)
+{
+	static u64 cnt;
+	u64 new = cnt + offset;
+
+	hyp_assert_lock_held(&mem_protect_lock);
+
+	if (!offset || kvm_host_psci_config.version < PSCI_VERSION(1, 1))
+		return cnt;
+
+	if (!cnt || !new)
+		psci_call(PSCI_1_1_FN_MEM_PROTECT, offset < 0 ? 0 : 1, 0, 0);
+
+	cnt = new;
+	return cnt;
+}
+
+static bool psci_mem_protect_active(void)
+{
+	return psci_mem_protect(0);
+}
+
+void psci_mem_protect_inc(u64 n)
+{
+	hyp_spin_lock(&mem_protect_lock);
+	psci_mem_protect(n);
+	hyp_spin_unlock(&mem_protect_lock);
+}
+
+void psci_mem_protect_dec(u64 n)
+{
+	hyp_spin_lock(&mem_protect_lock);
+	psci_mem_protect(-n);
+	hyp_spin_unlock(&mem_protect_lock);
+}
+
 static unsigned long psci_0_1_handler(u64 func_id, struct kvm_cpu_context *host_ctxt)
 {
 	if (is_psci_0_1(cpu_off, func_id) || is_psci_0_1(migrate, func_id))
@@ -251,6 +290,8 @@ static unsigned long psci_0_2_handler(u64 func_id, struct kvm_cpu_context *host_
 	case PSCI_0_2_FN_SYSTEM_OFF:
 	case PSCI_0_2_FN_SYSTEM_RESET:
 		pkvm_poison_pvmfw_pages();
+		/* Avoid racing with a MEM_PROTECT call. */
+		hyp_spin_lock(&mem_protect_lock);
 		return psci_forward(host_ctxt);
 	case PSCI_0_2_FN64_CPU_SUSPEND:
 		return psci_cpu_suspend(func_id, host_ctxt);
@@ -266,6 +307,9 @@ static unsigned long psci_1_0_handler(u64 func_id, struct kvm_cpu_context *host_
 	switch (func_id) {
 	case PSCI_1_1_FN64_SYSTEM_RESET2:
 		pkvm_poison_pvmfw_pages();
+		hyp_spin_lock(&mem_protect_lock);
+		if (psci_mem_protect_active())
+			cpu_reg(host_ctxt, 0) = PSCI_0_2_FN_SYSTEM_RESET;
 		fallthrough;
 	case PSCI_1_0_FN_PSCI_FEATURES:
 	case PSCI_1_0_FN_SET_SUSPEND_MODE: