new stuff
diff --git a/kvmtest/.gitignore b/kvmtest/.gitignore
index ff9945a..ceecb63 100644
--- a/kvmtest/.gitignore
+++ b/kvmtest/.gitignore
@@ -1,2 +1,5 @@
 cow-client
 cow-shmem
+kvmtest-arm64
+*.o
+*~
diff --git a/kvmtest/Makefile b/kvmtest/Makefile
index 7b22f02..243995b 100644
--- a/kvmtest/Makefile
+++ b/kvmtest/Makefile
@@ -1,7 +1,17 @@
-all: cow-shmem cow-client
+TARGETS := cow-shmem cow-client kvmtest-arm64
 
-%: %.c
-	aarch64-linux-gnu-gcc -o $@ $< -static -Wall -Werror -D_GNU_SOURCE
+all: $(TARGETS)
+
+FLAGS := -static -Wall -Werror -D_GNU_SOURCE
+
+$(TARGETS): %: %.o helpers.o
+	aarch64-linux-gnu-gcc -o $@ $^ $(FLAGS)
+
+%.o: %.c
+	aarch64-linux-gnu-gcc -c -o $@ $< $(FLAGS)
 
 clean:
-	rm cow-shmem cow-client
+	$(RM) $(TARGETS) *.o *~
+
+upload: all
+	scp -P 8022 $(TARGETS) keirf@localhost:~/.
diff --git a/kvmtest/cow-client.c b/kvmtest/cow-client.c
index 845a4fd..f2b8c91 100644
--- a/kvmtest/cow-client.c
+++ b/kvmtest/cow-client.c
@@ -1,3 +1,4 @@
+#include <inttypes.h>
 #include <stdio.h>
 #include <err.h>
 #include <string.h>
@@ -21,12 +22,14 @@
 #include <sys/socket.h>
 #include <errno.h>
 
+#include "helpers.h"
+
 /* ++ uapi/pkvm_shmem.h */
 #define KVM_SHMEM_ALLOC_TYPE 1
 
 enum kvm_shmem_alloc_type { 
-	KVM_SHMEM_ALLOCTYPE_VMALLOC,
-	KVM_SHMEM_ALLOCTYPE_PAGES_EXACT };
+    KVM_SHMEM_ALLOCTYPE_VMALLOC,
+    KVM_SHMEM_ALLOCTYPE_PAGES_EXACT };
 
 #define KVM_SHMEM_ALLOC(alloc) _IO(KVM_SHMEM_ALLOC_TYPE, alloc)
 #define KVM_SHMEM_ALLOC_PAGES KVM_SHMEM_ALLOC(KVM_SHMEM_ALLOCTYPE_PAGES_EXACT)
@@ -38,93 +41,262 @@
 #define KVM_SHMEM_AREA_MAKE_READONLY _IO(1, 2)
 /* -- uapi/pkvm_shmem.h */
 
-#define PAGE_SHIFT 12
-#define PAGE_SIZE (1 << PAGE_SHIFT)
-
 struct shmem_area {
     int size;
     int fd;
-    void* mmap;
-    ulong kaddr;
+    void *mmap;
     ulong phys;
 };
 
-void shmem_area_getinfo(struct shmem_area *area)
-{
-	if(ioctl(area->fd, KVM_SHMEM_AREA_KADDR, &area->kaddr))
-		err(1, "Can't get kernel area address");
-	printf("Allocated area at %lx of size 0x%x bytes\n",
-	       area->kaddr, area->size);
+/* Early CoW breaks the share immediately by writing at an offset into
+ * the shared page. This should not affect the shared page, and the
+ * copy-break page must still start with the expected magic value.
+ */
+asm(
+    ".global guest_code_early_cow_start, guest_code_early_cow_end\n"
+    "guest_code_early_cow_start:\n"
+    "  mov x10, x0       \n"
+    "  ldp x0, x1, [x10] \n"
+    "  ldp x2, x3, [x10, #16]\n"
+    "  str x0,[x3, #16]  \n" /* early CoW */
+    "  ldr w9, [x3]      \n"
+    "  add x0, x9, x0    \n"
+    "  add x0, x1, x0    \n"
+    "  str x0,[x2]       \n"
+    "  brk #0            \n"
+    "guest_code_early_cow_end:\n"
+    );
 
-	if (ioctl(area->fd, KVM_SHMEM_AREA_PHYS, &area->phys))
-		err(1, "Can't get kernel area physical address");
+extern char guest_code_early_cow_start[], guest_code_early_cow_end[];
+
+/* Late CoW writes to the share after reading from it. It overwrites the
+ * magic value at the start of the page. It should not affect the magic
+ * value in the original read-only shared page.
+ */
+asm(
+    ".global guest_code_start, guest_code_end\n"
+    "guest_code_start:   \n"
+    "  mov x10, x0       \n"
+    "  ldp x0, x1, [x10] \n"
+    "  ldp x2, x3, [x10, #16]\n"
+    "  ldr w9, [x3]      \n"
+    "  add x0, x9, x0    \n"
+    "  add x0, x1, x0    \n"
+    "  str x0,[x3]       \n" /* late CoW */
+    "  str x0,[x2]       \n"
+    "  brk #0            \n"
+    "guest_code_end:     \n"
+    );
+
+extern char guest_code_start[], guest_code_end[];
+
+static int run_vm(struct shmem_area *area, bool is_protected,
+                  bool is_early_cow)
+{
+    int kvm, vmfd, vcpufd;
+    const uint64_t code_address = 0x1000;
+    const uint64_t mmio_address = 0x2000;
+    const uint64_t cow_address = 0x3000;
+    uint64_t *mem_code;
+    struct kvm_run *run;
+
+    printf("Making %sprotected VM\n", is_protected ? "" : "un");
+    printf("CoW happens %s\n", is_early_cow ? "early" : "late");
+
+    kvm = get_kvm();
+
+    vmfd = create_vm(kvm, is_protected);
+
+    /* Allocate one aligned page of guest memory to hold the code. */
+    mem_code = mmap(NULL, 0x1000,
+        PROT_READ | PROT_WRITE, MAP_SHARED | MAP_ANONYMOUS, -1, 0);
+    if (!mem_code)
+        err(1, "allocating guest memory");
+
+    if (is_early_cow)
+        memcpy(mem_code, guest_code_early_cow_start,
+               guest_code_early_cow_end-guest_code_early_cow_start);
+    else
+        memcpy(mem_code, guest_code_start, guest_code_end-guest_code_start);
+
+    mem_code[0x100+0] = 0x110;
+    mem_code[0x100+1] = 0x220;
+    mem_code[0x100+2] = mmio_address;
+    mem_code[0x100+3] = cow_address;
+
+    /* Memory space. */
+    vm_add_mem_page(vmfd, 0, code_address, mem_code);
+    //vm_add_mmio_page(vmfd, 1, mmio_address);
+    vm_add_mem_page(vmfd, 2, cow_address, area->mmap);
+
+    /* Create one CPU to run in the VM. */
+    vcpufd = create_vcpu(kvm, vmfd, &run, false);
+
+    /* Protected VM with no firmware: Note that we only control X0 and PC. */
+    set_one_reg(vcpufd, REG_PC, code_address);
+    set_one_reg(vcpufd, REG_X(0), code_address + 0x800);
+
+    /* Repeatedly run code and handle VM exits. */
+    for (;;) {
+        KVM_IOCTL(vcpufd, KVM_RUN, NULL);
+        switch (run->exit_reason) {
+        case KVM_EXIT_DEBUG:
+            puts("KVM_EXIT_DEBUG");
+            return 0;
+        case KVM_EXIT_MMIO:
+        {
+            uint64_t payload = *(uint64_t*)(run->mmio.data); /* sorry */
+            printf("KVM_EXIT_MMIO: addr = 0x%llx, len = %u, is_write = %u, data = 0x%"PRIx64"\n",
+                run->mmio.phys_addr, run->mmio.len, run->mmio.is_write,
+                payload);
+            return 0;/* XXX */
+            break;
+        }
+        case KVM_EXIT_FAIL_ENTRY:
+            errx(1, "KVM_EXIT_FAIL_ENTRY: hardware_entry_failure_reason = 0x%llx",
+                 (unsigned long long)run->fail_entry.hardware_entry_failure_reason);
+        case KVM_EXIT_INTERNAL_ERROR:
+            errx(1, "KVM_EXIT_INTERNAL_ERROR: suberror = 0x%x",
+                run->internal.suberror);
+        default:
+            errx(1, "exit_reason = 0x%x", run->exit_reason);
+        }
+    }
 }
 
-void shmem_area_mmap(struct shmem_area *area)
+static void shmem_area_getinfo(struct shmem_area *area)
 {
-	printf("Try to mmap kernel address %lx fd %d\n", area->kaddr, area->fd);
-	area->mmap = mmap(NULL, area->size, PROT_READ | PROT_WRITE,
-			  MAP_SHARED, area->fd, 0);
-	if ((void*)area->mmap == MAP_FAILED)
-		err(1, "Can't mmap kernel area");
+    if (ioctl(area->fd, KVM_SHMEM_AREA_PHYS, &area->phys))
+        err(1, "Can't get kernel area physical address");
+}
+
+static void shmem_area_mmap(struct shmem_area *area)
+{
+    area->mmap = mmap(NULL, area->size, PROT_READ | PROT_WRITE,
+                      MAP_SHARED, area->fd, 0);
+    if ((void*)area->mmap == MAP_FAILED)
+        err(1, "Can't mmap kernel area");
 }
 
 static void recv_fd(int socket, int *fds, int n)
 {
-	struct msghdr msg = {0};
-	struct cmsghdr *cmsg;
-	char buf[CMSG_SPACE(n * sizeof(int))], dup[256];
-	struct iovec io = { .iov_base = &dup, .iov_len = sizeof(dup) };
+    struct msghdr msg = {0};
+    struct cmsghdr *cmsg;
+    char buf[CMSG_SPACE(n * sizeof(int))], dup[256];
+    struct iovec io = { .iov_base = &dup, .iov_len = sizeof(dup) };
 
-	memset(buf, 0, sizeof(buf));
+    memset(buf, 0, sizeof(buf));
 
-	msg.msg_iov = &io;
-	msg.msg_iovlen = 1;
-	msg.msg_control = buf;
-	msg.msg_controllen = sizeof(buf);
+    msg.msg_iov = &io;
+    msg.msg_iovlen = 1;
+    msg.msg_control = buf;
+    msg.msg_controllen = sizeof(buf);
 
-	if (recvmsg (socket, &msg, 0) < 0)
-		err(1, "Failed to receive message");
+    if (recvmsg (socket, &msg, 0) < 0)
+        err(1, "Failed to receive message");
 
-	cmsg = CMSG_FIRSTHDR(&msg);
+    cmsg = CMSG_FIRSTHDR(&msg);
 
-	memcpy(fds, (int *)CMSG_DATA(cmsg), n * sizeof(int));
+    memcpy(fds, (int *)CMSG_DATA(cmsg), n * sizeof(int));
 }
 
-int mk_uds(const char *path)
+static int mk_uds(const char *path)
 {
-	int sfd;
-	struct sockaddr_un addr;
+    int sfd;
+    struct sockaddr_un addr;
 
-	sfd = socket(AF_UNIX, SOCK_STREAM, 0);
-	if (sfd == -1)
-		err(1, "Failed to create socket");
+    sfd = socket(AF_UNIX, SOCK_STREAM, 0);
+    if (sfd == -1)
+        err(1, "Failed to create socket");
 
-	memset(&addr, 0, sizeof(struct sockaddr_un));
-	addr.sun_family = AF_UNIX;
-	strncpy(addr.sun_path, path, sizeof(addr.sun_path) - 1);
+    memset(&addr, 0, sizeof(struct sockaddr_un));
+    addr.sun_family = AF_UNIX;
+    strncpy(addr.sun_path, path, sizeof(addr.sun_path) - 1);
 
-	if (connect(sfd, (struct sockaddr *)&addr, sizeof(struct sockaddr_un)) == -1)
-                err(1, "Failed to connect to socket");
+    if (connect(sfd, (struct sockaddr *)&addr, sizeof(struct sockaddr_un)) == -1)
+        err(1, "Failed to connect to socket");
 
-	return sfd;
+    return sfd;
+}
+
+static int nr_sigsegv;
+static void sigsegv_handler(int signr, siginfo_t *info, void *raw_context)
+{
+    ucontext_t *uc = raw_context;
+    mcontext_t *mc = &uc->uc_mcontext;
+    nr_sigsegv++;
+    printf("SIGSEGV #%d at PC %llx\n", nr_sigsegv, mc->pc);
+    mc->pc += 4;
+}
+
+static void set_sigv_handler(void)
+{
+    struct sigaction sa = { 0 };
+
+    sigemptyset(&sa.sa_mask);
+    sa.sa_sigaction = sigsegv_handler;
+    sa.sa_flags = SA_SIGINFO;
+    if (sigaction(SIGSEGV, &sa, NULL) != 0)
+        err(1, NULL);
+}
+
+static void clear_sigv_handler(void)
+{
+    struct sigaction sa = { 0 };
+
+    sigemptyset(&sa.sa_mask);
+    sa.sa_handler = SIG_DFL;
+    if (sigaction(SIGSEGV, &sa, NULL) != 0)
+        err(1, NULL);
 }
 
 int main(int argc, char** argv)
 {
-	int sock_fd;
-	struct shmem_area area = { .size = PAGE_SIZE };
+    int sock_fd;
+    struct shmem_area area = { .size = PAGE_SIZE };
+    bool is_protected_vm, is_early_cow;
 
-	sock_fd = mk_uds("shmem.sock");
-	recv_fd(sock_fd, &area.fd, 1);
-	printf("Received FD %d\n", area.fd);
+    if ((argc != 2) || (strlen(argv[1]) != 2)) {
+        printf("Usage: %s {up}{el}\n", argv[0]);
+        return 0;
+    }
 
-	shmem_area_getinfo(&area);
-	shmem_area_mmap(&area);
-	printf("Area: ka=%lx pa=%lx va=%p size=%d\n",
-	       area.kaddr, area.phys, area.mmap, area.size);
-	for (;;);
+    is_protected_vm = (argv[1][0] == 'p');
+    is_early_cow = (argv[1][1] == 'e');
 
-	return 0;
+    sock_fd = mk_uds("shmem.sock");
+    recv_fd(sock_fd, &area.fd, 1);
+    printf("Received FD %d\n", area.fd);
+
+    shmem_area_getinfo(&area);
+    shmem_area_mmap(&area);
+    printf("Area: pa=%lx va=%p size=%d\n",
+           area.phys, area.mmap, area.size);
+    printf("Value @ %lx: %08"PRIx32"\n",
+           area.phys, *(volatile uint32_t *)area.mmap);
+
+    /* Attempt to write to the read-only mapping. We expect SIGSEGV. */
+    printf("Testing write protection @ %lx\n", area.phys);
+    set_sigv_handler();
+    *(volatile uint32_t *)area.mmap = 0xdeadbeef;
+    clear_sigv_handler();
+
+    if (nr_sigsegv != 1) {
+        printf("Expected precisely one SIGSEGV signal\n");
+        return 0;
+    }
+
+    run_vm(&area, is_protected_vm, is_early_cow);
+
+    return 0;
 }
 
+/*
+ * Local variables:
+ * mode: C
+ * c-file-style: "Linux"
+ * c-basic-offset: 4
+ * tab-width: 4
+ * indent-tabs-mode: nil
+ * End:
+ */
diff --git a/kvmtest/cow-shmem.c b/kvmtest/cow-shmem.c
index 4283c72..379b7ca 100644
--- a/kvmtest/cow-shmem.c
+++ b/kvmtest/cow-shmem.c
@@ -4,6 +4,7 @@
 #include <stddef.h>
 #include <stdint.h>
 #include <stdlib.h>
+#include <inttypes.h>
 #include <sys/types.h>
 #include <sys/stat.h>
 #include <sys/ioctl.h>
@@ -25,8 +26,8 @@
 #define KVM_SHMEM_ALLOC_TYPE 1
 
 enum kvm_shmem_alloc_type { 
-	KVM_SHMEM_ALLOCTYPE_VMALLOC,
-	KVM_SHMEM_ALLOCTYPE_PAGES_EXACT };
+    KVM_SHMEM_ALLOCTYPE_VMALLOC,
+    KVM_SHMEM_ALLOCTYPE_PAGES_EXACT };
 
 #define KVM_SHMEM_ALLOC(alloc) _IO(KVM_SHMEM_ALLOC_TYPE, alloc)
 #define KVM_SHMEM_ALLOC_PAGES KVM_SHMEM_ALLOC(KVM_SHMEM_ALLOCTYPE_PAGES_EXACT)
@@ -45,116 +46,123 @@
     int size;
     int fd;
     void* mmap;
-    ulong kaddr;
     ulong phys;
 };
 
 void shmem_area_alloc(int fd, struct shmem_area *area)
 {
-	area->fd = ioctl(fd, KVM_SHMEM_ALLOC_PAGES, area->size);
-	if (area->fd < 0)
-		err(1, "Can't allocate shmem_area");
+    area->fd = ioctl(fd, KVM_SHMEM_ALLOC_PAGES, area->size);
+    if (area->fd < 0)
+        err(1, "Can't allocate shmem_area");
 
-	if (ioctl(area->fd, KVM_SHMEM_AREA_KADDR, &area->kaddr))
-		err(1, "Can't get kernel area address");
-	printf("Allocated area at %lx of size 0x%x bytes\n",
-	       area->kaddr, area->size);
-
-	if (ioctl(area->fd, KVM_SHMEM_AREA_PHYS, &area->phys))
-		err(1, "Can't get kernel area physical address");
+    if (ioctl(area->fd, KVM_SHMEM_AREA_PHYS, &area->phys))
+        err(1, "Can't get kernel area physical address");
 }
 
 void shmem_area_mmap(struct shmem_area *area)
 {
-	printf("Try to mmap kernel address %lx fd %d\n", area->kaddr, area->fd);
-	area->mmap = mmap(NULL, area->size, PROT_READ | PROT_WRITE,
-			  MAP_SHARED, area->fd, 0);
-	if ((void*)area->mmap == MAP_FAILED)
-		err(1, "Can't mmap kernel area");
+    area->mmap = mmap(NULL, area->size, PROT_READ | PROT_WRITE,
+                      MAP_SHARED, area->fd, 0);
+    if ((void*)area->mmap == MAP_FAILED)
+        err(1, "Can't mmap kernel area");
 }
 
 void shmem_area_make_readonly(struct shmem_area *area)
 {
-	if (ioctl(area->fd, KVM_SHMEM_AREA_MAKE_READONLY))
-		err(1, "Can't make area r/o");
+    if (ioctl(area->fd, KVM_SHMEM_AREA_MAKE_READONLY))
+        err(1, "Can't make area r/o");
 }
 
 static
 void send_fd(int socket, int *fds, int n)  // send fd by socket
 {
-	struct msghdr msg = {0};
-	struct cmsghdr *cmsg;
-	char buf[CMSG_SPACE(n * sizeof(int))], dup[256];
-	struct iovec io = { .iov_base = &dup, .iov_len = sizeof(dup) };
+    struct msghdr msg = {0};
+    struct cmsghdr *cmsg;
+    char buf[CMSG_SPACE(n * sizeof(int))], dup[256];
+    struct iovec io = { .iov_base = &dup, .iov_len = sizeof(dup) };
 
-	memset(buf, 0, sizeof(buf));
+    memset(buf, 0, sizeof(buf));
 
-	msg.msg_iov = &io;
-	msg.msg_iovlen = 1;
-	msg.msg_control = buf;
-	msg.msg_controllen = sizeof(buf);
+    msg.msg_iov = &io;
+    msg.msg_iovlen = 1;
+    msg.msg_control = buf;
+    msg.msg_controllen = sizeof(buf);
 
-	cmsg = CMSG_FIRSTHDR(&msg);
-	cmsg->cmsg_level = SOL_SOCKET;
-	cmsg->cmsg_type = SCM_RIGHTS;
-	cmsg->cmsg_len = CMSG_LEN(n * sizeof(int));
+    cmsg = CMSG_FIRSTHDR(&msg);
+    cmsg->cmsg_level = SOL_SOCKET;
+    cmsg->cmsg_type = SCM_RIGHTS;
+    cmsg->cmsg_len = CMSG_LEN(n * sizeof(int));
 
-	memcpy ((int *) CMSG_DATA(cmsg), fds, n * sizeof (int));
+    memcpy ((int *) CMSG_DATA(cmsg), fds, n * sizeof (int));
 
-	if (sendmsg(socket, &msg, 0) < 0)
-		err(1, "Failed to send message");
+    if (sendmsg(socket, &msg, 0) < 0)
+        err(1, "Failed to send message");
 }
 
 
 int mk_uds(const char *path)
 {
-	int sfd;
-	struct sockaddr_un addr;
+    int sfd;
+    struct sockaddr_un addr;
 
-	sfd = socket(AF_UNIX, SOCK_STREAM, 0);
-	if (sfd == -1)
-		err(1, "Failed to create socket");
+    sfd = socket(AF_UNIX, SOCK_STREAM, 0);
+    if (sfd == -1)
+        err(1, "Failed to create socket");
 
-	if (unlink(path) == -1 && errno != ENOENT)
-		err(1, "Removing socket file failed");
+    if (unlink(path) == -1 && errno != ENOENT)
+        err(1, "Removing socket file failed");
 
-	memset(&addr, 0, sizeof(struct sockaddr_un));
-	addr.sun_family = AF_UNIX;
-	strncpy(addr.sun_path, path, sizeof(addr.sun_path) - 1);
+    memset(&addr, 0, sizeof(struct sockaddr_un));
+    addr.sun_family = AF_UNIX;
+    strncpy(addr.sun_path, path, sizeof(addr.sun_path) - 1);
 
-	if (bind(sfd, (struct sockaddr *)&addr,
-		 sizeof(struct sockaddr_un)) == -1)
-		err(1, "Failed to bind to socket");
+    if (bind(sfd, (struct sockaddr *)&addr,
+             sizeof(struct sockaddr_un)) == -1)
+        err(1, "Failed to bind to socket");
 
-	if (listen(sfd, 5) == -1)
-		err(1, NULL);
+    if (listen(sfd, 5) == -1)
+        err(1, NULL);
 
-	return sfd;
+    return sfd;
 }
 
 int main(int argc, char** argv)
 {
-	int fd, sock_fd;
-	struct shmem_area area = { .size = PAGE_SIZE };
+    int fd, sock_fd;
+    struct shmem_area area = { .size = PAGE_SIZE };
 
-	fd = open("/sys/kernel/debug/pkvm_shmem", O_RDWR);
-	if (fd == -1) err(1, "pkvm_shmem");
+    fd = open("/sys/kernel/debug/pkvm_shmem", O_RDWR);
+    if (fd == -1) err(1, "pkvm_shmem");
 
-	shmem_area_alloc(fd, &area);
-	shmem_area_mmap(&area);
-	printf("Area: ka=%lx pa=%lx va=%p size=%d\n",
-	       area.kaddr, area.phys, area.mmap, area.size);
-	shmem_area_make_readonly(&area);
+    shmem_area_alloc(fd, &area);
+    shmem_area_mmap(&area);
+    printf("Area: pa=%lx va=%p size=%d\n",
+           area.phys, area.mmap, area.size);
+    *(volatile uint32_t *)area.mmap = 0xdeadf00d;
+    shmem_area_make_readonly(&area);
+    printf("Written %08"PRIx32" to %lx\n",
+           *(volatile uint32_t *)area.mmap, area.phys);
+    printf("Area %lx-%lx is now read-only\n",
+           area.phys, area.phys+area.size-1);
 
-	sock_fd = mk_uds("shmem.sock");
-	for (;;) {
-		int cfd = accept(sock_fd, NULL, NULL);
-		if (cfd == -1)
-			err(1, "Failed to accept incoming connection");
-		send_fd(cfd, &area.fd, 1);
-		close(cfd);
-	}
+    sock_fd = mk_uds("shmem.sock");
+    for (;;) {
+        int cfd = accept(sock_fd, NULL, NULL);
+        if (cfd == -1)
+            err(1, "Failed to accept incoming connection");
+        send_fd(cfd, &area.fd, 1);
+        close(cfd);
+    }
 
-	return 0;
+    return 0;
 }
 
+/*
+ * Local variables:
+ * mode: C
+ * c-file-style: "Linux"
+ * c-basic-offset: 4
+ * tab-width: 4
+ * indent-tabs-mode: nil
+ * End:
+ */
diff --git a/kvmtest/helpers.c b/kvmtest/helpers.c
new file mode 100644
index 0000000..6c30ba7
--- /dev/null
+++ b/kvmtest/helpers.c
@@ -0,0 +1,115 @@
+
+#include <err.h>
+#include <errno.h>
+#include <fcntl.h>
+#include <linux/kvm.h>
+#include <stdint.h>
+#include <stdio.h>
+#include <stdlib.h>
+#include <string.h>
+#include <sys/ioctl.h>
+#include <sys/mman.h>
+#include <sys/stat.h>
+#include <sys/types.h>
+
+#include "helpers.h"
+
+int get_kvm(void)
+{
+    int kvm, ret;
+
+    kvm = open("/dev/kvm", O_RDWR | O_CLOEXEC);
+    if (kvm < 0)
+        err(kvm, "/dev/kvm");
+
+    /* Ensure this is the stable version of the KVM API (defined as 12) */
+    ret = KVM_IOCTL(kvm, KVM_GET_API_VERSION, NULL);
+    if (ret != 12)
+        errx(-EINVAL, "KVM_GET_API_VERSION %d, expected 12", ret);
+
+    return kvm;
+}
+
+int create_vm(int kvm, bool is_protected)
+{
+    unsigned long flags = 0;
+    if (is_protected)
+        flags |= 1ul << 31;
+    return KVM_IOCTL(kvm, KVM_CREATE_VM, flags);
+}
+
+int create_vcpu(int kvm, int vmfd, struct kvm_run **run, bool enable_debug)
+{
+    struct kvm_guest_debug debug = {
+        .control = KVM_GUESTDBG_ENABLE,
+    };
+    struct kvm_vcpu_init vcpu_init;
+    size_t mmap_size;
+    int vcpufd;
+
+    /* Create one CPU to run in the VM. */
+    vcpufd = KVM_IOCTL(vmfd, KVM_CREATE_VCPU, (unsigned long)0);
+
+    /* Map the shared kvm_run structure and following data. */
+    mmap_size = KVM_IOCTL(kvm, KVM_GET_VCPU_MMAP_SIZE, NULL);
+    if (mmap_size < sizeof(*run))
+        err(-ENOMEM, "KVM_GET_VCPU_MMAP_SIZE unexpectedly small");
+    *run = mmap(NULL, mmap_size, PROT_READ | PROT_WRITE, MAP_SHARED, vcpufd, 0);
+    if (!*run)
+        err(-ENOMEM, "mmap vcpu");
+
+    /* Query KVM for preferred CPU target type that can be emulated. */
+    KVM_IOCTL(vmfd, KVM_ARM_PREFERRED_TARGET, &vcpu_init);
+    KVM_IOCTL(vcpufd, KVM_ARM_VCPU_INIT, &vcpu_init);
+
+    if (enable_debug)
+        KVM_IOCTL(vcpufd, KVM_SET_GUEST_DEBUG, &debug);
+
+    return vcpufd;
+}
+
+void vm_add_mem_page(int vmfd, int slot, uint64_t addr, void *uaddr)
+{
+    struct kvm_userspace_memory_region region = {
+        .slot = slot,
+        .guest_phys_addr = addr,
+        .userspace_addr = (uint64_t)uaddr,
+        .memory_size = PAGE_SIZE,
+    };
+
+    KVM_IOCTL(vmfd, KVM_SET_USER_MEMORY_REGION, &region);
+}
+
+void vm_add_mmio_page(int vmfd, int slot, uint64_t addr)
+{
+    struct kvm_userspace_memory_region region = {
+        .flags = KVM_MEM_READONLY,
+        .slot = slot,
+        .guest_phys_addr = addr,
+        .userspace_addr = 0ULL,
+        .memory_size = PAGE_SIZE,
+    };
+
+    KVM_IOCTL(vmfd, KVM_SET_USER_MEMORY_REGION, &region);
+}
+
+void set_one_reg(int vcpufd, uint64_t reg_id, uint64_t val)
+{
+    uint64_t reg_data;
+    struct kvm_one_reg reg;
+
+    reg.addr = (__u64) &reg_data;
+    reg_data = val;
+    reg.id = reg_id;
+    KVM_IOCTL(vcpufd, KVM_SET_ONE_REG, &reg);
+}
+
+/*
+ * Local variables:
+ * mode: C
+ * c-file-style: "Linux"
+ * c-basic-offset: 4
+ * tab-width: 4
+ * indent-tabs-mode: nil
+ * End:
+ */
diff --git a/kvmtest/helpers.h b/kvmtest/helpers.h
new file mode 100644
index 0000000..ba93f96
--- /dev/null
+++ b/kvmtest/helpers.h
@@ -0,0 +1,62 @@
+#ifndef __HELPERS_H
+#define __HELPERS_H
+
+#include <linux/kvm.h>
+
+#define PTRS_PER_PTE    512UL
+#define PAGE_SHIFT      12
+#define PAGE_SIZE       (1 << PAGE_SHIFT)
+#define PMD_SIZE        (PTRS_PER_PTE * PAGE_SIZE)
+#define PUD_SIZE        (PTRS_PER_PTE * PMD_SIZE)
+
+/* Register ids */
+#define REG_X(number)    (0x6030000000100000ULL + (number) * 2UL)
+#define REG_PC        0x6030000000100040ULL
+
+#define KVM_IOCTL(fd, ioctl_id, ...)            \
+({                                              \
+    int ret = ioctl(fd, ioctl_id, __VA_ARGS__); \
+    if (ret < 0)                                \
+        err(ret, #ioctl_id );                   \
+    ret;                                        \
+})
+
+typedef enum {
+    false = 0,
+    true = 1,
+} bool;
+
+int get_kvm(void);
+int create_vm(int kvm, bool is_protected);
+int create_vcpu(int kvm, int vmfd, struct kvm_run **run, bool debug);
+void vm_add_mem_page(int vmfd, int slot, uint64_t addr, void *uaddr);
+void vm_add_mmio_page(int vmfd, int slot, uint64_t addr);
+void set_one_reg(int vcpufd, uint64_t reg_id, uint64_t val);
+
+#define ALIGN(x, y)  (((x)+(y)-1) & ~((y)-1))
+
+#define BITMAP_SIZE(mem) (ALIGN((mem) / PAGE_SIZE, 64) / 8)
+#define OFFSET_IDX(offset) (offset / PAGE_SIZE / 8)
+#define OFFSET_BIT(offset) (1 << ((offset / PAGE_SIZE) % 8))
+
+static inline bool bitmap_equal(uint8_t *bm_a, uint8_t *bm_b, uint64_t len)
+{
+    for (uint64_t i = 0; i < len; i++) {
+        if (bm_a[i] != bm_b[i])
+            return false;
+    }
+
+    return true;
+}
+
+#endif /* __HELPERS_H */
+
+/*
+ * Local variables:
+ * mode: C
+ * c-file-style: "Linux"
+ * c-basic-offset: 4
+ * tab-width: 4
+ * indent-tabs-mode: nil
+ * End:
+ */
diff --git a/kvmtest/kvmtest-arm64.c b/kvmtest/kvmtest-arm64.c
index 4363ff5..fec4385 100644
--- a/kvmtest/kvmtest-arm64.c
+++ b/kvmtest/kvmtest-arm64.c
@@ -35,6 +35,7 @@
 #include <err.h>
 #include <fcntl.h>
 #include <linux/kvm.h>
+#include <inttypes.h>
 #include <stdint.h>
 #include <stdio.h>
 #include <stdlib.h>
@@ -44,134 +45,62 @@
 #include <sys/stat.h>
 #include <sys/types.h>
 
-int main(void)
-{
-    int kvm, vmfd, vcpufd, ret;
+#include "helpers.h"
 
-    /* Add x0 to x1 and outputs the result to MMIO at address in x2. */
-    const uint8_t code[] = {
-        0x20, 0x00, 0x00, 0x8b, /* add x0, x1, x0 */
-        0x40, 0x00, 0x00, 0xf9, /* str x0, [x2]*/
-        0x00, 0x00, 0x20, 0xd4, /* brk */
-    };
+asm(
+    ".global guest_code_start, guest_code_end\n"
+    "guest_code_start:   \n"
+    "  mov x10, x0       \n"
+    "  ldp x0, x1, [x10] \n"
+    "  ldp x2, x3, [x10, #16]\n"
+    "  add x0, x1, x0    \n"
+    "  str x0,[x2]       \n"
+    "  brk #0            \n"
+    "guest_code_end:     \n"
+    );
+
+extern char guest_code_start[], guest_code_end[];
+
+int run_vm(bool is_protected)
+{
+    int kvm, vmfd, vcpufd;
     const uint64_t code_address = 0x1000;
     const uint64_t mmio_address = 0x2000;
-    uint8_t *mem_code = NULL;
-    size_t mmap_size;
-    struct kvm_run *run = NULL;
+    uint64_t *mem_code;
+    struct kvm_run *run;
 
-    kvm = open("/dev/kvm", O_RDWR | O_CLOEXEC);
-    if (kvm < 0)
-        err(1, "/dev/kvm");
+    kvm = get_kvm();
 
-    /* Ensure this is the stable version of the KVM API (defined as 12) */
-    ret = ioctl(kvm, KVM_GET_API_VERSION, NULL);
-    if (ret < 0)
-        err(1, "KVM_GET_API_VERSION");
-    if (ret != 12)
-        errx(1, "KVM_GET_API_VERSION %d, expected 12", ret);
-
-    vmfd = ioctl(kvm, KVM_CREATE_VM, (unsigned long)0);
-    if (vmfd < 0)
-        err(1, "KVM_CREATE_VM");
+    vmfd = create_vm(kvm, is_protected);
 
     /* Allocate one aligned page of guest memory to hold the code. */
     mem_code = mmap(NULL, 0x1000,
         PROT_READ | PROT_WRITE, MAP_SHARED | MAP_ANONYMOUS, -1, 0);
     if (!mem_code)
         err(1, "allocating guest memory");
-    memcpy(mem_code, code, sizeof(code));
+
+    memcpy(mem_code, guest_code_start, guest_code_end-guest_code_start);
+    mem_code[0x100+0] = 0x101;
+    mem_code[0x100+1] = 0x202;
+    mem_code[0x100+2] = mmio_address;
+    mem_code[0x100+3] = 0xdeadbeef;
 
     /* Map code memory to the second page frame. */
-    struct kvm_userspace_memory_region region = {
-        .slot = 0,
-        .guest_phys_addr = code_address,
-        .memory_size = 0x1000,
-        .userspace_addr = (uint64_t)mem_code,
-    };
-    ret = ioctl(vmfd, KVM_SET_USER_MEMORY_REGION, &region);
-    if (ret < 0)
-        err(1, "KVM_SET_USER_MEMORY_REGION");
+    vm_add_mem_page(vmfd, 0, code_address, mem_code);
 
     /* Use third page frame of guest memory to simulate MMIO. */
-    region.flags = KVM_MEM_READONLY; /* triggers KVM_EXIT_MEMIO on write */
-    region.slot = 1;
-    region.guest_phys_addr = mmio_address;
-    region.userspace_addr = 0ULL;
-    ret = ioctl(vmfd, KVM_SET_USER_MEMORY_REGION, &region);
-    if (ret < 0)
-        err(1, "KVM_SET_USER_MEMORY_REGION");
+    vm_add_mmio_page(vmfd, 1, mmio_address);
 
     /* Create one CPU to run in the VM. */
-    vcpufd = ioctl(vmfd, KVM_CREATE_VCPU, (unsigned long)0);
-    if (vcpufd < 0)
-        err(1, "KVM_CREATE_VCPU");
+    vcpufd = create_vcpu(kvm, vmfd, &run, false);
 
-    /* Map the shared kvm_run structure and following data. */
-    ret = ioctl(kvm, KVM_GET_VCPU_MMAP_SIZE, NULL);
-    if (ret < 0)
-        err(1, "KVM_GET_VCPU_MMAP_SIZE");
-    mmap_size = ret;
-    if (mmap_size < sizeof(*run))
-        errx(1, "KVM_GET_VCPU_MMAP_SIZE unexpectedly small");
-    run = mmap(NULL, mmap_size, PROT_READ | PROT_WRITE, MAP_SHARED, vcpufd, 0);
-    if (!run)
-        err(1, "mmap vcpu");
-
-    /* Query KVM for preferred CPU target type that can be emulated. */
-    struct kvm_vcpu_init vcpu_init;
-    ret = ioctl(vmfd, KVM_ARM_PREFERRED_TARGET, &vcpu_init);
-    if (ret < 0)
-        err(1, "KVM_PREFERRED_TARGET");
-
-    /* Initialize VCPU with the preferred type obtained above. */
-    ret = ioctl(vcpufd, KVM_ARM_VCPU_INIT, &vcpu_init);
-    if (ret < 0)
-        err(1, "KVM_ARM_VCPU_INIT");
-
-    /* Prepare the kvm_one_reg structure to use for populating registers. */
-    uint64_t reg_data;
-    struct kvm_one_reg reg;
-    reg.addr = (__u64) &reg_data;
-
-    // Initialize input registers (x0 and x1) to 2.
-    reg_data = 2;
-    reg.id = 0x6030000000100000; // x0 id
-    ret = ioctl(vcpufd, KVM_SET_ONE_REG, &reg);
-    if (ret != 0)
-        err(1, "KVM_SET_ONE_REG");
-    reg.id = 0x6030000000100002; // x1 id
-    ret = ioctl(vcpufd, KVM_SET_ONE_REG, &reg);
-    if (ret != 0)
-        err(1, "KVM_SET_ONE_REG");
-
-    // Initialize x3 to point to the simulated MMIO region.
-    reg.id = 0x6030000000100004; // x3 id
-    reg_data = mmio_address;
-    ret = ioctl(vcpufd, KVM_SET_ONE_REG, &reg);
-    if (ret != 0)
-        err(1, "KVM_SET_ONE_REG");
-
-    // Initialize the PC to point to the start of the code.
-    reg.id = 0x6030000000100040; // pc id
-    reg_data = code_address;
-    ret = ioctl(vcpufd, KVM_SET_ONE_REG, &reg);
-    if (ret != 0)
-        err(1, "KVM_SET_ONE_REG");
-
-    // Enable debug so that brk instruction would exit KVM_RUN (KVM_EXIT_DEBUG).
-    struct kvm_guest_debug debug = {
-        .control = KVM_GUESTDBG_ENABLE,
-    };
-    ret = ioctl(vcpufd, KVM_SET_GUEST_DEBUG, &debug);
-    if (ret < 0)
-        err(1, "KVM_SET_GUEST_DEBUG");
+    /* Protected VM with no firmware: Note that we only control X0 and PC. */
+    set_one_reg(vcpufd, REG_PC, code_address);
+    set_one_reg(vcpufd, REG_X(0), code_address + 0x800);
 
     /* Repeatedly run code and handle VM exits. */
     for (;;) {
-        ret = ioctl(vcpufd, KVM_RUN, NULL);
-        if (ret < 0)
-            err(1, "KVM_RUN");
+        KVM_IOCTL(vcpufd, KVM_RUN, NULL);
         switch (run->exit_reason) {
         case KVM_EXIT_DEBUG:
             puts("KVM_EXIT_DEBUG");
@@ -179,9 +108,10 @@
         case KVM_EXIT_MMIO:
         {
             uint64_t payload = *(uint64_t*)(run->mmio.data); /* sorry */
-            printf("KVM_EXIT_MMIO: addr = 0x%llx, len = %u, is_write = %u, data = 0x%08llx\n",
+            printf("KVM_EXIT_MMIO: addr = 0x%llx, len = %u, is_write = %u, data = 0x%"PRIx64"\n",
                 run->mmio.phys_addr, run->mmio.len, run->mmio.is_write,
                 payload);
+            return 0;/* XXX */
             break;
         }
         case KVM_EXIT_FAIL_ENTRY:
@@ -195,3 +125,18 @@
         }
     }
 }
+
+int main(int argc, char **argv)
+{
+    return run_vm(argc > 1);
+}
+
+/*
+ * Local variables:
+ * mode: C
+ * c-file-style: "Linux"
+ * c-basic-offset: 4
+ * tab-width: 4
+ * indent-tabs-mode: nil
+ * End:
+ */