xref: /openbmc/linux/net/tls/tls_main.c (revision 7288dd2f)
1 /*
2  * Copyright (c) 2016-2017, Mellanox Technologies. All rights reserved.
3  * Copyright (c) 2016-2017, Dave Watson <davejwatson@fb.com>. All rights reserved.
4  *
5  * This software is available to you under a choice of one of two
6  * licenses.  You may choose to be licensed under the terms of the GNU
7  * General Public License (GPL) Version 2, available from the file
8  * COPYING in the main directory of this source tree, or the
9  * OpenIB.org BSD license below:
10  *
11  *     Redistribution and use in source and binary forms, with or
12  *     without modification, are permitted provided that the following
13  *     conditions are met:
14  *
15  *      - Redistributions of source code must retain the above
16  *        copyright notice, this list of conditions and the following
17  *        disclaimer.
18  *
19  *      - Redistributions in binary form must reproduce the above
20  *        copyright notice, this list of conditions and the following
21  *        disclaimer in the documentation and/or other materials
22  *        provided with the distribution.
23  *
24  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
25  * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
26  * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
27  * NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS
28  * BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN
29  * ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
30  * CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
31  * SOFTWARE.
32  */
33 
34 #include <linux/module.h>
35 
36 #include <net/tcp.h>
37 #include <net/inet_common.h>
38 #include <linux/highmem.h>
39 #include <linux/netdevice.h>
40 #include <linux/sched/signal.h>
41 #include <linux/inetdevice.h>
42 #include <linux/inet_diag.h>
43 
44 #include <net/snmp.h>
45 #include <net/tls.h>
46 #include <net/tls_toe.h>
47 
48 #include "tls.h"
49 
50 MODULE_AUTHOR("Mellanox Technologies");
51 MODULE_DESCRIPTION("Transport Layer Security Support");
52 MODULE_LICENSE("Dual BSD/GPL");
53 MODULE_ALIAS_TCP_ULP("tls");
54 
55 enum {
56 	TLSV4,
57 	TLSV6,
58 	TLS_NUM_PROTS,
59 };
60 
61 #define CIPHER_SIZE_DESC(cipher) [cipher] = { \
62 	.iv = cipher ## _IV_SIZE, \
63 	.key = cipher ## _KEY_SIZE, \
64 	.salt = cipher ## _SALT_SIZE, \
65 	.tag = cipher ## _TAG_SIZE, \
66 	.rec_seq = cipher ## _REC_SEQ_SIZE, \
67 }
68 
69 const struct tls_cipher_size_desc tls_cipher_size_desc[] = {
70 	CIPHER_SIZE_DESC(TLS_CIPHER_AES_GCM_128),
71 	CIPHER_SIZE_DESC(TLS_CIPHER_AES_GCM_256),
72 	CIPHER_SIZE_DESC(TLS_CIPHER_AES_CCM_128),
73 	CIPHER_SIZE_DESC(TLS_CIPHER_CHACHA20_POLY1305),
74 	CIPHER_SIZE_DESC(TLS_CIPHER_SM4_GCM),
75 	CIPHER_SIZE_DESC(TLS_CIPHER_SM4_CCM),
76 };
77 
78 static const struct proto *saved_tcpv6_prot;
79 static DEFINE_MUTEX(tcpv6_prot_mutex);
80 static const struct proto *saved_tcpv4_prot;
81 static DEFINE_MUTEX(tcpv4_prot_mutex);
82 static struct proto tls_prots[TLS_NUM_PROTS][TLS_NUM_CONFIG][TLS_NUM_CONFIG];
83 static struct proto_ops tls_proto_ops[TLS_NUM_PROTS][TLS_NUM_CONFIG][TLS_NUM_CONFIG];
84 static void build_protos(struct proto prot[TLS_NUM_CONFIG][TLS_NUM_CONFIG],
85 			 const struct proto *base);
86 
87 void update_sk_prot(struct sock *sk, struct tls_context *ctx)
88 {
89 	int ip_ver = sk->sk_family == AF_INET6 ? TLSV6 : TLSV4;
90 
91 	WRITE_ONCE(sk->sk_prot,
92 		   &tls_prots[ip_ver][ctx->tx_conf][ctx->rx_conf]);
93 	WRITE_ONCE(sk->sk_socket->ops,
94 		   &tls_proto_ops[ip_ver][ctx->tx_conf][ctx->rx_conf]);
95 }
96 
97 int wait_on_pending_writer(struct sock *sk, long *timeo)
98 {
99 	int rc = 0;
100 	DEFINE_WAIT_FUNC(wait, woken_wake_function);
101 
102 	add_wait_queue(sk_sleep(sk), &wait);
103 	while (1) {
104 		if (!*timeo) {
105 			rc = -EAGAIN;
106 			break;
107 		}
108 
109 		if (signal_pending(current)) {
110 			rc = sock_intr_errno(*timeo);
111 			break;
112 		}
113 
114 		if (sk_wait_event(sk, timeo,
115 				  !READ_ONCE(sk->sk_write_pending), &wait))
116 			break;
117 	}
118 	remove_wait_queue(sk_sleep(sk), &wait);
119 	return rc;
120 }
121 
122 int tls_push_sg(struct sock *sk,
123 		struct tls_context *ctx,
124 		struct scatterlist *sg,
125 		u16 first_offset,
126 		int flags)
127 {
128 	struct bio_vec bvec;
129 	struct msghdr msg = {
130 		.msg_flags = MSG_SPLICE_PAGES | flags,
131 	};
132 	int ret = 0;
133 	struct page *p;
134 	size_t size;
135 	int offset = first_offset;
136 
137 	size = sg->length - offset;
138 	offset += sg->offset;
139 
140 	ctx->splicing_pages = true;
141 	while (1) {
142 		/* is sending application-limited? */
143 		tcp_rate_check_app_limited(sk);
144 		p = sg_page(sg);
145 retry:
146 		bvec_set_page(&bvec, p, size, offset);
147 		iov_iter_bvec(&msg.msg_iter, ITER_SOURCE, &bvec, 1, size);
148 
149 		ret = tcp_sendmsg_locked(sk, &msg, size);
150 
151 		if (ret != size) {
152 			if (ret > 0) {
153 				offset += ret;
154 				size -= ret;
155 				goto retry;
156 			}
157 
158 			offset -= sg->offset;
159 			ctx->partially_sent_offset = offset;
160 			ctx->partially_sent_record = (void *)sg;
161 			ctx->splicing_pages = false;
162 			return ret;
163 		}
164 
165 		put_page(p);
166 		sk_mem_uncharge(sk, sg->length);
167 		sg = sg_next(sg);
168 		if (!sg)
169 			break;
170 
171 		offset = sg->offset;
172 		size = sg->length;
173 	}
174 
175 	ctx->splicing_pages = false;
176 
177 	return 0;
178 }
179 
180 static int tls_handle_open_record(struct sock *sk, int flags)
181 {
182 	struct tls_context *ctx = tls_get_ctx(sk);
183 
184 	if (tls_is_pending_open_record(ctx))
185 		return ctx->push_pending_record(sk, flags);
186 
187 	return 0;
188 }
189 
190 int tls_process_cmsg(struct sock *sk, struct msghdr *msg,
191 		     unsigned char *record_type)
192 {
193 	struct cmsghdr *cmsg;
194 	int rc = -EINVAL;
195 
196 	for_each_cmsghdr(cmsg, msg) {
197 		if (!CMSG_OK(msg, cmsg))
198 			return -EINVAL;
199 		if (cmsg->cmsg_level != SOL_TLS)
200 			continue;
201 
202 		switch (cmsg->cmsg_type) {
203 		case TLS_SET_RECORD_TYPE:
204 			if (cmsg->cmsg_len < CMSG_LEN(sizeof(*record_type)))
205 				return -EINVAL;
206 
207 			if (msg->msg_flags & MSG_MORE)
208 				return -EINVAL;
209 
210 			rc = tls_handle_open_record(sk, msg->msg_flags);
211 			if (rc)
212 				return rc;
213 
214 			*record_type = *(unsigned char *)CMSG_DATA(cmsg);
215 			rc = 0;
216 			break;
217 		default:
218 			return -EINVAL;
219 		}
220 	}
221 
222 	return rc;
223 }
224 
225 int tls_push_partial_record(struct sock *sk, struct tls_context *ctx,
226 			    int flags)
227 {
228 	struct scatterlist *sg;
229 	u16 offset;
230 
231 	sg = ctx->partially_sent_record;
232 	offset = ctx->partially_sent_offset;
233 
234 	ctx->partially_sent_record = NULL;
235 	return tls_push_sg(sk, ctx, sg, offset, flags);
236 }
237 
238 void tls_free_partial_record(struct sock *sk, struct tls_context *ctx)
239 {
240 	struct scatterlist *sg;
241 
242 	for (sg = ctx->partially_sent_record; sg; sg = sg_next(sg)) {
243 		put_page(sg_page(sg));
244 		sk_mem_uncharge(sk, sg->length);
245 	}
246 	ctx->partially_sent_record = NULL;
247 }
248 
249 static void tls_write_space(struct sock *sk)
250 {
251 	struct tls_context *ctx = tls_get_ctx(sk);
252 
253 	/* If splicing_pages call lower protocol write space handler
254 	 * to ensure we wake up any waiting operations there. For example
255 	 * if splicing pages where to call sk_wait_event.
256 	 */
257 	if (ctx->splicing_pages) {
258 		ctx->sk_write_space(sk);
259 		return;
260 	}
261 
262 #ifdef CONFIG_TLS_DEVICE
263 	if (ctx->tx_conf == TLS_HW)
264 		tls_device_write_space(sk, ctx);
265 	else
266 #endif
267 		tls_sw_write_space(sk, ctx);
268 
269 	ctx->sk_write_space(sk);
270 }
271 
272 /**
273  * tls_ctx_free() - free TLS ULP context
274  * @sk:  socket to with @ctx is attached
275  * @ctx: TLS context structure
276  *
277  * Free TLS context. If @sk is %NULL caller guarantees that the socket
278  * to which @ctx was attached has no outstanding references.
279  */
280 void tls_ctx_free(struct sock *sk, struct tls_context *ctx)
281 {
282 	if (!ctx)
283 		return;
284 
285 	memzero_explicit(&ctx->crypto_send, sizeof(ctx->crypto_send));
286 	memzero_explicit(&ctx->crypto_recv, sizeof(ctx->crypto_recv));
287 	mutex_destroy(&ctx->tx_lock);
288 
289 	if (sk)
290 		kfree_rcu(ctx, rcu);
291 	else
292 		kfree(ctx);
293 }
294 
295 static void tls_sk_proto_cleanup(struct sock *sk,
296 				 struct tls_context *ctx, long timeo)
297 {
298 	if (unlikely(sk->sk_write_pending) &&
299 	    !wait_on_pending_writer(sk, &timeo))
300 		tls_handle_open_record(sk, 0);
301 
302 	/* We need these for tls_sw_fallback handling of other packets */
303 	if (ctx->tx_conf == TLS_SW) {
304 		kfree(ctx->tx.rec_seq);
305 		kfree(ctx->tx.iv);
306 		tls_sw_release_resources_tx(sk);
307 		TLS_DEC_STATS(sock_net(sk), LINUX_MIB_TLSCURRTXSW);
308 	} else if (ctx->tx_conf == TLS_HW) {
309 		tls_device_free_resources_tx(sk);
310 		TLS_DEC_STATS(sock_net(sk), LINUX_MIB_TLSCURRTXDEVICE);
311 	}
312 
313 	if (ctx->rx_conf == TLS_SW) {
314 		tls_sw_release_resources_rx(sk);
315 		TLS_DEC_STATS(sock_net(sk), LINUX_MIB_TLSCURRRXSW);
316 	} else if (ctx->rx_conf == TLS_HW) {
317 		tls_device_offload_cleanup_rx(sk);
318 		TLS_DEC_STATS(sock_net(sk), LINUX_MIB_TLSCURRRXDEVICE);
319 	}
320 }
321 
322 static void tls_sk_proto_close(struct sock *sk, long timeout)
323 {
324 	struct inet_connection_sock *icsk = inet_csk(sk);
325 	struct tls_context *ctx = tls_get_ctx(sk);
326 	long timeo = sock_sndtimeo(sk, 0);
327 	bool free_ctx;
328 
329 	if (ctx->tx_conf == TLS_SW)
330 		tls_sw_cancel_work_tx(ctx);
331 
332 	lock_sock(sk);
333 	free_ctx = ctx->tx_conf != TLS_HW && ctx->rx_conf != TLS_HW;
334 
335 	if (ctx->tx_conf != TLS_BASE || ctx->rx_conf != TLS_BASE)
336 		tls_sk_proto_cleanup(sk, ctx, timeo);
337 
338 	write_lock_bh(&sk->sk_callback_lock);
339 	if (free_ctx)
340 		rcu_assign_pointer(icsk->icsk_ulp_data, NULL);
341 	WRITE_ONCE(sk->sk_prot, ctx->sk_proto);
342 	if (sk->sk_write_space == tls_write_space)
343 		sk->sk_write_space = ctx->sk_write_space;
344 	write_unlock_bh(&sk->sk_callback_lock);
345 	release_sock(sk);
346 	if (ctx->tx_conf == TLS_SW)
347 		tls_sw_free_ctx_tx(ctx);
348 	if (ctx->rx_conf == TLS_SW || ctx->rx_conf == TLS_HW)
349 		tls_sw_strparser_done(ctx);
350 	if (ctx->rx_conf == TLS_SW)
351 		tls_sw_free_ctx_rx(ctx);
352 	ctx->sk_proto->close(sk, timeout);
353 
354 	if (free_ctx)
355 		tls_ctx_free(sk, ctx);
356 }
357 
358 static __poll_t tls_sk_poll(struct file *file, struct socket *sock,
359 			    struct poll_table_struct *wait)
360 {
361 	struct tls_sw_context_rx *ctx;
362 	struct tls_context *tls_ctx;
363 	struct sock *sk = sock->sk;
364 	struct sk_psock *psock;
365 	__poll_t mask = 0;
366 	u8 shutdown;
367 	int state;
368 
369 	mask = tcp_poll(file, sock, wait);
370 
371 	state = inet_sk_state_load(sk);
372 	shutdown = READ_ONCE(sk->sk_shutdown);
373 	if (unlikely(state != TCP_ESTABLISHED || shutdown & RCV_SHUTDOWN))
374 		return mask;
375 
376 	tls_ctx = tls_get_ctx(sk);
377 	ctx = tls_sw_ctx_rx(tls_ctx);
378 	psock = sk_psock_get(sk);
379 
380 	if (skb_queue_empty_lockless(&ctx->rx_list) &&
381 	    !tls_strp_msg_ready(ctx) &&
382 	    sk_psock_queue_empty(psock))
383 		mask &= ~(EPOLLIN | EPOLLRDNORM);
384 
385 	if (psock)
386 		sk_psock_put(sk, psock);
387 
388 	return mask;
389 }
390 
391 static int do_tls_getsockopt_conf(struct sock *sk, char __user *optval,
392 				  int __user *optlen, int tx)
393 {
394 	int rc = 0;
395 	struct tls_context *ctx = tls_get_ctx(sk);
396 	struct tls_crypto_info *crypto_info;
397 	struct cipher_context *cctx;
398 	int len;
399 
400 	if (get_user(len, optlen))
401 		return -EFAULT;
402 
403 	if (!optval || (len < sizeof(*crypto_info))) {
404 		rc = -EINVAL;
405 		goto out;
406 	}
407 
408 	if (!ctx) {
409 		rc = -EBUSY;
410 		goto out;
411 	}
412 
413 	/* get user crypto info */
414 	if (tx) {
415 		crypto_info = &ctx->crypto_send.info;
416 		cctx = &ctx->tx;
417 	} else {
418 		crypto_info = &ctx->crypto_recv.info;
419 		cctx = &ctx->rx;
420 	}
421 
422 	if (!TLS_CRYPTO_INFO_READY(crypto_info)) {
423 		rc = -EBUSY;
424 		goto out;
425 	}
426 
427 	if (len == sizeof(*crypto_info)) {
428 		if (copy_to_user(optval, crypto_info, sizeof(*crypto_info)))
429 			rc = -EFAULT;
430 		goto out;
431 	}
432 
433 	switch (crypto_info->cipher_type) {
434 	case TLS_CIPHER_AES_GCM_128: {
435 		struct tls12_crypto_info_aes_gcm_128 *
436 		  crypto_info_aes_gcm_128 =
437 		  container_of(crypto_info,
438 			       struct tls12_crypto_info_aes_gcm_128,
439 			       info);
440 
441 		if (len != sizeof(*crypto_info_aes_gcm_128)) {
442 			rc = -EINVAL;
443 			goto out;
444 		}
445 		memcpy(crypto_info_aes_gcm_128->iv,
446 		       cctx->iv + TLS_CIPHER_AES_GCM_128_SALT_SIZE,
447 		       TLS_CIPHER_AES_GCM_128_IV_SIZE);
448 		memcpy(crypto_info_aes_gcm_128->rec_seq, cctx->rec_seq,
449 		       TLS_CIPHER_AES_GCM_128_REC_SEQ_SIZE);
450 		if (copy_to_user(optval,
451 				 crypto_info_aes_gcm_128,
452 				 sizeof(*crypto_info_aes_gcm_128)))
453 			rc = -EFAULT;
454 		break;
455 	}
456 	case TLS_CIPHER_AES_GCM_256: {
457 		struct tls12_crypto_info_aes_gcm_256 *
458 		  crypto_info_aes_gcm_256 =
459 		  container_of(crypto_info,
460 			       struct tls12_crypto_info_aes_gcm_256,
461 			       info);
462 
463 		if (len != sizeof(*crypto_info_aes_gcm_256)) {
464 			rc = -EINVAL;
465 			goto out;
466 		}
467 		memcpy(crypto_info_aes_gcm_256->iv,
468 		       cctx->iv + TLS_CIPHER_AES_GCM_256_SALT_SIZE,
469 		       TLS_CIPHER_AES_GCM_256_IV_SIZE);
470 		memcpy(crypto_info_aes_gcm_256->rec_seq, cctx->rec_seq,
471 		       TLS_CIPHER_AES_GCM_256_REC_SEQ_SIZE);
472 		if (copy_to_user(optval,
473 				 crypto_info_aes_gcm_256,
474 				 sizeof(*crypto_info_aes_gcm_256)))
475 			rc = -EFAULT;
476 		break;
477 	}
478 	case TLS_CIPHER_AES_CCM_128: {
479 		struct tls12_crypto_info_aes_ccm_128 *aes_ccm_128 =
480 			container_of(crypto_info,
481 				struct tls12_crypto_info_aes_ccm_128, info);
482 
483 		if (len != sizeof(*aes_ccm_128)) {
484 			rc = -EINVAL;
485 			goto out;
486 		}
487 		memcpy(aes_ccm_128->iv,
488 		       cctx->iv + TLS_CIPHER_AES_CCM_128_SALT_SIZE,
489 		       TLS_CIPHER_AES_CCM_128_IV_SIZE);
490 		memcpy(aes_ccm_128->rec_seq, cctx->rec_seq,
491 		       TLS_CIPHER_AES_CCM_128_REC_SEQ_SIZE);
492 		if (copy_to_user(optval, aes_ccm_128, sizeof(*aes_ccm_128)))
493 			rc = -EFAULT;
494 		break;
495 	}
496 	case TLS_CIPHER_CHACHA20_POLY1305: {
497 		struct tls12_crypto_info_chacha20_poly1305 *chacha20_poly1305 =
498 			container_of(crypto_info,
499 				struct tls12_crypto_info_chacha20_poly1305,
500 				info);
501 
502 		if (len != sizeof(*chacha20_poly1305)) {
503 			rc = -EINVAL;
504 			goto out;
505 		}
506 		memcpy(chacha20_poly1305->iv,
507 		       cctx->iv + TLS_CIPHER_CHACHA20_POLY1305_SALT_SIZE,
508 		       TLS_CIPHER_CHACHA20_POLY1305_IV_SIZE);
509 		memcpy(chacha20_poly1305->rec_seq, cctx->rec_seq,
510 		       TLS_CIPHER_CHACHA20_POLY1305_REC_SEQ_SIZE);
511 		if (copy_to_user(optval, chacha20_poly1305,
512 				sizeof(*chacha20_poly1305)))
513 			rc = -EFAULT;
514 		break;
515 	}
516 	case TLS_CIPHER_SM4_GCM: {
517 		struct tls12_crypto_info_sm4_gcm *sm4_gcm_info =
518 			container_of(crypto_info,
519 				struct tls12_crypto_info_sm4_gcm, info);
520 
521 		if (len != sizeof(*sm4_gcm_info)) {
522 			rc = -EINVAL;
523 			goto out;
524 		}
525 		memcpy(sm4_gcm_info->iv,
526 		       cctx->iv + TLS_CIPHER_SM4_GCM_SALT_SIZE,
527 		       TLS_CIPHER_SM4_GCM_IV_SIZE);
528 		memcpy(sm4_gcm_info->rec_seq, cctx->rec_seq,
529 		       TLS_CIPHER_SM4_GCM_REC_SEQ_SIZE);
530 		if (copy_to_user(optval, sm4_gcm_info, sizeof(*sm4_gcm_info)))
531 			rc = -EFAULT;
532 		break;
533 	}
534 	case TLS_CIPHER_SM4_CCM: {
535 		struct tls12_crypto_info_sm4_ccm *sm4_ccm_info =
536 			container_of(crypto_info,
537 				struct tls12_crypto_info_sm4_ccm, info);
538 
539 		if (len != sizeof(*sm4_ccm_info)) {
540 			rc = -EINVAL;
541 			goto out;
542 		}
543 		memcpy(sm4_ccm_info->iv,
544 		       cctx->iv + TLS_CIPHER_SM4_CCM_SALT_SIZE,
545 		       TLS_CIPHER_SM4_CCM_IV_SIZE);
546 		memcpy(sm4_ccm_info->rec_seq, cctx->rec_seq,
547 		       TLS_CIPHER_SM4_CCM_REC_SEQ_SIZE);
548 		if (copy_to_user(optval, sm4_ccm_info, sizeof(*sm4_ccm_info)))
549 			rc = -EFAULT;
550 		break;
551 	}
552 	case TLS_CIPHER_ARIA_GCM_128: {
553 		struct tls12_crypto_info_aria_gcm_128 *
554 		  crypto_info_aria_gcm_128 =
555 		  container_of(crypto_info,
556 			       struct tls12_crypto_info_aria_gcm_128,
557 			       info);
558 
559 		if (len != sizeof(*crypto_info_aria_gcm_128)) {
560 			rc = -EINVAL;
561 			goto out;
562 		}
563 		memcpy(crypto_info_aria_gcm_128->iv,
564 		       cctx->iv + TLS_CIPHER_ARIA_GCM_128_SALT_SIZE,
565 		       TLS_CIPHER_ARIA_GCM_128_IV_SIZE);
566 		memcpy(crypto_info_aria_gcm_128->rec_seq, cctx->rec_seq,
567 		       TLS_CIPHER_ARIA_GCM_128_REC_SEQ_SIZE);
568 		if (copy_to_user(optval,
569 				 crypto_info_aria_gcm_128,
570 				 sizeof(*crypto_info_aria_gcm_128)))
571 			rc = -EFAULT;
572 		break;
573 	}
574 	case TLS_CIPHER_ARIA_GCM_256: {
575 		struct tls12_crypto_info_aria_gcm_256 *
576 		  crypto_info_aria_gcm_256 =
577 		  container_of(crypto_info,
578 			       struct tls12_crypto_info_aria_gcm_256,
579 			       info);
580 
581 		if (len != sizeof(*crypto_info_aria_gcm_256)) {
582 			rc = -EINVAL;
583 			goto out;
584 		}
585 		memcpy(crypto_info_aria_gcm_256->iv,
586 		       cctx->iv + TLS_CIPHER_ARIA_GCM_256_SALT_SIZE,
587 		       TLS_CIPHER_ARIA_GCM_256_IV_SIZE);
588 		memcpy(crypto_info_aria_gcm_256->rec_seq, cctx->rec_seq,
589 		       TLS_CIPHER_ARIA_GCM_256_REC_SEQ_SIZE);
590 		if (copy_to_user(optval,
591 				 crypto_info_aria_gcm_256,
592 				 sizeof(*crypto_info_aria_gcm_256)))
593 			rc = -EFAULT;
594 		break;
595 	}
596 	default:
597 		rc = -EINVAL;
598 	}
599 
600 out:
601 	return rc;
602 }
603 
604 static int do_tls_getsockopt_tx_zc(struct sock *sk, char __user *optval,
605 				   int __user *optlen)
606 {
607 	struct tls_context *ctx = tls_get_ctx(sk);
608 	unsigned int value;
609 	int len;
610 
611 	if (get_user(len, optlen))
612 		return -EFAULT;
613 
614 	if (len != sizeof(value))
615 		return -EINVAL;
616 
617 	value = ctx->zerocopy_sendfile;
618 	if (copy_to_user(optval, &value, sizeof(value)))
619 		return -EFAULT;
620 
621 	return 0;
622 }
623 
624 static int do_tls_getsockopt_no_pad(struct sock *sk, char __user *optval,
625 				    int __user *optlen)
626 {
627 	struct tls_context *ctx = tls_get_ctx(sk);
628 	int value, len;
629 
630 	if (ctx->prot_info.version != TLS_1_3_VERSION)
631 		return -EINVAL;
632 
633 	if (get_user(len, optlen))
634 		return -EFAULT;
635 	if (len < sizeof(value))
636 		return -EINVAL;
637 
638 	value = -EINVAL;
639 	if (ctx->rx_conf == TLS_SW || ctx->rx_conf == TLS_HW)
640 		value = ctx->rx_no_pad;
641 	if (value < 0)
642 		return value;
643 
644 	if (put_user(sizeof(value), optlen))
645 		return -EFAULT;
646 	if (copy_to_user(optval, &value, sizeof(value)))
647 		return -EFAULT;
648 
649 	return 0;
650 }
651 
652 static int do_tls_getsockopt(struct sock *sk, int optname,
653 			     char __user *optval, int __user *optlen)
654 {
655 	int rc = 0;
656 
657 	lock_sock(sk);
658 
659 	switch (optname) {
660 	case TLS_TX:
661 	case TLS_RX:
662 		rc = do_tls_getsockopt_conf(sk, optval, optlen,
663 					    optname == TLS_TX);
664 		break;
665 	case TLS_TX_ZEROCOPY_RO:
666 		rc = do_tls_getsockopt_tx_zc(sk, optval, optlen);
667 		break;
668 	case TLS_RX_EXPECT_NO_PAD:
669 		rc = do_tls_getsockopt_no_pad(sk, optval, optlen);
670 		break;
671 	default:
672 		rc = -ENOPROTOOPT;
673 		break;
674 	}
675 
676 	release_sock(sk);
677 
678 	return rc;
679 }
680 
681 static int tls_getsockopt(struct sock *sk, int level, int optname,
682 			  char __user *optval, int __user *optlen)
683 {
684 	struct tls_context *ctx = tls_get_ctx(sk);
685 
686 	if (level != SOL_TLS)
687 		return ctx->sk_proto->getsockopt(sk, level,
688 						 optname, optval, optlen);
689 
690 	return do_tls_getsockopt(sk, optname, optval, optlen);
691 }
692 
693 static int do_tls_setsockopt_conf(struct sock *sk, sockptr_t optval,
694 				  unsigned int optlen, int tx)
695 {
696 	struct tls_crypto_info *crypto_info;
697 	struct tls_crypto_info *alt_crypto_info;
698 	struct tls_context *ctx = tls_get_ctx(sk);
699 	size_t optsize;
700 	int rc = 0;
701 	int conf;
702 
703 	if (sockptr_is_null(optval) || (optlen < sizeof(*crypto_info)))
704 		return -EINVAL;
705 
706 	if (tx) {
707 		crypto_info = &ctx->crypto_send.info;
708 		alt_crypto_info = &ctx->crypto_recv.info;
709 	} else {
710 		crypto_info = &ctx->crypto_recv.info;
711 		alt_crypto_info = &ctx->crypto_send.info;
712 	}
713 
714 	/* Currently we don't support set crypto info more than one time */
715 	if (TLS_CRYPTO_INFO_READY(crypto_info))
716 		return -EBUSY;
717 
718 	rc = copy_from_sockptr(crypto_info, optval, sizeof(*crypto_info));
719 	if (rc) {
720 		rc = -EFAULT;
721 		goto err_crypto_info;
722 	}
723 
724 	/* check version */
725 	if (crypto_info->version != TLS_1_2_VERSION &&
726 	    crypto_info->version != TLS_1_3_VERSION) {
727 		rc = -EINVAL;
728 		goto err_crypto_info;
729 	}
730 
731 	/* Ensure that TLS version and ciphers are same in both directions */
732 	if (TLS_CRYPTO_INFO_READY(alt_crypto_info)) {
733 		if (alt_crypto_info->version != crypto_info->version ||
734 		    alt_crypto_info->cipher_type != crypto_info->cipher_type) {
735 			rc = -EINVAL;
736 			goto err_crypto_info;
737 		}
738 	}
739 
740 	switch (crypto_info->cipher_type) {
741 	case TLS_CIPHER_AES_GCM_128:
742 		optsize = sizeof(struct tls12_crypto_info_aes_gcm_128);
743 		break;
744 	case TLS_CIPHER_AES_GCM_256: {
745 		optsize = sizeof(struct tls12_crypto_info_aes_gcm_256);
746 		break;
747 	}
748 	case TLS_CIPHER_AES_CCM_128:
749 		optsize = sizeof(struct tls12_crypto_info_aes_ccm_128);
750 		break;
751 	case TLS_CIPHER_CHACHA20_POLY1305:
752 		optsize = sizeof(struct tls12_crypto_info_chacha20_poly1305);
753 		break;
754 	case TLS_CIPHER_SM4_GCM:
755 		optsize = sizeof(struct tls12_crypto_info_sm4_gcm);
756 		break;
757 	case TLS_CIPHER_SM4_CCM:
758 		optsize = sizeof(struct tls12_crypto_info_sm4_ccm);
759 		break;
760 	case TLS_CIPHER_ARIA_GCM_128:
761 		if (crypto_info->version != TLS_1_2_VERSION) {
762 			rc = -EINVAL;
763 			goto err_crypto_info;
764 		}
765 		optsize = sizeof(struct tls12_crypto_info_aria_gcm_128);
766 		break;
767 	case TLS_CIPHER_ARIA_GCM_256:
768 		if (crypto_info->version != TLS_1_2_VERSION) {
769 			rc = -EINVAL;
770 			goto err_crypto_info;
771 		}
772 		optsize = sizeof(struct tls12_crypto_info_aria_gcm_256);
773 		break;
774 	default:
775 		rc = -EINVAL;
776 		goto err_crypto_info;
777 	}
778 
779 	if (optlen != optsize) {
780 		rc = -EINVAL;
781 		goto err_crypto_info;
782 	}
783 
784 	rc = copy_from_sockptr_offset(crypto_info + 1, optval,
785 				      sizeof(*crypto_info),
786 				      optlen - sizeof(*crypto_info));
787 	if (rc) {
788 		rc = -EFAULT;
789 		goto err_crypto_info;
790 	}
791 
792 	if (tx) {
793 		rc = tls_set_device_offload(sk, ctx);
794 		conf = TLS_HW;
795 		if (!rc) {
796 			TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSTXDEVICE);
797 			TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSCURRTXDEVICE);
798 		} else {
799 			rc = tls_set_sw_offload(sk, ctx, 1);
800 			if (rc)
801 				goto err_crypto_info;
802 			TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSTXSW);
803 			TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSCURRTXSW);
804 			conf = TLS_SW;
805 		}
806 	} else {
807 		rc = tls_set_device_offload_rx(sk, ctx);
808 		conf = TLS_HW;
809 		if (!rc) {
810 			TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSRXDEVICE);
811 			TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSCURRRXDEVICE);
812 		} else {
813 			rc = tls_set_sw_offload(sk, ctx, 0);
814 			if (rc)
815 				goto err_crypto_info;
816 			TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSRXSW);
817 			TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSCURRRXSW);
818 			conf = TLS_SW;
819 		}
820 		tls_sw_strparser_arm(sk, ctx);
821 	}
822 
823 	if (tx)
824 		ctx->tx_conf = conf;
825 	else
826 		ctx->rx_conf = conf;
827 	update_sk_prot(sk, ctx);
828 	if (tx) {
829 		ctx->sk_write_space = sk->sk_write_space;
830 		sk->sk_write_space = tls_write_space;
831 	} else {
832 		struct tls_sw_context_rx *rx_ctx = tls_sw_ctx_rx(ctx);
833 
834 		tls_strp_check_rcv(&rx_ctx->strp);
835 	}
836 	return 0;
837 
838 err_crypto_info:
839 	memzero_explicit(crypto_info, sizeof(union tls_crypto_context));
840 	return rc;
841 }
842 
843 static int do_tls_setsockopt_tx_zc(struct sock *sk, sockptr_t optval,
844 				   unsigned int optlen)
845 {
846 	struct tls_context *ctx = tls_get_ctx(sk);
847 	unsigned int value;
848 
849 	if (sockptr_is_null(optval) || optlen != sizeof(value))
850 		return -EINVAL;
851 
852 	if (copy_from_sockptr(&value, optval, sizeof(value)))
853 		return -EFAULT;
854 
855 	if (value > 1)
856 		return -EINVAL;
857 
858 	ctx->zerocopy_sendfile = value;
859 
860 	return 0;
861 }
862 
863 static int do_tls_setsockopt_no_pad(struct sock *sk, sockptr_t optval,
864 				    unsigned int optlen)
865 {
866 	struct tls_context *ctx = tls_get_ctx(sk);
867 	u32 val;
868 	int rc;
869 
870 	if (ctx->prot_info.version != TLS_1_3_VERSION ||
871 	    sockptr_is_null(optval) || optlen < sizeof(val))
872 		return -EINVAL;
873 
874 	rc = copy_from_sockptr(&val, optval, sizeof(val));
875 	if (rc)
876 		return -EFAULT;
877 	if (val > 1)
878 		return -EINVAL;
879 	rc = check_zeroed_sockptr(optval, sizeof(val), optlen - sizeof(val));
880 	if (rc < 1)
881 		return rc == 0 ? -EINVAL : rc;
882 
883 	lock_sock(sk);
884 	rc = -EINVAL;
885 	if (ctx->rx_conf == TLS_SW || ctx->rx_conf == TLS_HW) {
886 		ctx->rx_no_pad = val;
887 		tls_update_rx_zc_capable(ctx);
888 		rc = 0;
889 	}
890 	release_sock(sk);
891 
892 	return rc;
893 }
894 
895 static int do_tls_setsockopt(struct sock *sk, int optname, sockptr_t optval,
896 			     unsigned int optlen)
897 {
898 	int rc = 0;
899 
900 	switch (optname) {
901 	case TLS_TX:
902 	case TLS_RX:
903 		lock_sock(sk);
904 		rc = do_tls_setsockopt_conf(sk, optval, optlen,
905 					    optname == TLS_TX);
906 		release_sock(sk);
907 		break;
908 	case TLS_TX_ZEROCOPY_RO:
909 		lock_sock(sk);
910 		rc = do_tls_setsockopt_tx_zc(sk, optval, optlen);
911 		release_sock(sk);
912 		break;
913 	case TLS_RX_EXPECT_NO_PAD:
914 		rc = do_tls_setsockopt_no_pad(sk, optval, optlen);
915 		break;
916 	default:
917 		rc = -ENOPROTOOPT;
918 		break;
919 	}
920 	return rc;
921 }
922 
923 static int tls_setsockopt(struct sock *sk, int level, int optname,
924 			  sockptr_t optval, unsigned int optlen)
925 {
926 	struct tls_context *ctx = tls_get_ctx(sk);
927 
928 	if (level != SOL_TLS)
929 		return ctx->sk_proto->setsockopt(sk, level, optname, optval,
930 						 optlen);
931 
932 	return do_tls_setsockopt(sk, optname, optval, optlen);
933 }
934 
935 struct tls_context *tls_ctx_create(struct sock *sk)
936 {
937 	struct inet_connection_sock *icsk = inet_csk(sk);
938 	struct tls_context *ctx;
939 
940 	ctx = kzalloc(sizeof(*ctx), GFP_ATOMIC);
941 	if (!ctx)
942 		return NULL;
943 
944 	mutex_init(&ctx->tx_lock);
945 	rcu_assign_pointer(icsk->icsk_ulp_data, ctx);
946 	ctx->sk_proto = READ_ONCE(sk->sk_prot);
947 	ctx->sk = sk;
948 	return ctx;
949 }
950 
951 static void build_proto_ops(struct proto_ops ops[TLS_NUM_CONFIG][TLS_NUM_CONFIG],
952 			    const struct proto_ops *base)
953 {
954 	ops[TLS_BASE][TLS_BASE] = *base;
955 
956 	ops[TLS_SW  ][TLS_BASE] = ops[TLS_BASE][TLS_BASE];
957 	ops[TLS_SW  ][TLS_BASE].splice_eof	= tls_sw_splice_eof;
958 
959 	ops[TLS_BASE][TLS_SW  ] = ops[TLS_BASE][TLS_BASE];
960 	ops[TLS_BASE][TLS_SW  ].splice_read	= tls_sw_splice_read;
961 	ops[TLS_BASE][TLS_SW  ].poll		= tls_sk_poll;
962 	ops[TLS_BASE][TLS_SW  ].read_sock	= tls_sw_read_sock;
963 
964 	ops[TLS_SW  ][TLS_SW  ] = ops[TLS_SW  ][TLS_BASE];
965 	ops[TLS_SW  ][TLS_SW  ].splice_read	= tls_sw_splice_read;
966 	ops[TLS_SW  ][TLS_SW  ].poll		= tls_sk_poll;
967 	ops[TLS_SW  ][TLS_SW  ].read_sock	= tls_sw_read_sock;
968 
969 #ifdef CONFIG_TLS_DEVICE
970 	ops[TLS_HW  ][TLS_BASE] = ops[TLS_BASE][TLS_BASE];
971 
972 	ops[TLS_HW  ][TLS_SW  ] = ops[TLS_BASE][TLS_SW  ];
973 
974 	ops[TLS_BASE][TLS_HW  ] = ops[TLS_BASE][TLS_SW  ];
975 
976 	ops[TLS_SW  ][TLS_HW  ] = ops[TLS_SW  ][TLS_SW  ];
977 
978 	ops[TLS_HW  ][TLS_HW  ] = ops[TLS_HW  ][TLS_SW  ];
979 #endif
980 #ifdef CONFIG_TLS_TOE
981 	ops[TLS_HW_RECORD][TLS_HW_RECORD] = *base;
982 #endif
983 }
984 
985 static void tls_build_proto(struct sock *sk)
986 {
987 	int ip_ver = sk->sk_family == AF_INET6 ? TLSV6 : TLSV4;
988 	struct proto *prot = READ_ONCE(sk->sk_prot);
989 
990 	/* Build IPv6 TLS whenever the address of tcpv6 _prot changes */
991 	if (ip_ver == TLSV6 &&
992 	    unlikely(prot != smp_load_acquire(&saved_tcpv6_prot))) {
993 		mutex_lock(&tcpv6_prot_mutex);
994 		if (likely(prot != saved_tcpv6_prot)) {
995 			build_protos(tls_prots[TLSV6], prot);
996 			build_proto_ops(tls_proto_ops[TLSV6],
997 					sk->sk_socket->ops);
998 			smp_store_release(&saved_tcpv6_prot, prot);
999 		}
1000 		mutex_unlock(&tcpv6_prot_mutex);
1001 	}
1002 
1003 	if (ip_ver == TLSV4 &&
1004 	    unlikely(prot != smp_load_acquire(&saved_tcpv4_prot))) {
1005 		mutex_lock(&tcpv4_prot_mutex);
1006 		if (likely(prot != saved_tcpv4_prot)) {
1007 			build_protos(tls_prots[TLSV4], prot);
1008 			build_proto_ops(tls_proto_ops[TLSV4],
1009 					sk->sk_socket->ops);
1010 			smp_store_release(&saved_tcpv4_prot, prot);
1011 		}
1012 		mutex_unlock(&tcpv4_prot_mutex);
1013 	}
1014 }
1015 
1016 static void build_protos(struct proto prot[TLS_NUM_CONFIG][TLS_NUM_CONFIG],
1017 			 const struct proto *base)
1018 {
1019 	prot[TLS_BASE][TLS_BASE] = *base;
1020 	prot[TLS_BASE][TLS_BASE].setsockopt	= tls_setsockopt;
1021 	prot[TLS_BASE][TLS_BASE].getsockopt	= tls_getsockopt;
1022 	prot[TLS_BASE][TLS_BASE].close		= tls_sk_proto_close;
1023 
1024 	prot[TLS_SW][TLS_BASE] = prot[TLS_BASE][TLS_BASE];
1025 	prot[TLS_SW][TLS_BASE].sendmsg		= tls_sw_sendmsg;
1026 	prot[TLS_SW][TLS_BASE].splice_eof	= tls_sw_splice_eof;
1027 
1028 	prot[TLS_BASE][TLS_SW] = prot[TLS_BASE][TLS_BASE];
1029 	prot[TLS_BASE][TLS_SW].recvmsg		  = tls_sw_recvmsg;
1030 	prot[TLS_BASE][TLS_SW].sock_is_readable   = tls_sw_sock_is_readable;
1031 	prot[TLS_BASE][TLS_SW].close		  = tls_sk_proto_close;
1032 
1033 	prot[TLS_SW][TLS_SW] = prot[TLS_SW][TLS_BASE];
1034 	prot[TLS_SW][TLS_SW].recvmsg		= tls_sw_recvmsg;
1035 	prot[TLS_SW][TLS_SW].sock_is_readable   = tls_sw_sock_is_readable;
1036 	prot[TLS_SW][TLS_SW].close		= tls_sk_proto_close;
1037 
1038 #ifdef CONFIG_TLS_DEVICE
1039 	prot[TLS_HW][TLS_BASE] = prot[TLS_BASE][TLS_BASE];
1040 	prot[TLS_HW][TLS_BASE].sendmsg		= tls_device_sendmsg;
1041 	prot[TLS_HW][TLS_BASE].splice_eof	= tls_device_splice_eof;
1042 
1043 	prot[TLS_HW][TLS_SW] = prot[TLS_BASE][TLS_SW];
1044 	prot[TLS_HW][TLS_SW].sendmsg		= tls_device_sendmsg;
1045 	prot[TLS_HW][TLS_SW].splice_eof		= tls_device_splice_eof;
1046 
1047 	prot[TLS_BASE][TLS_HW] = prot[TLS_BASE][TLS_SW];
1048 
1049 	prot[TLS_SW][TLS_HW] = prot[TLS_SW][TLS_SW];
1050 
1051 	prot[TLS_HW][TLS_HW] = prot[TLS_HW][TLS_SW];
1052 #endif
1053 #ifdef CONFIG_TLS_TOE
1054 	prot[TLS_HW_RECORD][TLS_HW_RECORD] = *base;
1055 	prot[TLS_HW_RECORD][TLS_HW_RECORD].hash		= tls_toe_hash;
1056 	prot[TLS_HW_RECORD][TLS_HW_RECORD].unhash	= tls_toe_unhash;
1057 #endif
1058 }
1059 
1060 static int tls_init(struct sock *sk)
1061 {
1062 	struct tls_context *ctx;
1063 	int rc = 0;
1064 
1065 	tls_build_proto(sk);
1066 
1067 #ifdef CONFIG_TLS_TOE
1068 	if (tls_toe_bypass(sk))
1069 		return 0;
1070 #endif
1071 
1072 	/* The TLS ulp is currently supported only for TCP sockets
1073 	 * in ESTABLISHED state.
1074 	 * Supporting sockets in LISTEN state will require us
1075 	 * to modify the accept implementation to clone rather then
1076 	 * share the ulp context.
1077 	 */
1078 	if (sk->sk_state != TCP_ESTABLISHED)
1079 		return -ENOTCONN;
1080 
1081 	/* allocate tls context */
1082 	write_lock_bh(&sk->sk_callback_lock);
1083 	ctx = tls_ctx_create(sk);
1084 	if (!ctx) {
1085 		rc = -ENOMEM;
1086 		goto out;
1087 	}
1088 
1089 	ctx->tx_conf = TLS_BASE;
1090 	ctx->rx_conf = TLS_BASE;
1091 	update_sk_prot(sk, ctx);
1092 out:
1093 	write_unlock_bh(&sk->sk_callback_lock);
1094 	return rc;
1095 }
1096 
1097 static void tls_update(struct sock *sk, struct proto *p,
1098 		       void (*write_space)(struct sock *sk))
1099 {
1100 	struct tls_context *ctx;
1101 
1102 	WARN_ON_ONCE(sk->sk_prot == p);
1103 
1104 	ctx = tls_get_ctx(sk);
1105 	if (likely(ctx)) {
1106 		ctx->sk_write_space = write_space;
1107 		ctx->sk_proto = p;
1108 	} else {
1109 		/* Pairs with lockless read in sk_clone_lock(). */
1110 		WRITE_ONCE(sk->sk_prot, p);
1111 		sk->sk_write_space = write_space;
1112 	}
1113 }
1114 
1115 static u16 tls_user_config(struct tls_context *ctx, bool tx)
1116 {
1117 	u16 config = tx ? ctx->tx_conf : ctx->rx_conf;
1118 
1119 	switch (config) {
1120 	case TLS_BASE:
1121 		return TLS_CONF_BASE;
1122 	case TLS_SW:
1123 		return TLS_CONF_SW;
1124 	case TLS_HW:
1125 		return TLS_CONF_HW;
1126 	case TLS_HW_RECORD:
1127 		return TLS_CONF_HW_RECORD;
1128 	}
1129 	return 0;
1130 }
1131 
1132 static int tls_get_info(const struct sock *sk, struct sk_buff *skb)
1133 {
1134 	u16 version, cipher_type;
1135 	struct tls_context *ctx;
1136 	struct nlattr *start;
1137 	int err;
1138 
1139 	start = nla_nest_start_noflag(skb, INET_ULP_INFO_TLS);
1140 	if (!start)
1141 		return -EMSGSIZE;
1142 
1143 	rcu_read_lock();
1144 	ctx = rcu_dereference(inet_csk(sk)->icsk_ulp_data);
1145 	if (!ctx) {
1146 		err = 0;
1147 		goto nla_failure;
1148 	}
1149 	version = ctx->prot_info.version;
1150 	if (version) {
1151 		err = nla_put_u16(skb, TLS_INFO_VERSION, version);
1152 		if (err)
1153 			goto nla_failure;
1154 	}
1155 	cipher_type = ctx->prot_info.cipher_type;
1156 	if (cipher_type) {
1157 		err = nla_put_u16(skb, TLS_INFO_CIPHER, cipher_type);
1158 		if (err)
1159 			goto nla_failure;
1160 	}
1161 	err = nla_put_u16(skb, TLS_INFO_TXCONF, tls_user_config(ctx, true));
1162 	if (err)
1163 		goto nla_failure;
1164 
1165 	err = nla_put_u16(skb, TLS_INFO_RXCONF, tls_user_config(ctx, false));
1166 	if (err)
1167 		goto nla_failure;
1168 
1169 	if (ctx->tx_conf == TLS_HW && ctx->zerocopy_sendfile) {
1170 		err = nla_put_flag(skb, TLS_INFO_ZC_RO_TX);
1171 		if (err)
1172 			goto nla_failure;
1173 	}
1174 	if (ctx->rx_no_pad) {
1175 		err = nla_put_flag(skb, TLS_INFO_RX_NO_PAD);
1176 		if (err)
1177 			goto nla_failure;
1178 	}
1179 
1180 	rcu_read_unlock();
1181 	nla_nest_end(skb, start);
1182 	return 0;
1183 
1184 nla_failure:
1185 	rcu_read_unlock();
1186 	nla_nest_cancel(skb, start);
1187 	return err;
1188 }
1189 
1190 static size_t tls_get_info_size(const struct sock *sk)
1191 {
1192 	size_t size = 0;
1193 
1194 	size += nla_total_size(0) +		/* INET_ULP_INFO_TLS */
1195 		nla_total_size(sizeof(u16)) +	/* TLS_INFO_VERSION */
1196 		nla_total_size(sizeof(u16)) +	/* TLS_INFO_CIPHER */
1197 		nla_total_size(sizeof(u16)) +	/* TLS_INFO_RXCONF */
1198 		nla_total_size(sizeof(u16)) +	/* TLS_INFO_TXCONF */
1199 		nla_total_size(0) +		/* TLS_INFO_ZC_RO_TX */
1200 		nla_total_size(0) +		/* TLS_INFO_RX_NO_PAD */
1201 		0;
1202 
1203 	return size;
1204 }
1205 
1206 static int __net_init tls_init_net(struct net *net)
1207 {
1208 	int err;
1209 
1210 	net->mib.tls_statistics = alloc_percpu(struct linux_tls_mib);
1211 	if (!net->mib.tls_statistics)
1212 		return -ENOMEM;
1213 
1214 	err = tls_proc_init(net);
1215 	if (err)
1216 		goto err_free_stats;
1217 
1218 	return 0;
1219 err_free_stats:
1220 	free_percpu(net->mib.tls_statistics);
1221 	return err;
1222 }
1223 
1224 static void __net_exit tls_exit_net(struct net *net)
1225 {
1226 	tls_proc_fini(net);
1227 	free_percpu(net->mib.tls_statistics);
1228 }
1229 
1230 static struct pernet_operations tls_proc_ops = {
1231 	.init = tls_init_net,
1232 	.exit = tls_exit_net,
1233 };
1234 
1235 static struct tcp_ulp_ops tcp_tls_ulp_ops __read_mostly = {
1236 	.name			= "tls",
1237 	.owner			= THIS_MODULE,
1238 	.init			= tls_init,
1239 	.update			= tls_update,
1240 	.get_info		= tls_get_info,
1241 	.get_info_size		= tls_get_info_size,
1242 };
1243 
1244 static int __init tls_register(void)
1245 {
1246 	int err;
1247 
1248 	err = register_pernet_subsys(&tls_proc_ops);
1249 	if (err)
1250 		return err;
1251 
1252 	err = tls_strp_dev_init();
1253 	if (err)
1254 		goto err_pernet;
1255 
1256 	err = tls_device_init();
1257 	if (err)
1258 		goto err_strp;
1259 
1260 	tcp_register_ulp(&tcp_tls_ulp_ops);
1261 
1262 	return 0;
1263 err_strp:
1264 	tls_strp_dev_exit();
1265 err_pernet:
1266 	unregister_pernet_subsys(&tls_proc_ops);
1267 	return err;
1268 }
1269 
1270 static void __exit tls_unregister(void)
1271 {
1272 	tcp_unregister_ulp(&tcp_tls_ulp_ops);
1273 	tls_strp_dev_exit();
1274 	tls_device_cleanup();
1275 	unregister_pernet_subsys(&tls_proc_ops);
1276 }
1277 
1278 module_init(tls_register);
1279 module_exit(tls_unregister);
1280