udp: correct reuseport selection with connected sockets

UDP reuseport groups can hold a mix unconnected and connected sockets.
Ensure that connections only receive all traffic to their 4-tuple.

Fast reuseport returns on the first reuseport match on the assumption
that all matches are equal. Only if connections are present, return to
the previous behavior of scoring all sockets.

Record if connections are present and if so (1) treat such connected
sockets as an independent match from the group, (2) only return
2-tuple matches from reuseport and (3) do not return on the first
2-tuple reuseport match to allow for a higher scoring match later.

New field has_conns is set without locks. No other fields in the
bitmap are modified at runtime and the field is only ever set
unconditionally, so an RMW cannot miss a change.

Fixes: e32ea7e74727 ("soreuseport: fast reuseport UDP socket selection")
Link: http://lkml.kernel.org/r/CA+FuTSfRP09aJNYRt04SS6qj22ViiOEWaWmLAwX0psk8-PGNxw@mail.gmail.com
Signed-off-by: Willem de Bruijn <willemb@google.com>
Acked-by: Paolo Abeni <pabeni@redhat.com>
Acked-by: Craig Gallek <kraig@google.com>
Signed-off-by: Willem de Bruijn <willemb@google.com>
Signed-off-by: David S. Miller <davem@davemloft.net>
diff --git a/net/core/sock_reuseport.c b/net/core/sock_reuseport.c
index 9408f92..f3ceec9 100644
--- a/net/core/sock_reuseport.c
+++ b/net/core/sock_reuseport.c
@@ -295,8 +295,19 @@ struct sock *reuseport_select_sock(struct sock *sk,
 
 select_by_hash:
 		/* no bpf or invalid bpf result: fall back to hash usage */
-		if (!sk2)
-			sk2 = reuse->socks[reciprocal_scale(hash, socks)];
+		if (!sk2) {
+			int i, j;
+
+			i = j = reciprocal_scale(hash, socks);
+			while (reuse->socks[i]->sk_state == TCP_ESTABLISHED) {
+				i++;
+				if (i >= reuse->num_socks)
+					i = 0;
+				if (i == j)
+					goto out;
+			}
+			sk2 = reuse->socks[i];
+		}
 	}
 
 out:
diff --git a/net/ipv4/datagram.c b/net/ipv4/datagram.c
index 7bd29e6..9a0fe0c 100644
--- a/net/ipv4/datagram.c
+++ b/net/ipv4/datagram.c
@@ -15,6 +15,7 @@
 #include <net/sock.h>
 #include <net/route.h>
 #include <net/tcp_states.h>
+#include <net/sock_reuseport.h>
 
 int __ip4_datagram_connect(struct sock *sk, struct sockaddr *uaddr, int addr_len)
 {
@@ -69,6 +70,7 @@ int __ip4_datagram_connect(struct sock *sk, struct sockaddr *uaddr, int addr_len
 	}
 	inet->inet_daddr = fl4->daddr;
 	inet->inet_dport = usin->sin_port;
+	reuseport_has_conns(sk, true);
 	sk->sk_state = TCP_ESTABLISHED;
 	sk_set_txhash(sk);
 	inet->inet_id = jiffies;
diff --git a/net/ipv4/udp.c b/net/ipv4/udp.c
index d88821c..16486c8 100644
--- a/net/ipv4/udp.c
+++ b/net/ipv4/udp.c
@@ -423,12 +423,13 @@ static struct sock *udp4_lib_lookup2(struct net *net,
 		score = compute_score(sk, net, saddr, sport,
 				      daddr, hnum, dif, sdif);
 		if (score > badness) {
-			if (sk->sk_reuseport) {
+			if (sk->sk_reuseport &&
+			    sk->sk_state != TCP_ESTABLISHED) {
 				hash = udp_ehashfn(net, daddr, hnum,
 						   saddr, sport);
 				result = reuseport_select_sock(sk, hash, skb,
 							sizeof(struct udphdr));
-				if (result)
+				if (result && !reuseport_has_conns(sk, false))
 					return result;
 			}
 			badness = score;
diff --git a/net/ipv6/datagram.c b/net/ipv6/datagram.c
index 9ab897d..96f9392 100644
--- a/net/ipv6/datagram.c
+++ b/net/ipv6/datagram.c
@@ -27,6 +27,7 @@
 #include <net/ip6_route.h>
 #include <net/tcp_states.h>
 #include <net/dsfield.h>
+#include <net/sock_reuseport.h>
 
 #include <linux/errqueue.h>
 #include <linux/uaccess.h>
@@ -254,6 +255,7 @@ int __ip6_datagram_connect(struct sock *sk, struct sockaddr *uaddr,
 		goto out;
 	}
 
+	reuseport_has_conns(sk, true);
 	sk->sk_state = TCP_ESTABLISHED;
 	sk_set_txhash(sk);
 out:
diff --git a/net/ipv6/udp.c b/net/ipv6/udp.c
index 827fe73..5995fdc 100644
--- a/net/ipv6/udp.c
+++ b/net/ipv6/udp.c
@@ -158,13 +158,14 @@ static struct sock *udp6_lib_lookup2(struct net *net,
 		score = compute_score(sk, net, saddr, sport,
 				      daddr, hnum, dif, sdif);
 		if (score > badness) {
-			if (sk->sk_reuseport) {
+			if (sk->sk_reuseport &&
+			    sk->sk_state != TCP_ESTABLISHED) {
 				hash = udp6_ehashfn(net, daddr, hnum,
 						    saddr, sport);
 
 				result = reuseport_select_sock(sk, hash, skb,
 							sizeof(struct udphdr));
-				if (result)
+				if (result && !reuseport_has_conns(sk, false))
 					return result;
 			}
 			result = sk;