don't use account_locked_vm a la SEV

Change-Id: Id1f5378ba21e49a8c3765de59fbe70780851b18c
diff --git a/arch/arm64/include/asm/kvm_host.h b/arch/arm64/include/asm/kvm_host.h
index 72d4197..4d4d358 100644
--- a/arch/arm64/include/asm/kvm_host.h
+++ b/arch/arm64/include/asm/kvm_host.h
@@ -184,10 +184,13 @@ struct kvm_protected_vm {
 	pkvm_handle_t handle;
 	struct kvm_hyp_memcache teardown_mc;
 	struct rb_root pinned_pages;
+	unsigned long pages_locked;
 	gpa_t pvmfw_load_addr;
 	bool enabled;
 };
 
+int pkvm_account_locked_pages(struct kvm *kvm, unsigned long npages, bool inc);
+
 struct kvm_arch {
 	struct kvm_s2_mmu mmu;
 
diff --git a/arch/arm64/kvm/mmu.c b/arch/arm64/kvm/mmu.c
index 347280f..b9276ae 100644
--- a/arch/arm64/kvm/mmu.c
+++ b/arch/arm64/kvm/mmu.c
@@ -141,7 +141,6 @@ static void invalidate_icache_guest_page(void *va, size_t size)
 
 static int pkvm_unmap_guest(struct kvm *kvm, struct kvm_pinned_page *ppage)
 {
-	struct mm_struct *mm = kvm->mm;
 	int ret;
 
 	ret = kvm_call_hyp_nvhe(__pkvm_host_unmap_guest,
@@ -151,7 +150,7 @@ static int pkvm_unmap_guest(struct kvm *kvm, struct kvm_pinned_page *ppage)
 	if (ret)
 		return ret;
 
-	account_locked_vm(mm, 1, false);
+	pkvm_account_locked_pages(kvm, 1, false);
 	unpin_user_pages_dirty_lock(&ppage->page, 1, true);
 	rb_erase(&ppage->node, &kvm->arch.pkvm.pinned_pages);
 	kfree(ppage);
@@ -1284,6 +1283,25 @@ static int insert_ppage(struct kvm *kvm, struct kvm_pinned_page *ppage)
 	return 0;
 }
 
+int pkvm_account_locked_pages(struct kvm *kvm, unsigned long npages, bool inc)
+{
+	unsigned long locked, lock_limit;
+
+	lockdep_assert_held(&kvm->mmu_lock);
+	locked = kvm->arch.pkvm.pages_locked;
+	if (inc)
+		locked += npages;
+	else
+		locked -= npages;
+	lock_limit = rlimit(RLIMIT_MEMLOCK) >> PAGE_SHIFT;
+	if (locked > lock_limit && !capable(CAP_IPC_LOCK))
+		return -ENOMEM;
+
+	kvm->arch.pkvm.pages_locked = locked;
+
+	return 0;
+}
+
 static int pkvm_mem_abort(struct kvm_vcpu *vcpu, phys_addr_t fault_ipa,
 			  unsigned long hva)
 {
@@ -1303,10 +1321,6 @@ static int pkvm_mem_abort(struct kvm_vcpu *vcpu, phys_addr_t fault_ipa,
 	if (!ppage)
 		return -ENOMEM;
 
-	ret = account_locked_vm(mm, 1, true);
-	if (ret)
-		goto free_ppage;
-
 	mmap_read_lock(mm);
 	ret = pin_user_pages(hva, 1, flags, &page, NULL);
 	mmap_read_unlock(mm);
@@ -1314,10 +1328,10 @@ static int pkvm_mem_abort(struct kvm_vcpu *vcpu, phys_addr_t fault_ipa,
 	if (ret == -EHWPOISON) {
 		kvm_send_hwpoison_signal(hva, PAGE_SHIFT);
 		ret = 0;
-		goto dec_account;
+		goto free_ppage;
 	} else if (ret != 1) {
 		ret = -EFAULT;
-		goto dec_account;
+		goto free_ppage;
 	} else if (!PageSwapBacked(page)) {
 		/*
 		 * We really can't deal with page-cache pages returned by GUP
@@ -1334,16 +1348,19 @@ static int pkvm_mem_abort(struct kvm_vcpu *vcpu, phys_addr_t fault_ipa,
 		 * prevent try_to_unmap() from succeeding.
 		 */
 		ret = -EIO;
-		goto dec_account;
+		goto free_ppage;
 	}
 
 	write_lock(&kvm->mmu_lock);
+	ret = pkvm_account_locked_pages(kvm, 1, true);
+	if (ret)
+		goto unpin;
 	pfn = page_to_pfn(page);
 	ret = pkvm_host_map_guest(pfn, fault_ipa >> PAGE_SHIFT);
 	if (ret) {
 		if (ret == -EAGAIN)
 			ret = 0;
-		goto unpin;
+		goto dec_account;
 	}
 
 	ppage->page = page;
@@ -1353,11 +1370,11 @@ static int pkvm_mem_abort(struct kvm_vcpu *vcpu, phys_addr_t fault_ipa,
 
 	return 0;
 
+dec_account:
+	pkvm_account_locked_pages(kvm, 1, false);
 unpin:
 	write_unlock(&kvm->mmu_lock);
 	unpin_user_pages(&page, 1);
-dec_account:
-	account_locked_vm(mm, 1, false);
 free_ppage:
 	kfree(ppage);
 
diff --git a/arch/arm64/kvm/pkvm.c b/arch/arm64/kvm/pkvm.c
index ea2ea5b..2df50ec 100644
--- a/arch/arm64/kvm/pkvm.c
+++ b/arch/arm64/kvm/pkvm.c
@@ -218,7 +218,6 @@ int pkvm_create_hyp_vm(struct kvm *host_kvm)
 void pkvm_destroy_hyp_vm(struct kvm *host_kvm)
 {
 	struct kvm_pinned_page *ppage;
-	struct mm_struct *mm = current->mm;
 	struct rb_node *node;
 
 	if (host_kvm->arch.pkvm.handle) {
@@ -236,7 +235,8 @@ void pkvm_destroy_hyp_vm(struct kvm *host_kvm)
 					  page_to_pfn(ppage->page)));
 		cond_resched();
 
-		account_locked_vm(mm, 1, false);
+		/* TODO: take mmu lock ? */
+		pkvm_account_locked_pages(host_kvm, 1, false);
 		unpin_user_pages_dirty_lock(&ppage->page, 1, true);
 		node = rb_next(node);
 		rb_erase(&ppage->node, &host_kvm->arch.pkvm.pinned_pages);
@@ -270,7 +270,6 @@ static int rb_ppage_cmp(const void *key, const struct rb_node *node)
 void pkvm_host_reclaim_page(struct kvm *host_kvm, phys_addr_t ipa)
 {
 	struct kvm_pinned_page *ppage;
-	struct mm_struct *mm = current->mm;
 	struct rb_node *node;
 
 	write_lock(&host_kvm->mmu_lock);
@@ -278,6 +277,7 @@ void pkvm_host_reclaim_page(struct kvm *host_kvm, phys_addr_t ipa)
 		       rb_ppage_cmp);
 	if (node)
 		rb_erase(node, &host_kvm->arch.pkvm.pinned_pages);
+	pkvm_account_locked_pages(host_kvm, 1, false);
 	write_unlock(&host_kvm->mmu_lock);
 
 	WARN_ON(!node);
@@ -289,7 +289,6 @@ void pkvm_host_reclaim_page(struct kvm *host_kvm, phys_addr_t ipa)
 	WARN_ON(kvm_call_hyp_nvhe(__pkvm_host_reclaim_page,
 				  page_to_pfn(ppage->page)));
 
-	account_locked_vm(mm, 1, false);
 	unpin_user_pages_dirty_lock(&ppage->page, 1, true);
 	kfree(ppage);
 }