CoW cleanups
diff --git a/arch/arm64/include/asm/kvm_host.h b/arch/arm64/include/asm/kvm_host.h
index 43b5414..3088382 100644
--- a/arch/arm64/include/asm/kvm_host.h
+++ b/arch/arm64/include/asm/kvm_host.h
@@ -1252,4 +1252,6 @@ static inline void kvm_hyp_reserve(void) { }
 void kvm_arm_vcpu_power_off(struct kvm_vcpu *vcpu);
 bool kvm_arm_vcpu_stopped(struct kvm_vcpu *vcpu);
 
+bool vma_is_pkvm_shmem(struct vm_area_struct *vma);
+
 #endif /* __ARM64_KVM_HOST_H__ */
diff --git a/arch/arm64/kvm/mmu.c b/arch/arm64/kvm/mmu.c
index 1062a31..34864da 100644
--- a/arch/arm64/kvm/mmu.c
+++ b/arch/arm64/kvm/mmu.c
@@ -1482,30 +1482,6 @@ static int update_ppage(struct kvm *kvm, struct kvm_pinned_page *ppage)
 	return 0;
 }
 
-#define PagePkvmReadonly(page) (1)
-
-static struct page *kvm_cow_page(struct page *page, unsigned long hva,
-				 unsigned long flags)
-{
-	struct mm_struct *mm = current->mm;
-	struct vm_area_struct *vma;
-	struct page *newpage = NULL;
-	int ret = -EPERM;
-
-	mmap_read_lock(mm);
-
-	vma = find_vma(mm, hva);
-	if (!vma)
-		goto unlock;
-
-	zap_page_range_single(vma, hva, PAGE_SIZE, NULL);
-	ret = pin_user_pages(hva, 1, flags, &newpage);
-
-unlock:
-	mmap_read_unlock(mm);
-	return (ret == 1) ? newpage : NULL;
-}
-
 static int pkvm_mem_abort(struct kvm_vcpu *vcpu, phys_addr_t fault_ipa,
 			  struct kvm_memory_slot *memslot, unsigned long hva)
 {
@@ -1517,7 +1493,7 @@ static int pkvm_mem_abort(struct kvm_vcpu *vcpu, phys_addr_t fault_ipa,
 	int ret, nr_pages;
 	struct page *page, *oldpage = NULL;
 	u64 pfn;
-	bool is_cow = false, is_cow_write = false;
+	bool is_ro_shmem = false, is_cow_break = false;
 
 	nr_pages = hyp_memcache->nr_pages;
 	ret = topup_hyp_memcache(hyp_memcache, kvm_mmu_cache_min_pages(kvm),
@@ -1550,16 +1526,52 @@ static int pkvm_mem_abort(struct kvm_vcpu *vcpu, phys_addr_t fault_ipa,
 	} else if (ret != 1) {
 		ret = -EFAULT;
 		goto dec_account;
-	} else if (PageSwapBacked(page)) {
-		/* Ok */
-	} else if (PagePkvmReadonly(page)) {
-		/* Ok */
-		if (kvm_is_write_fault(vcpu) &&
-		    !!rb_find((void *)fault_ipa, &kvm->arch.pkvm.pinned_pages,
-			      rb_ppage_cmp)) /* XXX */ {
-			struct page *newpage;
-			if (0)newpage = kvm_cow_page(page, hva, flags);
-			newpage = alloc_page(GFP_KERNEL); /* XXX account to user */
+	} else if (!PageSwapBacked(page)) {
+
+		/* Special handling if the page is in a pkvm shmem vma. */
+		struct vm_area_struct *vma;
+		mmap_read_lock(mm);
+		vma = find_vma(mm, hva);
+		is_ro_shmem = vma && vma_is_pkvm_shmem(vma);
+		mmap_read_unlock(mm);
+
+		if (!is_ro_shmem) {
+			/*
+			 * We really can't deal with page-cache pages returned
+			 * by GUP because (a) we may trigger writeback of a
+			 * page for which we no longer have access and (b)
+			 * page_mkclean() won't find the stage-2 mapping in the
+			 * rmap so we can get out-of-whack with the filesystem
+			 * when marking the page dirty during unpinning (see
+			 * cc5095747edf ("ext4: don't BUG if someone dirty
+			 * pages without asking ext4 first")).
+			 *
+			 * Ideally we'd just restrict ourselves to anonymous
+			 * pages, but we also want to allow memfd (i.e. shmem)
+			 * pages, so check for pages backed by swap in the
+			 * knowledge that the GUP pin will prevent
+			 * try_to_unmap() from succeeding.
+			 */
+			ret = -EIO;
+			goto unpin;
+		}
+
+		/* Perform R/O CoW break if this is a write fault on an
+		 * *existing* R/O mapping. Otherwise we must fault in the
+		 * R/O mapping first.
+		 */
+		if (kvm_is_write_fault(vcpu)) {
+			read_lock(&kvm->mmu_lock);
+			is_cow_break = !!rb_find(
+				(void *)fault_ipa,
+				&kvm->arch.pkvm.pinned_pages,
+				rb_ppage_cmp);
+			read_unlock(&kvm->mmu_lock);
+		}
+
+		if (is_cow_break) {
+			/* XXX account to user */
+			struct page *newpage = alloc_page(GFP_KERNEL);
 			if (newpage == NULL) {
 				ret = -EFAULT;
 				goto unpin;
@@ -1569,35 +1581,17 @@ static int pkvm_mem_abort(struct kvm_vcpu *vcpu, phys_addr_t fault_ipa,
 			printk(" - %s: Got CoW page %lx count=%d\n",
 			       __FUNCTION__,
 			       page_to_pfn(page), page_ref_count(page));
-			is_cow_write = true;
 		}
-		is_cow = true;
-	} else {
-		/*
-		 * We really can't deal with page-cache pages returned by GUP
-		 * because (a) we may trigger writeback of a page for which we
-		 * no longer have access and (b) page_mkclean() won't find the
-		 * stage-2 mapping in the rmap so we can get out-of-whack with
-		 * the filesystem when marking the page dirty during unpinning
-		 * (see cc5095747edf ("ext4: don't BUG if someone dirty pages
-		 * without asking ext4 first")).
-		 *
-		 * Ideally we'd just restrict ourselves to anonymous pages, but
-		 * we also want to allow memfd (i.e. shmem) pages, so check for
-		 * pages backed by swap in the knowledge that the GUP pin will
-		 * prevent try_to_unmap() from succeeding.
-		 */
-		ret = -EIO;
-		goto unpin;
+
 	}
 
 	write_lock(&kvm->mmu_lock);
 	pfn = page_to_pfn(page);
-	if (!is_cow) {
+	if (!is_ro_shmem) {
 		ret = pkvm_host_map_guest(pfn, fault_ipa >> PAGE_SHIFT);
 	} else {
 		ret = pkvm_host_map_guest_ro(pfn, fault_ipa >> PAGE_SHIFT,
-					     is_cow_write);
+					     is_cow_break);
 	}
 	if (ret) {
 		if (ret == -EAGAIN)
@@ -2274,7 +2268,7 @@ int kvm_arch_prepare_memory_region(struct kvm *kvm,
 		}
 
 		if (new &&
-		    new->flags & KVM_MEM_LOG_DIRTY_PAGES) {
+		    new->flags & (KVM_MEM_LOG_DIRTY_PAGES | KVM_MEM_READONLY)) {
 			return -EPERM;
 		}
 	}
diff --git a/arch/arm64/kvm/pkvm_shmem.c b/arch/arm64/kvm/pkvm_shmem.c
index cdf351d..8d1505d 100644
--- a/arch/arm64/kvm/pkvm_shmem.c
+++ b/arch/arm64/kvm/pkvm_shmem.c
@@ -102,6 +102,8 @@ static to_page_fn_t get_to_page_fn(struct pkvm_shmem_area *area)
 	return NULL;
 }
 
+static const struct vm_operations_struct pkvm_shmem_ops;
+
 static int pkvm_shmem_area_mmap(struct file *filep,
 				 struct vm_area_struct *vma)
 {
@@ -117,6 +119,7 @@ static int pkvm_shmem_area_mmap(struct file *filep,
 	BUG_ON(!area->kaddr);
 
 	vm_flags_set(vma, VM_DONTEXPAND);
+	vma->vm_ops = &pkvm_shmem_ops;
 
 	to_page = get_to_page_fn(area);
 
@@ -183,6 +186,11 @@ static const struct file_operations pkvm_shmem_area_fops = {
 	.mmap = pkvm_shmem_area_mmap,
 };
 
+bool vma_is_pkvm_shmem(struct vm_area_struct *vma)
+{
+	return vma->vm_ops == &pkvm_shmem_ops;
+}
+
 static int pkvm_shmem_open(struct inode *inode, struct file *filep)
 {
 	printk("%s\n", __FUNCTION__);