xfrm: policy: store inexact policies in a tree ordered by destination address

This adds inexact lists per destination network, stored in a search tree.

Inexact lookups now return two 'candidate lists', the 'any' policies
('any' destionations), and a list of policies that share same
daddr/prefix.

Next patch will add a second search tree for 'saddr:any' policies
so we can avoid placing those on the 'any:any' list too.

Signed-off-by: Florian Westphal <fw@strlen.de>
Acked-by: David S. Miller <davem@davemloft.net>
Signed-off-by: Steffen Klassert <steffen.klassert@secunet.com>
diff --git a/include/net/xfrm.h b/include/net/xfrm.h
index 9df6dca..fa4b3c87 100644
--- a/include/net/xfrm.h
+++ b/include/net/xfrm.h
@@ -590,6 +590,7 @@ struct xfrm_policy {
 	struct xfrm_lifetime_cur curlft;
 	struct xfrm_policy_walk_entry walk;
 	struct xfrm_policy_queue polq;
+	bool                    bydst_reinsert;
 	u8			type;
 	u8			action;
 	u8			flags;
diff --git a/net/xfrm/xfrm_policy.c b/net/xfrm/xfrm_policy.c
index 4eb12e9..81447d5 100644
--- a/net/xfrm/xfrm_policy.c
+++ b/net/xfrm/xfrm_policy.c
@@ -49,6 +49,38 @@ struct xfrm_flo {
 /* prefixes smaller than this are stored in lists, not trees. */
 #define INEXACT_PREFIXLEN_IPV4	16
 #define INEXACT_PREFIXLEN_IPV6	48
+
+struct xfrm_pol_inexact_node {
+	struct rb_node node;
+	union {
+		xfrm_address_t addr;
+		struct rcu_head rcu;
+	};
+	u8 prefixlen;
+
+	/* the policies matching this node, can be empty list */
+	struct hlist_head hhead;
+};
+
+/* xfrm inexact policy search tree:
+ * xfrm_pol_inexact_bin = hash(dir,type,family,if_id);
+ *  |
+ * +---- root_d: sorted by daddr:prefix
+ * |                 |
+ * |        xfrm_pol_inexact_node
+ * |                 |
+ * |                 +- coarse policies and all any:daddr policies
+ * |
+ * +---- coarse policies and all any:any policies
+ *
+ * Lookups return two candidate lists:
+ * 1. any:any list from top-level xfrm_pol_inexact_bin
+ * 2. any:daddr list from daddr tree
+ *
+ * This result set then needs to be searched for the policy with
+ * the lowest priority.  If two results have same prio, youngest one wins.
+ */
+
 struct xfrm_pol_inexact_key {
 	possible_net_t net;
 	u32 if_id;
@@ -62,12 +94,17 @@ struct xfrm_pol_inexact_bin {
 	/* list containing '*:*' policies */
 	struct hlist_head hhead;
 
+	seqcount_t count;
+	/* tree sorted by daddr/prefix */
+	struct rb_root root_d;
+
 	/* slow path below */
 	struct list_head inexact_bins;
 	struct rcu_head rcu;
 };
 
 enum xfrm_pol_inexact_candidate_type {
+	XFRM_POL_CAND_DADDR,
 	XFRM_POL_CAND_ANY,
 
 	XFRM_POL_CAND_MAX,
@@ -658,6 +695,8 @@ xfrm_policy_inexact_alloc_bin(const struct xfrm_policy *pol, u8 dir)
 
 	bin->k = k;
 	INIT_HLIST_HEAD(&bin->hhead);
+	bin->root_d = RB_ROOT;
+	seqcount_init(&bin->count);
 
 	prev = rhashtable_lookup_get_insert_key(&xfrm_policy_inexact_table,
 						&bin->k, &bin->head,
@@ -708,9 +747,211 @@ xfrm_policy_inexact_insert_use_any_list(const struct xfrm_policy *policy)
 	return saddr_any && daddr_any;
 }
 
+static void xfrm_pol_inexact_node_init(struct xfrm_pol_inexact_node *node,
+				       const xfrm_address_t *addr, u8 prefixlen)
+{
+	node->addr = *addr;
+	node->prefixlen = prefixlen;
+}
+
+static struct xfrm_pol_inexact_node *
+xfrm_pol_inexact_node_alloc(const xfrm_address_t *addr, u8 prefixlen)
+{
+	struct xfrm_pol_inexact_node *node;
+
+	node = kzalloc(sizeof(*node), GFP_ATOMIC);
+	if (node)
+		xfrm_pol_inexact_node_init(node, addr, prefixlen);
+
+	return node;
+}
+
+static int xfrm_policy_addr_delta(const xfrm_address_t *a,
+				  const xfrm_address_t *b,
+				  u8 prefixlen, u16 family)
+{
+	unsigned int pdw, pbi;
+	int delta = 0;
+
+	switch (family) {
+	case AF_INET:
+		if (sizeof(long) == 4 && prefixlen == 0)
+			return ntohl(a->a4) - ntohl(b->a4);
+		return (ntohl(a->a4) & ((~0UL << (32 - prefixlen)))) -
+		       (ntohl(b->a4) & ((~0UL << (32 - prefixlen))));
+	case AF_INET6:
+		pdw = prefixlen >> 5;
+		pbi = prefixlen & 0x1f;
+
+		if (pdw) {
+			delta = memcmp(a->a6, b->a6, pdw << 2);
+			if (delta)
+				return delta;
+		}
+		if (pbi) {
+			u32 mask = ~0u << (32 - pbi);
+
+			delta = (ntohl(a->a6[pdw]) & mask) -
+				(ntohl(b->a6[pdw]) & mask);
+		}
+		break;
+	default:
+		break;
+	}
+
+	return delta;
+}
+
+static void xfrm_policy_inexact_list_reinsert(struct net *net,
+					      struct xfrm_pol_inexact_node *n,
+					      u16 family)
+{
+	struct hlist_node *newpos = NULL;
+	struct xfrm_policy *policy, *p;
+
+	list_for_each_entry_reverse(policy, &net->xfrm.policy_all, walk.all) {
+		if (!policy->bydst_reinsert)
+			continue;
+
+		WARN_ON_ONCE(policy->family != family);
+
+		policy->bydst_reinsert = false;
+		hlist_for_each_entry(p, &n->hhead, bydst) {
+			if (policy->priority >= p->priority)
+				newpos = &p->bydst;
+			else
+				break;
+		}
+
+		if (newpos)
+			hlist_add_behind(&policy->bydst, newpos);
+		else
+			hlist_add_head(&policy->bydst, &n->hhead);
+	}
+}
+
+/* merge nodes v and n */
+static void xfrm_policy_inexact_node_merge(struct net *net,
+					   struct xfrm_pol_inexact_node *v,
+					   struct xfrm_pol_inexact_node *n,
+					   u16 family)
+{
+	struct xfrm_policy *tmp;
+
+	hlist_for_each_entry(tmp, &v->hhead, bydst)
+		tmp->bydst_reinsert = true;
+	hlist_for_each_entry(tmp, &n->hhead, bydst)
+		tmp->bydst_reinsert = true;
+
+	INIT_HLIST_HEAD(&n->hhead);
+	xfrm_policy_inexact_list_reinsert(net, n, family);
+}
+
+static struct xfrm_pol_inexact_node *
+xfrm_policy_inexact_insert_node(struct net *net,
+				struct rb_root *root,
+				xfrm_address_t *addr,
+				u16 family, u8 prefixlen, u8 dir)
+{
+	struct xfrm_pol_inexact_node *cached = NULL;
+	struct rb_node **p, *parent = NULL;
+	struct xfrm_pol_inexact_node *node;
+
+	p = &root->rb_node;
+	while (*p) {
+		int delta;
+
+		parent = *p;
+		node = rb_entry(*p, struct xfrm_pol_inexact_node, node);
+
+		delta = xfrm_policy_addr_delta(addr, &node->addr,
+					       node->prefixlen,
+					       family);
+		if (delta == 0 && prefixlen >= node->prefixlen) {
+			WARN_ON_ONCE(cached); /* ipsec policies got lost */
+			return node;
+		}
+
+		if (delta < 0)
+			p = &parent->rb_left;
+		else
+			p = &parent->rb_right;
+
+		if (prefixlen < node->prefixlen) {
+			delta = xfrm_policy_addr_delta(addr, &node->addr,
+						       prefixlen,
+						       family);
+			if (delta)
+				continue;
+
+			/* This node is a subnet of the new prefix. It needs
+			 * to be removed and re-inserted with the smaller
+			 * prefix and all nodes that are now also covered
+			 * by the reduced prefixlen.
+			 */
+			rb_erase(&node->node, root);
+
+			if (!cached) {
+				xfrm_pol_inexact_node_init(node, addr,
+							   prefixlen);
+				cached = node;
+			} else {
+				/* This node also falls within the new
+				 * prefixlen. Merge the to-be-reinserted
+				 * node and this one.
+				 */
+				xfrm_policy_inexact_node_merge(net, node,
+							       cached, family);
+				kfree_rcu(node, rcu);
+			}
+
+			/* restart */
+			p = &root->rb_node;
+			parent = NULL;
+		}
+	}
+
+	node = cached;
+	if (!node) {
+		node = xfrm_pol_inexact_node_alloc(addr, prefixlen);
+		if (!node)
+			return NULL;
+	}
+
+	rb_link_node_rcu(&node->node, parent, p);
+	rb_insert_color(&node->node, root);
+
+	return node;
+}
+
+static void xfrm_policy_inexact_gc_tree(struct rb_root *r, bool rm)
+{
+	struct xfrm_pol_inexact_node *node;
+	struct rb_node *rn = rb_first(r);
+
+	while (rn) {
+		node = rb_entry(rn, struct xfrm_pol_inexact_node, node);
+
+		rn = rb_next(rn);
+
+		if (!hlist_empty(&node->hhead)) {
+			WARN_ON_ONCE(rm);
+			continue;
+		}
+
+		rb_erase(&node->node, r);
+		kfree_rcu(node, rcu);
+	}
+}
+
 static void __xfrm_policy_inexact_prune_bin(struct xfrm_pol_inexact_bin *b, bool net_exit)
 {
-	if (!hlist_empty(&b->hhead)) {
+	write_seqcount_begin(&b->count);
+	xfrm_policy_inexact_gc_tree(&b->root_d, net_exit);
+	write_seqcount_end(&b->count);
+
+	if (!RB_EMPTY_ROOT(&b->root_d) ||
+	    !hlist_empty(&b->hhead)) {
 		WARN_ON_ONCE(net_exit);
 		return;
 	}
@@ -741,6 +982,37 @@ static void __xfrm_policy_inexact_flush(struct net *net)
 		__xfrm_policy_inexact_prune_bin(bin, false);
 }
 
+static struct hlist_head *
+xfrm_policy_inexact_alloc_chain(struct xfrm_pol_inexact_bin *bin,
+				struct xfrm_policy *policy, u8 dir)
+{
+	struct xfrm_pol_inexact_node *n;
+	struct net *net;
+
+	net = xp_net(policy);
+	lockdep_assert_held(&net->xfrm.xfrm_policy_lock);
+
+	if (xfrm_policy_inexact_insert_use_any_list(policy))
+		return &bin->hhead;
+
+	if (xfrm_pol_inexact_addr_use_any_list(&policy->selector.daddr,
+					       policy->family,
+					       policy->selector.prefixlen_d))
+		return &bin->hhead;
+
+	/* daddr is fixed */
+	write_seqcount_begin(&bin->count);
+	n = xfrm_policy_inexact_insert_node(net,
+					    &bin->root_d,
+					    &policy->selector.daddr,
+					    policy->family,
+					    policy->selector.prefixlen_d, dir);
+	write_seqcount_end(&bin->count);
+	if (!n)
+		return NULL;
+	return &n->hhead;
+}
+
 static struct xfrm_policy *
 xfrm_policy_inexact_insert(struct xfrm_policy *policy, u8 dir, int excl)
 {
@@ -756,13 +1028,12 @@ xfrm_policy_inexact_insert(struct xfrm_policy *policy, u8 dir, int excl)
 	net = xp_net(policy);
 	lockdep_assert_held(&net->xfrm.xfrm_policy_lock);
 
-	if (xfrm_policy_inexact_insert_use_any_list(policy)) {
-		chain = &bin->hhead;
-		goto insert_to_list;
+	chain = xfrm_policy_inexact_alloc_chain(bin, policy, dir);
+	if (!chain) {
+		__xfrm_policy_inexact_prune_bin(bin, false);
+		return ERR_PTR(-ENOMEM);
 	}
 
-	chain = &bin->hhead;
-insert_to_list:
 	delpol = xfrm_policy_insert_list(chain, policy, excl);
 	if (delpol && excl) {
 		__xfrm_policy_inexact_prune_bin(bin, false);
@@ -843,6 +1114,9 @@ static void xfrm_hash_rebuild(struct work_struct *work)
 		bin = xfrm_policy_inexact_alloc_bin(policy, dir);
 		if (!bin)
 			goto out_unlock;
+
+		if (!xfrm_policy_inexact_alloc_chain(bin, policy, dir))
+			goto out_unlock;
 	}
 
 	/* reset the bydst and inexact table in all directions */
@@ -1462,17 +1736,64 @@ static int xfrm_policy_match(const struct xfrm_policy *pol,
 	return ret;
 }
 
+static struct xfrm_pol_inexact_node *
+xfrm_policy_lookup_inexact_addr(const struct rb_root *r,
+				seqcount_t *count,
+				const xfrm_address_t *addr, u16 family)
+{
+	const struct rb_node *parent;
+	int seq;
+
+again:
+	seq = read_seqcount_begin(count);
+
+	parent = rcu_dereference_raw(r->rb_node);
+	while (parent) {
+		struct xfrm_pol_inexact_node *node;
+		int delta;
+
+		node = rb_entry(parent, struct xfrm_pol_inexact_node, node);
+
+		delta = xfrm_policy_addr_delta(addr, &node->addr,
+					       node->prefixlen, family);
+		if (delta < 0) {
+			parent = rcu_dereference_raw(parent->rb_left);
+			continue;
+		} else if (delta > 0) {
+			parent = rcu_dereference_raw(parent->rb_right);
+			continue;
+		}
+
+		return node;
+	}
+
+	if (read_seqcount_retry(count, seq))
+		goto again;
+
+	return NULL;
+}
+
 static bool
 xfrm_policy_find_inexact_candidates(struct xfrm_pol_inexact_candidates *cand,
 				    struct xfrm_pol_inexact_bin *b,
 				    const xfrm_address_t *saddr,
 				    const xfrm_address_t *daddr)
 {
+	struct xfrm_pol_inexact_node *n;
+	u16 family;
+
 	if (!b)
 		return false;
 
+	family = b->k.family;
 	memset(cand, 0, sizeof(*cand));
 	cand->res[XFRM_POL_CAND_ANY] = &b->hhead;
+
+	n = xfrm_policy_lookup_inexact_addr(&b->root_d, &b->count, daddr,
+					    family);
+	if (n)
+		cand->res[XFRM_POL_CAND_DADDR] = &n->hhead;
+
 	return true;
 }