bpf: Introduce function-by-function verification

New llvm and old llvm with libbpf help produce BTF that distinguish global and
static functions. Unlike arguments of static function the arguments of global
functions cannot be removed or optimized away by llvm. The compiler has to use
exactly the arguments specified in a function prototype. The argument type
information allows the verifier validate each global function independently.
For now only supported argument types are pointer to context and scalars. In
the future pointers to structures, sizes, pointer to packet data can be
supported as well. Consider the following example:

static int f1(int ...)
{
  ...
}

int f3(int b);

int f2(int a)
{
  f1(a) + f3(a);
}

int f3(int b)
{
  ...
}

int main(...)
{
  f1(...) + f2(...) + f3(...);
}

The verifier will start its safety checks from the first global function f2().
It will recursively descend into f1() because it's static. Then it will check
that arguments match for the f3() invocation inside f2(). It will not descend
into f3(). It will finish f2() that has to be successfully verified for all
possible values of 'a'. Then it will proceed with f3(). That function also has
to be safe for all possible values of 'b'. Then it will start subprog 0 (which
is main() function). It will recursively descend into f1() and will skip full
check of f2() and f3(), since they are global. The order of processing global
functions doesn't affect safety, since all global functions must be proven safe
based on their arguments only.

Such function by function verification can drastically improve speed of the
verification and reduce complexity.

Note that the stack limit of 512 still applies to the call chain regardless whether
functions were static or global. The nested level of 8 also still applies. The
same recursion prevention checks are in place as well.

The type information and static/global kind is preserved after the verification
hence in the above example global function f2() and f3() can be replaced later
by equivalent functions with the same types that are loaded and verified later
without affecting safety of this main() program. Such replacement (re-linking)
of global functions is a subject of future patches.

Signed-off-by: Alexei Starovoitov <ast@kernel.org>
Signed-off-by: Daniel Borkmann <daniel@iogearbox.net>
Acked-by: Song Liu <songliubraving@fb.com>
Link: https://lore.kernel.org/bpf/20200110064124.1760511-3-ast@kernel.org
diff --git a/kernel/bpf/verifier.c b/kernel/bpf/verifier.c
index f5af759..ca17dccc 100644
--- a/kernel/bpf/verifier.c
+++ b/kernel/bpf/verifier.c
@@ -1122,10 +1122,6 @@ static void init_reg_state(struct bpf_verifier_env *env,
 	regs[BPF_REG_FP].type = PTR_TO_STACK;
 	mark_reg_known_zero(env, regs, BPF_REG_FP);
 	regs[BPF_REG_FP].frameno = state->frameno;
-
-	/* 1st arg to a function */
-	regs[BPF_REG_1].type = PTR_TO_CTX;
-	mark_reg_known_zero(env, regs, BPF_REG_1);
 }
 
 #define BPF_MAIN_FUNC (-1)
@@ -2739,8 +2735,8 @@ static int get_callee_stack_depth(struct bpf_verifier_env *env,
 }
 #endif
 
-static int check_ctx_reg(struct bpf_verifier_env *env,
-			 const struct bpf_reg_state *reg, int regno)
+int check_ctx_reg(struct bpf_verifier_env *env,
+		  const struct bpf_reg_state *reg, int regno)
 {
 	/* Access to ctx or passing it to a helper is only allowed in
 	 * its original, unmodified form.
@@ -3956,12 +3952,26 @@ static int release_reference(struct bpf_verifier_env *env,
 	return 0;
 }
 
+static void clear_caller_saved_regs(struct bpf_verifier_env *env,
+				    struct bpf_reg_state *regs)
+{
+	int i;
+
+	/* after the call registers r0 - r5 were scratched */
+	for (i = 0; i < CALLER_SAVED_REGS; i++) {
+		mark_reg_not_init(env, regs, caller_saved[i]);
+		check_reg_arg(env, caller_saved[i], DST_OP_NO_MARK);
+	}
+}
+
 static int check_func_call(struct bpf_verifier_env *env, struct bpf_insn *insn,
 			   int *insn_idx)
 {
 	struct bpf_verifier_state *state = env->cur_state;
+	struct bpf_func_info_aux *func_info_aux;
 	struct bpf_func_state *caller, *callee;
 	int i, err, subprog, target_insn;
+	bool is_global = false;
 
 	if (state->curframe + 1 >= MAX_CALL_FRAMES) {
 		verbose(env, "the call stack of %d frames is too deep\n",
@@ -3984,6 +3994,32 @@ static int check_func_call(struct bpf_verifier_env *env, struct bpf_insn *insn,
 		return -EFAULT;
 	}
 
+	func_info_aux = env->prog->aux->func_info_aux;
+	if (func_info_aux)
+		is_global = func_info_aux[subprog].linkage == BTF_FUNC_GLOBAL;
+	err = btf_check_func_arg_match(env, subprog, caller->regs);
+	if (err == -EFAULT)
+		return err;
+	if (is_global) {
+		if (err) {
+			verbose(env, "Caller passes invalid args into func#%d\n",
+				subprog);
+			return err;
+		} else {
+			if (env->log.level & BPF_LOG_LEVEL)
+				verbose(env,
+					"Func#%d is global and valid. Skipping.\n",
+					subprog);
+			clear_caller_saved_regs(env, caller->regs);
+
+			/* All global functions return SCALAR_VALUE */
+			mark_reg_unknown(env, caller->regs, BPF_REG_0);
+
+			/* continue with next insn after call */
+			return 0;
+		}
+	}
+
 	callee = kzalloc(sizeof(*callee), GFP_KERNEL);
 	if (!callee)
 		return -ENOMEM;
@@ -4010,18 +4046,11 @@ static int check_func_call(struct bpf_verifier_env *env, struct bpf_insn *insn,
 	for (i = BPF_REG_1; i <= BPF_REG_5; i++)
 		callee->regs[i] = caller->regs[i];
 
-	/* after the call registers r0 - r5 were scratched */
-	for (i = 0; i < CALLER_SAVED_REGS; i++) {
-		mark_reg_not_init(env, caller->regs, caller_saved[i]);
-		check_reg_arg(env, caller_saved[i], DST_OP_NO_MARK);
-	}
+	clear_caller_saved_regs(env, caller->regs);
 
 	/* only increment it after check_reg_arg() finished */
 	state->curframe++;
 
-	if (btf_check_func_arg_match(env, subprog))
-		return -EINVAL;
-
 	/* and go analyze first insn of the callee */
 	*insn_idx = target_insn;
 
@@ -6771,12 +6800,13 @@ static int check_btf_func(struct bpf_verifier_env *env,
 
 		/* check type_id */
 		type = btf_type_by_id(btf, krecord[i].type_id);
-		if (!type || BTF_INFO_KIND(type->info) != BTF_KIND_FUNC) {
+		if (!type || !btf_type_is_func(type)) {
 			verbose(env, "invalid type id %d in func info",
 				krecord[i].type_id);
 			ret = -EINVAL;
 			goto err_free;
 		}
+		info_aux[i].linkage = BTF_INFO_VLEN(type->info);
 		prev_offset = krecord[i].insn_off;
 		urecord += urec_size;
 	}
@@ -7756,35 +7786,13 @@ static bool reg_type_mismatch(enum bpf_reg_type src, enum bpf_reg_type prev)
 
 static int do_check(struct bpf_verifier_env *env)
 {
-	struct bpf_verifier_state *state;
+	struct bpf_verifier_state *state = env->cur_state;
 	struct bpf_insn *insns = env->prog->insnsi;
 	struct bpf_reg_state *regs;
 	int insn_cnt = env->prog->len;
 	bool do_print_state = false;
 	int prev_insn_idx = -1;
 
-	env->prev_linfo = NULL;
-
-	state = kzalloc(sizeof(struct bpf_verifier_state), GFP_KERNEL);
-	if (!state)
-		return -ENOMEM;
-	state->curframe = 0;
-	state->speculative = false;
-	state->branches = 1;
-	state->frame[0] = kzalloc(sizeof(struct bpf_func_state), GFP_KERNEL);
-	if (!state->frame[0]) {
-		kfree(state);
-		return -ENOMEM;
-	}
-	env->cur_state = state;
-	init_func_state(env, state->frame[0],
-			BPF_MAIN_FUNC /* callsite */,
-			0 /* frameno */,
-			0 /* subprogno, zero == main subprog */);
-
-	if (btf_check_func_arg_match(env, 0))
-		return -EINVAL;
-
 	for (;;) {
 		struct bpf_insn *insn;
 		u8 class;
@@ -7862,7 +7870,7 @@ static int do_check(struct bpf_verifier_env *env)
 		}
 
 		regs = cur_regs(env);
-		env->insn_aux_data[env->insn_idx].seen = true;
+		env->insn_aux_data[env->insn_idx].seen = env->pass_cnt;
 		prev_insn_idx = env->insn_idx;
 
 		if (class == BPF_ALU || class == BPF_ALU64) {
@@ -8082,7 +8090,7 @@ static int do_check(struct bpf_verifier_env *env)
 					return err;
 
 				env->insn_idx++;
-				env->insn_aux_data[env->insn_idx].seen = true;
+				env->insn_aux_data[env->insn_idx].seen = env->pass_cnt;
 			} else {
 				verbose(env, "invalid BPF_LD mode\n");
 				return -EINVAL;
@@ -8095,7 +8103,6 @@ static int do_check(struct bpf_verifier_env *env)
 		env->insn_idx++;
 	}
 
-	env->prog->aux->stack_depth = env->subprog_info[0].stack_depth;
 	return 0;
 }
 
@@ -8372,7 +8379,7 @@ static int adjust_insn_aux_data(struct bpf_verifier_env *env,
 	memcpy(new_data + off + cnt - 1, old_data + off,
 	       sizeof(struct bpf_insn_aux_data) * (prog_len - off - cnt + 1));
 	for (i = off; i < off + cnt - 1; i++) {
-		new_data[i].seen = true;
+		new_data[i].seen = env->pass_cnt;
 		new_data[i].zext_dst = insn_has_def32(env, insn + i);
 	}
 	env->insn_aux_data = new_data;
@@ -9484,6 +9491,7 @@ static void free_states(struct bpf_verifier_env *env)
 		kfree(sl);
 		sl = sln;
 	}
+	env->free_list = NULL;
 
 	if (!env->explored_states)
 		return;
@@ -9497,11 +9505,159 @@ static void free_states(struct bpf_verifier_env *env)
 			kfree(sl);
 			sl = sln;
 		}
+		env->explored_states[i] = NULL;
+	}
+}
+
+/* The verifier is using insn_aux_data[] to store temporary data during
+ * verification and to store information for passes that run after the
+ * verification like dead code sanitization. do_check_common() for subprogram N
+ * may analyze many other subprograms. sanitize_insn_aux_data() clears all
+ * temporary data after do_check_common() finds that subprogram N cannot be
+ * verified independently. pass_cnt counts the number of times
+ * do_check_common() was run and insn->aux->seen tells the pass number
+ * insn_aux_data was touched. These variables are compared to clear temporary
+ * data from failed pass. For testing and experiments do_check_common() can be
+ * run multiple times even when prior attempt to verify is unsuccessful.
+ */
+static void sanitize_insn_aux_data(struct bpf_verifier_env *env)
+{
+	struct bpf_insn *insn = env->prog->insnsi;
+	struct bpf_insn_aux_data *aux;
+	int i, class;
+
+	for (i = 0; i < env->prog->len; i++) {
+		class = BPF_CLASS(insn[i].code);
+		if (class != BPF_LDX && class != BPF_STX)
+			continue;
+		aux = &env->insn_aux_data[i];
+		if (aux->seen != env->pass_cnt)
+			continue;
+		memset(aux, 0, offsetof(typeof(*aux), orig_idx));
+	}
+}
+
+static int do_check_common(struct bpf_verifier_env *env, int subprog)
+{
+	struct bpf_verifier_state *state;
+	struct bpf_reg_state *regs;
+	int ret, i;
+
+	env->prev_linfo = NULL;
+	env->pass_cnt++;
+
+	state = kzalloc(sizeof(struct bpf_verifier_state), GFP_KERNEL);
+	if (!state)
+		return -ENOMEM;
+	state->curframe = 0;
+	state->speculative = false;
+	state->branches = 1;
+	state->frame[0] = kzalloc(sizeof(struct bpf_func_state), GFP_KERNEL);
+	if (!state->frame[0]) {
+		kfree(state);
+		return -ENOMEM;
+	}
+	env->cur_state = state;
+	init_func_state(env, state->frame[0],
+			BPF_MAIN_FUNC /* callsite */,
+			0 /* frameno */,
+			subprog);
+
+	regs = state->frame[state->curframe]->regs;
+	if (subprog) {
+		ret = btf_prepare_func_args(env, subprog, regs);
+		if (ret)
+			goto out;
+		for (i = BPF_REG_1; i <= BPF_REG_5; i++) {
+			if (regs[i].type == PTR_TO_CTX)
+				mark_reg_known_zero(env, regs, i);
+			else if (regs[i].type == SCALAR_VALUE)
+				mark_reg_unknown(env, regs, i);
+		}
+	} else {
+		/* 1st arg to a function */
+		regs[BPF_REG_1].type = PTR_TO_CTX;
+		mark_reg_known_zero(env, regs, BPF_REG_1);
+		ret = btf_check_func_arg_match(env, subprog, regs);
+		if (ret == -EFAULT)
+			/* unlikely verifier bug. abort.
+			 * ret == 0 and ret < 0 are sadly acceptable for
+			 * main() function due to backward compatibility.
+			 * Like socket filter program may be written as:
+			 * int bpf_prog(struct pt_regs *ctx)
+			 * and never dereference that ctx in the program.
+			 * 'struct pt_regs' is a type mismatch for socket
+			 * filter that should be using 'struct __sk_buff'.
+			 */
+			goto out;
 	}
 
-	kvfree(env->explored_states);
+	ret = do_check(env);
+out:
+	free_verifier_state(env->cur_state, true);
+	env->cur_state = NULL;
+	while (!pop_stack(env, NULL, NULL));
+	free_states(env);
+	if (ret)
+		/* clean aux data in case subprog was rejected */
+		sanitize_insn_aux_data(env);
+	return ret;
 }
 
+/* Verify all global functions in a BPF program one by one based on their BTF.
+ * All global functions must pass verification. Otherwise the whole program is rejected.
+ * Consider:
+ * int bar(int);
+ * int foo(int f)
+ * {
+ *    return bar(f);
+ * }
+ * int bar(int b)
+ * {
+ *    ...
+ * }
+ * foo() will be verified first for R1=any_scalar_value. During verification it
+ * will be assumed that bar() already verified successfully and call to bar()
+ * from foo() will be checked for type match only. Later bar() will be verified
+ * independently to check that it's safe for R1=any_scalar_value.
+ */
+static int do_check_subprogs(struct bpf_verifier_env *env)
+{
+	struct bpf_prog_aux *aux = env->prog->aux;
+	int i, ret;
+
+	if (!aux->func_info)
+		return 0;
+
+	for (i = 1; i < env->subprog_cnt; i++) {
+		if (aux->func_info_aux[i].linkage != BTF_FUNC_GLOBAL)
+			continue;
+		env->insn_idx = env->subprog_info[i].start;
+		WARN_ON_ONCE(env->insn_idx == 0);
+		ret = do_check_common(env, i);
+		if (ret) {
+			return ret;
+		} else if (env->log.level & BPF_LOG_LEVEL) {
+			verbose(env,
+				"Func#%d is safe for any args that match its prototype\n",
+				i);
+		}
+	}
+	return 0;
+}
+
+static int do_check_main(struct bpf_verifier_env *env)
+{
+	int ret;
+
+	env->insn_idx = 0;
+	ret = do_check_common(env, 0);
+	if (!ret)
+		env->prog->aux->stack_depth = env->subprog_info[0].stack_depth;
+	return ret;
+}
+
+
 static void print_verification_stats(struct bpf_verifier_env *env)
 {
 	int i;
@@ -9849,18 +10005,14 @@ int bpf_check(struct bpf_prog **prog, union bpf_attr *attr,
 	if (ret < 0)
 		goto skip_full_check;
 
-	ret = do_check(env);
-	if (env->cur_state) {
-		free_verifier_state(env->cur_state, true);
-		env->cur_state = NULL;
-	}
+	ret = do_check_subprogs(env);
+	ret = ret ?: do_check_main(env);
 
 	if (ret == 0 && bpf_prog_is_dev_bound(env->prog->aux))
 		ret = bpf_prog_offload_finalize(env);
 
 skip_full_check:
-	while (!pop_stack(env, NULL, NULL));
-	free_states(env);
+	kvfree(env->explored_states);
 
 	if (ret == 0)
 		ret = check_max_stack_depth(env);