WiP: ANDROID: drivers: virt: Add pkvm_module_ops test

Add a test module exercising the pkvm_module_ops API.

Signed-off-by: Quentin Perret <qperret@google.com>
diff --git a/drivers/virt/Kconfig b/drivers/virt/Kconfig
index d8c848c..e72f916 100644
--- a/drivers/virt/Kconfig
+++ b/drivers/virt/Kconfig
@@ -49,4 +49,6 @@
 
 source "drivers/virt/coco/Kconfig"
 
+source "drivers/virt/pkvm_kunit/Kconfig"
+
 endif
diff --git a/drivers/virt/Makefile b/drivers/virt/Makefile
index f29901b..8cb845a 100644
--- a/drivers/virt/Makefile
+++ b/drivers/virt/Makefile
@@ -10,3 +10,4 @@
 obj-$(CONFIG_NITRO_ENCLAVES)	+= nitro_enclaves/
 obj-$(CONFIG_ACRN_HSM)		+= acrn/
 obj-y				+= coco/
+obj-y				+= pkvm_kunit/
diff --git a/drivers/virt/pkvm_kunit/Kconfig b/drivers/virt/pkvm_kunit/Kconfig
new file mode 100644
index 0000000..956850b
--- /dev/null
+++ b/drivers/virt/pkvm_kunit/Kconfig
@@ -0,0 +1,19 @@
+# SPDX-License-Identifier: GPL-2.0-only
+
+menuconfig PKVM_KUNIT_TESTS
+	bool "pKVM Kunit tests"
+	depends on ARM64 && KVM && KUNIT
+	default KUNIT_ALL_TESTS
+	help
+	  Say Y here to get to see options for Kunit tests related to Protected
+	  KVM on arm64.
+
+	  If you say N, all options in this submenu will be skipped and disabled.
+
+if PKVM_KUNIT_TESTS
+
+config PKVM_MODULE_OPS_KUNIT_TEST
+	tristate "Test pKVM module_ops"
+	default KUNIT_ALL_TESTS
+
+endif
diff --git a/drivers/virt/pkvm_kunit/Makefile b/drivers/virt/pkvm_kunit/Makefile
new file mode 100644
index 0000000..d069dd3
--- /dev/null
+++ b/drivers/virt/pkvm_kunit/Makefile
@@ -0,0 +1,3 @@
+# SPDX-License-Identifier: GPL-2.0-only
+
+obj-$(CONFIG_PKVM_MODULE_OPS_KUNIT_TEST)	+= module_ops/
diff --git a/drivers/virt/pkvm_kunit/module_ops/.gitignore b/drivers/virt/pkvm_kunit/module_ops/.gitignore
new file mode 100644
index 0000000..899547d8
--- /dev/null
+++ b/drivers/virt/pkvm_kunit/module_ops/.gitignore
@@ -0,0 +1,3 @@
+# SPDX-License-Identifier: GPL-2.0-only
+hyp.lds
+hyp-reloc.S
diff --git a/drivers/virt/pkvm_kunit/module_ops/Makefile b/drivers/virt/pkvm_kunit/module_ops/Makefile
new file mode 100644
index 0000000..c9e91e4
--- /dev/null
+++ b/drivers/virt/pkvm_kunit/module_ops/Makefile
@@ -0,0 +1,8 @@
+obj-m += pkvm_module_ops_kunit.o
+
+$(obj)/hyp/kvm_nvhe.o: FORCE
+	$(Q)$(MAKE) $(build)=$(obj)/hyp $(obj)/hyp/kvm_nvhe.o
+
+clean-files := hyp/hyp.lds hyp/hyp-reloc.S
+
+pkvm_module_ops_kunit-y := module_ops-host.o hyp/kvm_nvhe.o
diff --git a/drivers/virt/pkvm_kunit/module_ops/hyp/Makefile b/drivers/virt/pkvm_kunit/module_ops/hyp/Makefile
new file mode 100644
index 0000000..b20b7f5
--- /dev/null
+++ b/drivers/virt/pkvm_kunit/module_ops/hyp/Makefile
@@ -0,0 +1,2 @@
+hyp-obj-y := module_ops-hyp.o
+include $(srctree)/arch/arm64/kvm/hyp/nvhe/Makefile.module
diff --git a/drivers/virt/pkvm_kunit/module_ops/hyp/module_ops-hyp.c b/drivers/virt/pkvm_kunit/module_ops/hyp/module_ops-hyp.c
new file mode 100644
index 0000000..b0282ae
--- /dev/null
+++ b/drivers/virt/pkvm_kunit/module_ops/hyp/module_ops-hyp.c
@@ -0,0 +1,80 @@
+// SPDX-License-Identifier: GPL-2.0
+
+#include <asm/kvm_hyp.h>
+#include <asm/kvm_pkvm_module.h>
+
+#include "module_ops.h"
+
+static const struct pkvm_module_ops *ops;
+
+typedef u64 (*module_op_cb_t)(u64, u64, u64, u64, u64, u64);
+
+void __nocfi do_mod_op(struct user_pt_regs *regs)
+{
+	module_op_cb_t *raw_ops_arr = (module_op_cb_t *)ops;
+	module_op_cb_t raw_op;
+	u64 *r = regs->regs;
+
+	if (r[1] >= sizeof(struct pkvm_module_ops) || r[1] % sizeof(module_op_cb_t)) {
+		r[0] = SMCCC_RET_INVALID_PARAMETER;
+		return;
+	}
+
+	raw_op = raw_ops_arr[r[1] / sizeof(module_op_cb_t)];
+	r[1] = raw_op(r[2], r[3], r[4], r[5], r[6], r[7]);
+	r[0] = SMCCC_RET_SUCCESS;
+}
+
+struct perm_fault_desc fault_desc;
+static int handle_perm_fault(struct user_pt_regs *regs, u64 esr, u64 addr)
+{
+	fault_desc.addr = addr;
+	fault_desc.esr = esr;
+	write_sysreg_el2(read_sysreg_el2(SYS_ELR) + 4, SYS_ELR);
+	return 0;
+}
+
+static int __do_set_val(u64 *ptr, u64 val)
+{
+	WRITE_ONCE(*ptr, val);
+	return 0;
+}
+
+static u64 __do_get_val(u64 *ptr)
+{
+	return READ_ONCE(*ptr);
+}
+
+void __nocfi do_cmd(struct user_pt_regs *regs)
+{
+	u64 *r = regs->regs;
+
+	r[0] = SMCCC_RET_SUCCESS;
+	switch (r[1]) {
+	case CMD_SET_VAL:
+		r[1] = __do_set_val((u64 *)r[2], r[3]);
+		return;
+	case CMD_GET_VAL:
+		r[1] = __do_get_val((u64 *)r[2]);
+		return;
+	default:
+		r[0] = SMCCC_RET_NOT_SUPPORTED;
+	}
+}
+
+char serial_ring_buf[SERIAL_RING_BUF_LEN];
+u64 serial_idx;
+
+static void dummy_putc(char c)
+{
+	serial_ring_buf[serial_idx] = c;
+	serial_idx = (serial_idx + 1) % SERIAL_RING_BUF_LEN;
+}
+
+int module_ops_hyp_init(const struct pkvm_module_ops *__ops)
+{
+	ops = __ops;
+	ops->register_host_perm_fault_handler(handle_perm_fault);
+	ops->register_serial_driver(dummy_putc);
+	return 0;
+}
diff --git a/drivers/virt/pkvm_kunit/module_ops/hyp/module_ops.h b/drivers/virt/pkvm_kunit/module_ops/hyp/module_ops.h
new file mode 100644
index 0000000..f827761
--- /dev/null
+++ b/drivers/virt/pkvm_kunit/module_ops/hyp/module_ops.h
@@ -0,0 +1,29 @@
+#ifndef PKVM_KUNIT_MODULE_OPS_H
+#define PKVM_KUNIT_MODULE_OPS_H
+
+enum PKVM_MOD_OP_CMD {
+	CMD_SET_VAL,	/* u64 *hyp_va, u64 val */
+	CMD_GET_VAL,	/* u64 *hyp_va */
+};
+
+struct perm_fault_desc {
+	u64 addr;
+	u64 esr;
+};
+
+#define SERIAL_RING_BUF_LEN 128
+
+#ifdef __KVM_NVHE_HYPERVISOR__
+int module_ops_hyp_init(const struct pkvm_module_ops *ops);
+void do_mod_op(struct user_pt_regs *);
+void do_cmd(struct user_pt_regs *);
+#else
+int __kvm_nvhe_module_ops_hyp_init(const struct pkvm_module_ops *ops);
+void __kvm_nvhe_do_mod_op(struct user_pt_regs *);
+void __kvm_nvhe_do_cmd(struct user_pt_regs *);
+extern struct perm_fault_desc __kvm_nvhe_fault_desc;
+extern char __kvm_nvhe_serial_ring_buf[SERIAL_RING_BUF_LEN];
+extern u64 __kvm_nvhe_serial_idx;
+#endif
+
+#endif /* PKVM_KUNIT_MODULE_OPS_H */
diff --git a/drivers/virt/pkvm_kunit/module_ops/module_ops-host.c b/drivers/virt/pkvm_kunit/module_ops/module_ops-host.c
new file mode 100644
index 0000000..cd4985b
--- /dev/null
+++ b/drivers/virt/pkvm_kunit/module_ops/module_ops-host.c
@@ -0,0 +1,318 @@
+// SPDX-License-Identifier: GPL-2.0
+
+#include <asm/esr.h>
+#include <asm/kvm_pkvm_module.h>
+#include <asm/word-at-a-time.h>
+#include <kunit/test.h>
+#include <linux/init.h>
+#include <linux/module.h>
+#include <linux/kernel.h>
+
+#include "hyp/module_ops.h"
+
+#ifndef MODULE
+BUILD_BUG("pKVM Kunit module_ops must be compiled as a module");
+#endif
+
+static int __mod_op_hc, __cmd_hc;
+static u64 __hyp_fault_va, __hyp_ring_buf_va;
+
+#define HYP_MOD_OP(name, ...)	pkvm_el2_mod_call(__mod_op_hc,					\
+						  offsetof(struct pkvm_module_ops, name),	\
+						  ##__VA_ARGS__)
+#define HYP_CMD(name, ...)	pkvm_el2_mod_call(__cmd_hc, name, ##__VA_ARGS__)
+
+static void module_ops_fixmap_test(struct kunit *test)
+{
+	u64 pa, hyp_va, val, randval;
+	u64 *ptr;
+
+	ptr = (u64 *)__get_free_page(GFP_KERNEL);
+	KUNIT_EXPECT_TRUE(test, ptr);
+	pa = virt_to_phys(ptr);
+
+	/* Hyp write through fixmap */
+	randval = ktime_get_ns();
+	hyp_va = HYP_MOD_OP(fixmap_map, pa);
+	HYP_CMD(CMD_SET_VAL, hyp_va, randval);
+	val = READ_ONCE(*ptr);
+	KUNIT_EXPECT_EQ(test, val, randval);
+	HYP_MOD_OP(fixmap_unmap);
+
+	free_page((unsigned long)ptr);
+}
+
+static void copy_hyp_fault_desc(struct perm_fault_desc *fault)
+{
+	u64 *dst = (u64 *)fault, *src = (u64 *)__hyp_fault_va;
+
+	for (size_t off = 0; off < sizeof(*fault); dst++, src++, off += 8)
+		*dst = HYP_CMD(CMD_GET_VAL, src);
+}
+
+static void clear_hyp_fault_desc(void)
+{
+	u64 *dst = (u64 *)__hyp_fault_va;
+
+	for (size_t off = 0; off < sizeof(struct perm_fault_desc); dst++, off += 8)
+		HYP_CMD(CMD_SET_VAL, dst, 0ULL);
+}
+
+static void module_ops_host_stage2_mod_prot_test(struct kunit *test)
+{
+	u64 pfn, pa, val, randval;
+	struct perm_fault_desc fault;
+	u64 *ptr;
+	int ret;
+
+	ptr = (u64 *)__get_free_page(GFP_KERNEL);
+	KUNIT_EXPECT_TRUE(test, ptr);
+	pfn = virt_to_pfn(ptr);
+	pa = virt_to_phys(ptr);
+
+	randval = ktime_get_ns();
+	*ptr = randval;
+
+	/* Read from RO */
+	clear_hyp_fault_desc();
+	ret = HYP_MOD_OP(host_stage2_mod_prot, pfn, KVM_PGTABLE_PROT_R, 1);
+	KUNIT_EXPECT_FALSE(test, ret);
+	val = READ_ONCE(*ptr);
+	KUNIT_EXPECT_EQ(test, val, randval);
+	copy_hyp_fault_desc(&fault);
+	KUNIT_EXPECT_FALSE(test, fault.esr);
+
+	/* Write to RO */
+	randval = ktime_get_ns();
+	WRITE_ONCE(*ptr, randval);
+	copy_hyp_fault_desc(&fault);
+	KUNIT_EXPECT_EQ(test, fault.addr, pa);
+	KUNIT_EXPECT_TRUE(test, esr_fsc_is_permission_fault(fault.esr));
+	clear_hyp_fault_desc();
+
+	/* Read and write from/to RO */
+	ret = HYP_MOD_OP(host_stage2_mod_prot, pfn, KVM_PGTABLE_PROT_R | KVM_PGTABLE_PROT_W, 1);
+	WRITE_ONCE(*ptr, randval);
+	copy_hyp_fault_desc(&fault);
+	KUNIT_EXPECT_FALSE(test, fault.esr);
+
+	val = READ_ONCE(*ptr);
+	copy_hyp_fault_desc(&fault);
+	KUNIT_EXPECT_FALSE(test, fault.esr);
+	KUNIT_EXPECT_EQ(test, val, randval);
+
+	/* Read and write from/to X */
+	ret = HYP_MOD_OP(host_stage2_mod_prot, pfn, KVM_PGTABLE_PROT_X, 1);
+	KUNIT_EXPECT_FALSE(test, ret);
+	WRITE_ONCE(*(ptr + 1), randval);
+	copy_hyp_fault_desc(&fault);
+	KUNIT_EXPECT_EQ(test, fault.addr, pa + 8);
+	KUNIT_EXPECT_TRUE(test, esr_fsc_is_permission_fault(fault.esr));
+	clear_hyp_fault_desc();
+
+	val = READ_ONCE(*(ptr + 2));
+	copy_hyp_fault_desc(&fault);
+	KUNIT_EXPECT_EQ(test, fault.addr, pa + 16);
+	KUNIT_EXPECT_TRUE(test, esr_fsc_is_permission_fault(fault.esr));
+	clear_hyp_fault_desc();
+
+	/* Transitions from a MODULE_OWNED state must be blocked */
+	ret = HYP_MOD_OP(host_donate_hyp, pfn, 1);
+	KUNIT_EXPECT_EQ(test, ret, -EPERM);
+	ret = HYP_MOD_OP(hyp_donate_host, pfn, 1);
+	KUNIT_EXPECT_EQ(test, ret, -EPERM);
+	ret = HYP_MOD_OP(host_share_hyp, pfn);
+	KUNIT_EXPECT_EQ(test, ret, -EPERM);
+	ret = HYP_MOD_OP(host_unshare_hyp, pfn);
+	KUNIT_EXPECT_EQ(test, ret, -EPERM);
+
+	/* Restore to a pristine state */
+	ret = HYP_MOD_OP(host_stage2_mod_prot, pfn, KVM_PGTABLE_PROT_RWX, 1);
+	KUNIT_EXPECT_FALSE(test, ret);
+
+	free_page((unsigned long)ptr);
+}
+
+static void module_ops_host_donate_hyp_test(struct kunit *test)
+{
+	u64 *ptr;
+	u64 pfn;
+	int ret;
+
+	ptr = (u64 *)__get_free_page(GFP_KERNEL);
+	KUNIT_EXPECT_TRUE(test, ptr);
+	pfn = virt_to_pfn(ptr);
+
+	/* No module-initiated transitions on hyp-owned pages */
+	ret = HYP_MOD_OP(host_donate_hyp, pfn, 1);
+	KUNIT_EXPECT_FALSE(test, ret);
+	ret = HYP_MOD_OP(host_stage2_mod_prot, pfn, KVM_PGTABLE_PROT_R, 1);
+	KUNIT_EXPECT_EQ(test, ret, -EPERM);
+	ret = HYP_MOD_OP(host_stage2_mod_prot, pfn, KVM_PGTABLE_PROT_RWX, 1);
+	KUNIT_EXPECT_EQ(test, ret, -EPERM);
+	ret = HYP_MOD_OP(hyp_donate_host, pfn, 1);
+	KUNIT_EXPECT_FALSE(test, ret);
+
+	free_page((unsigned long)ptr);
+}
+
+static void module_ops_host_share_hyp_test(struct kunit *test)
+{
+	u64 pfn, hyp_va, pa, randval;
+	u64 *ptr;
+	int ret;
+
+	ptr = (u64 *)__get_free_page(GFP_KERNEL);
+	KUNIT_EXPECT_TRUE(test, ptr);
+	pfn = virt_to_pfn(ptr);
+	pa = virt_to_phys(ptr);
+
+	/* No module-initiated transitions on hyp-borrowed pages */
+	ret = HYP_MOD_OP(host_share_hyp, pfn);
+	KUNIT_EXPECT_FALSE(test, ret);
+	ret = HYP_MOD_OP(host_stage2_mod_prot, pfn, KVM_PGTABLE_PROT_R, 1);
+	KUNIT_EXPECT_EQ(test, ret, -EPERM);
+	ret = HYP_MOD_OP(host_stage2_mod_prot, pfn, KVM_PGTABLE_PROT_RWX, 1);
+	KUNIT_EXPECT_EQ(test, ret, -EPERM);
+
+	/* Pin shared pages */
+	hyp_va = HYP_MOD_OP(hyp_va, pa);
+	ret = HYP_MOD_OP(pin_shared_mem, hyp_va, hyp_va + PAGE_SIZE);
+	KUNIT_EXPECT_FALSE(test, ret);
+	ret = HYP_MOD_OP(host_unshare_hyp, pfn);
+	KUNIT_EXPECT_EQ(test, ret, -EBUSY);
+	randval = ktime_get_ns();
+	HYP_CMD(CMD_SET_VAL, hyp_va, randval);
+	KUNIT_EXPECT_EQ(test, *ptr, randval);
+	HYP_MOD_OP(unpin_shared_mem, hyp_va, hyp_va + PAGE_SIZE);
+	ret = HYP_MOD_OP(host_unshare_hyp, pfn);
+	KUNIT_EXPECT_FALSE(test, ret);
+
+	free_page((unsigned long)ptr);
+}
+
+static void module_ops_private_range_test(struct kunit *test)
+{
+	u64 pfn, pa, val, randval, hyp_va;
+	u64 *ptr;
+	int ret;
+
+	/* All allocations in this test are leaked, unmapping from private range not supported */
+	ptr = (u64 *)__get_free_page(GFP_KERNEL);
+	KUNIT_EXPECT_TRUE(test, ptr);
+	pfn = virt_to_pfn(ptr);
+	pa = virt_to_phys(ptr);
+
+	/* Alloc + map in the private range, w/o a donation */
+	randval = ktime_get_ns();
+	*ptr = randval;
+	hyp_va = HYP_MOD_OP(alloc_module_va, 1);
+	ret = HYP_MOD_OP(map_module_page, virt_to_pfn(ptr), hyp_va, KVM_PGTABLE_PROT_RW, true);
+	KUNIT_EXPECT_FALSE(test, ret);
+
+	val = HYP_CMD(CMD_GET_VAL, hyp_va);
+	KUNIT_EXPECT_EQ(test, val, randval);
+
+	randval = ktime_get_ns();
+	WRITE_ONCE(*ptr, randval);
+	val = HYP_CMD(CMD_GET_VAL, hyp_va);
+	KUNIT_EXPECT_EQ(test, val, randval);
+
+	/* Alloc + map in the private range, w/ a donation */
+	ptr = (u64 *)__get_free_pages(GFP_KERNEL, 1); /* 2 pages to allow misaligned access below */
+	KUNIT_EXPECT_TRUE(test, ptr);
+	memset(ptr, ~0, PAGE_SIZE * 2);
+	hyp_va = HYP_MOD_OP(alloc_module_va, 1);
+	ret = HYP_MOD_OP(map_module_page, virt_to_pfn(ptr) + 1, hyp_va, KVM_PGTABLE_PROT_RW, false);
+	KUNIT_EXPECT_FALSE(test, ret);
+	val = HYP_CMD(CMD_GET_VAL, hyp_va);
+	KUNIT_EXPECT_EQ(test, val, ~0ULL);
+	/* Forcefully misalign an access to trigger the extable handling and 0 padding */
+	val = load_unaligned_zeropad(((u32 *)ptr) - 1);
+	KUNIT_EXPECT_EQ(test, val, 0xffffffff00000000);
+
+	/* Find leaf PTE in host stage-2 for a hyp-owned page */
+	HYP_CMD(CMD_SET_VAL, hyp_va, 0x0);
+	ret = HYP_MOD_OP(host_stage2_get_leaf, virt_to_phys(ptr) + PAGE_SIZE, hyp_va, 0);
+	KUNIT_EXPECT_FALSE(test, ret);
+	val = HYP_CMD(CMD_GET_VAL, hyp_va);
+	/* PTE must be an annotation */
+	KUNIT_EXPECT_TRUE(test, val);
+	KUNIT_EXPECT_FALSE(test, val & 0x1);
+}
+
+static void module_ops_serial_test(struct kunit *test)
+{
+	char str1[] = "hello", str2[] = " world!";// str3[] = "hello world!";
+	u64 pfn, pa, hyp_va, buf[2];
+	void *ptr;
+
+	for (int i = 0; i < sizeof(str1) - 1; i++)
+		HYP_MOD_OP(putc, str1[i]);
+	buf[0] = HYP_CMD(CMD_GET_VAL, __hyp_ring_buf_va);
+	KUNIT_EXPECT_FALSE(test, strcmp((char *)&buf, str1));
+
+	ptr = (void *)__get_free_page(GFP_KERNEL);
+	KUNIT_EXPECT_TRUE(test, ptr);
+	pfn = virt_to_pfn(ptr);
+	pa = virt_to_phys(ptr);
+
+	memcpy(ptr, str2, sizeof(str2));
+	hyp_va = HYP_MOD_OP(fixmap_map, pa);
+	HYP_MOD_OP(puts, hyp_va);
+	HYP_MOD_OP(fixmap_unmap);
+	buf[0] = HYP_CMD(CMD_GET_VAL, __hyp_ring_buf_va);
+	buf[1] = HYP_CMD(CMD_GET_VAL, __hyp_ring_buf_va + 8);
+	/* XXX - broken */
+	//KUNIT_EXPECT_FALSE(test, strcmp((char *)&buf, str3));
+
+	free_page((unsigned long)ptr);
+}
+
+static int module_ops_test_init(struct kunit_suite *suite)
+{
+	unsigned long token;
+	int ret;
+
+	pr_warn("pKVM module ops Kunit test: DON'T USE IN PRODUCTION!\n"
+		"This module breaks the pKVM security guarantees for \n"
+		"testing purposes and must not be used outside of a  \n"
+		"testing environment!\n");
+
+	ret = pkvm_load_el2_module(__kvm_nvhe_module_ops_hyp_init, &token);
+	if (ret)
+		return ret;
+
+	__mod_op_hc = pkvm_register_el2_mod_call(__kvm_nvhe_do_mod_op, token);
+	if (__mod_op_hc < 0)
+		return __mod_op_hc;
+
+	__cmd_hc = pkvm_register_el2_mod_call(__kvm_nvhe_do_cmd, token);
+	if (__cmd_hc < 0)
+		return __cmd_hc;
+
+	__hyp_fault_va = pkvm_el2_mod_va(&__kvm_nvhe_fault_desc, token);
+	__hyp_ring_buf_va = pkvm_el2_mod_va(&__kvm_nvhe_serial_ring_buf, token);
+
+	return 0;
+}
+
+static struct kunit_case module_ops_test_cases[] = {
+	KUNIT_CASE(module_ops_fixmap_test),
+	KUNIT_CASE(module_ops_host_stage2_mod_prot_test),
+	KUNIT_CASE(module_ops_host_donate_hyp_test),
+	KUNIT_CASE(module_ops_host_share_hyp_test),
+	KUNIT_CASE(module_ops_private_range_test),
+	KUNIT_CASE(module_ops_serial_test),
+	{}
+};
+
+static struct kunit_suite module_ops_test_suite = {
+	.name = "pkvm_module_ops",
+	.test_cases = module_ops_test_cases,
+	.suite_init = module_ops_test_init,
+};
+kunit_test_suite(module_ops_test_suite);
+
+MODULE_DESCRIPTION("pKVM module_ops Kunit test");
+MODULE_LICENSE("GPL v2");