pkvm_shmem: Do not overflow page refcount for PAGE_READONLY pages
diff --git a/arch/arm64/kvm/hyp/nvhe/mem_protect.c b/arch/arm64/kvm/hyp/nvhe/mem_protect.c
index 2d33962..dd0a7a9 100644
--- a/arch/arm64/kvm/hyp/nvhe/mem_protect.c
+++ b/arch/arm64/kvm/hyp/nvhe/mem_protect.c
@@ -1997,7 +1997,7 @@ int __pkvm_host_share_guest(u64 pfn, u64 gfn, struct pkvm_hyp_vcpu *vcpu)
int __pkvm_host_share_ro_guest(u64 pfn, u64 gfn, struct pkvm_hyp_vcpu *vcpu)
{
- int ret;
+ int ret, refcount;
u64 host_addr = hyp_pfn_to_phys(pfn);
u64 guest_addr = hyp_pfn_to_phys(gfn);
u64 size = PAGE_SIZE;
@@ -2017,6 +2017,15 @@ int __pkvm_host_share_ro_guest(u64 pfn, u64 gfn, struct pkvm_hyp_vcpu *vcpu)
if (ret)
goto unlock;
+ /* Check that we would not overflow the page refcount. */
+ page = hyp_phys_to_page(host_addr);
+ refcount = hyp_refcount_get(page->refcount);
+ WARN_ON(refcount < 1);
+ if (refcount == (unsigned short)-1) {
+ ret = -ERANGE;
+ goto unlock;
+ }
+
prot = pkvm_mkstate(KVM_PGTABLE_PROT_R, PKVM_PAGE_READONLY);
ret = kvm_pgtable_stage2_map(
&vm->pgt, guest_addr, size, host_addr, prot,
@@ -2024,8 +2033,6 @@ int __pkvm_host_share_ro_guest(u64 pfn, u64 gfn, struct pkvm_hyp_vcpu *vcpu)
if (ret)
goto unlock;
- page = hyp_phys_to_page(host_addr);
- WARN_ON(hyp_refcount_get(page->refcount) < 1);
hyp_page_ref_inc(page);
unlock: