KVM: arm64: THP support for pKVM guests

When a pKVM stage-2 guest abort happens, if that fault is within the
remits of a transparent huge page at stage-1, that entire page is pinned
and we can map the entire range of that huge page at stage-2. The
hypervisor in return installs a block mapping. This reduces TLB pressure
and the number of faults.

Signed-off-by: Vincent Donnefort <vdonnefort@google.com>
diff --git a/arch/arm64/include/asm/kvm_mmu.h b/arch/arm64/include/asm/kvm_mmu.h
index 0955aa6..30f1e6d 100644
--- a/arch/arm64/include/asm/kvm_mmu.h
+++ b/arch/arm64/include/asm/kvm_mmu.h
@@ -182,7 +182,7 @@ int kvm_phys_addr_ioremap(struct kvm *kvm, phys_addr_t guest_ipa,
 			  phys_addr_t pa, unsigned long size, bool writable);
 
 int kvm_handle_guest_abort(struct kvm_vcpu *vcpu);
-int pkvm_mem_abort_range(struct kvm_vcpu *vcpu, phys_addr_t fault_ipa, size_t size);
+int pkvm_mem_abort_range(struct kvm_vcpu *vcpu, phys_addr_t fault_ipa, ssize_t size);
 
 phys_addr_t kvm_mmu_get_httbr(void);
 phys_addr_t kvm_get_idmap_vector(void);
diff --git a/arch/arm64/kvm/mmu.c b/arch/arm64/kvm/mmu.c
index 98ba27a..d27206d 100644
--- a/arch/arm64/kvm/mmu.c
+++ b/arch/arm64/kvm/mmu.c
@@ -1431,7 +1431,10 @@ static int pkvm_host_map_guest(u64 pfn, u64 gfn, size_t size)
 	/*
 	 * Getting -EPERM at this point implies that the pfn has already been
 	 * mapped. This should only ever happen when two vCPUs faulted on the
-	 * same page, and the current one lost the race to do the mapping.
+	 * same page, and the current one lost the race to do the mapping...
+	 *
+	 * ...or if we've tried to map a region containing an already mapped
+	 * entry.
 	 */
 	return (ret == -EPERM) ? -EAGAIN : ret;
 }
@@ -1458,12 +1461,14 @@ static int insert_ppage(struct kvm *kvm, struct kvm_pinned_page *ppage)
 				  ppage, GFP_KERNEL);
 }
 
-static int pkvm_mem_abort(struct kvm_vcpu *vcpu, phys_addr_t fault_ipa,
-			  unsigned long hva)
+static int pkvm_mem_abort(struct kvm_vcpu *vcpu, phys_addr_t *fault_ipa,
+			  struct kvm_memory_slot *memslot, unsigned long hva,
+			  size_t *size)
 {
 	struct kvm_hyp_memcache *hyp_memcache = &vcpu->arch.pkvm_memcache;
-	struct mm_struct *mm = current->mm;
 	unsigned int flags = FOLL_HWPOISON | FOLL_LONGTERM | FOLL_WRITE;
+	unsigned long index, pmd_offset, page_size;
+	struct mm_struct *mm = current->mm;
 	struct kvm_pinned_page *ppage;
 	struct kvm *kvm = vcpu->kvm;
 	struct page *page;
@@ -1479,10 +1484,6 @@ static int pkvm_mem_abort(struct kvm_vcpu *vcpu, phys_addr_t fault_ipa,
 	if (!ppage)
 		return -ENOMEM;
 
-	ret = account_locked_vm(mm, 1, true);
-	if (ret)
-		goto free_ppage;
-
 	mmap_read_lock(mm);
 	ret = pin_user_pages(hva, 1, flags, &page);
 	mmap_read_unlock(mm);
@@ -1490,10 +1491,10 @@ static int pkvm_mem_abort(struct kvm_vcpu *vcpu, phys_addr_t fault_ipa,
 	if (ret == -EHWPOISON) {
 		kvm_send_hwpoison_signal(hva, PAGE_SHIFT);
 		ret = 0;
-		goto dec_account;
+		goto free_ppage;
 	} else if (ret != 1) {
 		ret = -EFAULT;
-		goto dec_account;
+		goto free_ppage;
 	} else if (!PageSwapBacked(page)) {
 		/*
 		 * We really can't deal with page-cache pages returned by GUP
@@ -1510,39 +1511,69 @@ static int pkvm_mem_abort(struct kvm_vcpu *vcpu, phys_addr_t fault_ipa,
 		 * prevent try_to_unmap() from succeeding.
 		 */
 		ret = -EIO;
-		goto dec_account;
+		goto free_ppage;
 	}
 
-	write_lock(&kvm->mmu_lock);
 	pfn = page_to_pfn(page);
-	ret = pkvm_host_map_guest(pfn, fault_ipa >> PAGE_SHIFT, PAGE_SIZE);
+	pmd_offset = *fault_ipa & (PMD_SIZE - 1);
+	page_size = transparent_hugepage_adjust(kvm, memslot,
+						hva, &pfn,
+						fault_ipa);
+	page = pfn_to_page(pfn);
+
+	if (size)
+		*size = page_size;
+
+	ret = account_locked_vm(mm, page_size >> PAGE_SHIFT, true);
+	if (ret)
+		goto unpin;
+
+	write_lock(&kvm->mmu_lock);
+	/*
+	 * If we already have a mapping in the middle of the THP, we have no
+	 * other choice than enforcing PAGE_SIZE for pkvm_host_map_guest() to
+	 * succeed.
+	 */
+	index = *fault_ipa;
+	if (page_size > PAGE_SIZE &&
+	    mt_find(&kvm->arch.pkvm.pinned_pages, &index, index + page_size - 1)) {
+		*fault_ipa += pmd_offset;
+		pfn += pmd_offset >> PAGE_SHIFT;
+		page = pfn_to_page(pfn);
+		page_size = PAGE_SIZE;
+		account_locked_vm(mm, (page_size >> PAGE_SHIFT) - 1, false);
+	}
+
+	ret = pkvm_host_map_guest(pfn, *fault_ipa >> PAGE_SHIFT, page_size);
 	if (ret) {
 		if (ret == -EAGAIN)
 			ret = 0;
-		goto unpin;
+
+		goto dec_account;
 	}
 
 	ppage->page = page;
-	ppage->ipa = fault_ipa;
-	ppage->order = 0;
+	ppage->ipa = *fault_ipa;
+	ppage->order = get_order(page_size);
 	ppage->pins = 1 << ppage->order;
 	WARN_ON(insert_ppage(kvm, ppage));
+
 	write_unlock(&kvm->mmu_lock);
 
 	return 0;
 
-unpin:
-	write_unlock(&kvm->mmu_lock);
-	unpin_user_pages(&page, 1);
 dec_account:
-	account_locked_vm(mm, 1, false);
+	write_unlock(&kvm->mmu_lock);
+	account_locked_vm(mm, page_size >> PAGE_SHIFT, false);
+unpin:
+	unpin_user_pages(&page, 1);
 free_ppage:
 	kfree(ppage);
 
 	return ret;
 }
 
-int pkvm_mem_abort_range(struct kvm_vcpu *vcpu, phys_addr_t fault_ipa, size_t size)
+int pkvm_mem_abort_range(struct kvm_vcpu *vcpu, phys_addr_t fault_ipa, ssize_t size)
 {
 	phys_addr_t ipa_end = fault_ipa + size - 1;
 	struct kvm_pinned_page *ppage;
@@ -1559,13 +1590,15 @@ int pkvm_mem_abort_range(struct kvm_vcpu *vcpu, phys_addr_t fault_ipa, size_t si
 
 	ppage = find_ppage_or_above(vcpu->kvm, fault_ipa);
 
-	while (size) {
+	while (size > 0) {
 		gfn_t gfn = fault_ipa >> PAGE_SHIFT;
 		struct kvm_memory_slot *memslot;
 		unsigned long hva;
+		size_t page_size;
 		bool writable;
 
 		if (ppage && ppage->ipa == fault_ipa) {
+			page_size = PAGE_SIZE << ppage->order;
 			ppage = mt_next(&vcpu->kvm->arch.pkvm.pinned_pages,
 					ppage->ipa, ULONG_MAX);
 		} else {
@@ -1576,13 +1609,13 @@ int pkvm_mem_abort_range(struct kvm_vcpu *vcpu, phys_addr_t fault_ipa, size_t si
 				goto end;
 			}
 
-			err = pkvm_mem_abort(vcpu, fault_ipa, hva);
+			err = pkvm_mem_abort(vcpu, &fault_ipa, memslot, hva, &page_size);
 			if (err)
 				goto end;
 		}
 
-		size -= PAGE_SIZE;
-		fault_ipa += PAGE_SIZE;
+		size -= page_size;
+		fault_ipa += page_size;
 	}
 end:
 	srcu_read_unlock(&vcpu->kvm->srcu, idx);
@@ -1956,7 +1989,7 @@ int kvm_handle_guest_abort(struct kvm_vcpu *vcpu)
 	}
 
 	if (is_protected_kvm_enabled())
-		ret = pkvm_mem_abort(vcpu, fault_ipa, hva);
+		ret = pkvm_mem_abort(vcpu, &fault_ipa, memslot, hva, NULL);
 	else
 		ret = user_mem_abort(vcpu, fault_ipa, memslot, hva, fault_status);