1 // SPDX-License-Identifier: GPL-2.0-only
2 /*
3 * Copyright (c) 2021, 2022 Oracle. All rights reserved.
4 *
5 * The AUTH_TLS credential is used only to probe a remote peer
6 * for RPC-over-TLS support.
7 */
8
9 #include <linux/types.h>
10 #include <linux/module.h>
11 #include <linux/sunrpc/clnt.h>
12
13 static const char *starttls_token = "STARTTLS";
14 static const size_t starttls_len = 8;
15
16 static struct rpc_auth tls_auth;
17 static struct rpc_cred tls_cred;
18
tls_encode_probe(struct rpc_rqst * rqstp,struct xdr_stream * xdr,const void * obj)19 static void tls_encode_probe(struct rpc_rqst *rqstp, struct xdr_stream *xdr,
20 const void *obj)
21 {
22 }
23
tls_decode_probe(struct rpc_rqst * rqstp,struct xdr_stream * xdr,void * obj)24 static int tls_decode_probe(struct rpc_rqst *rqstp, struct xdr_stream *xdr,
25 void *obj)
26 {
27 return 0;
28 }
29
30 static const struct rpc_procinfo rpcproc_tls_probe = {
31 .p_encode = tls_encode_probe,
32 .p_decode = tls_decode_probe,
33 };
34
rpc_tls_probe_call_prepare(struct rpc_task * task,void * data)35 static void rpc_tls_probe_call_prepare(struct rpc_task *task, void *data)
36 {
37 task->tk_flags &= ~RPC_TASK_NO_RETRANS_TIMEOUT;
38 rpc_call_start(task);
39 }
40
rpc_tls_probe_call_done(struct rpc_task * task,void * data)41 static void rpc_tls_probe_call_done(struct rpc_task *task, void *data)
42 {
43 }
44
45 static const struct rpc_call_ops rpc_tls_probe_ops = {
46 .rpc_call_prepare = rpc_tls_probe_call_prepare,
47 .rpc_call_done = rpc_tls_probe_call_done,
48 };
49
tls_probe(struct rpc_clnt * clnt)50 static int tls_probe(struct rpc_clnt *clnt)
51 {
52 struct rpc_message msg = {
53 .rpc_proc = &rpcproc_tls_probe,
54 };
55 struct rpc_task_setup task_setup_data = {
56 .rpc_client = clnt,
57 .rpc_message = &msg,
58 .rpc_op_cred = &tls_cred,
59 .callback_ops = &rpc_tls_probe_ops,
60 .flags = RPC_TASK_SOFT | RPC_TASK_SOFTCONN,
61 };
62 struct rpc_task *task;
63 int status;
64
65 task = rpc_run_task(&task_setup_data);
66 if (IS_ERR(task))
67 return PTR_ERR(task);
68 status = task->tk_status;
69 rpc_put_task(task);
70 return status;
71 }
72
tls_create(const struct rpc_auth_create_args * args,struct rpc_clnt * clnt)73 static struct rpc_auth *tls_create(const struct rpc_auth_create_args *args,
74 struct rpc_clnt *clnt)
75 {
76 refcount_inc(&tls_auth.au_count);
77 return &tls_auth;
78 }
79
tls_destroy(struct rpc_auth * auth)80 static void tls_destroy(struct rpc_auth *auth)
81 {
82 }
83
tls_lookup_cred(struct rpc_auth * auth,struct auth_cred * acred,int flags)84 static struct rpc_cred *tls_lookup_cred(struct rpc_auth *auth,
85 struct auth_cred *acred, int flags)
86 {
87 return get_rpccred(&tls_cred);
88 }
89
tls_destroy_cred(struct rpc_cred * cred)90 static void tls_destroy_cred(struct rpc_cred *cred)
91 {
92 }
93
tls_match(struct auth_cred * acred,struct rpc_cred * cred,int taskflags)94 static int tls_match(struct auth_cred *acred, struct rpc_cred *cred, int taskflags)
95 {
96 return 1;
97 }
98
tls_marshal(struct rpc_task * task,struct xdr_stream * xdr)99 static int tls_marshal(struct rpc_task *task, struct xdr_stream *xdr)
100 {
101 __be32 *p;
102
103 p = xdr_reserve_space(xdr, 4 * XDR_UNIT);
104 if (!p)
105 return -EMSGSIZE;
106 /* Credential */
107 *p++ = rpc_auth_tls;
108 *p++ = xdr_zero;
109 /* Verifier */
110 *p++ = rpc_auth_null;
111 *p = xdr_zero;
112 return 0;
113 }
114
tls_refresh(struct rpc_task * task)115 static int tls_refresh(struct rpc_task *task)
116 {
117 set_bit(RPCAUTH_CRED_UPTODATE, &task->tk_rqstp->rq_cred->cr_flags);
118 return 0;
119 }
120
tls_validate(struct rpc_task * task,struct xdr_stream * xdr)121 static int tls_validate(struct rpc_task *task, struct xdr_stream *xdr)
122 {
123 __be32 *p;
124 void *str;
125
126 p = xdr_inline_decode(xdr, XDR_UNIT);
127 if (!p)
128 return -EIO;
129 if (*p != rpc_auth_null)
130 return -EIO;
131 if (xdr_stream_decode_opaque_inline(xdr, &str, starttls_len) != starttls_len)
132 return -EPROTONOSUPPORT;
133 if (memcmp(str, starttls_token, starttls_len))
134 return -EPROTONOSUPPORT;
135 return 0;
136 }
137
138 const struct rpc_authops authtls_ops = {
139 .owner = THIS_MODULE,
140 .au_flavor = RPC_AUTH_TLS,
141 .au_name = "NULL",
142 .create = tls_create,
143 .destroy = tls_destroy,
144 .lookup_cred = tls_lookup_cred,
145 .ping = tls_probe,
146 };
147
148 static struct rpc_auth tls_auth = {
149 .au_cslack = NUL_CALLSLACK,
150 .au_rslack = NUL_REPLYSLACK,
151 .au_verfsize = NUL_REPLYSLACK,
152 .au_ralign = NUL_REPLYSLACK,
153 .au_ops = &authtls_ops,
154 .au_flavor = RPC_AUTH_TLS,
155 .au_count = REFCOUNT_INIT(1),
156 };
157
158 static const struct rpc_credops tls_credops = {
159 .cr_name = "AUTH_TLS",
160 .crdestroy = tls_destroy_cred,
161 .crmatch = tls_match,
162 .crmarshal = tls_marshal,
163 .crwrap_req = rpcauth_wrap_req_encode,
164 .crrefresh = tls_refresh,
165 .crvalidate = tls_validate,
166 .crunwrap_resp = rpcauth_unwrap_resp_decode,
167 };
168
169 static struct rpc_cred tls_cred = {
170 .cr_lru = LIST_HEAD_INIT(tls_cred.cr_lru),
171 .cr_auth = &tls_auth,
172 .cr_ops = &tls_credops,
173 .cr_count = REFCOUNT_INIT(2),
174 .cr_flags = 1UL << RPCAUTH_CRED_UPTODATE,
175 };
176