pkvm_shmem: Implement page refcounting at EL2.
diff --git a/arch/arm64/kvm/hyp/nvhe/mem_protect.c b/arch/arm64/kvm/hyp/nvhe/mem_protect.c
index e4c259e..2d33962 100644
--- a/arch/arm64/kvm/hyp/nvhe/mem_protect.c
+++ b/arch/arm64/kvm/hyp/nvhe/mem_protect.c
@@ -1500,20 +1500,28 @@ static int do_donate(struct pkvm_mem_donation *donation)
 
 int __pkvm_host_make_page_readonly(u64 pfn)
 {
-	int ret;
 	u64 host_addr = hyp_pfn_to_phys(pfn);
+	enum kvm_pgtable_prot prot;
+	struct hyp_page *page;
+	int ret;
 
 	host_lock_component();
 
 	ret = __host_check_page_state_range(host_addr, PAGE_SIZE,
 					    PKVM_PAGE_OWNED);
-	if (!ret) {
-		enum kvm_pgtable_prot prot = pkvm_mkstate(PKVM_HOST_MEM_PROT_R,
-							  PKVM_PAGE_READONLY);
-		ret = host_stage2_idmap_locked(host_addr, PAGE_SIZE, prot);
-		/* XXX REFCOUNT */
-	}
+	if (ret)
+		goto unlock;
 
+	prot = pkvm_mkstate(PKVM_HOST_MEM_PROT_R,
+			    PKVM_PAGE_READONLY);
+	ret = host_stage2_idmap_locked(host_addr, PAGE_SIZE, prot);
+	if (ret)
+		goto unlock;
+
+	page = hyp_phys_to_page(host_addr);
+	hyp_set_page_refcounted(page);
+
+unlock:
 	host_unlock_component();
 
 	return ret;
@@ -1521,18 +1529,28 @@ int __pkvm_host_make_page_readonly(u64 pfn)
 
 int __pkvm_host_make_page_writeable(u64 pfn)
 {
-	int ret;
+	int ret, refcount;
 	u64 host_addr = hyp_pfn_to_phys(pfn);
+	struct hyp_page *page;
 
 	host_lock_component();
 
-	/* XXX REFCOUNT */
+	page = hyp_phys_to_page(host_addr);
+	refcount = hyp_refcount_get(page->refcount);
+	if (refcount != 1) {
+		WARN_ON(refcount < 1);
+		ret = -EPERM;
+		goto unlock;
+	}
+	hyp_page_ref_dec(page);
+
 	ret = __host_check_page_state_range(host_addr, PAGE_SIZE,
 					    PKVM_PAGE_READONLY);
 	if (!ret)
 		ret = __host_set_page_state_range(host_addr, PAGE_SIZE,
 						  PKVM_PAGE_OWNED);
 
+unlock:
 	host_unlock_component();
 
 	return ret;
@@ -1985,6 +2003,7 @@ int __pkvm_host_share_ro_guest(u64 pfn, u64 gfn, struct pkvm_hyp_vcpu *vcpu)
 	u64 size = PAGE_SIZE;
 	struct pkvm_hyp_vm *vm = pkvm_hyp_vcpu_to_hyp_vm(vcpu);
 	enum kvm_pgtable_prot prot;
+	struct hyp_page *page;
 
 	host_lock_component();
 	guest_lock_component(vm);
@@ -2002,6 +2021,12 @@ int __pkvm_host_share_ro_guest(u64 pfn, u64 gfn, struct pkvm_hyp_vcpu *vcpu)
 	ret = kvm_pgtable_stage2_map(
 		&vm->pgt, guest_addr, size, host_addr, prot,
 		&vcpu->vcpu.arch.pkvm_memcache, 0);
+	if (ret)
+		goto unlock;
+
+	page = hyp_phys_to_page(host_addr);
+	WARN_ON(hyp_refcount_get(page->refcount) < 1);
+	hyp_page_ref_inc(page);
 
 unlock:
 	guest_unlock_component(vm);
@@ -2044,6 +2069,7 @@ int __pkvm_host_share_cow_guest(u64 pfn, u64 gfn, struct pkvm_hyp_vcpu *vcpu)
 	u64 size = PAGE_SIZE;
 	struct pkvm_hyp_vm *vm = pkvm_hyp_vcpu_to_hyp_vm(vcpu);
 	enum kvm_pgtable_prot prot;
+	struct hyp_page *page;
 
 	struct guest_cow_data data = {
 		.expected_state = PKVM_PAGE_READONLY,
@@ -2083,6 +2109,12 @@ int __pkvm_host_share_cow_guest(u64 pfn, u64 gfn, struct pkvm_hyp_vcpu *vcpu)
 	ret = kvm_pgtable_stage2_map(
 		&vm->pgt, guest_addr, size, host_addr, prot,
 		&vcpu->vcpu.arch.pkvm_memcache, 0);
+	if (ret)
+		goto unlock;
+
+	page = hyp_phys_to_page(data.pa);
+	hyp_page_ref_dec(page);
+	WARN_ON(hyp_refcount_get(page->refcount) < 1);
 
 unlock:
 	guest_unlock_component(vm);
@@ -2225,6 +2257,7 @@ int __pkvm_host_reclaim_page(struct pkvm_hyp_vm *vm, u64 pfn, u64 ipa)
 	u64 hyp_va;
 	struct kvm_s2_mmu *mmu;
 	struct kvm_vmid *kvm_vmid;
+	struct hyp_page *page;
 	u64 vmid;
 
 	host_lock_component();
@@ -2307,7 +2340,9 @@ int __pkvm_host_reclaim_page(struct pkvm_hyp_vm *vm, u64 pfn, u64 ipa)
 		}
 		break;
 	case PKVM_PAGE_READONLY:
-		/* Dec refcount */
+		page = hyp_phys_to_page(phys);
+		hyp_page_ref_dec(page);
+		WARN_ON(hyp_refcount_get(page->refcount) < 1);
 		goto unlock; /* no adjustment to host stage 2 */
 	default:
 		BUG_ON(1);