bpf: Add reference tracking to verifier

Allow helper functions to acquire a reference and return it into a
register. Specific pointer types such as the PTR_TO_SOCKET will
implicitly represent such a reference. The verifier must ensure that
these references are released exactly once in each path through the
program.

To achieve this, this commit assigns an id to the pointer and tracks it
in the 'bpf_func_state', then when the function or program exits,
verifies that all of the acquired references have been freed. When the
pointer is passed to a function that frees the reference, it is removed
from the 'bpf_func_state` and all existing copies of the pointer in
registers are marked invalid.

Signed-off-by: Joe Stringer <joe@wand.net.nz>
Acked-by: Alexei Starovoitov <ast@kernel.org>
Signed-off-by: Daniel Borkmann <daniel@iogearbox.net>
diff --git a/include/linux/bpf_verifier.h b/include/linux/bpf_verifier.h
index a411363..7b6fd2a 100644
--- a/include/linux/bpf_verifier.h
+++ b/include/linux/bpf_verifier.h
@@ -104,6 +104,17 @@ struct bpf_stack_state {
 	u8 slot_type[BPF_REG_SIZE];
 };
 
+struct bpf_reference_state {
+	/* Track each reference created with a unique id, even if the same
+	 * instruction creates the reference multiple times (eg, via CALL).
+	 */
+	int id;
+	/* Instruction where the allocation of this reference occurred. This
+	 * is used purely to inform the user of a reference leak.
+	 */
+	int insn_idx;
+};
+
 /* state of the program:
  * type of all registers and stack info
  */
@@ -121,7 +132,9 @@ struct bpf_func_state {
 	 */
 	u32 subprogno;
 
-	/* should be second to last. See copy_func_state() */
+	/* The following fields should be last. See copy_func_state() */
+	int acquired_refs;
+	struct bpf_reference_state *refs;
 	int allocated_stack;
 	struct bpf_stack_state *stack;
 };
@@ -217,11 +230,16 @@ __printf(2, 0) void bpf_verifier_vlog(struct bpf_verifier_log *log,
 __printf(2, 3) void bpf_verifier_log_write(struct bpf_verifier_env *env,
 					   const char *fmt, ...);
 
-static inline struct bpf_reg_state *cur_regs(struct bpf_verifier_env *env)
+static inline struct bpf_func_state *cur_func(struct bpf_verifier_env *env)
 {
 	struct bpf_verifier_state *cur = env->cur_state;
 
-	return cur->frame[cur->curframe]->regs;
+	return cur->frame[cur->curframe];
+}
+
+static inline struct bpf_reg_state *cur_regs(struct bpf_verifier_env *env)
+{
+	return cur_func(env)->regs;
 }
 
 int bpf_prog_offload_verifier_prep(struct bpf_verifier_env *env);
diff --git a/kernel/bpf/verifier.c b/kernel/bpf/verifier.c
index 11e9823..cd0d8bc 100644
--- a/kernel/bpf/verifier.c
+++ b/kernel/bpf/verifier.c
@@ -1,5 +1,6 @@
 /* Copyright (c) 2011-2014 PLUMgrid, http://plumgrid.com
  * Copyright (c) 2016 Facebook
+ * Copyright (c) 2018 Covalent IO, Inc. http://covalent.io
  *
  * This program is free software; you can redistribute it and/or
  * modify it under the terms of version 2 of the GNU General Public
@@ -140,6 +141,18 @@ static const struct bpf_verifier_ops * const bpf_verifier_ops[] = {
  *
  * After the call R0 is set to return type of the function and registers R1-R5
  * are set to NOT_INIT to indicate that they are no longer readable.
+ *
+ * The following reference types represent a potential reference to a kernel
+ * resource which, after first being allocated, must be checked and freed by
+ * the BPF program:
+ * - PTR_TO_SOCKET_OR_NULL, PTR_TO_SOCKET
+ *
+ * When the verifier sees a helper call return a reference type, it allocates a
+ * pointer id for the reference and stores it in the current function state.
+ * Similar to the way that PTR_TO_MAP_VALUE_OR_NULL is converted into
+ * PTR_TO_MAP_VALUE, PTR_TO_SOCKET_OR_NULL becomes PTR_TO_SOCKET when the type
+ * passes through a NULL-check conditional. For the branch wherein the state is
+ * changed to CONST_IMM, the verifier releases the reference.
  */
 
 /* verifier_state + insn_idx are pushed to stack when branch is encountered */
@@ -189,6 +202,7 @@ struct bpf_call_arg_meta {
 	int access_size;
 	s64 msize_smax_value;
 	u64 msize_umax_value;
+	int ptr_id;
 };
 
 static DEFINE_MUTEX(bpf_verifier_lock);
@@ -251,7 +265,42 @@ static bool type_is_pkt_pointer(enum bpf_reg_type type)
 
 static bool reg_type_may_be_null(enum bpf_reg_type type)
 {
-	return type == PTR_TO_MAP_VALUE_OR_NULL;
+	return type == PTR_TO_MAP_VALUE_OR_NULL ||
+	       type == PTR_TO_SOCKET_OR_NULL;
+}
+
+static bool type_is_refcounted(enum bpf_reg_type type)
+{
+	return type == PTR_TO_SOCKET;
+}
+
+static bool type_is_refcounted_or_null(enum bpf_reg_type type)
+{
+	return type == PTR_TO_SOCKET || type == PTR_TO_SOCKET_OR_NULL;
+}
+
+static bool reg_is_refcounted(const struct bpf_reg_state *reg)
+{
+	return type_is_refcounted(reg->type);
+}
+
+static bool reg_is_refcounted_or_null(const struct bpf_reg_state *reg)
+{
+	return type_is_refcounted_or_null(reg->type);
+}
+
+static bool arg_type_is_refcounted(enum bpf_arg_type type)
+{
+	return type == ARG_PTR_TO_SOCKET;
+}
+
+/* Determine whether the function releases some resources allocated by another
+ * function call. The first reference type argument will be assumed to be
+ * released by release_reference().
+ */
+static bool is_release_function(enum bpf_func_id func_id)
+{
+	return false;
 }
 
 /* string representation of 'enum bpf_reg_type' */
@@ -385,6 +434,12 @@ static void print_verifier_state(struct bpf_verifier_env *env,
 		else
 			verbose(env, "=%s", types_buf);
 	}
+	if (state->acquired_refs && state->refs[0].id) {
+		verbose(env, " refs=%d", state->refs[0].id);
+		for (i = 1; i < state->acquired_refs; i++)
+			if (state->refs[i].id)
+				verbose(env, ",%d", state->refs[i].id);
+	}
 	verbose(env, "\n");
 }
 
@@ -403,6 +458,8 @@ static int copy_##NAME##_state(struct bpf_func_state *dst,		\
 	       sizeof(*src->FIELD) * (src->COUNT / SIZE));		\
 	return 0;							\
 }
+/* copy_reference_state() */
+COPY_STATE_FN(reference, acquired_refs, refs, 1)
 /* copy_stack_state() */
 COPY_STATE_FN(stack, allocated_stack, stack, BPF_REG_SIZE)
 #undef COPY_STATE_FN
@@ -441,6 +498,8 @@ static int realloc_##NAME##_state(struct bpf_func_state *state, int size, \
 	state->FIELD = new_##FIELD;					\
 	return 0;							\
 }
+/* realloc_reference_state() */
+REALLOC_STATE_FN(reference, acquired_refs, refs, 1)
 /* realloc_stack_state() */
 REALLOC_STATE_FN(stack, allocated_stack, stack, BPF_REG_SIZE)
 #undef REALLOC_STATE_FN
@@ -452,16 +511,89 @@ REALLOC_STATE_FN(stack, allocated_stack, stack, BPF_REG_SIZE)
  * which realloc_stack_state() copies over. It points to previous
  * bpf_verifier_state which is never reallocated.
  */
-static int realloc_func_state(struct bpf_func_state *state, int size,
-			      bool copy_old)
+static int realloc_func_state(struct bpf_func_state *state, int stack_size,
+			      int refs_size, bool copy_old)
 {
-	return realloc_stack_state(state, size, copy_old);
+	int err = realloc_reference_state(state, refs_size, copy_old);
+	if (err)
+		return err;
+	return realloc_stack_state(state, stack_size, copy_old);
+}
+
+/* Acquire a pointer id from the env and update the state->refs to include
+ * this new pointer reference.
+ * On success, returns a valid pointer id to associate with the register
+ * On failure, returns a negative errno.
+ */
+static int acquire_reference_state(struct bpf_verifier_env *env, int insn_idx)
+{
+	struct bpf_func_state *state = cur_func(env);
+	int new_ofs = state->acquired_refs;
+	int id, err;
+
+	err = realloc_reference_state(state, state->acquired_refs + 1, true);
+	if (err)
+		return err;
+	id = ++env->id_gen;
+	state->refs[new_ofs].id = id;
+	state->refs[new_ofs].insn_idx = insn_idx;
+
+	return id;
+}
+
+/* release function corresponding to acquire_reference_state(). Idempotent. */
+static int __release_reference_state(struct bpf_func_state *state, int ptr_id)
+{
+	int i, last_idx;
+
+	if (!ptr_id)
+		return -EFAULT;
+
+	last_idx = state->acquired_refs - 1;
+	for (i = 0; i < state->acquired_refs; i++) {
+		if (state->refs[i].id == ptr_id) {
+			if (last_idx && i != last_idx)
+				memcpy(&state->refs[i], &state->refs[last_idx],
+				       sizeof(*state->refs));
+			memset(&state->refs[last_idx], 0, sizeof(*state->refs));
+			state->acquired_refs--;
+			return 0;
+		}
+	}
+	return -EFAULT;
+}
+
+/* variation on the above for cases where we expect that there must be an
+ * outstanding reference for the specified ptr_id.
+ */
+static int release_reference_state(struct bpf_verifier_env *env, int ptr_id)
+{
+	struct bpf_func_state *state = cur_func(env);
+	int err;
+
+	err = __release_reference_state(state, ptr_id);
+	if (WARN_ON_ONCE(err != 0))
+		verbose(env, "verifier internal error: can't release reference\n");
+	return err;
+}
+
+static int transfer_reference_state(struct bpf_func_state *dst,
+				    struct bpf_func_state *src)
+{
+	int err = realloc_reference_state(dst, src->acquired_refs, false);
+	if (err)
+		return err;
+	err = copy_reference_state(dst, src);
+	if (err)
+		return err;
+	return 0;
 }
 
 static void free_func_state(struct bpf_func_state *state)
 {
 	if (!state)
 		return;
+	kfree(state->refs);
 	kfree(state->stack);
 	kfree(state);
 }
@@ -487,10 +619,14 @@ static int copy_func_state(struct bpf_func_state *dst,
 {
 	int err;
 
-	err = realloc_func_state(dst, src->allocated_stack, false);
+	err = realloc_func_state(dst, src->allocated_stack, src->acquired_refs,
+				 false);
 	if (err)
 		return err;
-	memcpy(dst, src, offsetof(struct bpf_func_state, allocated_stack));
+	memcpy(dst, src, offsetof(struct bpf_func_state, acquired_refs));
+	err = copy_reference_state(dst, src);
+	if (err)
+		return err;
 	return copy_stack_state(dst, src);
 }
 
@@ -1015,7 +1151,7 @@ static int check_stack_write(struct bpf_verifier_env *env,
 	enum bpf_reg_type type;
 
 	err = realloc_func_state(state, round_up(slot + 1, BPF_REG_SIZE),
-				 true);
+				 state->acquired_refs, true);
 	if (err)
 		return err;
 	/* caller checked that off % size == 0 and -MAX_BPF_STACK <= off < 0,
@@ -1399,7 +1535,8 @@ static bool is_ctx_reg(struct bpf_verifier_env *env, int regno)
 {
 	const struct bpf_reg_state *reg = cur_regs(env) + regno;
 
-	return reg->type == PTR_TO_CTX;
+	return reg->type == PTR_TO_CTX ||
+	       reg->type == PTR_TO_SOCKET;
 }
 
 static bool is_pkt_reg(struct bpf_verifier_env *env, int regno)
@@ -2003,6 +2140,12 @@ static int check_func_arg(struct bpf_verifier_env *env, u32 regno,
 		expected_type = PTR_TO_SOCKET;
 		if (type != expected_type)
 			goto err_type;
+		if (meta->ptr_id || !reg->id) {
+			verbose(env, "verifier internal error: mismatched references meta=%d, reg=%d\n",
+				meta->ptr_id, reg->id);
+			return -EFAULT;
+		}
+		meta->ptr_id = reg->id;
 	} else if (arg_type_is_mem_ptr(arg_type)) {
 		expected_type = PTR_TO_STACK;
 		/* One exception here. In case function allows for NULL to be
@@ -2292,10 +2435,32 @@ static bool check_arg_pair_ok(const struct bpf_func_proto *fn)
 	return true;
 }
 
+static bool check_refcount_ok(const struct bpf_func_proto *fn)
+{
+	int count = 0;
+
+	if (arg_type_is_refcounted(fn->arg1_type))
+		count++;
+	if (arg_type_is_refcounted(fn->arg2_type))
+		count++;
+	if (arg_type_is_refcounted(fn->arg3_type))
+		count++;
+	if (arg_type_is_refcounted(fn->arg4_type))
+		count++;
+	if (arg_type_is_refcounted(fn->arg5_type))
+		count++;
+
+	/* We only support one arg being unreferenced at the moment,
+	 * which is sufficient for the helper functions we have right now.
+	 */
+	return count <= 1;
+}
+
 static int check_func_proto(const struct bpf_func_proto *fn)
 {
 	return check_raw_mode_ok(fn) &&
-	       check_arg_pair_ok(fn) ? 0 : -EINVAL;
+	       check_arg_pair_ok(fn) &&
+	       check_refcount_ok(fn) ? 0 : -EINVAL;
 }
 
 /* Packet data might have moved, any old PTR_TO_PACKET[_META,_END]
@@ -2328,12 +2493,45 @@ static void clear_all_pkt_pointers(struct bpf_verifier_env *env)
 		__clear_all_pkt_pointers(env, vstate->frame[i]);
 }
 
+static void release_reg_references(struct bpf_verifier_env *env,
+				   struct bpf_func_state *state, int id)
+{
+	struct bpf_reg_state *regs = state->regs, *reg;
+	int i;
+
+	for (i = 0; i < MAX_BPF_REG; i++)
+		if (regs[i].id == id)
+			mark_reg_unknown(env, regs, i);
+
+	bpf_for_each_spilled_reg(i, state, reg) {
+		if (!reg)
+			continue;
+		if (reg_is_refcounted(reg) && reg->id == id)
+			__mark_reg_unknown(reg);
+	}
+}
+
+/* The pointer with the specified id has released its reference to kernel
+ * resources. Identify all copies of the same pointer and clear the reference.
+ */
+static int release_reference(struct bpf_verifier_env *env,
+			     struct bpf_call_arg_meta *meta)
+{
+	struct bpf_verifier_state *vstate = env->cur_state;
+	int i;
+
+	for (i = 0; i <= vstate->curframe; i++)
+		release_reg_references(env, vstate->frame[i], meta->ptr_id);
+
+	return release_reference_state(env, meta->ptr_id);
+}
+
 static int check_func_call(struct bpf_verifier_env *env, struct bpf_insn *insn,
 			   int *insn_idx)
 {
 	struct bpf_verifier_state *state = env->cur_state;
 	struct bpf_func_state *caller, *callee;
-	int i, subprog, target_insn;
+	int i, err, subprog, target_insn;
 
 	if (state->curframe + 1 >= MAX_CALL_FRAMES) {
 		verbose(env, "the call stack of %d frames is too deep\n",
@@ -2371,6 +2569,11 @@ static int check_func_call(struct bpf_verifier_env *env, struct bpf_insn *insn,
 			state->curframe + 1 /* frameno within this callchain */,
 			subprog /* subprog number within this prog */);
 
+	/* Transfer references to the callee */
+	err = transfer_reference_state(callee, caller);
+	if (err)
+		return err;
+
 	/* copy r1 - r5 args that callee can access.  The copy includes parent
 	 * pointers, which connects us up to the liveness chain
 	 */
@@ -2403,6 +2606,7 @@ static int prepare_func_exit(struct bpf_verifier_env *env, int *insn_idx)
 	struct bpf_verifier_state *state = env->cur_state;
 	struct bpf_func_state *caller, *callee;
 	struct bpf_reg_state *r0;
+	int err;
 
 	callee = state->frame[state->curframe];
 	r0 = &callee->regs[BPF_REG_0];
@@ -2422,6 +2626,11 @@ static int prepare_func_exit(struct bpf_verifier_env *env, int *insn_idx)
 	/* return to the caller whatever r0 had in the callee */
 	caller->regs[BPF_REG_0] = *r0;
 
+	/* Transfer references to the caller */
+	err = transfer_reference_state(caller, callee);
+	if (err)
+		return err;
+
 	*insn_idx = callee->callsite + 1;
 	if (env->log.level) {
 		verbose(env, "returning from callee:\n");
@@ -2478,6 +2687,18 @@ record_func_map(struct bpf_verifier_env *env, struct bpf_call_arg_meta *meta,
 	return 0;
 }
 
+static int check_reference_leak(struct bpf_verifier_env *env)
+{
+	struct bpf_func_state *state = cur_func(env);
+	int i;
+
+	for (i = 0; i < state->acquired_refs; i++) {
+		verbose(env, "Unreleased reference id=%d alloc_insn=%d\n",
+			state->refs[i].id, state->refs[i].insn_idx);
+	}
+	return state->acquired_refs ? -EINVAL : 0;
+}
+
 static int check_helper_call(struct bpf_verifier_env *env, int func_id, int insn_idx)
 {
 	const struct bpf_func_proto *fn = NULL;
@@ -2556,6 +2777,18 @@ static int check_helper_call(struct bpf_verifier_env *env, int func_id, int insn
 			return err;
 	}
 
+	if (func_id == BPF_FUNC_tail_call) {
+		err = check_reference_leak(env);
+		if (err) {
+			verbose(env, "tail_call would lead to reference leak\n");
+			return err;
+		}
+	} else if (is_release_function(func_id)) {
+		err = release_reference(env, &meta);
+		if (err)
+			return err;
+	}
+
 	regs = cur_regs(env);
 
 	/* check that flags argument in get_local_storage(map, flags) is 0,
@@ -2599,9 +2832,12 @@ static int check_helper_call(struct bpf_verifier_env *env, int func_id, int insn
 		regs[BPF_REG_0].map_ptr = meta.map_ptr;
 		regs[BPF_REG_0].id = ++env->id_gen;
 	} else if (fn->ret_type == RET_PTR_TO_SOCKET_OR_NULL) {
+		int id = acquire_reference_state(env, insn_idx);
+		if (id < 0)
+			return id;
 		mark_reg_known_zero(env, regs, BPF_REG_0);
 		regs[BPF_REG_0].type = PTR_TO_SOCKET_OR_NULL;
-		regs[BPF_REG_0].id = ++env->id_gen;
+		regs[BPF_REG_0].id = id;
 	} else {
 		verbose(env, "unknown return type %d of func %s#%d\n",
 			fn->ret_type, func_id_name(func_id), func_id);
@@ -3665,7 +3901,8 @@ static void reg_combine_min_max(struct bpf_reg_state *true_src,
 	}
 }
 
-static void mark_ptr_or_null_reg(struct bpf_reg_state *reg, u32 id,
+static void mark_ptr_or_null_reg(struct bpf_func_state *state,
+				 struct bpf_reg_state *reg, u32 id,
 				 bool is_null)
 {
 	if (reg_type_may_be_null(reg->type) && reg->id == id) {
@@ -3691,11 +3928,13 @@ static void mark_ptr_or_null_reg(struct bpf_reg_state *reg, u32 id,
 		} else if (reg->type == PTR_TO_SOCKET_OR_NULL) {
 			reg->type = PTR_TO_SOCKET;
 		}
-		/* We don't need id from this point onwards anymore, thus we
-		 * should better reset it, so that state pruning has chances
-		 * to take effect.
-		 */
-		reg->id = 0;
+		if (is_null || !reg_is_refcounted(reg)) {
+			/* We don't need id from this point onwards anymore,
+			 * thus we should better reset it, so that state
+			 * pruning has chances to take effect.
+			 */
+			reg->id = 0;
+		}
 	}
 }
 
@@ -3710,15 +3949,18 @@ static void mark_ptr_or_null_regs(struct bpf_verifier_state *vstate, u32 regno,
 	u32 id = regs[regno].id;
 	int i, j;
 
+	if (reg_is_refcounted_or_null(&regs[regno]) && is_null)
+		__release_reference_state(state, id);
+
 	for (i = 0; i < MAX_BPF_REG; i++)
-		mark_ptr_or_null_reg(&regs[i], id, is_null);
+		mark_ptr_or_null_reg(state, &regs[i], id, is_null);
 
 	for (j = 0; j <= vstate->curframe; j++) {
 		state = vstate->frame[j];
 		bpf_for_each_spilled_reg(i, state, reg) {
 			if (!reg)
 				continue;
-			mark_ptr_or_null_reg(reg, id, is_null);
+			mark_ptr_or_null_reg(state, reg, id, is_null);
 		}
 	}
 }
@@ -4050,6 +4292,16 @@ static int check_ld_abs(struct bpf_verifier_env *env, struct bpf_insn *insn)
 	if (err)
 		return err;
 
+	/* Disallow usage of BPF_LD_[ABS|IND] with reference tracking, as
+	 * gen_ld_abs() may terminate the program at runtime, leading to
+	 * reference leak.
+	 */
+	err = check_reference_leak(env);
+	if (err) {
+		verbose(env, "BPF_LD_[ABS|IND] cannot be mixed with socket references\n");
+		return err;
+	}
+
 	if (regs[BPF_REG_6].type != PTR_TO_CTX) {
 		verbose(env,
 			"at the time of BPF_LD_ABS|IND R6 != pointer to skb\n");
@@ -4542,6 +4794,14 @@ static bool stacksafe(struct bpf_func_state *old,
 	return true;
 }
 
+static bool refsafe(struct bpf_func_state *old, struct bpf_func_state *cur)
+{
+	if (old->acquired_refs != cur->acquired_refs)
+		return false;
+	return !memcmp(old->refs, cur->refs,
+		       sizeof(*old->refs) * old->acquired_refs);
+}
+
 /* compare two verifier states
  *
  * all states stored in state_list are known to be valid, since
@@ -4587,6 +4847,9 @@ static bool func_states_equal(struct bpf_func_state *old,
 
 	if (!stacksafe(old, cur, idmap))
 		goto out_free;
+
+	if (!refsafe(old, cur))
+		goto out_free;
 	ret = true;
 out_free:
 	kfree(idmap);
@@ -4868,6 +5131,7 @@ static int do_check(struct bpf_verifier_env *env)
 
 		regs = cur_regs(env);
 		env->insn_aux_data[insn_idx].seen = true;
+
 		if (class == BPF_ALU || class == BPF_ALU64) {
 			err = check_alu_op(env, insn);
 			if (err)
@@ -5032,6 +5296,10 @@ static int do_check(struct bpf_verifier_env *env)
 					continue;
 				}
 
+				err = check_reference_leak(env);
+				if (err)
+					return err;
+
 				/* eBPF calling convetion is such that R0 is used
 				 * to return the value from eBPF program.
 				 * Make sure that it's readable at this time