netns xfrm: per-netns xfrm_state_all list

This is done to get
a) simple "something leaked" check
b) cover possible DoSes when other netns puts many, many xfrm_states
   onto a list.
c) not miss "alien xfrm_state" check in some of list iterators in future.

Signed-off-by: Alexey Dobriyan <adobriyan@gmail.com>
Signed-off-by: David S. Miller <davem@davemloft.net>
diff --git a/net/xfrm/xfrm_state.c b/net/xfrm/xfrm_state.c
index 81bde76..85bb8548 100644
--- a/net/xfrm/xfrm_state.c
+++ b/net/xfrm/xfrm_state.c
@@ -50,7 +50,6 @@
  * Main use is finding SA after policy selected tunnel or transport mode.
  * Also, it can be used by ah/esp icmp error handler to find offending SA.
  */
-static LIST_HEAD(xfrm_state_all);
 static struct hlist_head *xfrm_state_bydst __read_mostly;
 static struct hlist_head *xfrm_state_bysrc __read_mostly;
 static struct hlist_head *xfrm_state_byspi __read_mostly;
@@ -855,7 +854,7 @@
 
 		if (km_query(x, tmpl, pol) == 0) {
 			x->km.state = XFRM_STATE_ACQ;
-			list_add(&x->km.all, &xfrm_state_all);
+			list_add(&x->km.all, &init_net.xfrm.state_all);
 			hlist_add_head(&x->bydst, xfrm_state_bydst+h);
 			h = xfrm_src_hash(daddr, saddr, family);
 			hlist_add_head(&x->bysrc, xfrm_state_bysrc+h);
@@ -924,7 +923,7 @@
 
 	x->genid = ++xfrm_state_genid;
 
-	list_add(&x->km.all, &xfrm_state_all);
+	list_add(&x->km.all, &init_net.xfrm.state_all);
 
 	h = xfrm_dst_hash(&x->id.daddr, &x->props.saddr,
 			  x->props.reqid, x->props.family);
@@ -1053,7 +1052,7 @@
 		xfrm_state_hold(x);
 		x->timer.expires = jiffies + sysctl_xfrm_acq_expires*HZ;
 		add_timer(&x->timer);
-		list_add(&x->km.all, &xfrm_state_all);
+		list_add(&x->km.all, &init_net.xfrm.state_all);
 		hlist_add_head(&x->bydst, xfrm_state_bydst+h);
 		h = xfrm_src_hash(daddr, saddr, family);
 		hlist_add_head(&x->bysrc, xfrm_state_bysrc+h);
@@ -1559,10 +1558,10 @@
 
 	spin_lock_bh(&xfrm_state_lock);
 	if (list_empty(&walk->all))
-		x = list_first_entry(&xfrm_state_all, struct xfrm_state_walk, all);
+		x = list_first_entry(&init_net.xfrm.state_all, struct xfrm_state_walk, all);
 	else
 		x = list_entry(&walk->all, struct xfrm_state_walk, all);
-	list_for_each_entry_from(x, &xfrm_state_all, all) {
+	list_for_each_entry_from(x, &init_net.xfrm.state_all, all) {
 		if (x->state == XFRM_STATE_DEAD)
 			continue;
 		state = container_of(x, struct xfrm_state, km);
@@ -2085,6 +2084,8 @@
 {
 	unsigned int sz;
 
+	INIT_LIST_HEAD(&net->xfrm.state_all);
+
 	sz = sizeof(struct hlist_head) * 8;
 
 	xfrm_state_bydst = xfrm_hash_alloc(sz);
@@ -2100,6 +2101,7 @@
 
 void xfrm_state_fini(struct net *net)
 {
+	WARN_ON(!list_empty(&net->xfrm.state_all));
 }
 
 #ifdef CONFIG_AUDITSYSCALL