netfilter: Pass socket pointer down through okfn().

On the output paths in particular, we have to sometimes deal with two
socket contexts.  First, and usually skb->sk, is the local socket that
generated the frame.

And second, is potentially the socket used to control a tunneling
socket, such as one the encapsulates using UDP.

We do not want to disassociate skb->sk when encapsulating in order
to fix this, because that would break socket memory accounting.

The most extreme case where this can cause huge problems is an
AF_PACKET socket transmitting over a vxlan device.  We hit code
paths doing checks that assume they are dealing with an ipv4
socket, but are actually operating upon the AF_PACKET one.

Signed-off-by: David S. Miller <davem@davemloft.net>
diff --git a/net/ipv4/arp.c b/net/ipv4/arp.c
index c6e67aa..933a928 100644
--- a/net/ipv4/arp.c
+++ b/net/ipv4/arp.c
@@ -591,7 +591,8 @@
 void arp_xmit(struct sk_buff *skb)
 {
 	/* Send it off, maybe filter it using firewalling first.  */
-	NF_HOOK(NFPROTO_ARP, NF_ARP_OUT, skb, NULL, skb->dev, dev_queue_xmit);
+	NF_HOOK(NFPROTO_ARP, NF_ARP_OUT, NULL, skb,
+		NULL, skb->dev, dev_queue_xmit_sk);
 }
 EXPORT_SYMBOL(arp_xmit);
 
@@ -625,7 +626,7 @@
  *	Process an arp request.
  */
 
-static int arp_process(struct sk_buff *skb)
+static int arp_process(struct sock *sk, struct sk_buff *skb)
 {
 	struct net_device *dev = skb->dev;
 	struct in_device *in_dev = __in_dev_get_rcu(dev);
@@ -846,7 +847,7 @@
 
 static void parp_redo(struct sk_buff *skb)
 {
-	arp_process(skb);
+	arp_process(NULL, skb);
 }
 
 
@@ -879,7 +880,8 @@
 
 	memset(NEIGH_CB(skb), 0, sizeof(struct neighbour_cb));
 
-	return NF_HOOK(NFPROTO_ARP, NF_ARP_IN, skb, dev, NULL, arp_process);
+	return NF_HOOK(NFPROTO_ARP, NF_ARP_IN, NULL, skb,
+		       dev, NULL, arp_process);
 
 consumeskb:
 	consume_skb(skb);
diff --git a/net/ipv4/ip_forward.c b/net/ipv4/ip_forward.c
index d9bc28a..939992c 100644
--- a/net/ipv4/ip_forward.c
+++ b/net/ipv4/ip_forward.c
@@ -57,7 +57,7 @@
 }
 
 
-static int ip_forward_finish(struct sk_buff *skb)
+static int ip_forward_finish(struct sock *sk, struct sk_buff *skb)
 {
 	struct ip_options *opt	= &(IPCB(skb)->opt);
 
@@ -68,7 +68,7 @@
 		ip_forward_options(skb);
 
 	skb_sender_cpu_clear(skb);
-	return dst_output(skb);
+	return dst_output_sk(sk, skb);
 }
 
 int ip_forward(struct sk_buff *skb)
@@ -136,8 +136,8 @@
 
 	skb->priority = rt_tos2priority(iph->tos);
 
-	return NF_HOOK(NFPROTO_IPV4, NF_INET_FORWARD, skb, skb->dev,
-		       rt->dst.dev, ip_forward_finish);
+	return NF_HOOK(NFPROTO_IPV4, NF_INET_FORWARD, NULL, skb,
+		       skb->dev, rt->dst.dev, ip_forward_finish);
 
 sr_failed:
 	/*
diff --git a/net/ipv4/ip_input.c b/net/ipv4/ip_input.c
index 2e0410e..2db4c87 100644
--- a/net/ipv4/ip_input.c
+++ b/net/ipv4/ip_input.c
@@ -187,7 +187,7 @@
 	return false;
 }
 
-static int ip_local_deliver_finish(struct sk_buff *skb)
+static int ip_local_deliver_finish(struct sock *sk, struct sk_buff *skb)
 {
 	struct net *net = dev_net(skb->dev);
 
@@ -253,7 +253,8 @@
 			return 0;
 	}
 
-	return NF_HOOK(NFPROTO_IPV4, NF_INET_LOCAL_IN, skb, skb->dev, NULL,
+	return NF_HOOK(NFPROTO_IPV4, NF_INET_LOCAL_IN, NULL, skb,
+		       skb->dev, NULL,
 		       ip_local_deliver_finish);
 }
 
@@ -309,7 +310,7 @@
 int sysctl_ip_early_demux __read_mostly = 1;
 EXPORT_SYMBOL(sysctl_ip_early_demux);
 
-static int ip_rcv_finish(struct sk_buff *skb)
+static int ip_rcv_finish(struct sock *sk, struct sk_buff *skb)
 {
 	const struct iphdr *iph = ip_hdr(skb);
 	struct rtable *rt;
@@ -451,7 +452,8 @@
 	/* Must drop socket now because of tproxy. */
 	skb_orphan(skb);
 
-	return NF_HOOK(NFPROTO_IPV4, NF_INET_PRE_ROUTING, skb, dev, NULL,
+	return NF_HOOK(NFPROTO_IPV4, NF_INET_PRE_ROUTING, NULL, skb,
+		       dev, NULL,
 		       ip_rcv_finish);
 
 csum_error:
diff --git a/net/ipv4/ip_output.c b/net/ipv4/ip_output.c
index 26f6f79..5da4d15 100644
--- a/net/ipv4/ip_output.c
+++ b/net/ipv4/ip_output.c
@@ -91,14 +91,19 @@
 }
 EXPORT_SYMBOL(ip_send_check);
 
-int __ip_local_out(struct sk_buff *skb)
+int __ip_local_out_sk(struct sock *sk, struct sk_buff *skb)
 {
 	struct iphdr *iph = ip_hdr(skb);
 
 	iph->tot_len = htons(skb->len);
 	ip_send_check(iph);
-	return nf_hook(NFPROTO_IPV4, NF_INET_LOCAL_OUT, skb, NULL,
-		       skb_dst(skb)->dev, dst_output);
+	return nf_hook(NFPROTO_IPV4, NF_INET_LOCAL_OUT, sk, skb, NULL,
+		       skb_dst(skb)->dev, dst_output_sk);
+}
+
+int __ip_local_out(struct sk_buff *skb)
+{
+	return __ip_local_out_sk(skb->sk, skb);
 }
 
 int ip_local_out_sk(struct sock *sk, struct sk_buff *skb)
@@ -163,7 +168,7 @@
 }
 EXPORT_SYMBOL_GPL(ip_build_and_send_pkt);
 
-static inline int ip_finish_output2(struct sk_buff *skb)
+static inline int ip_finish_output2(struct sock *sk, struct sk_buff *skb)
 {
 	struct dst_entry *dst = skb_dst(skb);
 	struct rtable *rt = (struct rtable *)dst;
@@ -211,7 +216,7 @@
 	return -EINVAL;
 }
 
-static int ip_finish_output_gso(struct sk_buff *skb)
+static int ip_finish_output_gso(struct sock *sk, struct sk_buff *skb)
 {
 	netdev_features_t features;
 	struct sk_buff *segs;
@@ -220,7 +225,7 @@
 	/* common case: locally created skb or seglen is <= mtu */
 	if (((IPCB(skb)->flags & IPSKB_FORWARDED) == 0) ||
 	      skb_gso_network_seglen(skb) <= ip_skb_dst_mtu(skb))
-		return ip_finish_output2(skb);
+		return ip_finish_output2(sk, skb);
 
 	/* Slowpath -  GSO segment length is exceeding the dst MTU.
 	 *
@@ -243,7 +248,7 @@
 		int err;
 
 		segs->next = NULL;
-		err = ip_fragment(segs, ip_finish_output2);
+		err = ip_fragment(sk, segs, ip_finish_output2);
 
 		if (err && ret == 0)
 			ret = err;
@@ -253,22 +258,22 @@
 	return ret;
 }
 
-static int ip_finish_output(struct sk_buff *skb)
+static int ip_finish_output(struct sock *sk, struct sk_buff *skb)
 {
 #if defined(CONFIG_NETFILTER) && defined(CONFIG_XFRM)
 	/* Policy lookup after SNAT yielded a new policy */
 	if (skb_dst(skb)->xfrm) {
 		IPCB(skb)->flags |= IPSKB_REROUTED;
-		return dst_output(skb);
+		return dst_output_sk(sk, skb);
 	}
 #endif
 	if (skb_is_gso(skb))
-		return ip_finish_output_gso(skb);
+		return ip_finish_output_gso(sk, skb);
 
 	if (skb->len > ip_skb_dst_mtu(skb))
-		return ip_fragment(skb, ip_finish_output2);
+		return ip_fragment(sk, skb, ip_finish_output2);
 
-	return ip_finish_output2(skb);
+	return ip_finish_output2(sk, skb);
 }
 
 int ip_mc_output(struct sock *sk, struct sk_buff *skb)
@@ -307,7 +312,7 @@
 			struct sk_buff *newskb = skb_clone(skb, GFP_ATOMIC);
 			if (newskb)
 				NF_HOOK(NFPROTO_IPV4, NF_INET_POST_ROUTING,
-					newskb, NULL, newskb->dev,
+					sk, newskb, NULL, newskb->dev,
 					dev_loopback_xmit);
 		}
 
@@ -322,11 +327,11 @@
 	if (rt->rt_flags&RTCF_BROADCAST) {
 		struct sk_buff *newskb = skb_clone(skb, GFP_ATOMIC);
 		if (newskb)
-			NF_HOOK(NFPROTO_IPV4, NF_INET_POST_ROUTING, newskb,
+			NF_HOOK(NFPROTO_IPV4, NF_INET_POST_ROUTING, sk, newskb,
 				NULL, newskb->dev, dev_loopback_xmit);
 	}
 
-	return NF_HOOK_COND(NFPROTO_IPV4, NF_INET_POST_ROUTING, skb, NULL,
+	return NF_HOOK_COND(NFPROTO_IPV4, NF_INET_POST_ROUTING, sk, skb, NULL,
 			    skb->dev, ip_finish_output,
 			    !(IPCB(skb)->flags & IPSKB_REROUTED));
 }
@@ -340,7 +345,8 @@
 	skb->dev = dev;
 	skb->protocol = htons(ETH_P_IP);
 
-	return NF_HOOK_COND(NFPROTO_IPV4, NF_INET_POST_ROUTING, skb, NULL, dev,
+	return NF_HOOK_COND(NFPROTO_IPV4, NF_INET_POST_ROUTING, sk, skb,
+			    NULL, dev,
 			    ip_finish_output,
 			    !(IPCB(skb)->flags & IPSKB_REROUTED));
 }
@@ -480,7 +486,8 @@
  *	single device frame, and queue such a frame for sending.
  */
 
-int ip_fragment(struct sk_buff *skb, int (*output)(struct sk_buff *))
+int ip_fragment(struct sock *sk, struct sk_buff *skb,
+		int (*output)(struct sock *, struct sk_buff *))
 {
 	struct iphdr *iph;
 	int ptr;
@@ -593,7 +600,7 @@
 				ip_send_check(iph);
 			}
 
-			err = output(skb);
+			err = output(sk, skb);
 
 			if (!err)
 				IP_INC_STATS(dev_net(dev), IPSTATS_MIB_FRAGCREATES);
@@ -730,7 +737,7 @@
 
 		ip_send_check(iph);
 
-		err = output(skb2);
+		err = output(sk, skb2);
 		if (err)
 			goto fail;
 
diff --git a/net/ipv4/ipmr.c b/net/ipv4/ipmr.c
index 5f17d0e..3a2c016 100644
--- a/net/ipv4/ipmr.c
+++ b/net/ipv4/ipmr.c
@@ -1679,7 +1679,7 @@
 	nf_reset(skb);
 }
 
-static inline int ipmr_forward_finish(struct sk_buff *skb)
+static inline int ipmr_forward_finish(struct sock *sk, struct sk_buff *skb)
 {
 	struct ip_options *opt = &(IPCB(skb)->opt);
 
@@ -1689,7 +1689,7 @@
 	if (unlikely(opt->optlen))
 		ip_forward_options(skb);
 
-	return dst_output(skb);
+	return dst_output_sk(sk, skb);
 }
 
 /*
@@ -1788,7 +1788,8 @@
 	 * not mrouter) cannot join to more than one interface - it will
 	 * result in receiving multiple packets.
 	 */
-	NF_HOOK(NFPROTO_IPV4, NF_INET_FORWARD, skb, skb->dev, dev,
+	NF_HOOK(NFPROTO_IPV4, NF_INET_FORWARD, NULL, skb,
+		skb->dev, dev,
 		ipmr_forward_finish);
 	return;
 
diff --git a/net/ipv4/raw.c b/net/ipv4/raw.c
index 6d0fa8f..c0bb648 100644
--- a/net/ipv4/raw.c
+++ b/net/ipv4/raw.c
@@ -412,8 +412,8 @@
 		icmp_out_count(net, ((struct icmphdr *)
 			skb_transport_header(skb))->type);
 
-	err = NF_HOOK(NFPROTO_IPV4, NF_INET_LOCAL_OUT, skb, NULL,
-		      rt->dst.dev, dst_output);
+	err = NF_HOOK(NFPROTO_IPV4, NF_INET_LOCAL_OUT, sk, skb,
+		      NULL, rt->dst.dev, dst_output_sk);
 	if (err > 0)
 		err = net_xmit_errno(err);
 	if (err)
diff --git a/net/ipv4/xfrm4_input.c b/net/ipv4/xfrm4_input.c
index cac7468..60b032f 100644
--- a/net/ipv4/xfrm4_input.c
+++ b/net/ipv4/xfrm4_input.c
@@ -22,7 +22,7 @@
 	return xfrm4_extract_header(skb);
 }
 
-static inline int xfrm4_rcv_encap_finish(struct sk_buff *skb)
+static inline int xfrm4_rcv_encap_finish(struct sock *sk, struct sk_buff *skb)
 {
 	if (!skb_dst(skb)) {
 		const struct iphdr *iph = ip_hdr(skb);
@@ -52,7 +52,8 @@
 	iph->tot_len = htons(skb->len);
 	ip_send_check(iph);
 
-	NF_HOOK(NFPROTO_IPV4, NF_INET_PRE_ROUTING, skb, skb->dev, NULL,
+	NF_HOOK(NFPROTO_IPV4, NF_INET_PRE_ROUTING, NULL, skb,
+		skb->dev, NULL,
 		xfrm4_rcv_encap_finish);
 	return 0;
 }
diff --git a/net/ipv4/xfrm4_output.c b/net/ipv4/xfrm4_output.c
index dab7381..2878dbf 100644
--- a/net/ipv4/xfrm4_output.c
+++ b/net/ipv4/xfrm4_output.c
@@ -69,7 +69,7 @@
 }
 EXPORT_SYMBOL(xfrm4_prepare_output);
 
-int xfrm4_output_finish(struct sk_buff *skb)
+int xfrm4_output_finish(struct sock *sk, struct sk_buff *skb)
 {
 	memset(IPCB(skb), 0, sizeof(*IPCB(skb)));
 
@@ -77,26 +77,26 @@
 	IPCB(skb)->flags |= IPSKB_XFRM_TRANSFORMED;
 #endif
 
-	return xfrm_output(skb);
+	return xfrm_output(sk, skb);
 }
 
-static int __xfrm4_output(struct sk_buff *skb)
+static int __xfrm4_output(struct sock *sk, struct sk_buff *skb)
 {
 	struct xfrm_state *x = skb_dst(skb)->xfrm;
 
 #ifdef CONFIG_NETFILTER
 	if (!x) {
 		IPCB(skb)->flags |= IPSKB_REROUTED;
-		return dst_output(skb);
+		return dst_output_sk(sk, skb);
 	}
 #endif
 
-	return x->outer_mode->afinfo->output_finish(skb);
+	return x->outer_mode->afinfo->output_finish(sk, skb);
 }
 
 int xfrm4_output(struct sock *sk, struct sk_buff *skb)
 {
-	return NF_HOOK_COND(NFPROTO_IPV4, NF_INET_POST_ROUTING, skb,
+	return NF_HOOK_COND(NFPROTO_IPV4, NF_INET_POST_ROUTING, sk, skb,
 			    NULL, skb_dst(skb)->dev, __xfrm4_output,
 			    !(IPCB(skb)->flags & IPSKB_REROUTED));
 }