arm64: ioremap/iounmap use range for MMIO guard hypercalls

The hypervisor can accept a range for the mmio guard hypercall. This
intends to reduce back and forth between with the hypervisor.

The feature is advertised via the KVM_FUNC_HAS_RANGE flag from the
MEM_INFO HVC.

Signed-off-by: Vincent Donnefort <vdonnefort@google.com>
diff --git a/arch/arm64/mm/ioremap.c b/arch/arm64/mm/ioremap.c
index c1c04567..998f83c 100644
--- a/arch/arm64/mm/ioremap.c
+++ b/arch/arm64/mm/ioremap.c
@@ -2,6 +2,7 @@
 
 #define pr_fmt(fmt)	"ioremap: " fmt
 
+#include <linux/maple_tree.h>
 #include <linux/mm.h>
 #include <linux/slab.h>
 #include <linux/io.h>
@@ -12,7 +13,7 @@
 #ifndef ARM_SMCCC_KVM_FUNC_MMIO_GUARD_INFO
 #define ARM_SMCCC_KVM_FUNC_MMIO_GUARD_INFO	5
 
-#define ARM_SMCCC_VENDOR_HYP_KVM_MMIO_GUARD_INFO_FUNC_ID			\
+#define ARM_SMCCC_VENDOR_HYP_KVM_MMIO_GUARD_INFO_FUNC_ID		\
 	ARM_SMCCC_CALL_VAL(ARM_SMCCC_FAST_CALL,				\
 			   ARM_SMCCC_SMC_64,				\
 			   ARM_SMCCC_OWNER_VENDOR_HYP,			\
@@ -22,7 +23,7 @@
 #ifndef ARM_SMCCC_KVM_FUNC_MMIO_GUARD_ENROLL
 #define ARM_SMCCC_KVM_FUNC_MMIO_GUARD_ENROLL	6
 
-#define ARM_SMCCC_VENDOR_HYP_KVM_MMIO_GUARD_ENROLL_FUNC_ID			\
+#define ARM_SMCCC_VENDOR_HYP_KVM_MMIO_GUARD_ENROLL_FUNC_ID		\
 	ARM_SMCCC_CALL_VAL(ARM_SMCCC_FAST_CALL,				\
 			   ARM_SMCCC_SMC_64,				\
 			   ARM_SMCCC_OWNER_VENDOR_HYP,			\
@@ -42,24 +43,22 @@
 #ifndef ARM_SMCCC_KVM_FUNC_MMIO_GUARD_UNMAP
 #define ARM_SMCCC_KVM_FUNC_MMIO_GUARD_UNMAP	8
 
-#define ARM_SMCCC_VENDOR_HYP_KVM_MMIO_GUARD_UNMAP_FUNC_ID			\
+#define ARM_SMCCC_VENDOR_HYP_KVM_MMIO_GUARD_UNMAP_FUNC_ID		\
 	ARM_SMCCC_CALL_VAL(ARM_SMCCC_FAST_CALL,				\
 			   ARM_SMCCC_SMC_64,				\
 			   ARM_SMCCC_OWNER_VENDOR_HYP,			\
 			   ARM_SMCCC_KVM_FUNC_MMIO_GUARD_UNMAP)
 #endif	/* ARM_SMCCC_KVM_FUNC_MMIO_GUARD_UNMAP */
 
-struct ioremap_guard_ref {
-	refcount_t	count;
-};
-
 static DEFINE_STATIC_KEY_FALSE(ioremap_guard_key);
-static DEFINE_XARRAY(ioremap_guard_array);
+static DEFINE_MTREE(ioremap_guard_refcount);
 static DEFINE_MUTEX(ioremap_guard_lock);
 
 static size_t guard_granule;
+static bool guard_has_range;
 
 static bool ioremap_guard;
+
 static int __init ioremap_guard_setup(char *str)
 {
 	ioremap_guard = true;
@@ -93,6 +92,8 @@
 		return;
 	}
 
+	guard_has_range = !!(res.a1 & KVM_FUNC_HAS_RANGE);
+
 	arm_smccc_1_1_invoke(ARM_SMCCC_VENDOR_HYP_KVM_MMIO_GUARD_ENROLL_FUNC_ID,
 			     &res);
 	if (res.a0 == SMCCC_RET_SUCCESS) {
@@ -104,31 +105,134 @@
 	}
 }
 
+static int __invoke_mmioguard(phys_addr_t phys_addr, size_t size,
+			      unsigned long prot, u32 func_id, size_t *done)
+{
+	u64 arg2, arg3 = 0, arg_size = guard_has_range ? size : 0;
+	struct arm_smccc_res res;
+
+	switch (func_id) {
+	case ARM_SMCCC_VENDOR_HYP_KVM_MMIO_GUARD_MAP_FUNC_ID:
+		arg2 = prot;
+		arg3 = arg_size;
+		break;
+	case ARM_SMCCC_VENDOR_HYP_KVM_MMIO_GUARD_UNMAP_FUNC_ID:
+		arg2 = arg_size;
+		break;
+	default:
+		return -EINVAL;
+	}
+
+	arm_smccc_1_1_hvc(func_id, phys_addr, arg2, arg3, &res);
+	if (res.a0 != SMCCC_RET_SUCCESS)
+		return -EINVAL;
+
+	*done = guard_has_range ? res.a1 : guard_granule;
+
+	return 0;
+}
+
+static size_t __do_xregister_phys_range(phys_addr_t phys_addr, size_t size,
+					unsigned long prot, u32 func_id)
+{
+	size_t done = 0, __done;
+	int ret;
+
+	while (size) {
+		ret = __invoke_mmioguard(phys_addr, size, prot, func_id, &__done);
+		if (ret)
+			break;
+
+		done += __done;
+
+		if (WARN_ON(__done > size))
+			break;
+
+		phys_addr += __done;
+		size -= __done;
+	}
+
+	return done;
+}
+
+static size_t __do_register_phys_range(phys_addr_t phys_addr, size_t size,
+				       unsigned long prot)
+{
+	return __do_xregister_phys_range(phys_addr, size, prot,
+					 ARM_SMCCC_VENDOR_HYP_KVM_MMIO_GUARD_MAP_FUNC_ID);
+}
+
+static size_t __do_unregister_phys_range(phys_addr_t phys_addr, size_t size)
+{
+	return __do_xregister_phys_range(phys_addr, size, 0,
+					 ARM_SMCCC_VENDOR_HYP_KVM_MMIO_GUARD_UNMAP_FUNC_ID);
+}
+
+static int ioremap_unregister_phys_range(phys_addr_t phys_addr, size_t size)
+{
+	size_t unregistered;
+
+	if (size % guard_granule)
+		return -ERANGE;
+
+	unregistered = __do_unregister_phys_range(phys_addr, size);
+
+	return unregistered == size ? 0 : -EINVAL;
+}
+
+static int ioremap_register_phys_range(phys_addr_t phys_addr, size_t size, pgprot_t prot)
+{
+	size_t registered;
+
+	if (size % guard_granule)
+		return -ERANGE;
+
+	registered = __do_register_phys_range(phys_addr, size, prot.pgprot);
+	if (registered != size) {
+		pr_err("Failed to register %llx:%llx\n", phys_addr, phys_addr + size);
+		WARN_ON(ioremap_unregister_phys_range(phys_addr, registered));
+		return -EINVAL;
+	}
+
+	return 0;
+}
+
 void ioremap_phys_range_hook(phys_addr_t phys_addr, size_t size, pgprot_t prot)
 {
-	int guard_shift;
 
 	if (!static_branch_unlikely(&ioremap_guard_key))
 		return;
 
-	guard_shift = __builtin_ctzl(guard_granule);
-
 	mutex_lock(&ioremap_guard_lock);
 
 	while (size) {
-		u64 guard_fn = phys_addr >> guard_shift;
-		struct ioremap_guard_ref *ref;
-		struct arm_smccc_res res;
+		MA_STATE(mas, &ioremap_guard_refcount, phys_addr, ULONG_MAX);
+		void *entry = mas_find(&mas, phys_addr + size - 1);
+		size_t sub_size = size;
+		int ret;
 
-		if (pfn_valid(__phys_to_pfn(phys_addr)))
-			goto next;
+		if (entry) {
+			if (mas.index <= phys_addr) {
+				sub_size = min((unsigned long)size,
+					       mas.last + 1 - (unsigned long)phys_addr);
 
-		ref = xa_load(&ioremap_guard_array, guard_fn);
-		if (ref) {
-			refcount_inc(&ref->count);
-			goto next;
+				mas_set_range(&mas, phys_addr, phys_addr + sub_size - 1);
+				ret = mas_store_gfp(&mas, xa_mk_value(xa_to_value(entry) + 1),
+						    GFP_KERNEL);
+				if (ret) {
+					pr_err("Failed to inc refcount for 0x%llx:0x%llx\n",
+					       phys_addr, phys_addr + sub_size);
+				}
+
+				goto next;
+			}
+			sub_size = mas.last - phys_addr + 1;
 		}
 
+		ret = ioremap_register_phys_range(phys_addr, sub_size, prot);
+		if (ret)
+			break;
+
 		/*
 		 * It is acceptable for the allocation to fail, specially
 		 * if trying to ioremap something very early on, like with
@@ -136,78 +240,63 @@
 		 * This page will be permanently accessible, similar to a
 		 * saturated refcount.
 		 */
-		if (slab_is_available())
-			ref = kzalloc(sizeof(*ref), GFP_KERNEL);
-		if (ref) {
-			refcount_set(&ref->count, 1);
-			if (xa_err(xa_store(&ioremap_guard_array, guard_fn, ref,
-					    GFP_KERNEL))) {
-				kfree(ref);
-				ref = NULL;
+		if (slab_is_available()) {
+			mas_set_range(&mas, phys_addr, phys_addr + sub_size - 1);
+			ret = mas_store_gfp(&mas, xa_mk_value(1), GFP_KERNEL);
+			if (ret) {
+				pr_err("Failed to log 0x%llx:0x%llx\n",
+				       phys_addr, phys_addr + size);
 			}
 		}
-
-		arm_smccc_1_1_hvc(ARM_SMCCC_VENDOR_HYP_KVM_MMIO_GUARD_MAP_FUNC_ID,
-				  phys_addr, prot, &res);
-		if (res.a0 != SMCCC_RET_SUCCESS) {
-			pr_warn_ratelimited("Failed to register %llx\n",
-					    phys_addr);
-			xa_erase(&ioremap_guard_array, guard_fn);
-			kfree(ref);
-			goto out;
-		}
-
-	next:
-		size -= guard_granule;
-		phys_addr += guard_granule;
+next:
+		size -= sub_size;
+		phys_addr += sub_size;
 	}
-out:
+
+
 	mutex_unlock(&ioremap_guard_lock);
 }
 
 void iounmap_phys_range_hook(phys_addr_t phys_addr, size_t size)
 {
-	int guard_shift;
+	void *entry;
+
+	MA_STATE(mas, &ioremap_guard_refcount, 0, 0);
 
 	if (!static_branch_unlikely(&ioremap_guard_key))
 		return;
 
 	VM_BUG_ON(phys_addr & ~PAGE_MASK || size & ~PAGE_MASK);
-	guard_shift = __builtin_ctzl(guard_granule);
 
 	mutex_lock(&ioremap_guard_lock);
 
-	while (size) {
-		u64 guard_fn = phys_addr >> guard_shift;
-		struct ioremap_guard_ref *ref;
-		struct arm_smccc_res res;
+	mas_for_each(&mas, entry, phys_addr + size - 1) {
+		int refcount = xa_to_value(entry);
 
-		ref = xa_load(&ioremap_guard_array, guard_fn);
-		if (!ref) {
-			pr_warn_ratelimited("%llx not tracked, left mapped\n",
-					    phys_addr);
-			goto next;
+		if (!entry)
+			continue;
+
+		WARN_ON(!refcount);
+
+		if (mas.index < phys_addr || mas.last > phys_addr + size) {
+			unsigned long start = max((unsigned long)phys_addr, mas.index);
+			unsigned long end = min((unsigned long)phys_addr + size, mas.last);
+
+			mas_set_range(&mas, start, end);
 		}
 
-		if (!refcount_dec_and_test(&ref->count))
-			goto next;
-
-		xa_erase(&ioremap_guard_array, guard_fn);
-		kfree(ref);
-
-		arm_smccc_1_1_hvc(ARM_SMCCC_VENDOR_HYP_KVM_MMIO_GUARD_UNMAP_FUNC_ID,
-				  phys_addr, &res);
-		if (res.a0 != SMCCC_RET_SUCCESS) {
-			pr_warn_ratelimited("Failed to unregister %llx\n",
-					    phys_addr);
-			goto out;
+		if (mas_store_gfp(&mas, xa_mk_value(refcount - 1), GFP_KERNEL)) {
+			pr_err("Failed to dec refcount for 0x%lx:0x%lx\n",
+			       mas.index, mas.last);
+			continue;
 		}
 
-	next:
-		size -= guard_granule;
-		phys_addr += guard_granule;
+		if (refcount <= 1) {
+			WARN_ON(ioremap_unregister_phys_range(mas.index, mas.last - mas.index + 1));
+			mas_erase(&mas);
+		}
 	}
-out:
+
 	mutex_unlock(&ioremap_guard_lock);
 }