bpf: Track alignment of register values in the verifier.

Currently if we add only constant values to pointers we can fully
validate the alignment, and properly check if we need to reject the
program on !CONFIG_HAVE_EFFICIENT_UNALIGNED_ACCESS architectures.

However, once an unknown value is introduced we only allow byte sized
memory accesses which is too restrictive.

Add logic to track the known minimum alignment of register values,
and propagate this state into registers containing pointers.

The most common paradigm that makes use of this new logic is computing
the transport header using the IP header length field.  For example:

	struct ethhdr *ep = skb->data;
	struct iphdr *iph = (struct iphdr *) (ep + 1);
	struct tcphdr *th;
 ...
	n = iph->ihl;
	th = ((void *)iph + (n * 4));
	port = th->dest;

The existing code will reject the load of th->dest because it cannot
validate that the alignment is at least 2 once "n * 4" is added the
the packet pointer.

In the new code, the register holding "n * 4" will have a reg->min_align
value of 4, because any value multiplied by 4 will be at least 4 byte
aligned.  (actually, the eBPF code emitted by the compiler in this case
is most likely to use a shift left by 2, but the end result is identical)

At the critical addition:

	th = ((void *)iph + (n * 4));

The register holding 'th' will start with reg->off value of 14.  The
pointer addition will transform that reg into something that looks like:

	reg->aux_off = 14
	reg->aux_off_align = 4

Next, the verifier will look at the th->dest load, and it will see
a load offset of 2, and first check:

	if (reg->aux_off_align % size)

which will pass because aux_off_align is 4.  reg_off will be computed:

	reg_off = reg->off;
 ...
		reg_off += reg->aux_off;

plus we have off==2, and it will thus check:

	if ((NET_IP_ALIGN + reg_off + off) % size != 0)

which evaluates to:

	if ((NET_IP_ALIGN + 14 + 2) % size != 0)

On strict alignment architectures, NET_IP_ALIGN is 2, thus:

	if ((2 + 14 + 2) % size != 0)

which passes.

These pointer transformations and checks work regardless of whether
the constant offset or the variable with known alignment is added
first to the pointer register.

Signed-off-by: David S. Miller <davem@davemloft.net>
Acked-by: Daniel Borkmann <daniel@iogearbox.net>
diff --git a/include/linux/bpf_verifier.h b/include/linux/bpf_verifier.h
index 5efb4db..7c6a519 100644
--- a/include/linux/bpf_verifier.h
+++ b/include/linux/bpf_verifier.h
@@ -40,6 +40,9 @@ struct bpf_reg_state {
 	 */
 	s64 min_value;
 	u64 max_value;
+	u32 min_align;
+	u32 aux_off;
+	u32 aux_off_align;
 };
 
 enum bpf_stack_slot_type {
diff --git a/kernel/bpf/verifier.c b/kernel/bpf/verifier.c
index c5b56c9..cc7b626 100644
--- a/kernel/bpf/verifier.c
+++ b/kernel/bpf/verifier.c
@@ -241,6 +241,12 @@ static void print_verifier_state(struct bpf_verifier_state *state)
 		if (reg->max_value != BPF_REGISTER_MAX_RANGE)
 			verbose(",max_value=%llu",
 				(unsigned long long)reg->max_value);
+		if (reg->min_align)
+			verbose(",min_align=%u", reg->min_align);
+		if (reg->aux_off)
+			verbose(",aux_off=%u", reg->aux_off);
+		if (reg->aux_off_align)
+			verbose(",aux_off_align=%u", reg->aux_off_align);
 	}
 	for (i = 0; i < MAX_BPF_STACK; i += BPF_REG_SIZE) {
 		if (state->stack_slot_type[i] == STACK_SPILL)
@@ -466,6 +472,9 @@ static void init_reg_state(struct bpf_reg_state *regs)
 		regs[i].imm = 0;
 		regs[i].min_value = BPF_REGISTER_MIN_RANGE;
 		regs[i].max_value = BPF_REGISTER_MAX_RANGE;
+		regs[i].min_align = 0;
+		regs[i].aux_off = 0;
+		regs[i].aux_off_align = 0;
 	}
 
 	/* frame pointer */
@@ -492,6 +501,7 @@ static void reset_reg_range_values(struct bpf_reg_state *regs, u32 regno)
 {
 	regs[regno].min_value = BPF_REGISTER_MIN_RANGE;
 	regs[regno].max_value = BPF_REGISTER_MAX_RANGE;
+	regs[regno].min_align = 0;
 }
 
 static void mark_reg_unknown_value_and_range(struct bpf_reg_state *regs,
@@ -779,17 +789,28 @@ static bool is_pointer_value(struct bpf_verifier_env *env, int regno)
 }
 
 static int check_pkt_ptr_alignment(const struct bpf_reg_state *reg,
-				   int off, int size)
+				   int off, int size, bool strict)
 {
-	if (reg->id && size != 1) {
-		verbose("Unknown alignment. Only byte-sized access allowed in packet access.\n");
-		return -EACCES;
+	int reg_off;
+
+	/* Byte size accesses are always allowed. */
+	if (!strict || size == 1)
+		return 0;
+
+	reg_off = reg->off;
+	if (reg->id) {
+		if (reg->aux_off_align % size) {
+			verbose("Packet access is only %u byte aligned, %d byte access not allowed\n",
+				reg->aux_off_align, size);
+			return -EACCES;
+		}
+		reg_off += reg->aux_off;
 	}
 
 	/* skb->data is NET_IP_ALIGN-ed */
-	if ((NET_IP_ALIGN + reg->off + off) % size != 0) {
+	if ((NET_IP_ALIGN + reg_off + off) % size != 0) {
 		verbose("misaligned packet access off %d+%d+%d size %d\n",
-			NET_IP_ALIGN, reg->off, off, size);
+			NET_IP_ALIGN, reg_off, off, size);
 		return -EACCES;
 	}
 
@@ -797,9 +818,9 @@ static int check_pkt_ptr_alignment(const struct bpf_reg_state *reg,
 }
 
 static int check_val_ptr_alignment(const struct bpf_reg_state *reg,
-				   int size)
+				   int size, bool strict)
 {
-	if (size != 1) {
+	if (strict && size != 1) {
 		verbose("Unknown alignment. Only byte-sized access allowed in value access.\n");
 		return -EACCES;
 	}
@@ -810,13 +831,16 @@ static int check_val_ptr_alignment(const struct bpf_reg_state *reg,
 static int check_ptr_alignment(const struct bpf_reg_state *reg,
 			       int off, int size)
 {
+	bool strict = false;
+
+	if (!IS_ENABLED(CONFIG_HAVE_EFFICIENT_UNALIGNED_ACCESS))
+		strict = true;
+
 	switch (reg->type) {
 	case PTR_TO_PACKET:
-		return IS_ENABLED(CONFIG_HAVE_EFFICIENT_UNALIGNED_ACCESS) ? 0 :
-		       check_pkt_ptr_alignment(reg, off, size);
+		return check_pkt_ptr_alignment(reg, off, size, strict);
 	case PTR_TO_MAP_VALUE_ADJ:
-		return IS_ENABLED(CONFIG_HAVE_EFFICIENT_UNALIGNED_ACCESS) ? 0 :
-		       check_val_ptr_alignment(reg, size);
+		return check_val_ptr_alignment(reg, size, strict);
 	default:
 		if (off % size != 0) {
 			verbose("misaligned access off %d size %d\n",
@@ -883,6 +907,8 @@ static int check_mem_access(struct bpf_verifier_env *env, u32 regno, int off,
 							 value_regno);
 			/* note that reg.[id|off|range] == 0 */
 			state->regs[value_regno].type = reg_type;
+			state->regs[value_regno].aux_off = 0;
+			state->regs[value_regno].aux_off_align = 0;
 		}
 
 	} else if (reg->type == FRAME_PTR || reg->type == PTR_TO_STACK) {
@@ -1455,6 +1481,8 @@ static int check_packet_ptr_add(struct bpf_verifier_env *env,
 		 */
 		dst_reg->off += imm;
 	} else {
+		bool had_id;
+
 		if (src_reg->type == PTR_TO_PACKET) {
 			/* R6=pkt(id=0,off=0,r=62) R7=imm22; r7 += r6 */
 			tmp_reg = *dst_reg;  /* save r7 state */
@@ -1488,14 +1516,23 @@ static int check_packet_ptr_add(struct bpf_verifier_env *env,
 				src_reg->imm);
 			return -EACCES;
 		}
+
+		had_id = (dst_reg->id != 0);
+
 		/* dst_reg stays as pkt_ptr type and since some positive
 		 * integer value was added to the pointer, increment its 'id'
 		 */
 		dst_reg->id = ++env->id_gen;
 
-		/* something was added to pkt_ptr, set range and off to zero */
+		/* something was added to pkt_ptr, set range to zero */
+		dst_reg->aux_off = dst_reg->off;
 		dst_reg->off = 0;
 		dst_reg->range = 0;
+		if (had_id)
+			dst_reg->aux_off_align = min(dst_reg->aux_off_align,
+						     src_reg->min_align);
+		else
+			dst_reg->aux_off_align = src_reg->min_align;
 	}
 	return 0;
 }
@@ -1669,6 +1706,13 @@ static void check_reg_overflow(struct bpf_reg_state *reg)
 		reg->min_value = BPF_REGISTER_MIN_RANGE;
 }
 
+static u32 calc_align(u32 imm)
+{
+	if (!imm)
+		return 1U << 31;
+	return imm - ((imm - 1) & imm);
+}
+
 static void adjust_reg_min_max_vals(struct bpf_verifier_env *env,
 				    struct bpf_insn *insn)
 {
@@ -1676,8 +1720,10 @@ static void adjust_reg_min_max_vals(struct bpf_verifier_env *env,
 	s64 min_val = BPF_REGISTER_MIN_RANGE;
 	u64 max_val = BPF_REGISTER_MAX_RANGE;
 	u8 opcode = BPF_OP(insn->code);
+	u32 dst_align, src_align;
 
 	dst_reg = &regs[insn->dst_reg];
+	src_align = 0;
 	if (BPF_SRC(insn->code) == BPF_X) {
 		check_reg_overflow(&regs[insn->src_reg]);
 		min_val = regs[insn->src_reg].min_value;
@@ -1693,12 +1739,18 @@ static void adjust_reg_min_max_vals(struct bpf_verifier_env *env,
 		    regs[insn->src_reg].type != UNKNOWN_VALUE) {
 			min_val = BPF_REGISTER_MIN_RANGE;
 			max_val = BPF_REGISTER_MAX_RANGE;
+			src_align = 0;
+		} else {
+			src_align = regs[insn->src_reg].min_align;
 		}
 	} else if (insn->imm < BPF_REGISTER_MAX_RANGE &&
 		   (s64)insn->imm > BPF_REGISTER_MIN_RANGE) {
 		min_val = max_val = insn->imm;
+		src_align = calc_align(insn->imm);
 	}
 
+	dst_align = dst_reg->min_align;
+
 	/* We don't know anything about what was done to this register, mark it
 	 * as unknown.
 	 */
@@ -1723,18 +1775,21 @@ static void adjust_reg_min_max_vals(struct bpf_verifier_env *env,
 			dst_reg->min_value += min_val;
 		if (dst_reg->max_value != BPF_REGISTER_MAX_RANGE)
 			dst_reg->max_value += max_val;
+		dst_reg->min_align = min(src_align, dst_align);
 		break;
 	case BPF_SUB:
 		if (dst_reg->min_value != BPF_REGISTER_MIN_RANGE)
 			dst_reg->min_value -= min_val;
 		if (dst_reg->max_value != BPF_REGISTER_MAX_RANGE)
 			dst_reg->max_value -= max_val;
+		dst_reg->min_align = min(src_align, dst_align);
 		break;
 	case BPF_MUL:
 		if (dst_reg->min_value != BPF_REGISTER_MIN_RANGE)
 			dst_reg->min_value *= min_val;
 		if (dst_reg->max_value != BPF_REGISTER_MAX_RANGE)
 			dst_reg->max_value *= max_val;
+		dst_reg->min_align = max(src_align, dst_align);
 		break;
 	case BPF_AND:
 		/* Disallow AND'ing of negative numbers, ain't nobody got time
@@ -1746,17 +1801,23 @@ static void adjust_reg_min_max_vals(struct bpf_verifier_env *env,
 		else
 			dst_reg->min_value = 0;
 		dst_reg->max_value = max_val;
+		dst_reg->min_align = max(src_align, dst_align);
 		break;
 	case BPF_LSH:
 		/* Gotta have special overflow logic here, if we're shifting
 		 * more than MAX_RANGE then just assume we have an invalid
 		 * range.
 		 */
-		if (min_val > ilog2(BPF_REGISTER_MAX_RANGE))
+		if (min_val > ilog2(BPF_REGISTER_MAX_RANGE)) {
 			dst_reg->min_value = BPF_REGISTER_MIN_RANGE;
-		else if (dst_reg->min_value != BPF_REGISTER_MIN_RANGE)
-			dst_reg->min_value <<= min_val;
-
+			dst_reg->min_align = 1;
+		} else {
+			if (dst_reg->min_value != BPF_REGISTER_MIN_RANGE)
+				dst_reg->min_value <<= min_val;
+			if (!dst_reg->min_align)
+				dst_reg->min_align = 1;
+			dst_reg->min_align <<= min_val;
+		}
 		if (max_val > ilog2(BPF_REGISTER_MAX_RANGE))
 			dst_reg->max_value = BPF_REGISTER_MAX_RANGE;
 		else if (dst_reg->max_value != BPF_REGISTER_MAX_RANGE)
@@ -1766,11 +1827,19 @@ static void adjust_reg_min_max_vals(struct bpf_verifier_env *env,
 		/* RSH by a negative number is undefined, and the BPF_RSH is an
 		 * unsigned shift, so make the appropriate casts.
 		 */
-		if (min_val < 0 || dst_reg->min_value < 0)
+		if (min_val < 0 || dst_reg->min_value < 0) {
 			dst_reg->min_value = BPF_REGISTER_MIN_RANGE;
-		else
+		} else {
 			dst_reg->min_value =
 				(u64)(dst_reg->min_value) >> min_val;
+		}
+		if (min_val < 0) {
+			dst_reg->min_align = 1;
+		} else {
+			dst_reg->min_align >>= (u64) min_val;
+			if (!dst_reg->min_align)
+				dst_reg->min_align = 1;
+		}
 		if (dst_reg->max_value != BPF_REGISTER_MAX_RANGE)
 			dst_reg->max_value >>= max_val;
 		break;
@@ -1872,6 +1941,7 @@ static int check_alu_op(struct bpf_verifier_env *env, struct bpf_insn *insn)
 			regs[insn->dst_reg].imm = insn->imm;
 			regs[insn->dst_reg].max_value = insn->imm;
 			regs[insn->dst_reg].min_value = insn->imm;
+			regs[insn->dst_reg].min_align = calc_align(insn->imm);
 		}
 
 	} else if (opcode > BPF_END) {