mptcp: Add key generation and token tree

Generate the local keys, IDSN, and token when creating a new socket.
Introduce the token tree to track all tokens in use using a radix tree
with the MPTCP token itself as the index.

Override the rebuild_header callback in inet_connection_sock_af_ops for
creating the local key on a new outgoing connection.

Override the init_req callback of tcp_request_sock_ops for creating the
local key on a new incoming connection.

Will be used to obtain the MPTCP parent socket to handle incoming joins.

Co-developed-by: Davide Caratti <dcaratti@redhat.com>
Signed-off-by: Davide Caratti <dcaratti@redhat.com>
Co-developed-by: Florian Westphal <fw@strlen.de>
Signed-off-by: Florian Westphal <fw@strlen.de>
Co-developed-by: Paolo Abeni <pabeni@redhat.com>
Signed-off-by: Paolo Abeni <pabeni@redhat.com>
Signed-off-by: Peter Krystad <peter.krystad@linux.intel.com>
Signed-off-by: Christoph Paasch <cpaasch@apple.com>
Signed-off-by: David S. Miller <davem@davemloft.net>
diff --git a/net/mptcp/protocol.c b/net/mptcp/protocol.c
index e08a25e..3f66b6a 100644
--- a/net/mptcp/protocol.c
+++ b/net/mptcp/protocol.c
@@ -201,6 +201,7 @@ static void mptcp_close(struct sock *sk, long timeout)
 	struct mptcp_subflow_context *subflow, *tmp;
 	struct mptcp_sock *msk = mptcp_sk(sk);
 
+	mptcp_token_destroy(msk->token);
 	inet_sk_state_store(sk, TCP_CLOSE);
 
 	lock_sock(sk);
@@ -281,8 +282,10 @@ static struct sock *mptcp_accept(struct sock *sk, int flags, int *err,
 		msk = mptcp_sk(new_mptcp_sock);
 		msk->remote_key = subflow->remote_key;
 		msk->local_key = subflow->local_key;
+		msk->token = subflow->token;
 		msk->subflow = NULL;
 
+		mptcp_token_update_accept(newsk, new_mptcp_sock);
 		newsk = new_mptcp_sock;
 		mptcp_copy_inaddrs(newsk, ssk);
 		list_add(&subflow->node, &msk->conn_list);
@@ -299,6 +302,10 @@ static struct sock *mptcp_accept(struct sock *sk, int flags, int *err,
 	return newsk;
 }
 
+static void mptcp_destroy(struct sock *sk)
+{
+}
+
 static int mptcp_get_port(struct sock *sk, unsigned short snum)
 {
 	struct mptcp_sock *msk = mptcp_sk(sk);
@@ -331,6 +338,7 @@ void mptcp_finish_connect(struct sock *ssk)
 	 */
 	WRITE_ONCE(msk->remote_key, subflow->remote_key);
 	WRITE_ONCE(msk->local_key, subflow->local_key);
+	WRITE_ONCE(msk->token, subflow->token);
 }
 
 static void mptcp_sock_graft(struct sock *sk, struct socket *parent)
@@ -349,6 +357,7 @@ static struct proto mptcp_prot = {
 	.close		= mptcp_close,
 	.accept		= mptcp_accept,
 	.shutdown	= tcp_shutdown,
+	.destroy	= mptcp_destroy,
 	.sendmsg	= mptcp_sendmsg,
 	.recvmsg	= mptcp_recvmsg,
 	.hash		= inet_hash,
@@ -568,6 +577,12 @@ void __init mptcp_init(void)
 static struct proto_ops mptcp_v6_stream_ops;
 static struct proto mptcp_v6_prot;
 
+static void mptcp_v6_destroy(struct sock *sk)
+{
+	mptcp_destroy(sk);
+	inet6_destroy_sock(sk);
+}
+
 static struct inet_protosw mptcp_v6_protosw = {
 	.type		= SOCK_STREAM,
 	.protocol	= IPPROTO_MPTCP,
@@ -583,6 +598,7 @@ int mptcpv6_init(void)
 	mptcp_v6_prot = mptcp_prot;
 	strcpy(mptcp_v6_prot.name, "MPTCPv6");
 	mptcp_v6_prot.slab = NULL;
+	mptcp_v6_prot.destroy = mptcp_v6_destroy;
 	mptcp_v6_prot.obj_size = sizeof(struct mptcp_sock) +
 				 sizeof(struct ipv6_pinfo);