1 /* 2 * Copyright (c) 2016-2017, Mellanox Technologies. All rights reserved. 3 * Copyright (c) 2016-2017, Dave Watson <davejwatson@fb.com>. All rights reserved. 4 * 5 * This software is available to you under a choice of one of two 6 * licenses. You may choose to be licensed under the terms of the GNU 7 * General Public License (GPL) Version 2, available from the file 8 * COPYING in the main directory of this source tree, or the 9 * OpenIB.org BSD license below: 10 * 11 * Redistribution and use in source and binary forms, with or 12 * without modification, are permitted provided that the following 13 * conditions are met: 14 * 15 * - Redistributions of source code must retain the above 16 * copyright notice, this list of conditions and the following 17 * disclaimer. 18 * 19 * - Redistributions in binary form must reproduce the above 20 * copyright notice, this list of conditions and the following 21 * disclaimer in the documentation and/or other materials 22 * provided with the distribution. 23 * 24 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, 25 * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF 26 * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND 27 * NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS 28 * BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN 29 * ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN 30 * CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 31 * SOFTWARE. 32 */ 33 34 #include <linux/module.h> 35 36 #include <net/tcp.h> 37 #include <net/inet_common.h> 38 #include <linux/highmem.h> 39 #include <linux/netdevice.h> 40 #include <linux/sched/signal.h> 41 #include <linux/inetdevice.h> 42 #include <linux/inet_diag.h> 43 44 #include <net/snmp.h> 45 #include <net/tls.h> 46 #include <net/tls_toe.h> 47 48 #include "tls.h" 49 50 MODULE_AUTHOR("Mellanox Technologies"); 51 MODULE_DESCRIPTION("Transport Layer Security Support"); 52 MODULE_LICENSE("Dual BSD/GPL"); 53 MODULE_ALIAS_TCP_ULP("tls"); 54 55 enum { 56 TLSV4, 57 TLSV6, 58 TLS_NUM_PROTS, 59 }; 60 61 #define CHECK_CIPHER_DESC(cipher,ci) \ 62 static_assert(cipher ## _IV_SIZE <= MAX_IV_SIZE); \ 63 static_assert(cipher ## _REC_SEQ_SIZE <= TLS_MAX_REC_SEQ_SIZE); \ 64 static_assert(cipher ## _TAG_SIZE == TLS_TAG_SIZE); \ 65 static_assert(sizeof_field(struct ci, iv) == cipher ## _IV_SIZE); \ 66 static_assert(sizeof_field(struct ci, key) == cipher ## _KEY_SIZE); \ 67 static_assert(sizeof_field(struct ci, salt) == cipher ## _SALT_SIZE); \ 68 static_assert(sizeof_field(struct ci, rec_seq) == cipher ## _REC_SEQ_SIZE); 69 70 #define __CIPHER_DESC(ci) \ 71 .iv_offset = offsetof(struct ci, iv), \ 72 .key_offset = offsetof(struct ci, key), \ 73 .salt_offset = offsetof(struct ci, salt), \ 74 .rec_seq_offset = offsetof(struct ci, rec_seq), \ 75 .crypto_info = sizeof(struct ci) 76 77 #define CIPHER_DESC(cipher,ci,algname,_offloadable) [cipher - TLS_CIPHER_MIN] = { \ 78 .nonce = cipher ## _IV_SIZE, \ 79 .iv = cipher ## _IV_SIZE, \ 80 .key = cipher ## _KEY_SIZE, \ 81 .salt = cipher ## _SALT_SIZE, \ 82 .tag = cipher ## _TAG_SIZE, \ 83 .rec_seq = cipher ## _REC_SEQ_SIZE, \ 84 .cipher_name = algname, \ 85 .offloadable = _offloadable, \ 86 __CIPHER_DESC(ci), \ 87 } 88 89 #define CIPHER_DESC_NONCE0(cipher,ci,algname,_offloadable) [cipher - TLS_CIPHER_MIN] = { \ 90 .nonce = 0, \ 91 .iv = cipher ## _IV_SIZE, \ 92 .key = cipher ## _KEY_SIZE, \ 93 .salt = cipher ## _SALT_SIZE, \ 94 .tag = cipher ## _TAG_SIZE, \ 95 .rec_seq = cipher ## _REC_SEQ_SIZE, \ 96 .cipher_name = algname, \ 97 .offloadable = _offloadable, \ 98 __CIPHER_DESC(ci), \ 99 } 100 101 const struct tls_cipher_desc tls_cipher_desc[TLS_CIPHER_MAX + 1 - TLS_CIPHER_MIN] = { 102 CIPHER_DESC(TLS_CIPHER_AES_GCM_128, tls12_crypto_info_aes_gcm_128, "gcm(aes)", true), 103 CIPHER_DESC(TLS_CIPHER_AES_GCM_256, tls12_crypto_info_aes_gcm_256, "gcm(aes)", true), 104 CIPHER_DESC(TLS_CIPHER_AES_CCM_128, tls12_crypto_info_aes_ccm_128, "ccm(aes)", false), 105 CIPHER_DESC_NONCE0(TLS_CIPHER_CHACHA20_POLY1305, tls12_crypto_info_chacha20_poly1305, "rfc7539(chacha20,poly1305)", false), 106 CIPHER_DESC(TLS_CIPHER_SM4_GCM, tls12_crypto_info_sm4_gcm, "gcm(sm4)", false), 107 CIPHER_DESC(TLS_CIPHER_SM4_CCM, tls12_crypto_info_sm4_ccm, "ccm(sm4)", false), 108 CIPHER_DESC(TLS_CIPHER_ARIA_GCM_128, tls12_crypto_info_aria_gcm_128, "gcm(aria)", false), 109 CIPHER_DESC(TLS_CIPHER_ARIA_GCM_256, tls12_crypto_info_aria_gcm_256, "gcm(aria)", false), 110 }; 111 112 CHECK_CIPHER_DESC(TLS_CIPHER_AES_GCM_128, tls12_crypto_info_aes_gcm_128); 113 CHECK_CIPHER_DESC(TLS_CIPHER_AES_GCM_256, tls12_crypto_info_aes_gcm_256); 114 CHECK_CIPHER_DESC(TLS_CIPHER_AES_CCM_128, tls12_crypto_info_aes_ccm_128); 115 CHECK_CIPHER_DESC(TLS_CIPHER_CHACHA20_POLY1305, tls12_crypto_info_chacha20_poly1305); 116 CHECK_CIPHER_DESC(TLS_CIPHER_SM4_GCM, tls12_crypto_info_sm4_gcm); 117 CHECK_CIPHER_DESC(TLS_CIPHER_SM4_CCM, tls12_crypto_info_sm4_ccm); 118 CHECK_CIPHER_DESC(TLS_CIPHER_ARIA_GCM_128, tls12_crypto_info_aria_gcm_128); 119 CHECK_CIPHER_DESC(TLS_CIPHER_ARIA_GCM_256, tls12_crypto_info_aria_gcm_256); 120 121 static const struct proto *saved_tcpv6_prot; 122 static DEFINE_MUTEX(tcpv6_prot_mutex); 123 static const struct proto *saved_tcpv4_prot; 124 static DEFINE_MUTEX(tcpv4_prot_mutex); 125 static struct proto tls_prots[TLS_NUM_PROTS][TLS_NUM_CONFIG][TLS_NUM_CONFIG]; 126 static struct proto_ops tls_proto_ops[TLS_NUM_PROTS][TLS_NUM_CONFIG][TLS_NUM_CONFIG]; 127 static void build_protos(struct proto prot[TLS_NUM_CONFIG][TLS_NUM_CONFIG], 128 const struct proto *base); 129 130 void update_sk_prot(struct sock *sk, struct tls_context *ctx) 131 { 132 int ip_ver = sk->sk_family == AF_INET6 ? TLSV6 : TLSV4; 133 134 WRITE_ONCE(sk->sk_prot, 135 &tls_prots[ip_ver][ctx->tx_conf][ctx->rx_conf]); 136 WRITE_ONCE(sk->sk_socket->ops, 137 &tls_proto_ops[ip_ver][ctx->tx_conf][ctx->rx_conf]); 138 } 139 140 int wait_on_pending_writer(struct sock *sk, long *timeo) 141 { 142 DEFINE_WAIT_FUNC(wait, woken_wake_function); 143 int ret, rc = 0; 144 145 add_wait_queue(sk_sleep(sk), &wait); 146 while (1) { 147 if (!*timeo) { 148 rc = -EAGAIN; 149 break; 150 } 151 152 if (signal_pending(current)) { 153 rc = sock_intr_errno(*timeo); 154 break; 155 } 156 157 ret = sk_wait_event(sk, timeo, 158 !READ_ONCE(sk->sk_write_pending), &wait); 159 if (ret) { 160 if (ret < 0) 161 rc = ret; 162 break; 163 } 164 } 165 remove_wait_queue(sk_sleep(sk), &wait); 166 return rc; 167 } 168 169 int tls_push_sg(struct sock *sk, 170 struct tls_context *ctx, 171 struct scatterlist *sg, 172 u16 first_offset, 173 int flags) 174 { 175 struct bio_vec bvec; 176 struct msghdr msg = { 177 .msg_flags = MSG_SPLICE_PAGES | flags, 178 }; 179 int ret = 0; 180 struct page *p; 181 size_t size; 182 int offset = first_offset; 183 184 size = sg->length - offset; 185 offset += sg->offset; 186 187 ctx->splicing_pages = true; 188 while (1) { 189 /* is sending application-limited? */ 190 tcp_rate_check_app_limited(sk); 191 p = sg_page(sg); 192 retry: 193 bvec_set_page(&bvec, p, size, offset); 194 iov_iter_bvec(&msg.msg_iter, ITER_SOURCE, &bvec, 1, size); 195 196 ret = tcp_sendmsg_locked(sk, &msg, size); 197 198 if (ret != size) { 199 if (ret > 0) { 200 offset += ret; 201 size -= ret; 202 goto retry; 203 } 204 205 offset -= sg->offset; 206 ctx->partially_sent_offset = offset; 207 ctx->partially_sent_record = (void *)sg; 208 ctx->splicing_pages = false; 209 return ret; 210 } 211 212 put_page(p); 213 sk_mem_uncharge(sk, sg->length); 214 sg = sg_next(sg); 215 if (!sg) 216 break; 217 218 offset = sg->offset; 219 size = sg->length; 220 } 221 222 ctx->splicing_pages = false; 223 224 return 0; 225 } 226 227 static int tls_handle_open_record(struct sock *sk, int flags) 228 { 229 struct tls_context *ctx = tls_get_ctx(sk); 230 231 if (tls_is_pending_open_record(ctx)) 232 return ctx->push_pending_record(sk, flags); 233 234 return 0; 235 } 236 237 int tls_process_cmsg(struct sock *sk, struct msghdr *msg, 238 unsigned char *record_type) 239 { 240 struct cmsghdr *cmsg; 241 int rc = -EINVAL; 242 243 for_each_cmsghdr(cmsg, msg) { 244 if (!CMSG_OK(msg, cmsg)) 245 return -EINVAL; 246 if (cmsg->cmsg_level != SOL_TLS) 247 continue; 248 249 switch (cmsg->cmsg_type) { 250 case TLS_SET_RECORD_TYPE: 251 if (cmsg->cmsg_len < CMSG_LEN(sizeof(*record_type))) 252 return -EINVAL; 253 254 if (msg->msg_flags & MSG_MORE) 255 return -EINVAL; 256 257 rc = tls_handle_open_record(sk, msg->msg_flags); 258 if (rc) 259 return rc; 260 261 *record_type = *(unsigned char *)CMSG_DATA(cmsg); 262 rc = 0; 263 break; 264 default: 265 return -EINVAL; 266 } 267 } 268 269 return rc; 270 } 271 272 int tls_push_partial_record(struct sock *sk, struct tls_context *ctx, 273 int flags) 274 { 275 struct scatterlist *sg; 276 u16 offset; 277 278 sg = ctx->partially_sent_record; 279 offset = ctx->partially_sent_offset; 280 281 ctx->partially_sent_record = NULL; 282 return tls_push_sg(sk, ctx, sg, offset, flags); 283 } 284 285 void tls_free_partial_record(struct sock *sk, struct tls_context *ctx) 286 { 287 struct scatterlist *sg; 288 289 for (sg = ctx->partially_sent_record; sg; sg = sg_next(sg)) { 290 put_page(sg_page(sg)); 291 sk_mem_uncharge(sk, sg->length); 292 } 293 ctx->partially_sent_record = NULL; 294 } 295 296 static void tls_write_space(struct sock *sk) 297 { 298 struct tls_context *ctx = tls_get_ctx(sk); 299 300 /* If splicing_pages call lower protocol write space handler 301 * to ensure we wake up any waiting operations there. For example 302 * if splicing pages where to call sk_wait_event. 303 */ 304 if (ctx->splicing_pages) { 305 ctx->sk_write_space(sk); 306 return; 307 } 308 309 #ifdef CONFIG_TLS_DEVICE 310 if (ctx->tx_conf == TLS_HW) 311 tls_device_write_space(sk, ctx); 312 else 313 #endif 314 tls_sw_write_space(sk, ctx); 315 316 ctx->sk_write_space(sk); 317 } 318 319 /** 320 * tls_ctx_free() - free TLS ULP context 321 * @sk: socket to with @ctx is attached 322 * @ctx: TLS context structure 323 * 324 * Free TLS context. If @sk is %NULL caller guarantees that the socket 325 * to which @ctx was attached has no outstanding references. 326 */ 327 void tls_ctx_free(struct sock *sk, struct tls_context *ctx) 328 { 329 if (!ctx) 330 return; 331 332 memzero_explicit(&ctx->crypto_send, sizeof(ctx->crypto_send)); 333 memzero_explicit(&ctx->crypto_recv, sizeof(ctx->crypto_recv)); 334 mutex_destroy(&ctx->tx_lock); 335 336 if (sk) 337 kfree_rcu(ctx, rcu); 338 else 339 kfree(ctx); 340 } 341 342 static void tls_sk_proto_cleanup(struct sock *sk, 343 struct tls_context *ctx, long timeo) 344 { 345 if (unlikely(sk->sk_write_pending) && 346 !wait_on_pending_writer(sk, &timeo)) 347 tls_handle_open_record(sk, 0); 348 349 /* We need these for tls_sw_fallback handling of other packets */ 350 if (ctx->tx_conf == TLS_SW) { 351 kfree(ctx->tx.rec_seq); 352 kfree(ctx->tx.iv); 353 tls_sw_release_resources_tx(sk); 354 TLS_DEC_STATS(sock_net(sk), LINUX_MIB_TLSCURRTXSW); 355 } else if (ctx->tx_conf == TLS_HW) { 356 tls_device_free_resources_tx(sk); 357 TLS_DEC_STATS(sock_net(sk), LINUX_MIB_TLSCURRTXDEVICE); 358 } 359 360 if (ctx->rx_conf == TLS_SW) { 361 tls_sw_release_resources_rx(sk); 362 TLS_DEC_STATS(sock_net(sk), LINUX_MIB_TLSCURRRXSW); 363 } else if (ctx->rx_conf == TLS_HW) { 364 tls_device_offload_cleanup_rx(sk); 365 TLS_DEC_STATS(sock_net(sk), LINUX_MIB_TLSCURRRXDEVICE); 366 } 367 } 368 369 static void tls_sk_proto_close(struct sock *sk, long timeout) 370 { 371 struct inet_connection_sock *icsk = inet_csk(sk); 372 struct tls_context *ctx = tls_get_ctx(sk); 373 long timeo = sock_sndtimeo(sk, 0); 374 bool free_ctx; 375 376 if (ctx->tx_conf == TLS_SW) 377 tls_sw_cancel_work_tx(ctx); 378 379 lock_sock(sk); 380 free_ctx = ctx->tx_conf != TLS_HW && ctx->rx_conf != TLS_HW; 381 382 if (ctx->tx_conf != TLS_BASE || ctx->rx_conf != TLS_BASE) 383 tls_sk_proto_cleanup(sk, ctx, timeo); 384 385 write_lock_bh(&sk->sk_callback_lock); 386 if (free_ctx) 387 rcu_assign_pointer(icsk->icsk_ulp_data, NULL); 388 WRITE_ONCE(sk->sk_prot, ctx->sk_proto); 389 if (sk->sk_write_space == tls_write_space) 390 sk->sk_write_space = ctx->sk_write_space; 391 write_unlock_bh(&sk->sk_callback_lock); 392 release_sock(sk); 393 if (ctx->tx_conf == TLS_SW) 394 tls_sw_free_ctx_tx(ctx); 395 if (ctx->rx_conf == TLS_SW || ctx->rx_conf == TLS_HW) 396 tls_sw_strparser_done(ctx); 397 if (ctx->rx_conf == TLS_SW) 398 tls_sw_free_ctx_rx(ctx); 399 ctx->sk_proto->close(sk, timeout); 400 401 if (free_ctx) 402 tls_ctx_free(sk, ctx); 403 } 404 405 static __poll_t tls_sk_poll(struct file *file, struct socket *sock, 406 struct poll_table_struct *wait) 407 { 408 struct tls_sw_context_rx *ctx; 409 struct tls_context *tls_ctx; 410 struct sock *sk = sock->sk; 411 struct sk_psock *psock; 412 __poll_t mask = 0; 413 u8 shutdown; 414 int state; 415 416 mask = tcp_poll(file, sock, wait); 417 418 state = inet_sk_state_load(sk); 419 shutdown = READ_ONCE(sk->sk_shutdown); 420 if (unlikely(state != TCP_ESTABLISHED || shutdown & RCV_SHUTDOWN)) 421 return mask; 422 423 tls_ctx = tls_get_ctx(sk); 424 ctx = tls_sw_ctx_rx(tls_ctx); 425 psock = sk_psock_get(sk); 426 427 if (skb_queue_empty_lockless(&ctx->rx_list) && 428 !tls_strp_msg_ready(ctx) && 429 sk_psock_queue_empty(psock)) 430 mask &= ~(EPOLLIN | EPOLLRDNORM); 431 432 if (psock) 433 sk_psock_put(sk, psock); 434 435 return mask; 436 } 437 438 static int do_tls_getsockopt_conf(struct sock *sk, char __user *optval, 439 int __user *optlen, int tx) 440 { 441 int rc = 0; 442 const struct tls_cipher_desc *cipher_desc; 443 struct tls_context *ctx = tls_get_ctx(sk); 444 struct tls_crypto_info *crypto_info; 445 struct cipher_context *cctx; 446 int len; 447 448 if (get_user(len, optlen)) 449 return -EFAULT; 450 451 if (!optval || (len < sizeof(*crypto_info))) { 452 rc = -EINVAL; 453 goto out; 454 } 455 456 if (!ctx) { 457 rc = -EBUSY; 458 goto out; 459 } 460 461 /* get user crypto info */ 462 if (tx) { 463 crypto_info = &ctx->crypto_send.info; 464 cctx = &ctx->tx; 465 } else { 466 crypto_info = &ctx->crypto_recv.info; 467 cctx = &ctx->rx; 468 } 469 470 if (!TLS_CRYPTO_INFO_READY(crypto_info)) { 471 rc = -EBUSY; 472 goto out; 473 } 474 475 if (len == sizeof(*crypto_info)) { 476 if (copy_to_user(optval, crypto_info, sizeof(*crypto_info))) 477 rc = -EFAULT; 478 goto out; 479 } 480 481 cipher_desc = get_cipher_desc(crypto_info->cipher_type); 482 if (!cipher_desc || len != cipher_desc->crypto_info) { 483 rc = -EINVAL; 484 goto out; 485 } 486 487 memcpy(crypto_info_iv(crypto_info, cipher_desc), 488 cctx->iv + cipher_desc->salt, cipher_desc->iv); 489 memcpy(crypto_info_rec_seq(crypto_info, cipher_desc), 490 cctx->rec_seq, cipher_desc->rec_seq); 491 492 if (copy_to_user(optval, crypto_info, cipher_desc->crypto_info)) 493 rc = -EFAULT; 494 495 out: 496 return rc; 497 } 498 499 static int do_tls_getsockopt_tx_zc(struct sock *sk, char __user *optval, 500 int __user *optlen) 501 { 502 struct tls_context *ctx = tls_get_ctx(sk); 503 unsigned int value; 504 int len; 505 506 if (get_user(len, optlen)) 507 return -EFAULT; 508 509 if (len != sizeof(value)) 510 return -EINVAL; 511 512 value = ctx->zerocopy_sendfile; 513 if (copy_to_user(optval, &value, sizeof(value))) 514 return -EFAULT; 515 516 return 0; 517 } 518 519 static int do_tls_getsockopt_no_pad(struct sock *sk, char __user *optval, 520 int __user *optlen) 521 { 522 struct tls_context *ctx = tls_get_ctx(sk); 523 int value, len; 524 525 if (ctx->prot_info.version != TLS_1_3_VERSION) 526 return -EINVAL; 527 528 if (get_user(len, optlen)) 529 return -EFAULT; 530 if (len < sizeof(value)) 531 return -EINVAL; 532 533 value = -EINVAL; 534 if (ctx->rx_conf == TLS_SW || ctx->rx_conf == TLS_HW) 535 value = ctx->rx_no_pad; 536 if (value < 0) 537 return value; 538 539 if (put_user(sizeof(value), optlen)) 540 return -EFAULT; 541 if (copy_to_user(optval, &value, sizeof(value))) 542 return -EFAULT; 543 544 return 0; 545 } 546 547 static int do_tls_getsockopt(struct sock *sk, int optname, 548 char __user *optval, int __user *optlen) 549 { 550 int rc = 0; 551 552 lock_sock(sk); 553 554 switch (optname) { 555 case TLS_TX: 556 case TLS_RX: 557 rc = do_tls_getsockopt_conf(sk, optval, optlen, 558 optname == TLS_TX); 559 break; 560 case TLS_TX_ZEROCOPY_RO: 561 rc = do_tls_getsockopt_tx_zc(sk, optval, optlen); 562 break; 563 case TLS_RX_EXPECT_NO_PAD: 564 rc = do_tls_getsockopt_no_pad(sk, optval, optlen); 565 break; 566 default: 567 rc = -ENOPROTOOPT; 568 break; 569 } 570 571 release_sock(sk); 572 573 return rc; 574 } 575 576 static int tls_getsockopt(struct sock *sk, int level, int optname, 577 char __user *optval, int __user *optlen) 578 { 579 struct tls_context *ctx = tls_get_ctx(sk); 580 581 if (level != SOL_TLS) 582 return ctx->sk_proto->getsockopt(sk, level, 583 optname, optval, optlen); 584 585 return do_tls_getsockopt(sk, optname, optval, optlen); 586 } 587 588 static int do_tls_setsockopt_conf(struct sock *sk, sockptr_t optval, 589 unsigned int optlen, int tx) 590 { 591 struct tls_crypto_info *crypto_info; 592 struct tls_crypto_info *alt_crypto_info; 593 struct tls_context *ctx = tls_get_ctx(sk); 594 const struct tls_cipher_desc *cipher_desc; 595 int rc = 0; 596 int conf; 597 598 if (sockptr_is_null(optval) || (optlen < sizeof(*crypto_info))) 599 return -EINVAL; 600 601 if (tx) { 602 crypto_info = &ctx->crypto_send.info; 603 alt_crypto_info = &ctx->crypto_recv.info; 604 } else { 605 crypto_info = &ctx->crypto_recv.info; 606 alt_crypto_info = &ctx->crypto_send.info; 607 } 608 609 /* Currently we don't support set crypto info more than one time */ 610 if (TLS_CRYPTO_INFO_READY(crypto_info)) 611 return -EBUSY; 612 613 rc = copy_from_sockptr(crypto_info, optval, sizeof(*crypto_info)); 614 if (rc) { 615 rc = -EFAULT; 616 goto err_crypto_info; 617 } 618 619 /* check version */ 620 if (crypto_info->version != TLS_1_2_VERSION && 621 crypto_info->version != TLS_1_3_VERSION) { 622 rc = -EINVAL; 623 goto err_crypto_info; 624 } 625 626 /* Ensure that TLS version and ciphers are same in both directions */ 627 if (TLS_CRYPTO_INFO_READY(alt_crypto_info)) { 628 if (alt_crypto_info->version != crypto_info->version || 629 alt_crypto_info->cipher_type != crypto_info->cipher_type) { 630 rc = -EINVAL; 631 goto err_crypto_info; 632 } 633 } 634 635 cipher_desc = get_cipher_desc(crypto_info->cipher_type); 636 if (!cipher_desc) { 637 rc = -EINVAL; 638 goto err_crypto_info; 639 } 640 641 switch (crypto_info->cipher_type) { 642 case TLS_CIPHER_ARIA_GCM_128: 643 case TLS_CIPHER_ARIA_GCM_256: 644 if (crypto_info->version != TLS_1_2_VERSION) { 645 rc = -EINVAL; 646 goto err_crypto_info; 647 } 648 break; 649 } 650 651 if (optlen != cipher_desc->crypto_info) { 652 rc = -EINVAL; 653 goto err_crypto_info; 654 } 655 656 rc = copy_from_sockptr_offset(crypto_info + 1, optval, 657 sizeof(*crypto_info), 658 optlen - sizeof(*crypto_info)); 659 if (rc) { 660 rc = -EFAULT; 661 goto err_crypto_info; 662 } 663 664 if (tx) { 665 rc = tls_set_device_offload(sk, ctx); 666 conf = TLS_HW; 667 if (!rc) { 668 TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSTXDEVICE); 669 TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSCURRTXDEVICE); 670 } else { 671 rc = tls_set_sw_offload(sk, ctx, 1); 672 if (rc) 673 goto err_crypto_info; 674 TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSTXSW); 675 TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSCURRTXSW); 676 conf = TLS_SW; 677 } 678 } else { 679 rc = tls_set_device_offload_rx(sk, ctx); 680 conf = TLS_HW; 681 if (!rc) { 682 TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSRXDEVICE); 683 TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSCURRRXDEVICE); 684 } else { 685 rc = tls_set_sw_offload(sk, ctx, 0); 686 if (rc) 687 goto err_crypto_info; 688 TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSRXSW); 689 TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSCURRRXSW); 690 conf = TLS_SW; 691 } 692 tls_sw_strparser_arm(sk, ctx); 693 } 694 695 if (tx) 696 ctx->tx_conf = conf; 697 else 698 ctx->rx_conf = conf; 699 update_sk_prot(sk, ctx); 700 if (tx) { 701 ctx->sk_write_space = sk->sk_write_space; 702 sk->sk_write_space = tls_write_space; 703 } else { 704 struct tls_sw_context_rx *rx_ctx = tls_sw_ctx_rx(ctx); 705 706 tls_strp_check_rcv(&rx_ctx->strp); 707 } 708 return 0; 709 710 err_crypto_info: 711 memzero_explicit(crypto_info, sizeof(union tls_crypto_context)); 712 return rc; 713 } 714 715 static int do_tls_setsockopt_tx_zc(struct sock *sk, sockptr_t optval, 716 unsigned int optlen) 717 { 718 struct tls_context *ctx = tls_get_ctx(sk); 719 unsigned int value; 720 721 if (sockptr_is_null(optval) || optlen != sizeof(value)) 722 return -EINVAL; 723 724 if (copy_from_sockptr(&value, optval, sizeof(value))) 725 return -EFAULT; 726 727 if (value > 1) 728 return -EINVAL; 729 730 ctx->zerocopy_sendfile = value; 731 732 return 0; 733 } 734 735 static int do_tls_setsockopt_no_pad(struct sock *sk, sockptr_t optval, 736 unsigned int optlen) 737 { 738 struct tls_context *ctx = tls_get_ctx(sk); 739 u32 val; 740 int rc; 741 742 if (ctx->prot_info.version != TLS_1_3_VERSION || 743 sockptr_is_null(optval) || optlen < sizeof(val)) 744 return -EINVAL; 745 746 rc = copy_from_sockptr(&val, optval, sizeof(val)); 747 if (rc) 748 return -EFAULT; 749 if (val > 1) 750 return -EINVAL; 751 rc = check_zeroed_sockptr(optval, sizeof(val), optlen - sizeof(val)); 752 if (rc < 1) 753 return rc == 0 ? -EINVAL : rc; 754 755 lock_sock(sk); 756 rc = -EINVAL; 757 if (ctx->rx_conf == TLS_SW || ctx->rx_conf == TLS_HW) { 758 ctx->rx_no_pad = val; 759 tls_update_rx_zc_capable(ctx); 760 rc = 0; 761 } 762 release_sock(sk); 763 764 return rc; 765 } 766 767 static int do_tls_setsockopt(struct sock *sk, int optname, sockptr_t optval, 768 unsigned int optlen) 769 { 770 int rc = 0; 771 772 switch (optname) { 773 case TLS_TX: 774 case TLS_RX: 775 lock_sock(sk); 776 rc = do_tls_setsockopt_conf(sk, optval, optlen, 777 optname == TLS_TX); 778 release_sock(sk); 779 break; 780 case TLS_TX_ZEROCOPY_RO: 781 lock_sock(sk); 782 rc = do_tls_setsockopt_tx_zc(sk, optval, optlen); 783 release_sock(sk); 784 break; 785 case TLS_RX_EXPECT_NO_PAD: 786 rc = do_tls_setsockopt_no_pad(sk, optval, optlen); 787 break; 788 default: 789 rc = -ENOPROTOOPT; 790 break; 791 } 792 return rc; 793 } 794 795 static int tls_setsockopt(struct sock *sk, int level, int optname, 796 sockptr_t optval, unsigned int optlen) 797 { 798 struct tls_context *ctx = tls_get_ctx(sk); 799 800 if (level != SOL_TLS) 801 return ctx->sk_proto->setsockopt(sk, level, optname, optval, 802 optlen); 803 804 return do_tls_setsockopt(sk, optname, optval, optlen); 805 } 806 807 struct tls_context *tls_ctx_create(struct sock *sk) 808 { 809 struct inet_connection_sock *icsk = inet_csk(sk); 810 struct tls_context *ctx; 811 812 ctx = kzalloc(sizeof(*ctx), GFP_ATOMIC); 813 if (!ctx) 814 return NULL; 815 816 mutex_init(&ctx->tx_lock); 817 ctx->sk_proto = READ_ONCE(sk->sk_prot); 818 ctx->sk = sk; 819 /* Release semantic of rcu_assign_pointer() ensures that 820 * ctx->sk_proto is visible before changing sk->sk_prot in 821 * update_sk_prot(), and prevents reading uninitialized value in 822 * tls_{getsockopt, setsockopt}. Note that we do not need a 823 * read barrier in tls_{getsockopt,setsockopt} as there is an 824 * address dependency between sk->sk_proto->{getsockopt,setsockopt} 825 * and ctx->sk_proto. 826 */ 827 rcu_assign_pointer(icsk->icsk_ulp_data, ctx); 828 return ctx; 829 } 830 831 static void build_proto_ops(struct proto_ops ops[TLS_NUM_CONFIG][TLS_NUM_CONFIG], 832 const struct proto_ops *base) 833 { 834 ops[TLS_BASE][TLS_BASE] = *base; 835 836 ops[TLS_SW ][TLS_BASE] = ops[TLS_BASE][TLS_BASE]; 837 ops[TLS_SW ][TLS_BASE].splice_eof = tls_sw_splice_eof; 838 839 ops[TLS_BASE][TLS_SW ] = ops[TLS_BASE][TLS_BASE]; 840 ops[TLS_BASE][TLS_SW ].splice_read = tls_sw_splice_read; 841 ops[TLS_BASE][TLS_SW ].poll = tls_sk_poll; 842 ops[TLS_BASE][TLS_SW ].read_sock = tls_sw_read_sock; 843 844 ops[TLS_SW ][TLS_SW ] = ops[TLS_SW ][TLS_BASE]; 845 ops[TLS_SW ][TLS_SW ].splice_read = tls_sw_splice_read; 846 ops[TLS_SW ][TLS_SW ].poll = tls_sk_poll; 847 ops[TLS_SW ][TLS_SW ].read_sock = tls_sw_read_sock; 848 849 #ifdef CONFIG_TLS_DEVICE 850 ops[TLS_HW ][TLS_BASE] = ops[TLS_BASE][TLS_BASE]; 851 852 ops[TLS_HW ][TLS_SW ] = ops[TLS_BASE][TLS_SW ]; 853 854 ops[TLS_BASE][TLS_HW ] = ops[TLS_BASE][TLS_SW ]; 855 856 ops[TLS_SW ][TLS_HW ] = ops[TLS_SW ][TLS_SW ]; 857 858 ops[TLS_HW ][TLS_HW ] = ops[TLS_HW ][TLS_SW ]; 859 #endif 860 #ifdef CONFIG_TLS_TOE 861 ops[TLS_HW_RECORD][TLS_HW_RECORD] = *base; 862 #endif 863 } 864 865 static void tls_build_proto(struct sock *sk) 866 { 867 int ip_ver = sk->sk_family == AF_INET6 ? TLSV6 : TLSV4; 868 struct proto *prot = READ_ONCE(sk->sk_prot); 869 870 /* Build IPv6 TLS whenever the address of tcpv6 _prot changes */ 871 if (ip_ver == TLSV6 && 872 unlikely(prot != smp_load_acquire(&saved_tcpv6_prot))) { 873 mutex_lock(&tcpv6_prot_mutex); 874 if (likely(prot != saved_tcpv6_prot)) { 875 build_protos(tls_prots[TLSV6], prot); 876 build_proto_ops(tls_proto_ops[TLSV6], 877 sk->sk_socket->ops); 878 smp_store_release(&saved_tcpv6_prot, prot); 879 } 880 mutex_unlock(&tcpv6_prot_mutex); 881 } 882 883 if (ip_ver == TLSV4 && 884 unlikely(prot != smp_load_acquire(&saved_tcpv4_prot))) { 885 mutex_lock(&tcpv4_prot_mutex); 886 if (likely(prot != saved_tcpv4_prot)) { 887 build_protos(tls_prots[TLSV4], prot); 888 build_proto_ops(tls_proto_ops[TLSV4], 889 sk->sk_socket->ops); 890 smp_store_release(&saved_tcpv4_prot, prot); 891 } 892 mutex_unlock(&tcpv4_prot_mutex); 893 } 894 } 895 896 static void build_protos(struct proto prot[TLS_NUM_CONFIG][TLS_NUM_CONFIG], 897 const struct proto *base) 898 { 899 prot[TLS_BASE][TLS_BASE] = *base; 900 prot[TLS_BASE][TLS_BASE].setsockopt = tls_setsockopt; 901 prot[TLS_BASE][TLS_BASE].getsockopt = tls_getsockopt; 902 prot[TLS_BASE][TLS_BASE].close = tls_sk_proto_close; 903 904 prot[TLS_SW][TLS_BASE] = prot[TLS_BASE][TLS_BASE]; 905 prot[TLS_SW][TLS_BASE].sendmsg = tls_sw_sendmsg; 906 prot[TLS_SW][TLS_BASE].splice_eof = tls_sw_splice_eof; 907 908 prot[TLS_BASE][TLS_SW] = prot[TLS_BASE][TLS_BASE]; 909 prot[TLS_BASE][TLS_SW].recvmsg = tls_sw_recvmsg; 910 prot[TLS_BASE][TLS_SW].sock_is_readable = tls_sw_sock_is_readable; 911 prot[TLS_BASE][TLS_SW].close = tls_sk_proto_close; 912 913 prot[TLS_SW][TLS_SW] = prot[TLS_SW][TLS_BASE]; 914 prot[TLS_SW][TLS_SW].recvmsg = tls_sw_recvmsg; 915 prot[TLS_SW][TLS_SW].sock_is_readable = tls_sw_sock_is_readable; 916 prot[TLS_SW][TLS_SW].close = tls_sk_proto_close; 917 918 #ifdef CONFIG_TLS_DEVICE 919 prot[TLS_HW][TLS_BASE] = prot[TLS_BASE][TLS_BASE]; 920 prot[TLS_HW][TLS_BASE].sendmsg = tls_device_sendmsg; 921 prot[TLS_HW][TLS_BASE].splice_eof = tls_device_splice_eof; 922 923 prot[TLS_HW][TLS_SW] = prot[TLS_BASE][TLS_SW]; 924 prot[TLS_HW][TLS_SW].sendmsg = tls_device_sendmsg; 925 prot[TLS_HW][TLS_SW].splice_eof = tls_device_splice_eof; 926 927 prot[TLS_BASE][TLS_HW] = prot[TLS_BASE][TLS_SW]; 928 929 prot[TLS_SW][TLS_HW] = prot[TLS_SW][TLS_SW]; 930 931 prot[TLS_HW][TLS_HW] = prot[TLS_HW][TLS_SW]; 932 #endif 933 #ifdef CONFIG_TLS_TOE 934 prot[TLS_HW_RECORD][TLS_HW_RECORD] = *base; 935 prot[TLS_HW_RECORD][TLS_HW_RECORD].hash = tls_toe_hash; 936 prot[TLS_HW_RECORD][TLS_HW_RECORD].unhash = tls_toe_unhash; 937 #endif 938 } 939 940 static int tls_init(struct sock *sk) 941 { 942 struct tls_context *ctx; 943 int rc = 0; 944 945 tls_build_proto(sk); 946 947 #ifdef CONFIG_TLS_TOE 948 if (tls_toe_bypass(sk)) 949 return 0; 950 #endif 951 952 /* The TLS ulp is currently supported only for TCP sockets 953 * in ESTABLISHED state. 954 * Supporting sockets in LISTEN state will require us 955 * to modify the accept implementation to clone rather then 956 * share the ulp context. 957 */ 958 if (sk->sk_state != TCP_ESTABLISHED) 959 return -ENOTCONN; 960 961 /* allocate tls context */ 962 write_lock_bh(&sk->sk_callback_lock); 963 ctx = tls_ctx_create(sk); 964 if (!ctx) { 965 rc = -ENOMEM; 966 goto out; 967 } 968 969 ctx->tx_conf = TLS_BASE; 970 ctx->rx_conf = TLS_BASE; 971 update_sk_prot(sk, ctx); 972 out: 973 write_unlock_bh(&sk->sk_callback_lock); 974 return rc; 975 } 976 977 static void tls_update(struct sock *sk, struct proto *p, 978 void (*write_space)(struct sock *sk)) 979 { 980 struct tls_context *ctx; 981 982 WARN_ON_ONCE(sk->sk_prot == p); 983 984 ctx = tls_get_ctx(sk); 985 if (likely(ctx)) { 986 ctx->sk_write_space = write_space; 987 ctx->sk_proto = p; 988 } else { 989 /* Pairs with lockless read in sk_clone_lock(). */ 990 WRITE_ONCE(sk->sk_prot, p); 991 sk->sk_write_space = write_space; 992 } 993 } 994 995 static u16 tls_user_config(struct tls_context *ctx, bool tx) 996 { 997 u16 config = tx ? ctx->tx_conf : ctx->rx_conf; 998 999 switch (config) { 1000 case TLS_BASE: 1001 return TLS_CONF_BASE; 1002 case TLS_SW: 1003 return TLS_CONF_SW; 1004 case TLS_HW: 1005 return TLS_CONF_HW; 1006 case TLS_HW_RECORD: 1007 return TLS_CONF_HW_RECORD; 1008 } 1009 return 0; 1010 } 1011 1012 static int tls_get_info(struct sock *sk, struct sk_buff *skb) 1013 { 1014 u16 version, cipher_type; 1015 struct tls_context *ctx; 1016 struct nlattr *start; 1017 int err; 1018 1019 start = nla_nest_start_noflag(skb, INET_ULP_INFO_TLS); 1020 if (!start) 1021 return -EMSGSIZE; 1022 1023 rcu_read_lock(); 1024 ctx = rcu_dereference(inet_csk(sk)->icsk_ulp_data); 1025 if (!ctx) { 1026 err = 0; 1027 goto nla_failure; 1028 } 1029 version = ctx->prot_info.version; 1030 if (version) { 1031 err = nla_put_u16(skb, TLS_INFO_VERSION, version); 1032 if (err) 1033 goto nla_failure; 1034 } 1035 cipher_type = ctx->prot_info.cipher_type; 1036 if (cipher_type) { 1037 err = nla_put_u16(skb, TLS_INFO_CIPHER, cipher_type); 1038 if (err) 1039 goto nla_failure; 1040 } 1041 err = nla_put_u16(skb, TLS_INFO_TXCONF, tls_user_config(ctx, true)); 1042 if (err) 1043 goto nla_failure; 1044 1045 err = nla_put_u16(skb, TLS_INFO_RXCONF, tls_user_config(ctx, false)); 1046 if (err) 1047 goto nla_failure; 1048 1049 if (ctx->tx_conf == TLS_HW && ctx->zerocopy_sendfile) { 1050 err = nla_put_flag(skb, TLS_INFO_ZC_RO_TX); 1051 if (err) 1052 goto nla_failure; 1053 } 1054 if (ctx->rx_no_pad) { 1055 err = nla_put_flag(skb, TLS_INFO_RX_NO_PAD); 1056 if (err) 1057 goto nla_failure; 1058 } 1059 1060 rcu_read_unlock(); 1061 nla_nest_end(skb, start); 1062 return 0; 1063 1064 nla_failure: 1065 rcu_read_unlock(); 1066 nla_nest_cancel(skb, start); 1067 return err; 1068 } 1069 1070 static size_t tls_get_info_size(const struct sock *sk) 1071 { 1072 size_t size = 0; 1073 1074 size += nla_total_size(0) + /* INET_ULP_INFO_TLS */ 1075 nla_total_size(sizeof(u16)) + /* TLS_INFO_VERSION */ 1076 nla_total_size(sizeof(u16)) + /* TLS_INFO_CIPHER */ 1077 nla_total_size(sizeof(u16)) + /* TLS_INFO_RXCONF */ 1078 nla_total_size(sizeof(u16)) + /* TLS_INFO_TXCONF */ 1079 nla_total_size(0) + /* TLS_INFO_ZC_RO_TX */ 1080 nla_total_size(0) + /* TLS_INFO_RX_NO_PAD */ 1081 0; 1082 1083 return size; 1084 } 1085 1086 static int __net_init tls_init_net(struct net *net) 1087 { 1088 int err; 1089 1090 net->mib.tls_statistics = alloc_percpu(struct linux_tls_mib); 1091 if (!net->mib.tls_statistics) 1092 return -ENOMEM; 1093 1094 err = tls_proc_init(net); 1095 if (err) 1096 goto err_free_stats; 1097 1098 return 0; 1099 err_free_stats: 1100 free_percpu(net->mib.tls_statistics); 1101 return err; 1102 } 1103 1104 static void __net_exit tls_exit_net(struct net *net) 1105 { 1106 tls_proc_fini(net); 1107 free_percpu(net->mib.tls_statistics); 1108 } 1109 1110 static struct pernet_operations tls_proc_ops = { 1111 .init = tls_init_net, 1112 .exit = tls_exit_net, 1113 }; 1114 1115 static struct tcp_ulp_ops tcp_tls_ulp_ops __read_mostly = { 1116 .name = "tls", 1117 .owner = THIS_MODULE, 1118 .init = tls_init, 1119 .update = tls_update, 1120 .get_info = tls_get_info, 1121 .get_info_size = tls_get_info_size, 1122 }; 1123 1124 static int __init tls_register(void) 1125 { 1126 int err; 1127 1128 err = register_pernet_subsys(&tls_proc_ops); 1129 if (err) 1130 return err; 1131 1132 err = tls_strp_dev_init(); 1133 if (err) 1134 goto err_pernet; 1135 1136 err = tls_device_init(); 1137 if (err) 1138 goto err_strp; 1139 1140 tcp_register_ulp(&tcp_tls_ulp_ops); 1141 1142 return 0; 1143 err_strp: 1144 tls_strp_dev_exit(); 1145 err_pernet: 1146 unregister_pernet_subsys(&tls_proc_ops); 1147 return err; 1148 } 1149 1150 static void __exit tls_unregister(void) 1151 { 1152 tcp_unregister_ulp(&tcp_tls_ulp_ops); 1153 tls_strp_dev_exit(); 1154 tls_device_cleanup(); 1155 unregister_pernet_subsys(&tls_proc_ops); 1156 } 1157 1158 module_init(tls_register); 1159 module_exit(tls_unregister); 1160