KVM: arm64: Remove exclusive pin from memory shared by a guest

When a protected guest shares a page, remove the exclusive pin,
but maintain a normal pin on the page.

Signed-off-by: Fuad Tabba <tabba@google.com>
diff --git a/arch/arm64/include/asm/kvm_host.h b/arch/arm64/include/asm/kvm_host.h
index 4851ce1..66e1019 100644
--- a/arch/arm64/include/asm/kvm_host.h
+++ b/arch/arm64/include/asm/kvm_host.h
@@ -1331,6 +1331,8 @@ struct kvm *kvm_arch_alloc_vm(void);
 
 #define vcpu_is_protected(vcpu)		kvm_vm_is_protected((vcpu)->kvm)
 
+int pkvm_handle_guest_share(struct kvm *kvm, gfn_t gfn);
+
 int kvm_arm_vcpu_finalize(struct kvm_vcpu *vcpu, int feature);
 bool kvm_arm_vcpu_is_finalized(struct kvm_vcpu *vcpu);
 
diff --git a/arch/arm64/kvm/hypercalls.c b/arch/arm64/kvm/hypercalls.c
index b08e181..11a2c2d 100644
--- a/arch/arm64/kvm/hypercalls.c
+++ b/arch/arm64/kvm/hypercalls.c
@@ -144,23 +144,24 @@ static bool kvm_smccc_test_fw_bmap(struct kvm_vcpu *vcpu, u32 func_id)
 	}
 }
 
-static int kvm_vcpu_exit_hcall(struct kvm_vcpu *vcpu, u32 nr, u32 nr_args)
+static int kvm_vcpu_handle_xshare(struct kvm_vcpu *vcpu, u32 nr)
 {
-	u64 mask = vcpu->kvm->arch.hypercall_exit_enabled;
-	u32 i;
+	u64 mask = BIT(ARM_SMCCC_KVM_FUNC_MEM_SHARE)|BIT(ARM_SMCCC_KVM_FUNC_MEM_UNSHARE);
+	gfn_t gfn = vcpu_get_reg(vcpu, 1) >> PAGE_SHIFT;
 
-	if (nr_args > 6 || !(mask & BIT(nr))) {
-		smccc_set_retval(vcpu, SMCCC_RET_INVALID_PARAMETER, 0, 0, 0);
-		return 1;
+	if (!(mask & BIT(nr)))
+		goto err;
+
+	if (nr == ARM_SMCCC_KVM_FUNC_MEM_SHARE) {
+		if (pkvm_handle_guest_share(vcpu->kvm, gfn))
+			goto err;
 	}
 
-	vcpu->run->exit_reason		= KVM_EXIT_HYPERCALL;
-	vcpu->run->hypercall.nr		= nr;
-
-	for (i = 0; i < nr_args; ++i)
-		vcpu->run->hypercall.args[i] = vcpu_get_reg(vcpu, i + 1);
-
 	return 0;
+
+err:
+	smccc_set_retval(vcpu, SMCCC_RET_INVALID_PARAMETER, 0, 0, 0);
+	return 1;
 }
 
 #define SMC32_ARCH_RANGE_BEGIN	ARM_SMCCC_VERSION_FUNC_ID
@@ -411,9 +412,9 @@ int kvm_smccc_call_handler(struct kvm_vcpu *vcpu)
 		val[0] = SMCCC_RET_SUCCESS;
 		break;
 	case ARM_SMCCC_VENDOR_HYP_KVM_MEM_SHARE_FUNC_ID:
-		return kvm_vcpu_exit_hcall(vcpu, ARM_SMCCC_KVM_FUNC_MEM_SHARE, 3);
+		return kvm_vcpu_handle_xshare(vcpu, ARM_SMCCC_KVM_FUNC_MEM_SHARE);
 	case ARM_SMCCC_VENDOR_HYP_KVM_MEM_UNSHARE_FUNC_ID:
-		return kvm_vcpu_exit_hcall(vcpu, ARM_SMCCC_KVM_FUNC_MEM_UNSHARE, 3);
+		return kvm_vcpu_handle_xshare(vcpu, ARM_SMCCC_KVM_FUNC_MEM_UNSHARE);
 	case ARM_SMCCC_TRNG_VERSION:
 	case ARM_SMCCC_TRNG_FEATURES:
 	case ARM_SMCCC_TRNG_GET_UUID:
diff --git a/arch/arm64/kvm/mmu.c b/arch/arm64/kvm/mmu.c
index 2d12b3c..482f251 100644
--- a/arch/arm64/kvm/mmu.c
+++ b/arch/arm64/kvm/mmu.c
@@ -1434,6 +1434,53 @@ static int insert_ppage(struct kvm *kvm, struct kvm_guest_page *ppage)
 	return 0;
 }
 
+static int cmp_ppages_gfn(const void *key, const struct rb_node *node)
+{
+	struct kvm_guest_page *p = container_of(node, struct kvm_guest_page, node);
+	gfn_t gfn = (gfn_t)key;
+
+	return (gfn < p->gfn) ? -1 : (gfn > p->gfn);
+}
+
+static struct kvm_guest_page *find_page(struct kvm *kvm, gfn_t gfn)
+{
+	struct rb_node *node;
+
+	node = rb_find((void *)gfn, &kvm->arch.pkvm.pinned_pages, cmp_ppages_gfn);
+
+	if (!node)
+		return NULL;
+
+	return container_of(node, struct kvm_guest_page, node);
+}
+
+int pkvm_handle_guest_share(struct kvm *kvm, gfn_t gfn)
+{
+	struct kvm_guest_page *ppage;
+	int ret = 0;
+
+	write_lock(&kvm->mmu_lock);
+
+	ppage = find_page(kvm, gfn);
+
+	if (WARN_ON(!ppage)) {
+		ret = -ENOENT;
+		goto out;
+	}
+
+	if (WARN_ON(!ppage->is_private)) {
+		ret = -EINVAL;
+		goto out;
+	}
+
+	unexc_user_page(ppage->page);
+	ppage->is_private = false;
+out:
+	write_unlock(&kvm->mmu_lock);
+
+	return ret;
+}
+
 static int pkvm_mem_abort(struct kvm_vcpu *vcpu, phys_addr_t fault_ipa,
 			  struct kvm_memory_slot *memslot)
 {