xref: /openbmc/linux/net/sunrpc/auth_tls.c (revision 5623ecfc)
112072652SChuck Lever // SPDX-License-Identifier: GPL-2.0-only
212072652SChuck Lever /*
312072652SChuck Lever  * Copyright (c) 2021, 2022 Oracle.  All rights reserved.
412072652SChuck Lever  *
512072652SChuck Lever  * The AUTH_TLS credential is used only to probe a remote peer
612072652SChuck Lever  * for RPC-over-TLS support.
712072652SChuck Lever  */
812072652SChuck Lever 
912072652SChuck Lever #include <linux/types.h>
1012072652SChuck Lever #include <linux/module.h>
1112072652SChuck Lever #include <linux/sunrpc/clnt.h>
1212072652SChuck Lever 
1312072652SChuck Lever static const char *starttls_token = "STARTTLS";
1412072652SChuck Lever static const size_t starttls_len = 8;
1512072652SChuck Lever 
1612072652SChuck Lever static struct rpc_auth tls_auth;
1712072652SChuck Lever static struct rpc_cred tls_cred;
1812072652SChuck Lever 
tls_encode_probe(struct rpc_rqst * rqstp,struct xdr_stream * xdr,const void * obj)1912072652SChuck Lever static void tls_encode_probe(struct rpc_rqst *rqstp, struct xdr_stream *xdr,
2012072652SChuck Lever 			     const void *obj)
2112072652SChuck Lever {
2212072652SChuck Lever }
2312072652SChuck Lever 
tls_decode_probe(struct rpc_rqst * rqstp,struct xdr_stream * xdr,void * obj)2412072652SChuck Lever static int tls_decode_probe(struct rpc_rqst *rqstp, struct xdr_stream *xdr,
2512072652SChuck Lever 			    void *obj)
2612072652SChuck Lever {
2712072652SChuck Lever 	return 0;
2812072652SChuck Lever }
2912072652SChuck Lever 
3012072652SChuck Lever static const struct rpc_procinfo rpcproc_tls_probe = {
3112072652SChuck Lever 	.p_encode	= tls_encode_probe,
3212072652SChuck Lever 	.p_decode	= tls_decode_probe,
3312072652SChuck Lever };
3412072652SChuck Lever 
rpc_tls_probe_call_prepare(struct rpc_task * task,void * data)3512072652SChuck Lever static void rpc_tls_probe_call_prepare(struct rpc_task *task, void *data)
3612072652SChuck Lever {
3712072652SChuck Lever 	task->tk_flags &= ~RPC_TASK_NO_RETRANS_TIMEOUT;
3812072652SChuck Lever 	rpc_call_start(task);
3912072652SChuck Lever }
4012072652SChuck Lever 
rpc_tls_probe_call_done(struct rpc_task * task,void * data)4112072652SChuck Lever static void rpc_tls_probe_call_done(struct rpc_task *task, void *data)
4212072652SChuck Lever {
4312072652SChuck Lever }
4412072652SChuck Lever 
4512072652SChuck Lever static const struct rpc_call_ops rpc_tls_probe_ops = {
4612072652SChuck Lever 	.rpc_call_prepare	= rpc_tls_probe_call_prepare,
4712072652SChuck Lever 	.rpc_call_done		= rpc_tls_probe_call_done,
4812072652SChuck Lever };
4912072652SChuck Lever 
tls_probe(struct rpc_clnt * clnt)5012072652SChuck Lever static int tls_probe(struct rpc_clnt *clnt)
5112072652SChuck Lever {
5212072652SChuck Lever 	struct rpc_message msg = {
5312072652SChuck Lever 		.rpc_proc	= &rpcproc_tls_probe,
5412072652SChuck Lever 	};
5512072652SChuck Lever 	struct rpc_task_setup task_setup_data = {
5612072652SChuck Lever 		.rpc_client	= clnt,
5712072652SChuck Lever 		.rpc_message	= &msg,
5812072652SChuck Lever 		.rpc_op_cred	= &tls_cred,
5912072652SChuck Lever 		.callback_ops	= &rpc_tls_probe_ops,
6012072652SChuck Lever 		.flags		= RPC_TASK_SOFT | RPC_TASK_SOFTCONN,
6112072652SChuck Lever 	};
6212072652SChuck Lever 	struct rpc_task	*task;
6312072652SChuck Lever 	int status;
6412072652SChuck Lever 
6512072652SChuck Lever 	task = rpc_run_task(&task_setup_data);
6612072652SChuck Lever 	if (IS_ERR(task))
6712072652SChuck Lever 		return PTR_ERR(task);
6812072652SChuck Lever 	status = task->tk_status;
6912072652SChuck Lever 	rpc_put_task(task);
7012072652SChuck Lever 	return status;
7112072652SChuck Lever }
7212072652SChuck Lever 
tls_create(const struct rpc_auth_create_args * args,struct rpc_clnt * clnt)7312072652SChuck Lever static struct rpc_auth *tls_create(const struct rpc_auth_create_args *args,
7412072652SChuck Lever 				   struct rpc_clnt *clnt)
7512072652SChuck Lever {
7612072652SChuck Lever 	refcount_inc(&tls_auth.au_count);
7712072652SChuck Lever 	return &tls_auth;
7812072652SChuck Lever }
7912072652SChuck Lever 
tls_destroy(struct rpc_auth * auth)8012072652SChuck Lever static void tls_destroy(struct rpc_auth *auth)
8112072652SChuck Lever {
8212072652SChuck Lever }
8312072652SChuck Lever 
tls_lookup_cred(struct rpc_auth * auth,struct auth_cred * acred,int flags)8412072652SChuck Lever static struct rpc_cred *tls_lookup_cred(struct rpc_auth *auth,
8512072652SChuck Lever 					struct auth_cred *acred, int flags)
8612072652SChuck Lever {
8712072652SChuck Lever 	return get_rpccred(&tls_cred);
8812072652SChuck Lever }
8912072652SChuck Lever 
tls_destroy_cred(struct rpc_cred * cred)9012072652SChuck Lever static void tls_destroy_cred(struct rpc_cred *cred)
9112072652SChuck Lever {
9212072652SChuck Lever }
9312072652SChuck Lever 
tls_match(struct auth_cred * acred,struct rpc_cred * cred,int taskflags)9412072652SChuck Lever static int tls_match(struct auth_cred *acred, struct rpc_cred *cred, int taskflags)
9512072652SChuck Lever {
9612072652SChuck Lever 	return 1;
9712072652SChuck Lever }
9812072652SChuck Lever 
tls_marshal(struct rpc_task * task,struct xdr_stream * xdr)9912072652SChuck Lever static int tls_marshal(struct rpc_task *task, struct xdr_stream *xdr)
10012072652SChuck Lever {
10112072652SChuck Lever 	__be32 *p;
10212072652SChuck Lever 
10312072652SChuck Lever 	p = xdr_reserve_space(xdr, 4 * XDR_UNIT);
10412072652SChuck Lever 	if (!p)
10512072652SChuck Lever 		return -EMSGSIZE;
10612072652SChuck Lever 	/* Credential */
10712072652SChuck Lever 	*p++ = rpc_auth_tls;
10812072652SChuck Lever 	*p++ = xdr_zero;
10912072652SChuck Lever 	/* Verifier */
11012072652SChuck Lever 	*p++ = rpc_auth_null;
11112072652SChuck Lever 	*p   = xdr_zero;
11212072652SChuck Lever 	return 0;
11312072652SChuck Lever }
11412072652SChuck Lever 
tls_refresh(struct rpc_task * task)11512072652SChuck Lever static int tls_refresh(struct rpc_task *task)
11612072652SChuck Lever {
11712072652SChuck Lever 	set_bit(RPCAUTH_CRED_UPTODATE, &task->tk_rqstp->rq_cred->cr_flags);
11812072652SChuck Lever 	return 0;
11912072652SChuck Lever }
12012072652SChuck Lever 
tls_validate(struct rpc_task * task,struct xdr_stream * xdr)12112072652SChuck Lever static int tls_validate(struct rpc_task *task, struct xdr_stream *xdr)
12212072652SChuck Lever {
12312072652SChuck Lever 	__be32 *p;
12412072652SChuck Lever 	void *str;
12512072652SChuck Lever 
12612072652SChuck Lever 	p = xdr_inline_decode(xdr, XDR_UNIT);
12712072652SChuck Lever 	if (!p)
12812072652SChuck Lever 		return -EIO;
12912072652SChuck Lever 	if (*p != rpc_auth_null)
13012072652SChuck Lever 		return -EIO;
13112072652SChuck Lever 	if (xdr_stream_decode_opaque_inline(xdr, &str, starttls_len) != starttls_len)
132*5623ecfcSChuck Lever 		return -EPROTONOSUPPORT;
13312072652SChuck Lever 	if (memcmp(str, starttls_token, starttls_len))
134*5623ecfcSChuck Lever 		return -EPROTONOSUPPORT;
13512072652SChuck Lever 	return 0;
13612072652SChuck Lever }
13712072652SChuck Lever 
13812072652SChuck Lever const struct rpc_authops authtls_ops = {
13912072652SChuck Lever 	.owner		= THIS_MODULE,
14012072652SChuck Lever 	.au_flavor	= RPC_AUTH_TLS,
14112072652SChuck Lever 	.au_name	= "NULL",
14212072652SChuck Lever 	.create		= tls_create,
14312072652SChuck Lever 	.destroy	= tls_destroy,
14412072652SChuck Lever 	.lookup_cred	= tls_lookup_cred,
14512072652SChuck Lever 	.ping		= tls_probe,
14612072652SChuck Lever };
14712072652SChuck Lever 
14812072652SChuck Lever static struct rpc_auth tls_auth = {
14912072652SChuck Lever 	.au_cslack	= NUL_CALLSLACK,
15012072652SChuck Lever 	.au_rslack	= NUL_REPLYSLACK,
15112072652SChuck Lever 	.au_verfsize	= NUL_REPLYSLACK,
15212072652SChuck Lever 	.au_ralign	= NUL_REPLYSLACK,
15312072652SChuck Lever 	.au_ops		= &authtls_ops,
15412072652SChuck Lever 	.au_flavor	= RPC_AUTH_TLS,
15512072652SChuck Lever 	.au_count	= REFCOUNT_INIT(1),
15612072652SChuck Lever };
15712072652SChuck Lever 
15812072652SChuck Lever static const struct rpc_credops tls_credops = {
15912072652SChuck Lever 	.cr_name	= "AUTH_TLS",
16012072652SChuck Lever 	.crdestroy	= tls_destroy_cred,
16112072652SChuck Lever 	.crmatch	= tls_match,
16212072652SChuck Lever 	.crmarshal	= tls_marshal,
16312072652SChuck Lever 	.crwrap_req	= rpcauth_wrap_req_encode,
16412072652SChuck Lever 	.crrefresh	= tls_refresh,
16512072652SChuck Lever 	.crvalidate	= tls_validate,
16612072652SChuck Lever 	.crunwrap_resp	= rpcauth_unwrap_resp_decode,
16712072652SChuck Lever };
16812072652SChuck Lever 
16912072652SChuck Lever static struct rpc_cred tls_cred = {
17012072652SChuck Lever 	.cr_lru		= LIST_HEAD_INIT(tls_cred.cr_lru),
17112072652SChuck Lever 	.cr_auth	= &tls_auth,
17212072652SChuck Lever 	.cr_ops		= &tls_credops,
17312072652SChuck Lever 	.cr_count	= REFCOUNT_INIT(2),
17412072652SChuck Lever 	.cr_flags	= 1UL << RPCAUTH_CRED_UPTODATE,
17512072652SChuck Lever };
176