net: ipv4: add second dif to inet socket lookups

Add a second device index, sdif, to inet socket lookups. sdif is the
index for ingress devices enslaved to an l3mdev. It allows the lookups
to consider the enslaved device as well as the L3 domain when searching
for a socket.

TCP moves the data in the cb. Prior to tcp_v4_rcv (e.g., early demux) the
ingress index is obtained from IPCB using inet_sdif and after the cb move
in  tcp_v4_rcv the tcp_v4_sdif helper is used.

Signed-off-by: David Ahern <dsahern@gmail.com>
Signed-off-by: David S. Miller <davem@davemloft.net>
diff --git a/include/net/inet_hashtables.h b/include/net/inet_hashtables.h
index 5026b1f..2dbbbff 100644
--- a/include/net/inet_hashtables.h
+++ b/include/net/inet_hashtables.h
@@ -221,16 +221,16 @@ struct sock *__inet_lookup_listener(struct net *net,
 				    const __be32 saddr, const __be16 sport,
 				    const __be32 daddr,
 				    const unsigned short hnum,
-				    const int dif);
+				    const int dif, const int sdif);
 
 static inline struct sock *inet_lookup_listener(struct net *net,
 		struct inet_hashinfo *hashinfo,
 		struct sk_buff *skb, int doff,
 		__be32 saddr, __be16 sport,
-		__be32 daddr, __be16 dport, int dif)
+		__be32 daddr, __be16 dport, int dif, int sdif)
 {
 	return __inet_lookup_listener(net, hashinfo, skb, doff, saddr, sport,
-				      daddr, ntohs(dport), dif);
+				      daddr, ntohs(dport), dif, sdif);
 }
 
 /* Socket demux engine toys. */
@@ -262,22 +262,24 @@ static inline struct sock *inet_lookup_listener(struct net *net,
 				   (((__force __u64)(__be32)(__daddr)) << 32) | \
 				   ((__force __u64)(__be32)(__saddr)))
 #endif /* __BIG_ENDIAN */
-#define INET_MATCH(__sk, __net, __cookie, __saddr, __daddr, __ports, __dif)	\
+#define INET_MATCH(__sk, __net, __cookie, __saddr, __daddr, __ports, __dif, __sdif) \
 	(((__sk)->sk_portpair == (__ports))			&&	\
 	 ((__sk)->sk_addrpair == (__cookie))			&&	\
 	 (!(__sk)->sk_bound_dev_if	||				\
-	   ((__sk)->sk_bound_dev_if == (__dif))) 		&& 	\
+	   ((__sk)->sk_bound_dev_if == (__dif))			||	\
+	   ((__sk)->sk_bound_dev_if == (__sdif)))		&&	\
 	 net_eq(sock_net(__sk), (__net)))
 #else /* 32-bit arch */
 #define INET_ADDR_COOKIE(__name, __saddr, __daddr) \
 	const int __name __deprecated __attribute__((unused))
 
-#define INET_MATCH(__sk, __net, __cookie, __saddr, __daddr, __ports, __dif) \
+#define INET_MATCH(__sk, __net, __cookie, __saddr, __daddr, __ports, __dif, __sdif) \
 	(((__sk)->sk_portpair == (__ports))		&&		\
 	 ((__sk)->sk_daddr	== (__saddr))		&&		\
 	 ((__sk)->sk_rcv_saddr	== (__daddr))		&&		\
 	 (!(__sk)->sk_bound_dev_if	||				\
-	   ((__sk)->sk_bound_dev_if == (__dif))) 	&&		\
+	   ((__sk)->sk_bound_dev_if == (__dif))		||		\
+	   ((__sk)->sk_bound_dev_if == (__sdif)))	&&		\
 	 net_eq(sock_net(__sk), (__net)))
 #endif /* 64-bit arch */
 
@@ -288,7 +290,7 @@ struct sock *__inet_lookup_established(struct net *net,
 				       struct inet_hashinfo *hashinfo,
 				       const __be32 saddr, const __be16 sport,
 				       const __be32 daddr, const u16 hnum,
-				       const int dif);
+				       const int dif, const int sdif);
 
 static inline struct sock *
 	inet_lookup_established(struct net *net, struct inet_hashinfo *hashinfo,
@@ -297,7 +299,7 @@ static inline struct sock *
 				const int dif)
 {
 	return __inet_lookup_established(net, hashinfo, saddr, sport, daddr,
-					 ntohs(dport), dif);
+					 ntohs(dport), dif, 0);
 }
 
 static inline struct sock *__inet_lookup(struct net *net,
@@ -305,20 +307,20 @@ static inline struct sock *__inet_lookup(struct net *net,
 					 struct sk_buff *skb, int doff,
 					 const __be32 saddr, const __be16 sport,
 					 const __be32 daddr, const __be16 dport,
-					 const int dif,
+					 const int dif, const int sdif,
 					 bool *refcounted)
 {
 	u16 hnum = ntohs(dport);
 	struct sock *sk;
 
 	sk = __inet_lookup_established(net, hashinfo, saddr, sport,
-				       daddr, hnum, dif);
+				       daddr, hnum, dif, sdif);
 	*refcounted = true;
 	if (sk)
 		return sk;
 	*refcounted = false;
 	return __inet_lookup_listener(net, hashinfo, skb, doff, saddr,
-				      sport, daddr, hnum, dif);
+				      sport, daddr, hnum, dif, sdif);
 }
 
 static inline struct sock *inet_lookup(struct net *net,
@@ -332,7 +334,7 @@ static inline struct sock *inet_lookup(struct net *net,
 	bool refcounted;
 
 	sk = __inet_lookup(net, hashinfo, skb, doff, saddr, sport, daddr,
-			   dport, dif, &refcounted);
+			   dport, dif, 0, &refcounted);
 
 	if (sk && !refcounted && !refcount_inc_not_zero(&sk->sk_refcnt))
 		sk = NULL;
@@ -344,6 +346,7 @@ static inline struct sock *__inet_lookup_skb(struct inet_hashinfo *hashinfo,
 					     int doff,
 					     const __be16 sport,
 					     const __be16 dport,
+					     const int sdif,
 					     bool *refcounted)
 {
 	struct sock *sk = skb_steal_sock(skb);
@@ -355,7 +358,7 @@ static inline struct sock *__inet_lookup_skb(struct inet_hashinfo *hashinfo,
 
 	return __inet_lookup(dev_net(skb_dst(skb)->dev), hashinfo, skb,
 			     doff, iph->saddr, sport,
-			     iph->daddr, dport, inet_iif(skb),
+			     iph->daddr, dport, inet_iif(skb), sdif,
 			     refcounted);
 }
 
diff --git a/include/net/tcp.h b/include/net/tcp.h
index 5173fec..2b89f1a 100644
--- a/include/net/tcp.h
+++ b/include/net/tcp.h
@@ -840,6 +840,16 @@ static inline bool inet_exact_dif_match(struct net *net, struct sk_buff *skb)
 	return false;
 }
 
+/* TCP_SKB_CB reference means this can not be used from early demux */
+static inline int tcp_v4_sdif(struct sk_buff *skb)
+{
+#if IS_ENABLED(CONFIG_NET_L3_MASTER_DEV)
+	if (skb && ipv4_l3mdev_skb(TCP_SKB_CB(skb)->header.h4.flags))
+		return TCP_SKB_CB(skb)->header.h4.iif;
+#endif
+	return 0;
+}
+
 /* Due to TSO, an SKB can be composed of multiple actual
  * packets.  To keep these tracked properly, we use this.
  */