| // SPDX-License-Identifier: GPL-2.0-only |
| /* |
| * Handshake request lifetime events |
| * |
| * Author: Chuck Lever <chuck.lever@oracle.com> |
| * |
| * Copyright (c) 2023, Oracle and/or its affiliates. |
| */ |
| |
| #include <linux/types.h> |
| #include <linux/socket.h> |
| #include <linux/kernel.h> |
| #include <linux/module.h> |
| #include <linux/skbuff.h> |
| #include <linux/inet.h> |
| #include <linux/rhashtable.h> |
| |
| #include <net/sock.h> |
| #include <net/genetlink.h> |
| #include <net/netns/generic.h> |
| |
| #include <kunit/visibility.h> |
| |
| #include <uapi/linux/handshake.h> |
| #include "handshake.h" |
| |
| #include <trace/events/handshake.h> |
| |
| /* |
| * We need both a handshake_req -> sock mapping, and a sock -> |
| * handshake_req mapping. Both are one-to-one. |
| * |
| * To avoid adding another pointer field to struct sock, net/handshake |
| * maintains a hash table, indexed by the memory address of @sock, to |
| * find the struct handshake_req outstanding for that socket. The |
| * reverse direction uses a simple pointer field in the handshake_req |
| * struct. |
| */ |
| |
| static struct rhashtable handshake_rhashtbl ____cacheline_aligned_in_smp; |
| |
| static const struct rhashtable_params handshake_rhash_params = { |
| .key_len = sizeof_field(struct handshake_req, hr_sk), |
| .key_offset = offsetof(struct handshake_req, hr_sk), |
| .head_offset = offsetof(struct handshake_req, hr_rhash), |
| .automatic_shrinking = true, |
| }; |
| |
| int handshake_req_hash_init(void) |
| { |
| return rhashtable_init(&handshake_rhashtbl, &handshake_rhash_params); |
| } |
| |
| void handshake_req_hash_destroy(void) |
| { |
| rhashtable_destroy(&handshake_rhashtbl); |
| } |
| |
| struct handshake_req *handshake_req_hash_lookup(struct sock *sk) |
| { |
| return rhashtable_lookup_fast(&handshake_rhashtbl, &sk, |
| handshake_rhash_params); |
| } |
| EXPORT_SYMBOL_IF_KUNIT(handshake_req_hash_lookup); |
| |
| static bool handshake_req_hash_add(struct handshake_req *req) |
| { |
| int ret; |
| |
| ret = rhashtable_lookup_insert_fast(&handshake_rhashtbl, |
| &req->hr_rhash, |
| handshake_rhash_params); |
| return ret == 0; |
| } |
| |
| static void handshake_req_destroy(struct handshake_req *req) |
| { |
| if (req->hr_proto->hp_destroy) |
| req->hr_proto->hp_destroy(req); |
| rhashtable_remove_fast(&handshake_rhashtbl, &req->hr_rhash, |
| handshake_rhash_params); |
| kfree(req); |
| } |
| |
| static void handshake_sk_destruct(struct sock *sk) |
| { |
| void (*sk_destruct)(struct sock *sk); |
| struct handshake_req *req; |
| |
| req = handshake_req_hash_lookup(sk); |
| if (!req) |
| return; |
| |
| trace_handshake_destruct(sock_net(sk), req, sk); |
| sk_destruct = req->hr_odestruct; |
| handshake_req_destroy(req); |
| if (sk_destruct) |
| sk_destruct(sk); |
| } |
| |
| /** |
| * handshake_req_alloc - Allocate a handshake request |
| * @proto: security protocol |
| * @flags: memory allocation flags |
| * |
| * Returns an initialized handshake_req or NULL. |
| */ |
| struct handshake_req *handshake_req_alloc(const struct handshake_proto *proto, |
| gfp_t flags) |
| { |
| struct handshake_req *req; |
| |
| if (!proto) |
| return NULL; |
| if (proto->hp_handler_class <= HANDSHAKE_HANDLER_CLASS_NONE) |
| return NULL; |
| if (proto->hp_handler_class >= HANDSHAKE_HANDLER_CLASS_MAX) |
| return NULL; |
| if (!proto->hp_accept || !proto->hp_done) |
| return NULL; |
| |
| req = kzalloc(struct_size(req, hr_priv, proto->hp_privsize), flags); |
| if (!req) |
| return NULL; |
| |
| INIT_LIST_HEAD(&req->hr_list); |
| req->hr_proto = proto; |
| return req; |
| } |
| EXPORT_SYMBOL(handshake_req_alloc); |
| |
| /** |
| * handshake_req_private - Get per-handshake private data |
| * @req: handshake arguments |
| * |
| */ |
| void *handshake_req_private(struct handshake_req *req) |
| { |
| return (void *)&req->hr_priv; |
| } |
| EXPORT_SYMBOL(handshake_req_private); |
| |
| static bool __add_pending_locked(struct handshake_net *hn, |
| struct handshake_req *req) |
| { |
| if (WARN_ON_ONCE(!list_empty(&req->hr_list))) |
| return false; |
| hn->hn_pending++; |
| list_add_tail(&req->hr_list, &hn->hn_requests); |
| return true; |
| } |
| |
| static void __remove_pending_locked(struct handshake_net *hn, |
| struct handshake_req *req) |
| { |
| hn->hn_pending--; |
| list_del_init(&req->hr_list); |
| } |
| |
| /* |
| * Returns %true if the request was found on @net's pending list, |
| * otherwise %false. |
| * |
| * If @req was on a pending list, it has not yet been accepted. |
| */ |
| static bool remove_pending(struct handshake_net *hn, struct handshake_req *req) |
| { |
| bool ret = false; |
| |
| spin_lock(&hn->hn_lock); |
| if (!list_empty(&req->hr_list)) { |
| __remove_pending_locked(hn, req); |
| ret = true; |
| } |
| spin_unlock(&hn->hn_lock); |
| |
| return ret; |
| } |
| |
| struct handshake_req *handshake_req_next(struct handshake_net *hn, int class) |
| { |
| struct handshake_req *req, *pos; |
| |
| req = NULL; |
| spin_lock(&hn->hn_lock); |
| list_for_each_entry(pos, &hn->hn_requests, hr_list) { |
| if (pos->hr_proto->hp_handler_class != class) |
| continue; |
| __remove_pending_locked(hn, pos); |
| req = pos; |
| break; |
| } |
| spin_unlock(&hn->hn_lock); |
| |
| return req; |
| } |
| EXPORT_SYMBOL_IF_KUNIT(handshake_req_next); |
| |
| /** |
| * handshake_req_submit - Submit a handshake request |
| * @sock: open socket on which to perform the handshake |
| * @req: handshake arguments |
| * @flags: memory allocation flags |
| * |
| * Return values: |
| * %0: Request queued |
| * %-EINVAL: Invalid argument |
| * %-EBUSY: A handshake is already under way for this socket |
| * %-ESRCH: No handshake agent is available |
| * %-EAGAIN: Too many pending handshake requests |
| * %-ENOMEM: Failed to allocate memory |
| * %-EMSGSIZE: Failed to construct notification message |
| * %-EOPNOTSUPP: Handshake module not initialized |
| * |
| * A zero return value from handshake_req_submit() means that |
| * exactly one subsequent completion callback is guaranteed. |
| * |
| * A negative return value from handshake_req_submit() means that |
| * no completion callback will be done and that @req has been |
| * destroyed. |
| */ |
| int handshake_req_submit(struct socket *sock, struct handshake_req *req, |
| gfp_t flags) |
| { |
| struct handshake_net *hn; |
| struct net *net; |
| int ret; |
| |
| if (!sock || !req || !sock->file) { |
| kfree(req); |
| return -EINVAL; |
| } |
| |
| req->hr_sk = sock->sk; |
| if (!req->hr_sk) { |
| kfree(req); |
| return -EINVAL; |
| } |
| req->hr_odestruct = req->hr_sk->sk_destruct; |
| req->hr_sk->sk_destruct = handshake_sk_destruct; |
| |
| ret = -EOPNOTSUPP; |
| net = sock_net(req->hr_sk); |
| hn = handshake_pernet(net); |
| if (!hn) |
| goto out_err; |
| |
| ret = -EAGAIN; |
| if (READ_ONCE(hn->hn_pending) >= hn->hn_pending_max) |
| goto out_err; |
| |
| spin_lock(&hn->hn_lock); |
| ret = -EOPNOTSUPP; |
| if (test_bit(HANDSHAKE_F_NET_DRAINING, &hn->hn_flags)) |
| goto out_unlock; |
| ret = -EBUSY; |
| if (!handshake_req_hash_add(req)) |
| goto out_unlock; |
| if (!__add_pending_locked(hn, req)) |
| goto out_unlock; |
| spin_unlock(&hn->hn_lock); |
| |
| ret = handshake_genl_notify(net, req->hr_proto, flags); |
| if (ret) { |
| trace_handshake_notify_err(net, req, req->hr_sk, ret); |
| if (remove_pending(hn, req)) |
| goto out_err; |
| } |
| |
| /* Prevent socket release while a handshake request is pending */ |
| sock_hold(req->hr_sk); |
| |
| trace_handshake_submit(net, req, req->hr_sk); |
| return 0; |
| |
| out_unlock: |
| spin_unlock(&hn->hn_lock); |
| out_err: |
| trace_handshake_submit_err(net, req, req->hr_sk, ret); |
| handshake_req_destroy(req); |
| return ret; |
| } |
| EXPORT_SYMBOL(handshake_req_submit); |
| |
| void handshake_complete(struct handshake_req *req, unsigned int status, |
| struct genl_info *info) |
| { |
| struct sock *sk = req->hr_sk; |
| struct net *net = sock_net(sk); |
| |
| if (!test_and_set_bit(HANDSHAKE_F_REQ_COMPLETED, &req->hr_flags)) { |
| trace_handshake_complete(net, req, sk, status); |
| req->hr_proto->hp_done(req, status, info); |
| |
| /* Handshake request is no longer pending */ |
| sock_put(sk); |
| } |
| } |
| EXPORT_SYMBOL_IF_KUNIT(handshake_complete); |
| |
| /** |
| * handshake_req_cancel - Cancel an in-progress handshake |
| * @sk: socket on which there is an ongoing handshake |
| * |
| * Request cancellation races with request completion. To determine |
| * who won, callers examine the return value from this function. |
| * |
| * Return values: |
| * %true - Uncompleted handshake request was canceled |
| * %false - Handshake request already completed or not found |
| */ |
| bool handshake_req_cancel(struct sock *sk) |
| { |
| struct handshake_req *req; |
| struct handshake_net *hn; |
| struct net *net; |
| |
| net = sock_net(sk); |
| req = handshake_req_hash_lookup(sk); |
| if (!req) { |
| trace_handshake_cancel_none(net, req, sk); |
| return false; |
| } |
| |
| hn = handshake_pernet(net); |
| if (hn && remove_pending(hn, req)) { |
| /* Request hadn't been accepted */ |
| goto out_true; |
| } |
| if (test_and_set_bit(HANDSHAKE_F_REQ_COMPLETED, &req->hr_flags)) { |
| /* Request already completed */ |
| trace_handshake_cancel_busy(net, req, sk); |
| return false; |
| } |
| |
| out_true: |
| trace_handshake_cancel(net, req, sk); |
| |
| /* Handshake request is no longer pending */ |
| sock_put(sk); |
| return true; |
| } |
| EXPORT_SYMBOL(handshake_req_cancel); |