mptcp: Add key generation and token tree
Generate the local keys, IDSN, and token when creating a new socket.
Introduce the token tree to track all tokens in use using a radix tree
with the MPTCP token itself as the index.
Override the rebuild_header callback in inet_connection_sock_af_ops for
creating the local key on a new outgoing connection.
Override the init_req callback of tcp_request_sock_ops for creating the
local key on a new incoming connection.
Will be used to obtain the MPTCP parent socket to handle incoming joins.
Co-developed-by: Davide Caratti <dcaratti@redhat.com>
Signed-off-by: Davide Caratti <dcaratti@redhat.com>
Co-developed-by: Florian Westphal <fw@strlen.de>
Signed-off-by: Florian Westphal <fw@strlen.de>
Co-developed-by: Paolo Abeni <pabeni@redhat.com>
Signed-off-by: Paolo Abeni <pabeni@redhat.com>
Signed-off-by: Peter Krystad <peter.krystad@linux.intel.com>
Signed-off-by: Christoph Paasch <cpaasch@apple.com>
Signed-off-by: David S. Miller <davem@davemloft.net>
diff --git a/net/mptcp/protocol.c b/net/mptcp/protocol.c
index e08a25e..3f66b6a 100644
--- a/net/mptcp/protocol.c
+++ b/net/mptcp/protocol.c
@@ -201,6 +201,7 @@ static void mptcp_close(struct sock *sk, long timeout)
struct mptcp_subflow_context *subflow, *tmp;
struct mptcp_sock *msk = mptcp_sk(sk);
+ mptcp_token_destroy(msk->token);
inet_sk_state_store(sk, TCP_CLOSE);
lock_sock(sk);
@@ -281,8 +282,10 @@ static struct sock *mptcp_accept(struct sock *sk, int flags, int *err,
msk = mptcp_sk(new_mptcp_sock);
msk->remote_key = subflow->remote_key;
msk->local_key = subflow->local_key;
+ msk->token = subflow->token;
msk->subflow = NULL;
+ mptcp_token_update_accept(newsk, new_mptcp_sock);
newsk = new_mptcp_sock;
mptcp_copy_inaddrs(newsk, ssk);
list_add(&subflow->node, &msk->conn_list);
@@ -299,6 +302,10 @@ static struct sock *mptcp_accept(struct sock *sk, int flags, int *err,
return newsk;
}
+static void mptcp_destroy(struct sock *sk)
+{
+}
+
static int mptcp_get_port(struct sock *sk, unsigned short snum)
{
struct mptcp_sock *msk = mptcp_sk(sk);
@@ -331,6 +338,7 @@ void mptcp_finish_connect(struct sock *ssk)
*/
WRITE_ONCE(msk->remote_key, subflow->remote_key);
WRITE_ONCE(msk->local_key, subflow->local_key);
+ WRITE_ONCE(msk->token, subflow->token);
}
static void mptcp_sock_graft(struct sock *sk, struct socket *parent)
@@ -349,6 +357,7 @@ static struct proto mptcp_prot = {
.close = mptcp_close,
.accept = mptcp_accept,
.shutdown = tcp_shutdown,
+ .destroy = mptcp_destroy,
.sendmsg = mptcp_sendmsg,
.recvmsg = mptcp_recvmsg,
.hash = inet_hash,
@@ -568,6 +577,12 @@ void __init mptcp_init(void)
static struct proto_ops mptcp_v6_stream_ops;
static struct proto mptcp_v6_prot;
+static void mptcp_v6_destroy(struct sock *sk)
+{
+ mptcp_destroy(sk);
+ inet6_destroy_sock(sk);
+}
+
static struct inet_protosw mptcp_v6_protosw = {
.type = SOCK_STREAM,
.protocol = IPPROTO_MPTCP,
@@ -583,6 +598,7 @@ int mptcpv6_init(void)
mptcp_v6_prot = mptcp_prot;
strcpy(mptcp_v6_prot.name, "MPTCPv6");
mptcp_v6_prot.slab = NULL;
+ mptcp_v6_prot.destroy = mptcp_v6_destroy;
mptcp_v6_prot.obj_size = sizeof(struct mptcp_sock) +
sizeof(struct ipv6_pinfo);