arm64: context-switch user tls register tpidr_el0 for compat tasks

Since commit a4780adeefd0 ("ARM: 7735/2: Preserve the user r/w register
TPIDRURW on context switch and fork"), arch/arm/ has context switched
the user-writable TLS register, so do the same for compat tasks running
under the arm64 kernel.

Reported-by: André Hentschel <nerv@dawncrow.de>
Tested-by: André Hentschel <nerv@dawncrow.de>
Signed-off-by: Will Deacon <will.deacon@arm.com>
Signed-off-by: Catalin Marinas <catalin.marinas@arm.com>
diff --git a/arch/arm64/include/asm/processor.h b/arch/arm64/include/asm/processor.h
index d2c37a1..e4c893e 100644
--- a/arch/arm64/include/asm/processor.h
+++ b/arch/arm64/include/asm/processor.h
@@ -78,13 +78,30 @@
 
 struct thread_struct {
 	struct cpu_context	cpu_context;	/* cpu context */
-	unsigned long		tp_value;
+	unsigned long		tp_value;	/* TLS register */
+#ifdef CONFIG_COMPAT
+	unsigned long		tp2_value;
+#endif
 	struct fpsimd_state	fpsimd_state;
 	unsigned long		fault_address;	/* fault info */
 	unsigned long		fault_code;	/* ESR_EL1 value */
 	struct debug_info	debug;		/* debugging */
 };
 
+#ifdef CONFIG_COMPAT
+#define task_user_tls(t)						\
+({									\
+	unsigned long *__tls;						\
+	if (is_compat_thread(task_thread_info(t)))			\
+		__tls = &(t)->thread.tp2_value;				\
+	else								\
+		__tls = &(t)->thread.tp_value;				\
+	__tls;								\
+ })
+#else
+#define task_user_tls(t)	(&(t)->thread.tp_value)
+#endif
+
 #define INIT_THREAD  {	}
 
 static inline void start_thread_common(struct pt_regs *regs, unsigned long pc)
diff --git a/arch/arm64/kernel/process.c b/arch/arm64/kernel/process.c
index c506bee..369f485 100644
--- a/arch/arm64/kernel/process.c
+++ b/arch/arm64/kernel/process.c
@@ -244,35 +244,35 @@
 		unsigned long stk_sz, struct task_struct *p)
 {
 	struct pt_regs *childregs = task_pt_regs(p);
-	unsigned long tls = p->thread.tp_value;
 
 	memset(&p->thread.cpu_context, 0, sizeof(struct cpu_context));
 
 	if (likely(!(p->flags & PF_KTHREAD))) {
 		*childregs = *current_pt_regs();
 		childregs->regs[0] = 0;
-		if (is_compat_thread(task_thread_info(p))) {
-			if (stack_start)
+
+		/*
+		 * Read the current TLS pointer from tpidr_el0 as it may be
+		 * out-of-sync with the saved value.
+		 */
+		asm("mrs %0, tpidr_el0" : "=r" (*task_user_tls(p)));
+
+		if (stack_start) {
+			if (is_compat_thread(task_thread_info(p)))
 				childregs->compat_sp = stack_start;
-		} else {
-			/*
-			 * Read the current TLS pointer from tpidr_el0 as it may be
-			 * out-of-sync with the saved value.
-			 */
-			asm("mrs %0, tpidr_el0" : "=r" (tls));
-			if (stack_start) {
-				/* 16-byte aligned stack mandatory on AArch64 */
-				if (stack_start & 15)
-					return -EINVAL;
+			/* 16-byte aligned stack mandatory on AArch64 */
+			else if (stack_start & 15)
+				return -EINVAL;
+			else
 				childregs->sp = stack_start;
-			}
 		}
+
 		/*
 		 * If a TLS pointer was passed to clone (4th argument), use it
 		 * for the new thread.
 		 */
 		if (clone_flags & CLONE_SETTLS)
-			tls = childregs->regs[3];
+			p->thread.tp_value = childregs->regs[3];
 	} else {
 		memset(childregs, 0, sizeof(struct pt_regs));
 		childregs->pstate = PSR_MODE_EL1h;
@@ -281,7 +281,6 @@
 	}
 	p->thread.cpu_context.pc = (unsigned long)ret_from_fork;
 	p->thread.cpu_context.sp = (unsigned long)childregs;
-	p->thread.tp_value = tls;
 
 	ptrace_hw_copy_thread(p);
 
@@ -292,18 +291,12 @@
 {
 	unsigned long tpidr, tpidrro;
 
-	if (!is_compat_task()) {
-		asm("mrs %0, tpidr_el0" : "=r" (tpidr));
-		current->thread.tp_value = tpidr;
-	}
+	asm("mrs %0, tpidr_el0" : "=r" (tpidr));
+	*task_user_tls(current) = tpidr;
 
-	if (is_compat_thread(task_thread_info(next))) {
-		tpidr = 0;
-		tpidrro = next->thread.tp_value;
-	} else {
-		tpidr = next->thread.tp_value;
-		tpidrro = 0;
-	}
+	tpidr = *task_user_tls(next);
+	tpidrro = is_compat_thread(task_thread_info(next)) ?
+		  next->thread.tp_value : 0;
 
 	asm(
 	"	msr	tpidr_el0, %0\n"