x86: bpf_jit: small optimization in emit_bpf_tail_call()

Saves 4 bytes replacing following instructions :

lea rax, [rsi + rdx * 8 + offsetof(...)]
mov rax, qword ptr [rax]
cmp rax, 0

by :

mov rax, [rsi + rdx * 8 + offsetof(...)]
test rax, rax

Signed-off-by: Eric Dumazet <edumazet@google.com>
Cc: Alexei Starovoitov <ast@kernel.org>
Cc: Daniel Borkmann <daniel@iogearbox.net>
Acked-by: Daniel Borkmann <daniel@iogearbox.net>
Acked-by: Alexei Starovoitov <ast@kernel.org>
Signed-off-by: David S. Miller <davem@davemloft.net>
diff --git a/arch/x86/net/bpf_jit_comp.c b/arch/x86/net/bpf_jit_comp.c
index 8194696..8c95736 100644
--- a/arch/x86/net/bpf_jit_comp.c
+++ b/arch/x86/net/bpf_jit_comp.c
@@ -287,7 +287,7 @@ static void emit_bpf_tail_call(u8 **pprog)
 	EMIT4(0x48, 0x8B, 0x46,                   /* mov rax, qword ptr [rsi + 16] */
 	      offsetof(struct bpf_array, map.max_entries));
 	EMIT3(0x48, 0x39, 0xD0);                  /* cmp rax, rdx */
-#define OFFSET1 47 /* number of bytes to jump */
+#define OFFSET1 43 /* number of bytes to jump */
 	EMIT2(X86_JBE, OFFSET1);                  /* jbe out */
 	label1 = cnt;
 
@@ -296,21 +296,20 @@ static void emit_bpf_tail_call(u8 **pprog)
 	 */
 	EMIT2_off32(0x8B, 0x85, 36);              /* mov eax, dword ptr [rbp + 36] */
 	EMIT3(0x83, 0xF8, MAX_TAIL_CALL_CNT);     /* cmp eax, MAX_TAIL_CALL_CNT */
-#define OFFSET2 36
+#define OFFSET2 32
 	EMIT2(X86_JA, OFFSET2);                   /* ja out */
 	label2 = cnt;
 	EMIT3(0x83, 0xC0, 0x01);                  /* add eax, 1 */
 	EMIT2_off32(0x89, 0x85, 36);              /* mov dword ptr [rbp + 36], eax */
 
 	/* prog = array->ptrs[index]; */
-	EMIT4_off32(0x48, 0x8D, 0x84, 0xD6,       /* lea rax, [rsi + rdx * 8 + offsetof(...)] */
+	EMIT4_off32(0x48, 0x8B, 0x84, 0xD6,       /* mov rax, [rsi + rdx * 8 + offsetof(...)] */
 		    offsetof(struct bpf_array, ptrs));
-	EMIT3(0x48, 0x8B, 0x00);                  /* mov rax, qword ptr [rax] */
 
 	/* if (prog == NULL)
 	 *   goto out;
 	 */
-	EMIT4(0x48, 0x83, 0xF8, 0x00);            /* cmp rax, 0 */
+	EMIT3(0x48, 0x85, 0xC0);		  /* test rax,rax */
 #define OFFSET3 10
 	EMIT2(X86_JE, OFFSET3);                   /* je out */
 	label3 = cnt;