1 /* 2 * Copyright (c) 2016-2017, Mellanox Technologies. All rights reserved. 3 * Copyright (c) 2016-2017, Dave Watson <davejwatson@fb.com>. All rights reserved. 4 * 5 * This software is available to you under a choice of one of two 6 * licenses. You may choose to be licensed under the terms of the GNU 7 * General Public License (GPL) Version 2, available from the file 8 * COPYING in the main directory of this source tree, or the 9 * OpenIB.org BSD license below: 10 * 11 * Redistribution and use in source and binary forms, with or 12 * without modification, are permitted provided that the following 13 * conditions are met: 14 * 15 * - Redistributions of source code must retain the above 16 * copyright notice, this list of conditions and the following 17 * disclaimer. 18 * 19 * - Redistributions in binary form must reproduce the above 20 * copyright notice, this list of conditions and the following 21 * disclaimer in the documentation and/or other materials 22 * provided with the distribution. 23 * 24 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, 25 * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF 26 * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND 27 * NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS 28 * BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN 29 * ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN 30 * CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 31 * SOFTWARE. 32 */ 33 34 #include <linux/module.h> 35 36 #include <net/tcp.h> 37 #include <net/inet_common.h> 38 #include <linux/highmem.h> 39 #include <linux/netdevice.h> 40 #include <linux/sched/signal.h> 41 #include <linux/inetdevice.h> 42 #include <linux/inet_diag.h> 43 44 #include <net/snmp.h> 45 #include <net/tls.h> 46 #include <net/tls_toe.h> 47 48 #include "tls.h" 49 50 MODULE_AUTHOR("Mellanox Technologies"); 51 MODULE_DESCRIPTION("Transport Layer Security Support"); 52 MODULE_LICENSE("Dual BSD/GPL"); 53 MODULE_ALIAS_TCP_ULP("tls"); 54 55 enum { 56 TLSV4, 57 TLSV6, 58 TLS_NUM_PROTS, 59 }; 60 61 static const struct proto *saved_tcpv6_prot; 62 static DEFINE_MUTEX(tcpv6_prot_mutex); 63 static const struct proto *saved_tcpv4_prot; 64 static DEFINE_MUTEX(tcpv4_prot_mutex); 65 static struct proto tls_prots[TLS_NUM_PROTS][TLS_NUM_CONFIG][TLS_NUM_CONFIG]; 66 static struct proto_ops tls_proto_ops[TLS_NUM_PROTS][TLS_NUM_CONFIG][TLS_NUM_CONFIG]; 67 static void build_protos(struct proto prot[TLS_NUM_CONFIG][TLS_NUM_CONFIG], 68 const struct proto *base); 69 70 void update_sk_prot(struct sock *sk, struct tls_context *ctx) 71 { 72 int ip_ver = sk->sk_family == AF_INET6 ? TLSV6 : TLSV4; 73 74 WRITE_ONCE(sk->sk_prot, 75 &tls_prots[ip_ver][ctx->tx_conf][ctx->rx_conf]); 76 WRITE_ONCE(sk->sk_socket->ops, 77 &tls_proto_ops[ip_ver][ctx->tx_conf][ctx->rx_conf]); 78 } 79 80 int wait_on_pending_writer(struct sock *sk, long *timeo) 81 { 82 int rc = 0; 83 DEFINE_WAIT_FUNC(wait, woken_wake_function); 84 85 add_wait_queue(sk_sleep(sk), &wait); 86 while (1) { 87 if (!*timeo) { 88 rc = -EAGAIN; 89 break; 90 } 91 92 if (signal_pending(current)) { 93 rc = sock_intr_errno(*timeo); 94 break; 95 } 96 97 if (sk_wait_event(sk, timeo, !sk->sk_write_pending, &wait)) 98 break; 99 } 100 remove_wait_queue(sk_sleep(sk), &wait); 101 return rc; 102 } 103 104 int tls_push_sg(struct sock *sk, 105 struct tls_context *ctx, 106 struct scatterlist *sg, 107 u16 first_offset, 108 int flags) 109 { 110 int sendpage_flags = flags | MSG_SENDPAGE_NOTLAST; 111 int ret = 0; 112 struct page *p; 113 size_t size; 114 int offset = first_offset; 115 116 size = sg->length - offset; 117 offset += sg->offset; 118 119 ctx->in_tcp_sendpages = true; 120 while (1) { 121 if (sg_is_last(sg)) 122 sendpage_flags = flags; 123 124 /* is sending application-limited? */ 125 tcp_rate_check_app_limited(sk); 126 p = sg_page(sg); 127 retry: 128 ret = do_tcp_sendpages(sk, p, offset, size, sendpage_flags); 129 130 if (ret != size) { 131 if (ret > 0) { 132 offset += ret; 133 size -= ret; 134 goto retry; 135 } 136 137 offset -= sg->offset; 138 ctx->partially_sent_offset = offset; 139 ctx->partially_sent_record = (void *)sg; 140 ctx->in_tcp_sendpages = false; 141 return ret; 142 } 143 144 put_page(p); 145 sk_mem_uncharge(sk, sg->length); 146 sg = sg_next(sg); 147 if (!sg) 148 break; 149 150 offset = sg->offset; 151 size = sg->length; 152 } 153 154 ctx->in_tcp_sendpages = false; 155 156 return 0; 157 } 158 159 static int tls_handle_open_record(struct sock *sk, int flags) 160 { 161 struct tls_context *ctx = tls_get_ctx(sk); 162 163 if (tls_is_pending_open_record(ctx)) 164 return ctx->push_pending_record(sk, flags); 165 166 return 0; 167 } 168 169 int tls_process_cmsg(struct sock *sk, struct msghdr *msg, 170 unsigned char *record_type) 171 { 172 struct cmsghdr *cmsg; 173 int rc = -EINVAL; 174 175 for_each_cmsghdr(cmsg, msg) { 176 if (!CMSG_OK(msg, cmsg)) 177 return -EINVAL; 178 if (cmsg->cmsg_level != SOL_TLS) 179 continue; 180 181 switch (cmsg->cmsg_type) { 182 case TLS_SET_RECORD_TYPE: 183 if (cmsg->cmsg_len < CMSG_LEN(sizeof(*record_type))) 184 return -EINVAL; 185 186 if (msg->msg_flags & MSG_MORE) 187 return -EINVAL; 188 189 rc = tls_handle_open_record(sk, msg->msg_flags); 190 if (rc) 191 return rc; 192 193 *record_type = *(unsigned char *)CMSG_DATA(cmsg); 194 rc = 0; 195 break; 196 default: 197 return -EINVAL; 198 } 199 } 200 201 return rc; 202 } 203 204 int tls_push_partial_record(struct sock *sk, struct tls_context *ctx, 205 int flags) 206 { 207 struct scatterlist *sg; 208 u16 offset; 209 210 sg = ctx->partially_sent_record; 211 offset = ctx->partially_sent_offset; 212 213 ctx->partially_sent_record = NULL; 214 return tls_push_sg(sk, ctx, sg, offset, flags); 215 } 216 217 void tls_free_partial_record(struct sock *sk, struct tls_context *ctx) 218 { 219 struct scatterlist *sg; 220 221 for (sg = ctx->partially_sent_record; sg; sg = sg_next(sg)) { 222 put_page(sg_page(sg)); 223 sk_mem_uncharge(sk, sg->length); 224 } 225 ctx->partially_sent_record = NULL; 226 } 227 228 static void tls_write_space(struct sock *sk) 229 { 230 struct tls_context *ctx = tls_get_ctx(sk); 231 232 /* If in_tcp_sendpages call lower protocol write space handler 233 * to ensure we wake up any waiting operations there. For example 234 * if do_tcp_sendpages where to call sk_wait_event. 235 */ 236 if (ctx->in_tcp_sendpages) { 237 ctx->sk_write_space(sk); 238 return; 239 } 240 241 #ifdef CONFIG_TLS_DEVICE 242 if (ctx->tx_conf == TLS_HW) 243 tls_device_write_space(sk, ctx); 244 else 245 #endif 246 tls_sw_write_space(sk, ctx); 247 248 ctx->sk_write_space(sk); 249 } 250 251 /** 252 * tls_ctx_free() - free TLS ULP context 253 * @sk: socket to with @ctx is attached 254 * @ctx: TLS context structure 255 * 256 * Free TLS context. If @sk is %NULL caller guarantees that the socket 257 * to which @ctx was attached has no outstanding references. 258 */ 259 void tls_ctx_free(struct sock *sk, struct tls_context *ctx) 260 { 261 if (!ctx) 262 return; 263 264 memzero_explicit(&ctx->crypto_send, sizeof(ctx->crypto_send)); 265 memzero_explicit(&ctx->crypto_recv, sizeof(ctx->crypto_recv)); 266 mutex_destroy(&ctx->tx_lock); 267 268 if (sk) 269 kfree_rcu(ctx, rcu); 270 else 271 kfree(ctx); 272 } 273 274 static void tls_sk_proto_cleanup(struct sock *sk, 275 struct tls_context *ctx, long timeo) 276 { 277 if (unlikely(sk->sk_write_pending) && 278 !wait_on_pending_writer(sk, &timeo)) 279 tls_handle_open_record(sk, 0); 280 281 /* We need these for tls_sw_fallback handling of other packets */ 282 if (ctx->tx_conf == TLS_SW) { 283 kfree(ctx->tx.rec_seq); 284 kfree(ctx->tx.iv); 285 tls_sw_release_resources_tx(sk); 286 TLS_DEC_STATS(sock_net(sk), LINUX_MIB_TLSCURRTXSW); 287 } else if (ctx->tx_conf == TLS_HW) { 288 tls_device_free_resources_tx(sk); 289 TLS_DEC_STATS(sock_net(sk), LINUX_MIB_TLSCURRTXDEVICE); 290 } 291 292 if (ctx->rx_conf == TLS_SW) { 293 tls_sw_release_resources_rx(sk); 294 TLS_DEC_STATS(sock_net(sk), LINUX_MIB_TLSCURRRXSW); 295 } else if (ctx->rx_conf == TLS_HW) { 296 tls_device_offload_cleanup_rx(sk); 297 TLS_DEC_STATS(sock_net(sk), LINUX_MIB_TLSCURRRXDEVICE); 298 } 299 } 300 301 static void tls_sk_proto_close(struct sock *sk, long timeout) 302 { 303 struct inet_connection_sock *icsk = inet_csk(sk); 304 struct tls_context *ctx = tls_get_ctx(sk); 305 long timeo = sock_sndtimeo(sk, 0); 306 bool free_ctx; 307 308 if (ctx->tx_conf == TLS_SW) 309 tls_sw_cancel_work_tx(ctx); 310 311 lock_sock(sk); 312 free_ctx = ctx->tx_conf != TLS_HW && ctx->rx_conf != TLS_HW; 313 314 if (ctx->tx_conf != TLS_BASE || ctx->rx_conf != TLS_BASE) 315 tls_sk_proto_cleanup(sk, ctx, timeo); 316 317 write_lock_bh(&sk->sk_callback_lock); 318 if (free_ctx) 319 rcu_assign_pointer(icsk->icsk_ulp_data, NULL); 320 WRITE_ONCE(sk->sk_prot, ctx->sk_proto); 321 if (sk->sk_write_space == tls_write_space) 322 sk->sk_write_space = ctx->sk_write_space; 323 write_unlock_bh(&sk->sk_callback_lock); 324 release_sock(sk); 325 if (ctx->tx_conf == TLS_SW) 326 tls_sw_free_ctx_tx(ctx); 327 if (ctx->rx_conf == TLS_SW || ctx->rx_conf == TLS_HW) 328 tls_sw_strparser_done(ctx); 329 if (ctx->rx_conf == TLS_SW) 330 tls_sw_free_ctx_rx(ctx); 331 ctx->sk_proto->close(sk, timeout); 332 333 if (free_ctx) 334 tls_ctx_free(sk, ctx); 335 } 336 337 static int do_tls_getsockopt_conf(struct sock *sk, char __user *optval, 338 int __user *optlen, int tx) 339 { 340 int rc = 0; 341 struct tls_context *ctx = tls_get_ctx(sk); 342 struct tls_crypto_info *crypto_info; 343 struct cipher_context *cctx; 344 int len; 345 346 if (get_user(len, optlen)) 347 return -EFAULT; 348 349 if (!optval || (len < sizeof(*crypto_info))) { 350 rc = -EINVAL; 351 goto out; 352 } 353 354 if (!ctx) { 355 rc = -EBUSY; 356 goto out; 357 } 358 359 /* get user crypto info */ 360 if (tx) { 361 crypto_info = &ctx->crypto_send.info; 362 cctx = &ctx->tx; 363 } else { 364 crypto_info = &ctx->crypto_recv.info; 365 cctx = &ctx->rx; 366 } 367 368 if (!TLS_CRYPTO_INFO_READY(crypto_info)) { 369 rc = -EBUSY; 370 goto out; 371 } 372 373 if (len == sizeof(*crypto_info)) { 374 if (copy_to_user(optval, crypto_info, sizeof(*crypto_info))) 375 rc = -EFAULT; 376 goto out; 377 } 378 379 switch (crypto_info->cipher_type) { 380 case TLS_CIPHER_AES_GCM_128: { 381 struct tls12_crypto_info_aes_gcm_128 * 382 crypto_info_aes_gcm_128 = 383 container_of(crypto_info, 384 struct tls12_crypto_info_aes_gcm_128, 385 info); 386 387 if (len != sizeof(*crypto_info_aes_gcm_128)) { 388 rc = -EINVAL; 389 goto out; 390 } 391 lock_sock(sk); 392 memcpy(crypto_info_aes_gcm_128->iv, 393 cctx->iv + TLS_CIPHER_AES_GCM_128_SALT_SIZE, 394 TLS_CIPHER_AES_GCM_128_IV_SIZE); 395 memcpy(crypto_info_aes_gcm_128->rec_seq, cctx->rec_seq, 396 TLS_CIPHER_AES_GCM_128_REC_SEQ_SIZE); 397 release_sock(sk); 398 if (copy_to_user(optval, 399 crypto_info_aes_gcm_128, 400 sizeof(*crypto_info_aes_gcm_128))) 401 rc = -EFAULT; 402 break; 403 } 404 case TLS_CIPHER_AES_GCM_256: { 405 struct tls12_crypto_info_aes_gcm_256 * 406 crypto_info_aes_gcm_256 = 407 container_of(crypto_info, 408 struct tls12_crypto_info_aes_gcm_256, 409 info); 410 411 if (len != sizeof(*crypto_info_aes_gcm_256)) { 412 rc = -EINVAL; 413 goto out; 414 } 415 lock_sock(sk); 416 memcpy(crypto_info_aes_gcm_256->iv, 417 cctx->iv + TLS_CIPHER_AES_GCM_256_SALT_SIZE, 418 TLS_CIPHER_AES_GCM_256_IV_SIZE); 419 memcpy(crypto_info_aes_gcm_256->rec_seq, cctx->rec_seq, 420 TLS_CIPHER_AES_GCM_256_REC_SEQ_SIZE); 421 release_sock(sk); 422 if (copy_to_user(optval, 423 crypto_info_aes_gcm_256, 424 sizeof(*crypto_info_aes_gcm_256))) 425 rc = -EFAULT; 426 break; 427 } 428 case TLS_CIPHER_AES_CCM_128: { 429 struct tls12_crypto_info_aes_ccm_128 *aes_ccm_128 = 430 container_of(crypto_info, 431 struct tls12_crypto_info_aes_ccm_128, info); 432 433 if (len != sizeof(*aes_ccm_128)) { 434 rc = -EINVAL; 435 goto out; 436 } 437 lock_sock(sk); 438 memcpy(aes_ccm_128->iv, 439 cctx->iv + TLS_CIPHER_AES_CCM_128_SALT_SIZE, 440 TLS_CIPHER_AES_CCM_128_IV_SIZE); 441 memcpy(aes_ccm_128->rec_seq, cctx->rec_seq, 442 TLS_CIPHER_AES_CCM_128_REC_SEQ_SIZE); 443 release_sock(sk); 444 if (copy_to_user(optval, aes_ccm_128, sizeof(*aes_ccm_128))) 445 rc = -EFAULT; 446 break; 447 } 448 case TLS_CIPHER_CHACHA20_POLY1305: { 449 struct tls12_crypto_info_chacha20_poly1305 *chacha20_poly1305 = 450 container_of(crypto_info, 451 struct tls12_crypto_info_chacha20_poly1305, 452 info); 453 454 if (len != sizeof(*chacha20_poly1305)) { 455 rc = -EINVAL; 456 goto out; 457 } 458 lock_sock(sk); 459 memcpy(chacha20_poly1305->iv, 460 cctx->iv + TLS_CIPHER_CHACHA20_POLY1305_SALT_SIZE, 461 TLS_CIPHER_CHACHA20_POLY1305_IV_SIZE); 462 memcpy(chacha20_poly1305->rec_seq, cctx->rec_seq, 463 TLS_CIPHER_CHACHA20_POLY1305_REC_SEQ_SIZE); 464 release_sock(sk); 465 if (copy_to_user(optval, chacha20_poly1305, 466 sizeof(*chacha20_poly1305))) 467 rc = -EFAULT; 468 break; 469 } 470 case TLS_CIPHER_SM4_GCM: { 471 struct tls12_crypto_info_sm4_gcm *sm4_gcm_info = 472 container_of(crypto_info, 473 struct tls12_crypto_info_sm4_gcm, info); 474 475 if (len != sizeof(*sm4_gcm_info)) { 476 rc = -EINVAL; 477 goto out; 478 } 479 lock_sock(sk); 480 memcpy(sm4_gcm_info->iv, 481 cctx->iv + TLS_CIPHER_SM4_GCM_SALT_SIZE, 482 TLS_CIPHER_SM4_GCM_IV_SIZE); 483 memcpy(sm4_gcm_info->rec_seq, cctx->rec_seq, 484 TLS_CIPHER_SM4_GCM_REC_SEQ_SIZE); 485 release_sock(sk); 486 if (copy_to_user(optval, sm4_gcm_info, sizeof(*sm4_gcm_info))) 487 rc = -EFAULT; 488 break; 489 } 490 case TLS_CIPHER_SM4_CCM: { 491 struct tls12_crypto_info_sm4_ccm *sm4_ccm_info = 492 container_of(crypto_info, 493 struct tls12_crypto_info_sm4_ccm, info); 494 495 if (len != sizeof(*sm4_ccm_info)) { 496 rc = -EINVAL; 497 goto out; 498 } 499 lock_sock(sk); 500 memcpy(sm4_ccm_info->iv, 501 cctx->iv + TLS_CIPHER_SM4_CCM_SALT_SIZE, 502 TLS_CIPHER_SM4_CCM_IV_SIZE); 503 memcpy(sm4_ccm_info->rec_seq, cctx->rec_seq, 504 TLS_CIPHER_SM4_CCM_REC_SEQ_SIZE); 505 release_sock(sk); 506 if (copy_to_user(optval, sm4_ccm_info, sizeof(*sm4_ccm_info))) 507 rc = -EFAULT; 508 break; 509 } 510 default: 511 rc = -EINVAL; 512 } 513 514 out: 515 return rc; 516 } 517 518 static int do_tls_getsockopt_tx_zc(struct sock *sk, char __user *optval, 519 int __user *optlen) 520 { 521 struct tls_context *ctx = tls_get_ctx(sk); 522 unsigned int value; 523 int len; 524 525 if (get_user(len, optlen)) 526 return -EFAULT; 527 528 if (len != sizeof(value)) 529 return -EINVAL; 530 531 value = ctx->zerocopy_sendfile; 532 if (copy_to_user(optval, &value, sizeof(value))) 533 return -EFAULT; 534 535 return 0; 536 } 537 538 static int do_tls_getsockopt_no_pad(struct sock *sk, char __user *optval, 539 int __user *optlen) 540 { 541 struct tls_context *ctx = tls_get_ctx(sk); 542 int value, len; 543 544 if (ctx->prot_info.version != TLS_1_3_VERSION) 545 return -EINVAL; 546 547 if (get_user(len, optlen)) 548 return -EFAULT; 549 if (len < sizeof(value)) 550 return -EINVAL; 551 552 lock_sock(sk); 553 value = -EINVAL; 554 if (ctx->rx_conf == TLS_SW || ctx->rx_conf == TLS_HW) 555 value = ctx->rx_no_pad; 556 release_sock(sk); 557 if (value < 0) 558 return value; 559 560 if (put_user(sizeof(value), optlen)) 561 return -EFAULT; 562 if (copy_to_user(optval, &value, sizeof(value))) 563 return -EFAULT; 564 565 return 0; 566 } 567 568 static int do_tls_getsockopt(struct sock *sk, int optname, 569 char __user *optval, int __user *optlen) 570 { 571 int rc = 0; 572 573 switch (optname) { 574 case TLS_TX: 575 case TLS_RX: 576 rc = do_tls_getsockopt_conf(sk, optval, optlen, 577 optname == TLS_TX); 578 break; 579 case TLS_TX_ZEROCOPY_RO: 580 rc = do_tls_getsockopt_tx_zc(sk, optval, optlen); 581 break; 582 case TLS_RX_EXPECT_NO_PAD: 583 rc = do_tls_getsockopt_no_pad(sk, optval, optlen); 584 break; 585 default: 586 rc = -ENOPROTOOPT; 587 break; 588 } 589 return rc; 590 } 591 592 static int tls_getsockopt(struct sock *sk, int level, int optname, 593 char __user *optval, int __user *optlen) 594 { 595 struct tls_context *ctx = tls_get_ctx(sk); 596 597 if (level != SOL_TLS) 598 return ctx->sk_proto->getsockopt(sk, level, 599 optname, optval, optlen); 600 601 return do_tls_getsockopt(sk, optname, optval, optlen); 602 } 603 604 static int do_tls_setsockopt_conf(struct sock *sk, sockptr_t optval, 605 unsigned int optlen, int tx) 606 { 607 struct tls_crypto_info *crypto_info; 608 struct tls_crypto_info *alt_crypto_info; 609 struct tls_context *ctx = tls_get_ctx(sk); 610 size_t optsize; 611 int rc = 0; 612 int conf; 613 614 if (sockptr_is_null(optval) || (optlen < sizeof(*crypto_info))) 615 return -EINVAL; 616 617 if (tx) { 618 crypto_info = &ctx->crypto_send.info; 619 alt_crypto_info = &ctx->crypto_recv.info; 620 } else { 621 crypto_info = &ctx->crypto_recv.info; 622 alt_crypto_info = &ctx->crypto_send.info; 623 } 624 625 /* Currently we don't support set crypto info more than one time */ 626 if (TLS_CRYPTO_INFO_READY(crypto_info)) 627 return -EBUSY; 628 629 rc = copy_from_sockptr(crypto_info, optval, sizeof(*crypto_info)); 630 if (rc) { 631 rc = -EFAULT; 632 goto err_crypto_info; 633 } 634 635 /* check version */ 636 if (crypto_info->version != TLS_1_2_VERSION && 637 crypto_info->version != TLS_1_3_VERSION) { 638 rc = -EINVAL; 639 goto err_crypto_info; 640 } 641 642 /* Ensure that TLS version and ciphers are same in both directions */ 643 if (TLS_CRYPTO_INFO_READY(alt_crypto_info)) { 644 if (alt_crypto_info->version != crypto_info->version || 645 alt_crypto_info->cipher_type != crypto_info->cipher_type) { 646 rc = -EINVAL; 647 goto err_crypto_info; 648 } 649 } 650 651 switch (crypto_info->cipher_type) { 652 case TLS_CIPHER_AES_GCM_128: 653 optsize = sizeof(struct tls12_crypto_info_aes_gcm_128); 654 break; 655 case TLS_CIPHER_AES_GCM_256: { 656 optsize = sizeof(struct tls12_crypto_info_aes_gcm_256); 657 break; 658 } 659 case TLS_CIPHER_AES_CCM_128: 660 optsize = sizeof(struct tls12_crypto_info_aes_ccm_128); 661 break; 662 case TLS_CIPHER_CHACHA20_POLY1305: 663 optsize = sizeof(struct tls12_crypto_info_chacha20_poly1305); 664 break; 665 case TLS_CIPHER_SM4_GCM: 666 optsize = sizeof(struct tls12_crypto_info_sm4_gcm); 667 break; 668 case TLS_CIPHER_SM4_CCM: 669 optsize = sizeof(struct tls12_crypto_info_sm4_ccm); 670 break; 671 default: 672 rc = -EINVAL; 673 goto err_crypto_info; 674 } 675 676 if (optlen != optsize) { 677 rc = -EINVAL; 678 goto err_crypto_info; 679 } 680 681 rc = copy_from_sockptr_offset(crypto_info + 1, optval, 682 sizeof(*crypto_info), 683 optlen - sizeof(*crypto_info)); 684 if (rc) { 685 rc = -EFAULT; 686 goto err_crypto_info; 687 } 688 689 if (tx) { 690 rc = tls_set_device_offload(sk, ctx); 691 conf = TLS_HW; 692 if (!rc) { 693 TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSTXDEVICE); 694 TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSCURRTXDEVICE); 695 } else { 696 rc = tls_set_sw_offload(sk, ctx, 1); 697 if (rc) 698 goto err_crypto_info; 699 TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSTXSW); 700 TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSCURRTXSW); 701 conf = TLS_SW; 702 } 703 } else { 704 rc = tls_set_device_offload_rx(sk, ctx); 705 conf = TLS_HW; 706 if (!rc) { 707 TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSRXDEVICE); 708 TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSCURRRXDEVICE); 709 } else { 710 rc = tls_set_sw_offload(sk, ctx, 0); 711 if (rc) 712 goto err_crypto_info; 713 TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSRXSW); 714 TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSCURRRXSW); 715 conf = TLS_SW; 716 } 717 tls_sw_strparser_arm(sk, ctx); 718 } 719 720 if (tx) 721 ctx->tx_conf = conf; 722 else 723 ctx->rx_conf = conf; 724 update_sk_prot(sk, ctx); 725 if (tx) { 726 ctx->sk_write_space = sk->sk_write_space; 727 sk->sk_write_space = tls_write_space; 728 } else { 729 struct tls_sw_context_rx *rx_ctx = tls_sw_ctx_rx(ctx); 730 731 tls_strp_check_rcv(&rx_ctx->strp); 732 } 733 return 0; 734 735 err_crypto_info: 736 memzero_explicit(crypto_info, sizeof(union tls_crypto_context)); 737 return rc; 738 } 739 740 static int do_tls_setsockopt_tx_zc(struct sock *sk, sockptr_t optval, 741 unsigned int optlen) 742 { 743 struct tls_context *ctx = tls_get_ctx(sk); 744 unsigned int value; 745 746 if (sockptr_is_null(optval) || optlen != sizeof(value)) 747 return -EINVAL; 748 749 if (copy_from_sockptr(&value, optval, sizeof(value))) 750 return -EFAULT; 751 752 if (value > 1) 753 return -EINVAL; 754 755 ctx->zerocopy_sendfile = value; 756 757 return 0; 758 } 759 760 static int do_tls_setsockopt_no_pad(struct sock *sk, sockptr_t optval, 761 unsigned int optlen) 762 { 763 struct tls_context *ctx = tls_get_ctx(sk); 764 u32 val; 765 int rc; 766 767 if (ctx->prot_info.version != TLS_1_3_VERSION || 768 sockptr_is_null(optval) || optlen < sizeof(val)) 769 return -EINVAL; 770 771 rc = copy_from_sockptr(&val, optval, sizeof(val)); 772 if (rc) 773 return -EFAULT; 774 if (val > 1) 775 return -EINVAL; 776 rc = check_zeroed_sockptr(optval, sizeof(val), optlen - sizeof(val)); 777 if (rc < 1) 778 return rc == 0 ? -EINVAL : rc; 779 780 lock_sock(sk); 781 rc = -EINVAL; 782 if (ctx->rx_conf == TLS_SW || ctx->rx_conf == TLS_HW) { 783 ctx->rx_no_pad = val; 784 tls_update_rx_zc_capable(ctx); 785 rc = 0; 786 } 787 release_sock(sk); 788 789 return rc; 790 } 791 792 static int do_tls_setsockopt(struct sock *sk, int optname, sockptr_t optval, 793 unsigned int optlen) 794 { 795 int rc = 0; 796 797 switch (optname) { 798 case TLS_TX: 799 case TLS_RX: 800 lock_sock(sk); 801 rc = do_tls_setsockopt_conf(sk, optval, optlen, 802 optname == TLS_TX); 803 release_sock(sk); 804 break; 805 case TLS_TX_ZEROCOPY_RO: 806 lock_sock(sk); 807 rc = do_tls_setsockopt_tx_zc(sk, optval, optlen); 808 release_sock(sk); 809 break; 810 case TLS_RX_EXPECT_NO_PAD: 811 rc = do_tls_setsockopt_no_pad(sk, optval, optlen); 812 break; 813 default: 814 rc = -ENOPROTOOPT; 815 break; 816 } 817 return rc; 818 } 819 820 static int tls_setsockopt(struct sock *sk, int level, int optname, 821 sockptr_t optval, unsigned int optlen) 822 { 823 struct tls_context *ctx = tls_get_ctx(sk); 824 825 if (level != SOL_TLS) 826 return ctx->sk_proto->setsockopt(sk, level, optname, optval, 827 optlen); 828 829 return do_tls_setsockopt(sk, optname, optval, optlen); 830 } 831 832 struct tls_context *tls_ctx_create(struct sock *sk) 833 { 834 struct inet_connection_sock *icsk = inet_csk(sk); 835 struct tls_context *ctx; 836 837 ctx = kzalloc(sizeof(*ctx), GFP_ATOMIC); 838 if (!ctx) 839 return NULL; 840 841 mutex_init(&ctx->tx_lock); 842 rcu_assign_pointer(icsk->icsk_ulp_data, ctx); 843 ctx->sk_proto = READ_ONCE(sk->sk_prot); 844 ctx->sk = sk; 845 return ctx; 846 } 847 848 static void build_proto_ops(struct proto_ops ops[TLS_NUM_CONFIG][TLS_NUM_CONFIG], 849 const struct proto_ops *base) 850 { 851 ops[TLS_BASE][TLS_BASE] = *base; 852 853 ops[TLS_SW ][TLS_BASE] = ops[TLS_BASE][TLS_BASE]; 854 ops[TLS_SW ][TLS_BASE].sendpage_locked = tls_sw_sendpage_locked; 855 856 ops[TLS_BASE][TLS_SW ] = ops[TLS_BASE][TLS_BASE]; 857 ops[TLS_BASE][TLS_SW ].splice_read = tls_sw_splice_read; 858 859 ops[TLS_SW ][TLS_SW ] = ops[TLS_SW ][TLS_BASE]; 860 ops[TLS_SW ][TLS_SW ].splice_read = tls_sw_splice_read; 861 862 #ifdef CONFIG_TLS_DEVICE 863 ops[TLS_HW ][TLS_BASE] = ops[TLS_BASE][TLS_BASE]; 864 ops[TLS_HW ][TLS_BASE].sendpage_locked = NULL; 865 866 ops[TLS_HW ][TLS_SW ] = ops[TLS_BASE][TLS_SW ]; 867 ops[TLS_HW ][TLS_SW ].sendpage_locked = NULL; 868 869 ops[TLS_BASE][TLS_HW ] = ops[TLS_BASE][TLS_SW ]; 870 871 ops[TLS_SW ][TLS_HW ] = ops[TLS_SW ][TLS_SW ]; 872 873 ops[TLS_HW ][TLS_HW ] = ops[TLS_HW ][TLS_SW ]; 874 ops[TLS_HW ][TLS_HW ].sendpage_locked = NULL; 875 #endif 876 #ifdef CONFIG_TLS_TOE 877 ops[TLS_HW_RECORD][TLS_HW_RECORD] = *base; 878 #endif 879 } 880 881 static void tls_build_proto(struct sock *sk) 882 { 883 int ip_ver = sk->sk_family == AF_INET6 ? TLSV6 : TLSV4; 884 struct proto *prot = READ_ONCE(sk->sk_prot); 885 886 /* Build IPv6 TLS whenever the address of tcpv6 _prot changes */ 887 if (ip_ver == TLSV6 && 888 unlikely(prot != smp_load_acquire(&saved_tcpv6_prot))) { 889 mutex_lock(&tcpv6_prot_mutex); 890 if (likely(prot != saved_tcpv6_prot)) { 891 build_protos(tls_prots[TLSV6], prot); 892 build_proto_ops(tls_proto_ops[TLSV6], 893 sk->sk_socket->ops); 894 smp_store_release(&saved_tcpv6_prot, prot); 895 } 896 mutex_unlock(&tcpv6_prot_mutex); 897 } 898 899 if (ip_ver == TLSV4 && 900 unlikely(prot != smp_load_acquire(&saved_tcpv4_prot))) { 901 mutex_lock(&tcpv4_prot_mutex); 902 if (likely(prot != saved_tcpv4_prot)) { 903 build_protos(tls_prots[TLSV4], prot); 904 build_proto_ops(tls_proto_ops[TLSV4], 905 sk->sk_socket->ops); 906 smp_store_release(&saved_tcpv4_prot, prot); 907 } 908 mutex_unlock(&tcpv4_prot_mutex); 909 } 910 } 911 912 static void build_protos(struct proto prot[TLS_NUM_CONFIG][TLS_NUM_CONFIG], 913 const struct proto *base) 914 { 915 prot[TLS_BASE][TLS_BASE] = *base; 916 prot[TLS_BASE][TLS_BASE].setsockopt = tls_setsockopt; 917 prot[TLS_BASE][TLS_BASE].getsockopt = tls_getsockopt; 918 prot[TLS_BASE][TLS_BASE].close = tls_sk_proto_close; 919 920 prot[TLS_SW][TLS_BASE] = prot[TLS_BASE][TLS_BASE]; 921 prot[TLS_SW][TLS_BASE].sendmsg = tls_sw_sendmsg; 922 prot[TLS_SW][TLS_BASE].sendpage = tls_sw_sendpage; 923 924 prot[TLS_BASE][TLS_SW] = prot[TLS_BASE][TLS_BASE]; 925 prot[TLS_BASE][TLS_SW].recvmsg = tls_sw_recvmsg; 926 prot[TLS_BASE][TLS_SW].sock_is_readable = tls_sw_sock_is_readable; 927 prot[TLS_BASE][TLS_SW].close = tls_sk_proto_close; 928 929 prot[TLS_SW][TLS_SW] = prot[TLS_SW][TLS_BASE]; 930 prot[TLS_SW][TLS_SW].recvmsg = tls_sw_recvmsg; 931 prot[TLS_SW][TLS_SW].sock_is_readable = tls_sw_sock_is_readable; 932 prot[TLS_SW][TLS_SW].close = tls_sk_proto_close; 933 934 #ifdef CONFIG_TLS_DEVICE 935 prot[TLS_HW][TLS_BASE] = prot[TLS_BASE][TLS_BASE]; 936 prot[TLS_HW][TLS_BASE].sendmsg = tls_device_sendmsg; 937 prot[TLS_HW][TLS_BASE].sendpage = tls_device_sendpage; 938 939 prot[TLS_HW][TLS_SW] = prot[TLS_BASE][TLS_SW]; 940 prot[TLS_HW][TLS_SW].sendmsg = tls_device_sendmsg; 941 prot[TLS_HW][TLS_SW].sendpage = tls_device_sendpage; 942 943 prot[TLS_BASE][TLS_HW] = prot[TLS_BASE][TLS_SW]; 944 945 prot[TLS_SW][TLS_HW] = prot[TLS_SW][TLS_SW]; 946 947 prot[TLS_HW][TLS_HW] = prot[TLS_HW][TLS_SW]; 948 #endif 949 #ifdef CONFIG_TLS_TOE 950 prot[TLS_HW_RECORD][TLS_HW_RECORD] = *base; 951 prot[TLS_HW_RECORD][TLS_HW_RECORD].hash = tls_toe_hash; 952 prot[TLS_HW_RECORD][TLS_HW_RECORD].unhash = tls_toe_unhash; 953 #endif 954 } 955 956 static int tls_init(struct sock *sk) 957 { 958 struct tls_context *ctx; 959 int rc = 0; 960 961 tls_build_proto(sk); 962 963 #ifdef CONFIG_TLS_TOE 964 if (tls_toe_bypass(sk)) 965 return 0; 966 #endif 967 968 /* The TLS ulp is currently supported only for TCP sockets 969 * in ESTABLISHED state. 970 * Supporting sockets in LISTEN state will require us 971 * to modify the accept implementation to clone rather then 972 * share the ulp context. 973 */ 974 if (sk->sk_state != TCP_ESTABLISHED) 975 return -ENOTCONN; 976 977 /* allocate tls context */ 978 write_lock_bh(&sk->sk_callback_lock); 979 ctx = tls_ctx_create(sk); 980 if (!ctx) { 981 rc = -ENOMEM; 982 goto out; 983 } 984 985 ctx->tx_conf = TLS_BASE; 986 ctx->rx_conf = TLS_BASE; 987 update_sk_prot(sk, ctx); 988 out: 989 write_unlock_bh(&sk->sk_callback_lock); 990 return rc; 991 } 992 993 static void tls_update(struct sock *sk, struct proto *p, 994 void (*write_space)(struct sock *sk)) 995 { 996 struct tls_context *ctx; 997 998 WARN_ON_ONCE(sk->sk_prot == p); 999 1000 ctx = tls_get_ctx(sk); 1001 if (likely(ctx)) { 1002 ctx->sk_write_space = write_space; 1003 ctx->sk_proto = p; 1004 } else { 1005 /* Pairs with lockless read in sk_clone_lock(). */ 1006 WRITE_ONCE(sk->sk_prot, p); 1007 sk->sk_write_space = write_space; 1008 } 1009 } 1010 1011 static u16 tls_user_config(struct tls_context *ctx, bool tx) 1012 { 1013 u16 config = tx ? ctx->tx_conf : ctx->rx_conf; 1014 1015 switch (config) { 1016 case TLS_BASE: 1017 return TLS_CONF_BASE; 1018 case TLS_SW: 1019 return TLS_CONF_SW; 1020 case TLS_HW: 1021 return TLS_CONF_HW; 1022 case TLS_HW_RECORD: 1023 return TLS_CONF_HW_RECORD; 1024 } 1025 return 0; 1026 } 1027 1028 static int tls_get_info(const struct sock *sk, struct sk_buff *skb) 1029 { 1030 u16 version, cipher_type; 1031 struct tls_context *ctx; 1032 struct nlattr *start; 1033 int err; 1034 1035 start = nla_nest_start_noflag(skb, INET_ULP_INFO_TLS); 1036 if (!start) 1037 return -EMSGSIZE; 1038 1039 rcu_read_lock(); 1040 ctx = rcu_dereference(inet_csk(sk)->icsk_ulp_data); 1041 if (!ctx) { 1042 err = 0; 1043 goto nla_failure; 1044 } 1045 version = ctx->prot_info.version; 1046 if (version) { 1047 err = nla_put_u16(skb, TLS_INFO_VERSION, version); 1048 if (err) 1049 goto nla_failure; 1050 } 1051 cipher_type = ctx->prot_info.cipher_type; 1052 if (cipher_type) { 1053 err = nla_put_u16(skb, TLS_INFO_CIPHER, cipher_type); 1054 if (err) 1055 goto nla_failure; 1056 } 1057 err = nla_put_u16(skb, TLS_INFO_TXCONF, tls_user_config(ctx, true)); 1058 if (err) 1059 goto nla_failure; 1060 1061 err = nla_put_u16(skb, TLS_INFO_RXCONF, tls_user_config(ctx, false)); 1062 if (err) 1063 goto nla_failure; 1064 1065 if (ctx->tx_conf == TLS_HW && ctx->zerocopy_sendfile) { 1066 err = nla_put_flag(skb, TLS_INFO_ZC_RO_TX); 1067 if (err) 1068 goto nla_failure; 1069 } 1070 if (ctx->rx_no_pad) { 1071 err = nla_put_flag(skb, TLS_INFO_RX_NO_PAD); 1072 if (err) 1073 goto nla_failure; 1074 } 1075 1076 rcu_read_unlock(); 1077 nla_nest_end(skb, start); 1078 return 0; 1079 1080 nla_failure: 1081 rcu_read_unlock(); 1082 nla_nest_cancel(skb, start); 1083 return err; 1084 } 1085 1086 static size_t tls_get_info_size(const struct sock *sk) 1087 { 1088 size_t size = 0; 1089 1090 size += nla_total_size(0) + /* INET_ULP_INFO_TLS */ 1091 nla_total_size(sizeof(u16)) + /* TLS_INFO_VERSION */ 1092 nla_total_size(sizeof(u16)) + /* TLS_INFO_CIPHER */ 1093 nla_total_size(sizeof(u16)) + /* TLS_INFO_RXCONF */ 1094 nla_total_size(sizeof(u16)) + /* TLS_INFO_TXCONF */ 1095 nla_total_size(0) + /* TLS_INFO_ZC_RO_TX */ 1096 nla_total_size(0) + /* TLS_INFO_RX_NO_PAD */ 1097 0; 1098 1099 return size; 1100 } 1101 1102 static int __net_init tls_init_net(struct net *net) 1103 { 1104 int err; 1105 1106 net->mib.tls_statistics = alloc_percpu(struct linux_tls_mib); 1107 if (!net->mib.tls_statistics) 1108 return -ENOMEM; 1109 1110 err = tls_proc_init(net); 1111 if (err) 1112 goto err_free_stats; 1113 1114 return 0; 1115 err_free_stats: 1116 free_percpu(net->mib.tls_statistics); 1117 return err; 1118 } 1119 1120 static void __net_exit tls_exit_net(struct net *net) 1121 { 1122 tls_proc_fini(net); 1123 free_percpu(net->mib.tls_statistics); 1124 } 1125 1126 static struct pernet_operations tls_proc_ops = { 1127 .init = tls_init_net, 1128 .exit = tls_exit_net, 1129 }; 1130 1131 static struct tcp_ulp_ops tcp_tls_ulp_ops __read_mostly = { 1132 .name = "tls", 1133 .owner = THIS_MODULE, 1134 .init = tls_init, 1135 .update = tls_update, 1136 .get_info = tls_get_info, 1137 .get_info_size = tls_get_info_size, 1138 }; 1139 1140 static int __init tls_register(void) 1141 { 1142 int err; 1143 1144 err = register_pernet_subsys(&tls_proc_ops); 1145 if (err) 1146 return err; 1147 1148 err = tls_strp_dev_init(); 1149 if (err) 1150 goto err_pernet; 1151 1152 err = tls_device_init(); 1153 if (err) 1154 goto err_strp; 1155 1156 tcp_register_ulp(&tcp_tls_ulp_ops); 1157 1158 return 0; 1159 err_strp: 1160 tls_strp_dev_exit(); 1161 err_pernet: 1162 unregister_pernet_subsys(&tls_proc_ops); 1163 return err; 1164 } 1165 1166 static void __exit tls_unregister(void) 1167 { 1168 tcp_unregister_ulp(&tcp_tls_ulp_ops); 1169 tls_strp_dev_exit(); 1170 tls_device_cleanup(); 1171 unregister_pernet_subsys(&tls_proc_ops); 1172 } 1173 1174 module_init(tls_register); 1175 module_exit(tls_unregister); 1176