112072652SChuck Lever // SPDX-License-Identifier: GPL-2.0-only
212072652SChuck Lever /*
312072652SChuck Lever * Copyright (c) 2021, 2022 Oracle. All rights reserved.
412072652SChuck Lever *
512072652SChuck Lever * The AUTH_TLS credential is used only to probe a remote peer
612072652SChuck Lever * for RPC-over-TLS support.
712072652SChuck Lever */
812072652SChuck Lever
912072652SChuck Lever #include <linux/types.h>
1012072652SChuck Lever #include <linux/module.h>
1112072652SChuck Lever #include <linux/sunrpc/clnt.h>
1212072652SChuck Lever
1312072652SChuck Lever static const char *starttls_token = "STARTTLS";
1412072652SChuck Lever static const size_t starttls_len = 8;
1512072652SChuck Lever
1612072652SChuck Lever static struct rpc_auth tls_auth;
1712072652SChuck Lever static struct rpc_cred tls_cred;
1812072652SChuck Lever
tls_encode_probe(struct rpc_rqst * rqstp,struct xdr_stream * xdr,const void * obj)1912072652SChuck Lever static void tls_encode_probe(struct rpc_rqst *rqstp, struct xdr_stream *xdr,
2012072652SChuck Lever const void *obj)
2112072652SChuck Lever {
2212072652SChuck Lever }
2312072652SChuck Lever
tls_decode_probe(struct rpc_rqst * rqstp,struct xdr_stream * xdr,void * obj)2412072652SChuck Lever static int tls_decode_probe(struct rpc_rqst *rqstp, struct xdr_stream *xdr,
2512072652SChuck Lever void *obj)
2612072652SChuck Lever {
2712072652SChuck Lever return 0;
2812072652SChuck Lever }
2912072652SChuck Lever
3012072652SChuck Lever static const struct rpc_procinfo rpcproc_tls_probe = {
3112072652SChuck Lever .p_encode = tls_encode_probe,
3212072652SChuck Lever .p_decode = tls_decode_probe,
3312072652SChuck Lever };
3412072652SChuck Lever
rpc_tls_probe_call_prepare(struct rpc_task * task,void * data)3512072652SChuck Lever static void rpc_tls_probe_call_prepare(struct rpc_task *task, void *data)
3612072652SChuck Lever {
3712072652SChuck Lever task->tk_flags &= ~RPC_TASK_NO_RETRANS_TIMEOUT;
3812072652SChuck Lever rpc_call_start(task);
3912072652SChuck Lever }
4012072652SChuck Lever
rpc_tls_probe_call_done(struct rpc_task * task,void * data)4112072652SChuck Lever static void rpc_tls_probe_call_done(struct rpc_task *task, void *data)
4212072652SChuck Lever {
4312072652SChuck Lever }
4412072652SChuck Lever
4512072652SChuck Lever static const struct rpc_call_ops rpc_tls_probe_ops = {
4612072652SChuck Lever .rpc_call_prepare = rpc_tls_probe_call_prepare,
4712072652SChuck Lever .rpc_call_done = rpc_tls_probe_call_done,
4812072652SChuck Lever };
4912072652SChuck Lever
tls_probe(struct rpc_clnt * clnt)5012072652SChuck Lever static int tls_probe(struct rpc_clnt *clnt)
5112072652SChuck Lever {
5212072652SChuck Lever struct rpc_message msg = {
5312072652SChuck Lever .rpc_proc = &rpcproc_tls_probe,
5412072652SChuck Lever };
5512072652SChuck Lever struct rpc_task_setup task_setup_data = {
5612072652SChuck Lever .rpc_client = clnt,
5712072652SChuck Lever .rpc_message = &msg,
5812072652SChuck Lever .rpc_op_cred = &tls_cred,
5912072652SChuck Lever .callback_ops = &rpc_tls_probe_ops,
6012072652SChuck Lever .flags = RPC_TASK_SOFT | RPC_TASK_SOFTCONN,
6112072652SChuck Lever };
6212072652SChuck Lever struct rpc_task *task;
6312072652SChuck Lever int status;
6412072652SChuck Lever
6512072652SChuck Lever task = rpc_run_task(&task_setup_data);
6612072652SChuck Lever if (IS_ERR(task))
6712072652SChuck Lever return PTR_ERR(task);
6812072652SChuck Lever status = task->tk_status;
6912072652SChuck Lever rpc_put_task(task);
7012072652SChuck Lever return status;
7112072652SChuck Lever }
7212072652SChuck Lever
tls_create(const struct rpc_auth_create_args * args,struct rpc_clnt * clnt)7312072652SChuck Lever static struct rpc_auth *tls_create(const struct rpc_auth_create_args *args,
7412072652SChuck Lever struct rpc_clnt *clnt)
7512072652SChuck Lever {
7612072652SChuck Lever refcount_inc(&tls_auth.au_count);
7712072652SChuck Lever return &tls_auth;
7812072652SChuck Lever }
7912072652SChuck Lever
tls_destroy(struct rpc_auth * auth)8012072652SChuck Lever static void tls_destroy(struct rpc_auth *auth)
8112072652SChuck Lever {
8212072652SChuck Lever }
8312072652SChuck Lever
tls_lookup_cred(struct rpc_auth * auth,struct auth_cred * acred,int flags)8412072652SChuck Lever static struct rpc_cred *tls_lookup_cred(struct rpc_auth *auth,
8512072652SChuck Lever struct auth_cred *acred, int flags)
8612072652SChuck Lever {
8712072652SChuck Lever return get_rpccred(&tls_cred);
8812072652SChuck Lever }
8912072652SChuck Lever
tls_destroy_cred(struct rpc_cred * cred)9012072652SChuck Lever static void tls_destroy_cred(struct rpc_cred *cred)
9112072652SChuck Lever {
9212072652SChuck Lever }
9312072652SChuck Lever
tls_match(struct auth_cred * acred,struct rpc_cred * cred,int taskflags)9412072652SChuck Lever static int tls_match(struct auth_cred *acred, struct rpc_cred *cred, int taskflags)
9512072652SChuck Lever {
9612072652SChuck Lever return 1;
9712072652SChuck Lever }
9812072652SChuck Lever
tls_marshal(struct rpc_task * task,struct xdr_stream * xdr)9912072652SChuck Lever static int tls_marshal(struct rpc_task *task, struct xdr_stream *xdr)
10012072652SChuck Lever {
10112072652SChuck Lever __be32 *p;
10212072652SChuck Lever
10312072652SChuck Lever p = xdr_reserve_space(xdr, 4 * XDR_UNIT);
10412072652SChuck Lever if (!p)
10512072652SChuck Lever return -EMSGSIZE;
10612072652SChuck Lever /* Credential */
10712072652SChuck Lever *p++ = rpc_auth_tls;
10812072652SChuck Lever *p++ = xdr_zero;
10912072652SChuck Lever /* Verifier */
11012072652SChuck Lever *p++ = rpc_auth_null;
11112072652SChuck Lever *p = xdr_zero;
11212072652SChuck Lever return 0;
11312072652SChuck Lever }
11412072652SChuck Lever
tls_refresh(struct rpc_task * task)11512072652SChuck Lever static int tls_refresh(struct rpc_task *task)
11612072652SChuck Lever {
11712072652SChuck Lever set_bit(RPCAUTH_CRED_UPTODATE, &task->tk_rqstp->rq_cred->cr_flags);
11812072652SChuck Lever return 0;
11912072652SChuck Lever }
12012072652SChuck Lever
tls_validate(struct rpc_task * task,struct xdr_stream * xdr)12112072652SChuck Lever static int tls_validate(struct rpc_task *task, struct xdr_stream *xdr)
12212072652SChuck Lever {
12312072652SChuck Lever __be32 *p;
12412072652SChuck Lever void *str;
12512072652SChuck Lever
12612072652SChuck Lever p = xdr_inline_decode(xdr, XDR_UNIT);
12712072652SChuck Lever if (!p)
12812072652SChuck Lever return -EIO;
12912072652SChuck Lever if (*p != rpc_auth_null)
13012072652SChuck Lever return -EIO;
13112072652SChuck Lever if (xdr_stream_decode_opaque_inline(xdr, &str, starttls_len) != starttls_len)
132*5623ecfcSChuck Lever return -EPROTONOSUPPORT;
13312072652SChuck Lever if (memcmp(str, starttls_token, starttls_len))
134*5623ecfcSChuck Lever return -EPROTONOSUPPORT;
13512072652SChuck Lever return 0;
13612072652SChuck Lever }
13712072652SChuck Lever
13812072652SChuck Lever const struct rpc_authops authtls_ops = {
13912072652SChuck Lever .owner = THIS_MODULE,
14012072652SChuck Lever .au_flavor = RPC_AUTH_TLS,
14112072652SChuck Lever .au_name = "NULL",
14212072652SChuck Lever .create = tls_create,
14312072652SChuck Lever .destroy = tls_destroy,
14412072652SChuck Lever .lookup_cred = tls_lookup_cred,
14512072652SChuck Lever .ping = tls_probe,
14612072652SChuck Lever };
14712072652SChuck Lever
14812072652SChuck Lever static struct rpc_auth tls_auth = {
14912072652SChuck Lever .au_cslack = NUL_CALLSLACK,
15012072652SChuck Lever .au_rslack = NUL_REPLYSLACK,
15112072652SChuck Lever .au_verfsize = NUL_REPLYSLACK,
15212072652SChuck Lever .au_ralign = NUL_REPLYSLACK,
15312072652SChuck Lever .au_ops = &authtls_ops,
15412072652SChuck Lever .au_flavor = RPC_AUTH_TLS,
15512072652SChuck Lever .au_count = REFCOUNT_INIT(1),
15612072652SChuck Lever };
15712072652SChuck Lever
15812072652SChuck Lever static const struct rpc_credops tls_credops = {
15912072652SChuck Lever .cr_name = "AUTH_TLS",
16012072652SChuck Lever .crdestroy = tls_destroy_cred,
16112072652SChuck Lever .crmatch = tls_match,
16212072652SChuck Lever .crmarshal = tls_marshal,
16312072652SChuck Lever .crwrap_req = rpcauth_wrap_req_encode,
16412072652SChuck Lever .crrefresh = tls_refresh,
16512072652SChuck Lever .crvalidate = tls_validate,
16612072652SChuck Lever .crunwrap_resp = rpcauth_unwrap_resp_decode,
16712072652SChuck Lever };
16812072652SChuck Lever
16912072652SChuck Lever static struct rpc_cred tls_cred = {
17012072652SChuck Lever .cr_lru = LIST_HEAD_INIT(tls_cred.cr_lru),
17112072652SChuck Lever .cr_auth = &tls_auth,
17212072652SChuck Lever .cr_ops = &tls_credops,
17312072652SChuck Lever .cr_count = REFCOUNT_INIT(2),
17412072652SChuck Lever .cr_flags = 1UL << RPCAUTH_CRED_UPTODATE,
17512072652SChuck Lever };
176