| // SPDX-License-Identifier: GPL-2.0-only | 
 | /* | 
 |  * Copyright (c) 2021, 2022 Oracle.  All rights reserved. | 
 |  * | 
 |  * The AUTH_TLS credential is used only to probe a remote peer | 
 |  * for RPC-over-TLS support. | 
 |  */ | 
 |  | 
 | #include <linux/types.h> | 
 | #include <linux/module.h> | 
 | #include <linux/sunrpc/clnt.h> | 
 |  | 
 | static const char *starttls_token = "STARTTLS"; | 
 | static const size_t starttls_len = 8; | 
 |  | 
 | static struct rpc_auth tls_auth; | 
 | static struct rpc_cred tls_cred; | 
 |  | 
 | static void tls_encode_probe(struct rpc_rqst *rqstp, struct xdr_stream *xdr, | 
 | 			     const void *obj) | 
 | { | 
 | } | 
 |  | 
 | static int tls_decode_probe(struct rpc_rqst *rqstp, struct xdr_stream *xdr, | 
 | 			    void *obj) | 
 | { | 
 | 	return 0; | 
 | } | 
 |  | 
 | static const struct rpc_procinfo rpcproc_tls_probe = { | 
 | 	.p_encode	= tls_encode_probe, | 
 | 	.p_decode	= tls_decode_probe, | 
 | }; | 
 |  | 
 | static void rpc_tls_probe_call_prepare(struct rpc_task *task, void *data) | 
 | { | 
 | 	task->tk_flags &= ~RPC_TASK_NO_RETRANS_TIMEOUT; | 
 | 	rpc_call_start(task); | 
 | } | 
 |  | 
 | static void rpc_tls_probe_call_done(struct rpc_task *task, void *data) | 
 | { | 
 | } | 
 |  | 
 | static const struct rpc_call_ops rpc_tls_probe_ops = { | 
 | 	.rpc_call_prepare	= rpc_tls_probe_call_prepare, | 
 | 	.rpc_call_done		= rpc_tls_probe_call_done, | 
 | }; | 
 |  | 
 | static int tls_probe(struct rpc_clnt *clnt) | 
 | { | 
 | 	struct rpc_message msg = { | 
 | 		.rpc_proc	= &rpcproc_tls_probe, | 
 | 	}; | 
 | 	struct rpc_task_setup task_setup_data = { | 
 | 		.rpc_client	= clnt, | 
 | 		.rpc_message	= &msg, | 
 | 		.rpc_op_cred	= &tls_cred, | 
 | 		.callback_ops	= &rpc_tls_probe_ops, | 
 | 		.flags		= RPC_TASK_SOFT | RPC_TASK_SOFTCONN, | 
 | 	}; | 
 | 	struct rpc_task	*task; | 
 | 	int status; | 
 |  | 
 | 	task = rpc_run_task(&task_setup_data); | 
 | 	if (IS_ERR(task)) | 
 | 		return PTR_ERR(task); | 
 | 	status = task->tk_status; | 
 | 	rpc_put_task(task); | 
 | 	return status; | 
 | } | 
 |  | 
 | static struct rpc_auth *tls_create(const struct rpc_auth_create_args *args, | 
 | 				   struct rpc_clnt *clnt) | 
 | { | 
 | 	refcount_inc(&tls_auth.au_count); | 
 | 	return &tls_auth; | 
 | } | 
 |  | 
 | static void tls_destroy(struct rpc_auth *auth) | 
 | { | 
 | } | 
 |  | 
 | static struct rpc_cred *tls_lookup_cred(struct rpc_auth *auth, | 
 | 					struct auth_cred *acred, int flags) | 
 | { | 
 | 	return get_rpccred(&tls_cred); | 
 | } | 
 |  | 
 | static void tls_destroy_cred(struct rpc_cred *cred) | 
 | { | 
 | } | 
 |  | 
 | static int tls_match(struct auth_cred *acred, struct rpc_cred *cred, int taskflags) | 
 | { | 
 | 	return 1; | 
 | } | 
 |  | 
 | static int tls_marshal(struct rpc_task *task, struct xdr_stream *xdr) | 
 | { | 
 | 	__be32 *p; | 
 |  | 
 | 	p = xdr_reserve_space(xdr, 4 * XDR_UNIT); | 
 | 	if (!p) | 
 | 		return -EMSGSIZE; | 
 | 	/* Credential */ | 
 | 	*p++ = rpc_auth_tls; | 
 | 	*p++ = xdr_zero; | 
 | 	/* Verifier */ | 
 | 	*p++ = rpc_auth_null; | 
 | 	*p   = xdr_zero; | 
 | 	return 0; | 
 | } | 
 |  | 
 | static int tls_refresh(struct rpc_task *task) | 
 | { | 
 | 	set_bit(RPCAUTH_CRED_UPTODATE, &task->tk_rqstp->rq_cred->cr_flags); | 
 | 	return 0; | 
 | } | 
 |  | 
 | static int tls_validate(struct rpc_task *task, struct xdr_stream *xdr) | 
 | { | 
 | 	__be32 *p; | 
 | 	void *str; | 
 |  | 
 | 	p = xdr_inline_decode(xdr, XDR_UNIT); | 
 | 	if (!p) | 
 | 		return -EIO; | 
 | 	if (*p != rpc_auth_null) | 
 | 		return -EIO; | 
 | 	if (xdr_stream_decode_opaque_inline(xdr, &str, starttls_len) != starttls_len) | 
 | 		return -EPROTONOSUPPORT; | 
 | 	if (memcmp(str, starttls_token, starttls_len)) | 
 | 		return -EPROTONOSUPPORT; | 
 | 	return 0; | 
 | } | 
 |  | 
 | const struct rpc_authops authtls_ops = { | 
 | 	.owner		= THIS_MODULE, | 
 | 	.au_flavor	= RPC_AUTH_TLS, | 
 | 	.au_name	= "NULL", | 
 | 	.create		= tls_create, | 
 | 	.destroy	= tls_destroy, | 
 | 	.lookup_cred	= tls_lookup_cred, | 
 | 	.ping		= tls_probe, | 
 | }; | 
 |  | 
 | static struct rpc_auth tls_auth = { | 
 | 	.au_cslack	= NUL_CALLSLACK, | 
 | 	.au_rslack	= NUL_REPLYSLACK, | 
 | 	.au_verfsize	= NUL_REPLYSLACK, | 
 | 	.au_ralign	= NUL_REPLYSLACK, | 
 | 	.au_ops		= &authtls_ops, | 
 | 	.au_flavor	= RPC_AUTH_TLS, | 
 | 	.au_count	= REFCOUNT_INIT(1), | 
 | }; | 
 |  | 
 | static const struct rpc_credops tls_credops = { | 
 | 	.cr_name	= "AUTH_TLS", | 
 | 	.crdestroy	= tls_destroy_cred, | 
 | 	.crmatch	= tls_match, | 
 | 	.crmarshal	= tls_marshal, | 
 | 	.crwrap_req	= rpcauth_wrap_req_encode, | 
 | 	.crrefresh	= tls_refresh, | 
 | 	.crvalidate	= tls_validate, | 
 | 	.crunwrap_resp	= rpcauth_unwrap_resp_decode, | 
 | }; | 
 |  | 
 | static struct rpc_cred tls_cred = { | 
 | 	.cr_lru		= LIST_HEAD_INIT(tls_cred.cr_lru), | 
 | 	.cr_auth	= &tls_auth, | 
 | 	.cr_ops		= &tls_credops, | 
 | 	.cr_count	= REFCOUNT_INIT(2), | 
 | 	.cr_flags	= 1UL << RPCAUTH_CRED_UPTODATE, | 
 | }; |