pkvm_shmem: Do not overflow page refcount for PAGE_READONLY pages
diff --git a/arch/arm64/kvm/hyp/nvhe/mem_protect.c b/arch/arm64/kvm/hyp/nvhe/mem_protect.c
index 2d33962..dd0a7a9 100644
--- a/arch/arm64/kvm/hyp/nvhe/mem_protect.c
+++ b/arch/arm64/kvm/hyp/nvhe/mem_protect.c
@@ -1997,7 +1997,7 @@ int __pkvm_host_share_guest(u64 pfn, u64 gfn, struct pkvm_hyp_vcpu *vcpu)
 
 int __pkvm_host_share_ro_guest(u64 pfn, u64 gfn, struct pkvm_hyp_vcpu *vcpu)
 {
-	int ret;
+	int ret, refcount;
 	u64 host_addr = hyp_pfn_to_phys(pfn);
 	u64 guest_addr = hyp_pfn_to_phys(gfn);
 	u64 size = PAGE_SIZE;
@@ -2017,6 +2017,15 @@ int __pkvm_host_share_ro_guest(u64 pfn, u64 gfn, struct pkvm_hyp_vcpu *vcpu)
 	if (ret)
 		goto unlock;
 
+	/* Check that we would not overflow the page refcount. */
+	page = hyp_phys_to_page(host_addr);
+	refcount = hyp_refcount_get(page->refcount);
+	WARN_ON(refcount < 1);
+	if (refcount == (unsigned short)-1) {
+		ret = -ERANGE;
+		goto unlock;
+	}
+
 	prot = pkvm_mkstate(KVM_PGTABLE_PROT_R, PKVM_PAGE_READONLY);
 	ret = kvm_pgtable_stage2_map(
 		&vm->pgt, guest_addr, size, host_addr, prot,
@@ -2024,8 +2033,6 @@ int __pkvm_host_share_ro_guest(u64 pfn, u64 gfn, struct pkvm_hyp_vcpu *vcpu)
 	if (ret)
 		goto unlock;
 
-	page = hyp_phys_to_page(host_addr);
-	WARN_ON(hyp_refcount_get(page->refcount) < 1);
 	hyp_page_ref_inc(page);
 
 unlock: