1 /* 2 * Copyright (c) 2016-2017, Mellanox Technologies. All rights reserved. 3 * Copyright (c) 2016-2017, Dave Watson <davejwatson@fb.com>. All rights reserved. 4 * 5 * This software is available to you under a choice of one of two 6 * licenses. You may choose to be licensed under the terms of the GNU 7 * General Public License (GPL) Version 2, available from the file 8 * COPYING in the main directory of this source tree, or the 9 * OpenIB.org BSD license below: 10 * 11 * Redistribution and use in source and binary forms, with or 12 * without modification, are permitted provided that the following 13 * conditions are met: 14 * 15 * - Redistributions of source code must retain the above 16 * copyright notice, this list of conditions and the following 17 * disclaimer. 18 * 19 * - Redistributions in binary form must reproduce the above 20 * copyright notice, this list of conditions and the following 21 * disclaimer in the documentation and/or other materials 22 * provided with the distribution. 23 * 24 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, 25 * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF 26 * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND 27 * NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS 28 * BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN 29 * ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN 30 * CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 31 * SOFTWARE. 32 */ 33 34 #include <linux/module.h> 35 36 #include <net/tcp.h> 37 #include <net/inet_common.h> 38 #include <linux/highmem.h> 39 #include <linux/netdevice.h> 40 #include <linux/sched/signal.h> 41 #include <linux/inetdevice.h> 42 #include <linux/inet_diag.h> 43 44 #include <net/snmp.h> 45 #include <net/tls.h> 46 #include <net/tls_toe.h> 47 48 #include "tls.h" 49 50 MODULE_AUTHOR("Mellanox Technologies"); 51 MODULE_DESCRIPTION("Transport Layer Security Support"); 52 MODULE_LICENSE("Dual BSD/GPL"); 53 MODULE_ALIAS_TCP_ULP("tls"); 54 55 enum { 56 TLSV4, 57 TLSV6, 58 TLS_NUM_PROTS, 59 }; 60 61 #define CIPHER_SIZE_DESC(cipher) [cipher] = { \ 62 .iv = cipher ## _IV_SIZE, \ 63 .key = cipher ## _KEY_SIZE, \ 64 .salt = cipher ## _SALT_SIZE, \ 65 .tag = cipher ## _TAG_SIZE, \ 66 .rec_seq = cipher ## _REC_SEQ_SIZE, \ 67 } 68 69 const struct tls_cipher_size_desc tls_cipher_size_desc[] = { 70 CIPHER_SIZE_DESC(TLS_CIPHER_AES_GCM_128), 71 CIPHER_SIZE_DESC(TLS_CIPHER_AES_GCM_256), 72 CIPHER_SIZE_DESC(TLS_CIPHER_AES_CCM_128), 73 CIPHER_SIZE_DESC(TLS_CIPHER_CHACHA20_POLY1305), 74 CIPHER_SIZE_DESC(TLS_CIPHER_SM4_GCM), 75 CIPHER_SIZE_DESC(TLS_CIPHER_SM4_CCM), 76 }; 77 78 static const struct proto *saved_tcpv6_prot; 79 static DEFINE_MUTEX(tcpv6_prot_mutex); 80 static const struct proto *saved_tcpv4_prot; 81 static DEFINE_MUTEX(tcpv4_prot_mutex); 82 static struct proto tls_prots[TLS_NUM_PROTS][TLS_NUM_CONFIG][TLS_NUM_CONFIG]; 83 static struct proto_ops tls_proto_ops[TLS_NUM_PROTS][TLS_NUM_CONFIG][TLS_NUM_CONFIG]; 84 static void build_protos(struct proto prot[TLS_NUM_CONFIG][TLS_NUM_CONFIG], 85 const struct proto *base); 86 87 void update_sk_prot(struct sock *sk, struct tls_context *ctx) 88 { 89 int ip_ver = sk->sk_family == AF_INET6 ? TLSV6 : TLSV4; 90 91 WRITE_ONCE(sk->sk_prot, 92 &tls_prots[ip_ver][ctx->tx_conf][ctx->rx_conf]); 93 WRITE_ONCE(sk->sk_socket->ops, 94 &tls_proto_ops[ip_ver][ctx->tx_conf][ctx->rx_conf]); 95 } 96 97 int wait_on_pending_writer(struct sock *sk, long *timeo) 98 { 99 int rc = 0; 100 DEFINE_WAIT_FUNC(wait, woken_wake_function); 101 102 add_wait_queue(sk_sleep(sk), &wait); 103 while (1) { 104 if (!*timeo) { 105 rc = -EAGAIN; 106 break; 107 } 108 109 if (signal_pending(current)) { 110 rc = sock_intr_errno(*timeo); 111 break; 112 } 113 114 if (sk_wait_event(sk, timeo, 115 !READ_ONCE(sk->sk_write_pending), &wait)) 116 break; 117 } 118 remove_wait_queue(sk_sleep(sk), &wait); 119 return rc; 120 } 121 122 int tls_push_sg(struct sock *sk, 123 struct tls_context *ctx, 124 struct scatterlist *sg, 125 u16 first_offset, 126 int flags) 127 { 128 struct bio_vec bvec; 129 struct msghdr msg = { 130 .msg_flags = MSG_SPLICE_PAGES | flags, 131 }; 132 int ret = 0; 133 struct page *p; 134 size_t size; 135 int offset = first_offset; 136 137 size = sg->length - offset; 138 offset += sg->offset; 139 140 ctx->splicing_pages = true; 141 while (1) { 142 /* is sending application-limited? */ 143 tcp_rate_check_app_limited(sk); 144 p = sg_page(sg); 145 retry: 146 bvec_set_page(&bvec, p, size, offset); 147 iov_iter_bvec(&msg.msg_iter, ITER_SOURCE, &bvec, 1, size); 148 149 ret = tcp_sendmsg_locked(sk, &msg, size); 150 151 if (ret != size) { 152 if (ret > 0) { 153 offset += ret; 154 size -= ret; 155 goto retry; 156 } 157 158 offset -= sg->offset; 159 ctx->partially_sent_offset = offset; 160 ctx->partially_sent_record = (void *)sg; 161 ctx->splicing_pages = false; 162 return ret; 163 } 164 165 put_page(p); 166 sk_mem_uncharge(sk, sg->length); 167 sg = sg_next(sg); 168 if (!sg) 169 break; 170 171 offset = sg->offset; 172 size = sg->length; 173 } 174 175 ctx->splicing_pages = false; 176 177 return 0; 178 } 179 180 static int tls_handle_open_record(struct sock *sk, int flags) 181 { 182 struct tls_context *ctx = tls_get_ctx(sk); 183 184 if (tls_is_pending_open_record(ctx)) 185 return ctx->push_pending_record(sk, flags); 186 187 return 0; 188 } 189 190 int tls_process_cmsg(struct sock *sk, struct msghdr *msg, 191 unsigned char *record_type) 192 { 193 struct cmsghdr *cmsg; 194 int rc = -EINVAL; 195 196 for_each_cmsghdr(cmsg, msg) { 197 if (!CMSG_OK(msg, cmsg)) 198 return -EINVAL; 199 if (cmsg->cmsg_level != SOL_TLS) 200 continue; 201 202 switch (cmsg->cmsg_type) { 203 case TLS_SET_RECORD_TYPE: 204 if (cmsg->cmsg_len < CMSG_LEN(sizeof(*record_type))) 205 return -EINVAL; 206 207 if (msg->msg_flags & MSG_MORE) 208 return -EINVAL; 209 210 rc = tls_handle_open_record(sk, msg->msg_flags); 211 if (rc) 212 return rc; 213 214 *record_type = *(unsigned char *)CMSG_DATA(cmsg); 215 rc = 0; 216 break; 217 default: 218 return -EINVAL; 219 } 220 } 221 222 return rc; 223 } 224 225 int tls_push_partial_record(struct sock *sk, struct tls_context *ctx, 226 int flags) 227 { 228 struct scatterlist *sg; 229 u16 offset; 230 231 sg = ctx->partially_sent_record; 232 offset = ctx->partially_sent_offset; 233 234 ctx->partially_sent_record = NULL; 235 return tls_push_sg(sk, ctx, sg, offset, flags); 236 } 237 238 void tls_free_partial_record(struct sock *sk, struct tls_context *ctx) 239 { 240 struct scatterlist *sg; 241 242 for (sg = ctx->partially_sent_record; sg; sg = sg_next(sg)) { 243 put_page(sg_page(sg)); 244 sk_mem_uncharge(sk, sg->length); 245 } 246 ctx->partially_sent_record = NULL; 247 } 248 249 static void tls_write_space(struct sock *sk) 250 { 251 struct tls_context *ctx = tls_get_ctx(sk); 252 253 /* If splicing_pages call lower protocol write space handler 254 * to ensure we wake up any waiting operations there. For example 255 * if splicing pages where to call sk_wait_event. 256 */ 257 if (ctx->splicing_pages) { 258 ctx->sk_write_space(sk); 259 return; 260 } 261 262 #ifdef CONFIG_TLS_DEVICE 263 if (ctx->tx_conf == TLS_HW) 264 tls_device_write_space(sk, ctx); 265 else 266 #endif 267 tls_sw_write_space(sk, ctx); 268 269 ctx->sk_write_space(sk); 270 } 271 272 /** 273 * tls_ctx_free() - free TLS ULP context 274 * @sk: socket to with @ctx is attached 275 * @ctx: TLS context structure 276 * 277 * Free TLS context. If @sk is %NULL caller guarantees that the socket 278 * to which @ctx was attached has no outstanding references. 279 */ 280 void tls_ctx_free(struct sock *sk, struct tls_context *ctx) 281 { 282 if (!ctx) 283 return; 284 285 memzero_explicit(&ctx->crypto_send, sizeof(ctx->crypto_send)); 286 memzero_explicit(&ctx->crypto_recv, sizeof(ctx->crypto_recv)); 287 mutex_destroy(&ctx->tx_lock); 288 289 if (sk) 290 kfree_rcu(ctx, rcu); 291 else 292 kfree(ctx); 293 } 294 295 static void tls_sk_proto_cleanup(struct sock *sk, 296 struct tls_context *ctx, long timeo) 297 { 298 if (unlikely(sk->sk_write_pending) && 299 !wait_on_pending_writer(sk, &timeo)) 300 tls_handle_open_record(sk, 0); 301 302 /* We need these for tls_sw_fallback handling of other packets */ 303 if (ctx->tx_conf == TLS_SW) { 304 kfree(ctx->tx.rec_seq); 305 kfree(ctx->tx.iv); 306 tls_sw_release_resources_tx(sk); 307 TLS_DEC_STATS(sock_net(sk), LINUX_MIB_TLSCURRTXSW); 308 } else if (ctx->tx_conf == TLS_HW) { 309 tls_device_free_resources_tx(sk); 310 TLS_DEC_STATS(sock_net(sk), LINUX_MIB_TLSCURRTXDEVICE); 311 } 312 313 if (ctx->rx_conf == TLS_SW) { 314 tls_sw_release_resources_rx(sk); 315 TLS_DEC_STATS(sock_net(sk), LINUX_MIB_TLSCURRRXSW); 316 } else if (ctx->rx_conf == TLS_HW) { 317 tls_device_offload_cleanup_rx(sk); 318 TLS_DEC_STATS(sock_net(sk), LINUX_MIB_TLSCURRRXDEVICE); 319 } 320 } 321 322 static void tls_sk_proto_close(struct sock *sk, long timeout) 323 { 324 struct inet_connection_sock *icsk = inet_csk(sk); 325 struct tls_context *ctx = tls_get_ctx(sk); 326 long timeo = sock_sndtimeo(sk, 0); 327 bool free_ctx; 328 329 if (ctx->tx_conf == TLS_SW) 330 tls_sw_cancel_work_tx(ctx); 331 332 lock_sock(sk); 333 free_ctx = ctx->tx_conf != TLS_HW && ctx->rx_conf != TLS_HW; 334 335 if (ctx->tx_conf != TLS_BASE || ctx->rx_conf != TLS_BASE) 336 tls_sk_proto_cleanup(sk, ctx, timeo); 337 338 write_lock_bh(&sk->sk_callback_lock); 339 if (free_ctx) 340 rcu_assign_pointer(icsk->icsk_ulp_data, NULL); 341 WRITE_ONCE(sk->sk_prot, ctx->sk_proto); 342 if (sk->sk_write_space == tls_write_space) 343 sk->sk_write_space = ctx->sk_write_space; 344 write_unlock_bh(&sk->sk_callback_lock); 345 release_sock(sk); 346 if (ctx->tx_conf == TLS_SW) 347 tls_sw_free_ctx_tx(ctx); 348 if (ctx->rx_conf == TLS_SW || ctx->rx_conf == TLS_HW) 349 tls_sw_strparser_done(ctx); 350 if (ctx->rx_conf == TLS_SW) 351 tls_sw_free_ctx_rx(ctx); 352 ctx->sk_proto->close(sk, timeout); 353 354 if (free_ctx) 355 tls_ctx_free(sk, ctx); 356 } 357 358 static __poll_t tls_sk_poll(struct file *file, struct socket *sock, 359 struct poll_table_struct *wait) 360 { 361 struct tls_sw_context_rx *ctx; 362 struct tls_context *tls_ctx; 363 struct sock *sk = sock->sk; 364 struct sk_psock *psock; 365 __poll_t mask = 0; 366 u8 shutdown; 367 int state; 368 369 mask = tcp_poll(file, sock, wait); 370 371 state = inet_sk_state_load(sk); 372 shutdown = READ_ONCE(sk->sk_shutdown); 373 if (unlikely(state != TCP_ESTABLISHED || shutdown & RCV_SHUTDOWN)) 374 return mask; 375 376 tls_ctx = tls_get_ctx(sk); 377 ctx = tls_sw_ctx_rx(tls_ctx); 378 psock = sk_psock_get(sk); 379 380 if (skb_queue_empty_lockless(&ctx->rx_list) && 381 !tls_strp_msg_ready(ctx) && 382 sk_psock_queue_empty(psock)) 383 mask &= ~(EPOLLIN | EPOLLRDNORM); 384 385 if (psock) 386 sk_psock_put(sk, psock); 387 388 return mask; 389 } 390 391 static int do_tls_getsockopt_conf(struct sock *sk, char __user *optval, 392 int __user *optlen, int tx) 393 { 394 int rc = 0; 395 struct tls_context *ctx = tls_get_ctx(sk); 396 struct tls_crypto_info *crypto_info; 397 struct cipher_context *cctx; 398 int len; 399 400 if (get_user(len, optlen)) 401 return -EFAULT; 402 403 if (!optval || (len < sizeof(*crypto_info))) { 404 rc = -EINVAL; 405 goto out; 406 } 407 408 if (!ctx) { 409 rc = -EBUSY; 410 goto out; 411 } 412 413 /* get user crypto info */ 414 if (tx) { 415 crypto_info = &ctx->crypto_send.info; 416 cctx = &ctx->tx; 417 } else { 418 crypto_info = &ctx->crypto_recv.info; 419 cctx = &ctx->rx; 420 } 421 422 if (!TLS_CRYPTO_INFO_READY(crypto_info)) { 423 rc = -EBUSY; 424 goto out; 425 } 426 427 if (len == sizeof(*crypto_info)) { 428 if (copy_to_user(optval, crypto_info, sizeof(*crypto_info))) 429 rc = -EFAULT; 430 goto out; 431 } 432 433 switch (crypto_info->cipher_type) { 434 case TLS_CIPHER_AES_GCM_128: { 435 struct tls12_crypto_info_aes_gcm_128 * 436 crypto_info_aes_gcm_128 = 437 container_of(crypto_info, 438 struct tls12_crypto_info_aes_gcm_128, 439 info); 440 441 if (len != sizeof(*crypto_info_aes_gcm_128)) { 442 rc = -EINVAL; 443 goto out; 444 } 445 memcpy(crypto_info_aes_gcm_128->iv, 446 cctx->iv + TLS_CIPHER_AES_GCM_128_SALT_SIZE, 447 TLS_CIPHER_AES_GCM_128_IV_SIZE); 448 memcpy(crypto_info_aes_gcm_128->rec_seq, cctx->rec_seq, 449 TLS_CIPHER_AES_GCM_128_REC_SEQ_SIZE); 450 if (copy_to_user(optval, 451 crypto_info_aes_gcm_128, 452 sizeof(*crypto_info_aes_gcm_128))) 453 rc = -EFAULT; 454 break; 455 } 456 case TLS_CIPHER_AES_GCM_256: { 457 struct tls12_crypto_info_aes_gcm_256 * 458 crypto_info_aes_gcm_256 = 459 container_of(crypto_info, 460 struct tls12_crypto_info_aes_gcm_256, 461 info); 462 463 if (len != sizeof(*crypto_info_aes_gcm_256)) { 464 rc = -EINVAL; 465 goto out; 466 } 467 memcpy(crypto_info_aes_gcm_256->iv, 468 cctx->iv + TLS_CIPHER_AES_GCM_256_SALT_SIZE, 469 TLS_CIPHER_AES_GCM_256_IV_SIZE); 470 memcpy(crypto_info_aes_gcm_256->rec_seq, cctx->rec_seq, 471 TLS_CIPHER_AES_GCM_256_REC_SEQ_SIZE); 472 if (copy_to_user(optval, 473 crypto_info_aes_gcm_256, 474 sizeof(*crypto_info_aes_gcm_256))) 475 rc = -EFAULT; 476 break; 477 } 478 case TLS_CIPHER_AES_CCM_128: { 479 struct tls12_crypto_info_aes_ccm_128 *aes_ccm_128 = 480 container_of(crypto_info, 481 struct tls12_crypto_info_aes_ccm_128, info); 482 483 if (len != sizeof(*aes_ccm_128)) { 484 rc = -EINVAL; 485 goto out; 486 } 487 memcpy(aes_ccm_128->iv, 488 cctx->iv + TLS_CIPHER_AES_CCM_128_SALT_SIZE, 489 TLS_CIPHER_AES_CCM_128_IV_SIZE); 490 memcpy(aes_ccm_128->rec_seq, cctx->rec_seq, 491 TLS_CIPHER_AES_CCM_128_REC_SEQ_SIZE); 492 if (copy_to_user(optval, aes_ccm_128, sizeof(*aes_ccm_128))) 493 rc = -EFAULT; 494 break; 495 } 496 case TLS_CIPHER_CHACHA20_POLY1305: { 497 struct tls12_crypto_info_chacha20_poly1305 *chacha20_poly1305 = 498 container_of(crypto_info, 499 struct tls12_crypto_info_chacha20_poly1305, 500 info); 501 502 if (len != sizeof(*chacha20_poly1305)) { 503 rc = -EINVAL; 504 goto out; 505 } 506 memcpy(chacha20_poly1305->iv, 507 cctx->iv + TLS_CIPHER_CHACHA20_POLY1305_SALT_SIZE, 508 TLS_CIPHER_CHACHA20_POLY1305_IV_SIZE); 509 memcpy(chacha20_poly1305->rec_seq, cctx->rec_seq, 510 TLS_CIPHER_CHACHA20_POLY1305_REC_SEQ_SIZE); 511 if (copy_to_user(optval, chacha20_poly1305, 512 sizeof(*chacha20_poly1305))) 513 rc = -EFAULT; 514 break; 515 } 516 case TLS_CIPHER_SM4_GCM: { 517 struct tls12_crypto_info_sm4_gcm *sm4_gcm_info = 518 container_of(crypto_info, 519 struct tls12_crypto_info_sm4_gcm, info); 520 521 if (len != sizeof(*sm4_gcm_info)) { 522 rc = -EINVAL; 523 goto out; 524 } 525 memcpy(sm4_gcm_info->iv, 526 cctx->iv + TLS_CIPHER_SM4_GCM_SALT_SIZE, 527 TLS_CIPHER_SM4_GCM_IV_SIZE); 528 memcpy(sm4_gcm_info->rec_seq, cctx->rec_seq, 529 TLS_CIPHER_SM4_GCM_REC_SEQ_SIZE); 530 if (copy_to_user(optval, sm4_gcm_info, sizeof(*sm4_gcm_info))) 531 rc = -EFAULT; 532 break; 533 } 534 case TLS_CIPHER_SM4_CCM: { 535 struct tls12_crypto_info_sm4_ccm *sm4_ccm_info = 536 container_of(crypto_info, 537 struct tls12_crypto_info_sm4_ccm, info); 538 539 if (len != sizeof(*sm4_ccm_info)) { 540 rc = -EINVAL; 541 goto out; 542 } 543 memcpy(sm4_ccm_info->iv, 544 cctx->iv + TLS_CIPHER_SM4_CCM_SALT_SIZE, 545 TLS_CIPHER_SM4_CCM_IV_SIZE); 546 memcpy(sm4_ccm_info->rec_seq, cctx->rec_seq, 547 TLS_CIPHER_SM4_CCM_REC_SEQ_SIZE); 548 if (copy_to_user(optval, sm4_ccm_info, sizeof(*sm4_ccm_info))) 549 rc = -EFAULT; 550 break; 551 } 552 case TLS_CIPHER_ARIA_GCM_128: { 553 struct tls12_crypto_info_aria_gcm_128 * 554 crypto_info_aria_gcm_128 = 555 container_of(crypto_info, 556 struct tls12_crypto_info_aria_gcm_128, 557 info); 558 559 if (len != sizeof(*crypto_info_aria_gcm_128)) { 560 rc = -EINVAL; 561 goto out; 562 } 563 memcpy(crypto_info_aria_gcm_128->iv, 564 cctx->iv + TLS_CIPHER_ARIA_GCM_128_SALT_SIZE, 565 TLS_CIPHER_ARIA_GCM_128_IV_SIZE); 566 memcpy(crypto_info_aria_gcm_128->rec_seq, cctx->rec_seq, 567 TLS_CIPHER_ARIA_GCM_128_REC_SEQ_SIZE); 568 if (copy_to_user(optval, 569 crypto_info_aria_gcm_128, 570 sizeof(*crypto_info_aria_gcm_128))) 571 rc = -EFAULT; 572 break; 573 } 574 case TLS_CIPHER_ARIA_GCM_256: { 575 struct tls12_crypto_info_aria_gcm_256 * 576 crypto_info_aria_gcm_256 = 577 container_of(crypto_info, 578 struct tls12_crypto_info_aria_gcm_256, 579 info); 580 581 if (len != sizeof(*crypto_info_aria_gcm_256)) { 582 rc = -EINVAL; 583 goto out; 584 } 585 memcpy(crypto_info_aria_gcm_256->iv, 586 cctx->iv + TLS_CIPHER_ARIA_GCM_256_SALT_SIZE, 587 TLS_CIPHER_ARIA_GCM_256_IV_SIZE); 588 memcpy(crypto_info_aria_gcm_256->rec_seq, cctx->rec_seq, 589 TLS_CIPHER_ARIA_GCM_256_REC_SEQ_SIZE); 590 if (copy_to_user(optval, 591 crypto_info_aria_gcm_256, 592 sizeof(*crypto_info_aria_gcm_256))) 593 rc = -EFAULT; 594 break; 595 } 596 default: 597 rc = -EINVAL; 598 } 599 600 out: 601 return rc; 602 } 603 604 static int do_tls_getsockopt_tx_zc(struct sock *sk, char __user *optval, 605 int __user *optlen) 606 { 607 struct tls_context *ctx = tls_get_ctx(sk); 608 unsigned int value; 609 int len; 610 611 if (get_user(len, optlen)) 612 return -EFAULT; 613 614 if (len != sizeof(value)) 615 return -EINVAL; 616 617 value = ctx->zerocopy_sendfile; 618 if (copy_to_user(optval, &value, sizeof(value))) 619 return -EFAULT; 620 621 return 0; 622 } 623 624 static int do_tls_getsockopt_no_pad(struct sock *sk, char __user *optval, 625 int __user *optlen) 626 { 627 struct tls_context *ctx = tls_get_ctx(sk); 628 int value, len; 629 630 if (ctx->prot_info.version != TLS_1_3_VERSION) 631 return -EINVAL; 632 633 if (get_user(len, optlen)) 634 return -EFAULT; 635 if (len < sizeof(value)) 636 return -EINVAL; 637 638 value = -EINVAL; 639 if (ctx->rx_conf == TLS_SW || ctx->rx_conf == TLS_HW) 640 value = ctx->rx_no_pad; 641 if (value < 0) 642 return value; 643 644 if (put_user(sizeof(value), optlen)) 645 return -EFAULT; 646 if (copy_to_user(optval, &value, sizeof(value))) 647 return -EFAULT; 648 649 return 0; 650 } 651 652 static int do_tls_getsockopt(struct sock *sk, int optname, 653 char __user *optval, int __user *optlen) 654 { 655 int rc = 0; 656 657 lock_sock(sk); 658 659 switch (optname) { 660 case TLS_TX: 661 case TLS_RX: 662 rc = do_tls_getsockopt_conf(sk, optval, optlen, 663 optname == TLS_TX); 664 break; 665 case TLS_TX_ZEROCOPY_RO: 666 rc = do_tls_getsockopt_tx_zc(sk, optval, optlen); 667 break; 668 case TLS_RX_EXPECT_NO_PAD: 669 rc = do_tls_getsockopt_no_pad(sk, optval, optlen); 670 break; 671 default: 672 rc = -ENOPROTOOPT; 673 break; 674 } 675 676 release_sock(sk); 677 678 return rc; 679 } 680 681 static int tls_getsockopt(struct sock *sk, int level, int optname, 682 char __user *optval, int __user *optlen) 683 { 684 struct tls_context *ctx = tls_get_ctx(sk); 685 686 if (level != SOL_TLS) 687 return ctx->sk_proto->getsockopt(sk, level, 688 optname, optval, optlen); 689 690 return do_tls_getsockopt(sk, optname, optval, optlen); 691 } 692 693 static int do_tls_setsockopt_conf(struct sock *sk, sockptr_t optval, 694 unsigned int optlen, int tx) 695 { 696 struct tls_crypto_info *crypto_info; 697 struct tls_crypto_info *alt_crypto_info; 698 struct tls_context *ctx = tls_get_ctx(sk); 699 size_t optsize; 700 int rc = 0; 701 int conf; 702 703 if (sockptr_is_null(optval) || (optlen < sizeof(*crypto_info))) 704 return -EINVAL; 705 706 if (tx) { 707 crypto_info = &ctx->crypto_send.info; 708 alt_crypto_info = &ctx->crypto_recv.info; 709 } else { 710 crypto_info = &ctx->crypto_recv.info; 711 alt_crypto_info = &ctx->crypto_send.info; 712 } 713 714 /* Currently we don't support set crypto info more than one time */ 715 if (TLS_CRYPTO_INFO_READY(crypto_info)) 716 return -EBUSY; 717 718 rc = copy_from_sockptr(crypto_info, optval, sizeof(*crypto_info)); 719 if (rc) { 720 rc = -EFAULT; 721 goto err_crypto_info; 722 } 723 724 /* check version */ 725 if (crypto_info->version != TLS_1_2_VERSION && 726 crypto_info->version != TLS_1_3_VERSION) { 727 rc = -EINVAL; 728 goto err_crypto_info; 729 } 730 731 /* Ensure that TLS version and ciphers are same in both directions */ 732 if (TLS_CRYPTO_INFO_READY(alt_crypto_info)) { 733 if (alt_crypto_info->version != crypto_info->version || 734 alt_crypto_info->cipher_type != crypto_info->cipher_type) { 735 rc = -EINVAL; 736 goto err_crypto_info; 737 } 738 } 739 740 switch (crypto_info->cipher_type) { 741 case TLS_CIPHER_AES_GCM_128: 742 optsize = sizeof(struct tls12_crypto_info_aes_gcm_128); 743 break; 744 case TLS_CIPHER_AES_GCM_256: { 745 optsize = sizeof(struct tls12_crypto_info_aes_gcm_256); 746 break; 747 } 748 case TLS_CIPHER_AES_CCM_128: 749 optsize = sizeof(struct tls12_crypto_info_aes_ccm_128); 750 break; 751 case TLS_CIPHER_CHACHA20_POLY1305: 752 optsize = sizeof(struct tls12_crypto_info_chacha20_poly1305); 753 break; 754 case TLS_CIPHER_SM4_GCM: 755 optsize = sizeof(struct tls12_crypto_info_sm4_gcm); 756 break; 757 case TLS_CIPHER_SM4_CCM: 758 optsize = sizeof(struct tls12_crypto_info_sm4_ccm); 759 break; 760 case TLS_CIPHER_ARIA_GCM_128: 761 if (crypto_info->version != TLS_1_2_VERSION) { 762 rc = -EINVAL; 763 goto err_crypto_info; 764 } 765 optsize = sizeof(struct tls12_crypto_info_aria_gcm_128); 766 break; 767 case TLS_CIPHER_ARIA_GCM_256: 768 if (crypto_info->version != TLS_1_2_VERSION) { 769 rc = -EINVAL; 770 goto err_crypto_info; 771 } 772 optsize = sizeof(struct tls12_crypto_info_aria_gcm_256); 773 break; 774 default: 775 rc = -EINVAL; 776 goto err_crypto_info; 777 } 778 779 if (optlen != optsize) { 780 rc = -EINVAL; 781 goto err_crypto_info; 782 } 783 784 rc = copy_from_sockptr_offset(crypto_info + 1, optval, 785 sizeof(*crypto_info), 786 optlen - sizeof(*crypto_info)); 787 if (rc) { 788 rc = -EFAULT; 789 goto err_crypto_info; 790 } 791 792 if (tx) { 793 rc = tls_set_device_offload(sk, ctx); 794 conf = TLS_HW; 795 if (!rc) { 796 TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSTXDEVICE); 797 TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSCURRTXDEVICE); 798 } else { 799 rc = tls_set_sw_offload(sk, ctx, 1); 800 if (rc) 801 goto err_crypto_info; 802 TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSTXSW); 803 TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSCURRTXSW); 804 conf = TLS_SW; 805 } 806 } else { 807 rc = tls_set_device_offload_rx(sk, ctx); 808 conf = TLS_HW; 809 if (!rc) { 810 TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSRXDEVICE); 811 TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSCURRRXDEVICE); 812 } else { 813 rc = tls_set_sw_offload(sk, ctx, 0); 814 if (rc) 815 goto err_crypto_info; 816 TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSRXSW); 817 TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSCURRRXSW); 818 conf = TLS_SW; 819 } 820 tls_sw_strparser_arm(sk, ctx); 821 } 822 823 if (tx) 824 ctx->tx_conf = conf; 825 else 826 ctx->rx_conf = conf; 827 update_sk_prot(sk, ctx); 828 if (tx) { 829 ctx->sk_write_space = sk->sk_write_space; 830 sk->sk_write_space = tls_write_space; 831 } else { 832 struct tls_sw_context_rx *rx_ctx = tls_sw_ctx_rx(ctx); 833 834 tls_strp_check_rcv(&rx_ctx->strp); 835 } 836 return 0; 837 838 err_crypto_info: 839 memzero_explicit(crypto_info, sizeof(union tls_crypto_context)); 840 return rc; 841 } 842 843 static int do_tls_setsockopt_tx_zc(struct sock *sk, sockptr_t optval, 844 unsigned int optlen) 845 { 846 struct tls_context *ctx = tls_get_ctx(sk); 847 unsigned int value; 848 849 if (sockptr_is_null(optval) || optlen != sizeof(value)) 850 return -EINVAL; 851 852 if (copy_from_sockptr(&value, optval, sizeof(value))) 853 return -EFAULT; 854 855 if (value > 1) 856 return -EINVAL; 857 858 ctx->zerocopy_sendfile = value; 859 860 return 0; 861 } 862 863 static int do_tls_setsockopt_no_pad(struct sock *sk, sockptr_t optval, 864 unsigned int optlen) 865 { 866 struct tls_context *ctx = tls_get_ctx(sk); 867 u32 val; 868 int rc; 869 870 if (ctx->prot_info.version != TLS_1_3_VERSION || 871 sockptr_is_null(optval) || optlen < sizeof(val)) 872 return -EINVAL; 873 874 rc = copy_from_sockptr(&val, optval, sizeof(val)); 875 if (rc) 876 return -EFAULT; 877 if (val > 1) 878 return -EINVAL; 879 rc = check_zeroed_sockptr(optval, sizeof(val), optlen - sizeof(val)); 880 if (rc < 1) 881 return rc == 0 ? -EINVAL : rc; 882 883 lock_sock(sk); 884 rc = -EINVAL; 885 if (ctx->rx_conf == TLS_SW || ctx->rx_conf == TLS_HW) { 886 ctx->rx_no_pad = val; 887 tls_update_rx_zc_capable(ctx); 888 rc = 0; 889 } 890 release_sock(sk); 891 892 return rc; 893 } 894 895 static int do_tls_setsockopt(struct sock *sk, int optname, sockptr_t optval, 896 unsigned int optlen) 897 { 898 int rc = 0; 899 900 switch (optname) { 901 case TLS_TX: 902 case TLS_RX: 903 lock_sock(sk); 904 rc = do_tls_setsockopt_conf(sk, optval, optlen, 905 optname == TLS_TX); 906 release_sock(sk); 907 break; 908 case TLS_TX_ZEROCOPY_RO: 909 lock_sock(sk); 910 rc = do_tls_setsockopt_tx_zc(sk, optval, optlen); 911 release_sock(sk); 912 break; 913 case TLS_RX_EXPECT_NO_PAD: 914 rc = do_tls_setsockopt_no_pad(sk, optval, optlen); 915 break; 916 default: 917 rc = -ENOPROTOOPT; 918 break; 919 } 920 return rc; 921 } 922 923 static int tls_setsockopt(struct sock *sk, int level, int optname, 924 sockptr_t optval, unsigned int optlen) 925 { 926 struct tls_context *ctx = tls_get_ctx(sk); 927 928 if (level != SOL_TLS) 929 return ctx->sk_proto->setsockopt(sk, level, optname, optval, 930 optlen); 931 932 return do_tls_setsockopt(sk, optname, optval, optlen); 933 } 934 935 struct tls_context *tls_ctx_create(struct sock *sk) 936 { 937 struct inet_connection_sock *icsk = inet_csk(sk); 938 struct tls_context *ctx; 939 940 ctx = kzalloc(sizeof(*ctx), GFP_ATOMIC); 941 if (!ctx) 942 return NULL; 943 944 mutex_init(&ctx->tx_lock); 945 rcu_assign_pointer(icsk->icsk_ulp_data, ctx); 946 ctx->sk_proto = READ_ONCE(sk->sk_prot); 947 ctx->sk = sk; 948 return ctx; 949 } 950 951 static void build_proto_ops(struct proto_ops ops[TLS_NUM_CONFIG][TLS_NUM_CONFIG], 952 const struct proto_ops *base) 953 { 954 ops[TLS_BASE][TLS_BASE] = *base; 955 956 ops[TLS_SW ][TLS_BASE] = ops[TLS_BASE][TLS_BASE]; 957 ops[TLS_SW ][TLS_BASE].splice_eof = tls_sw_splice_eof; 958 959 ops[TLS_BASE][TLS_SW ] = ops[TLS_BASE][TLS_BASE]; 960 ops[TLS_BASE][TLS_SW ].splice_read = tls_sw_splice_read; 961 ops[TLS_BASE][TLS_SW ].poll = tls_sk_poll; 962 ops[TLS_BASE][TLS_SW ].read_sock = tls_sw_read_sock; 963 964 ops[TLS_SW ][TLS_SW ] = ops[TLS_SW ][TLS_BASE]; 965 ops[TLS_SW ][TLS_SW ].splice_read = tls_sw_splice_read; 966 ops[TLS_SW ][TLS_SW ].poll = tls_sk_poll; 967 ops[TLS_SW ][TLS_SW ].read_sock = tls_sw_read_sock; 968 969 #ifdef CONFIG_TLS_DEVICE 970 ops[TLS_HW ][TLS_BASE] = ops[TLS_BASE][TLS_BASE]; 971 972 ops[TLS_HW ][TLS_SW ] = ops[TLS_BASE][TLS_SW ]; 973 974 ops[TLS_BASE][TLS_HW ] = ops[TLS_BASE][TLS_SW ]; 975 976 ops[TLS_SW ][TLS_HW ] = ops[TLS_SW ][TLS_SW ]; 977 978 ops[TLS_HW ][TLS_HW ] = ops[TLS_HW ][TLS_SW ]; 979 #endif 980 #ifdef CONFIG_TLS_TOE 981 ops[TLS_HW_RECORD][TLS_HW_RECORD] = *base; 982 #endif 983 } 984 985 static void tls_build_proto(struct sock *sk) 986 { 987 int ip_ver = sk->sk_family == AF_INET6 ? TLSV6 : TLSV4; 988 struct proto *prot = READ_ONCE(sk->sk_prot); 989 990 /* Build IPv6 TLS whenever the address of tcpv6 _prot changes */ 991 if (ip_ver == TLSV6 && 992 unlikely(prot != smp_load_acquire(&saved_tcpv6_prot))) { 993 mutex_lock(&tcpv6_prot_mutex); 994 if (likely(prot != saved_tcpv6_prot)) { 995 build_protos(tls_prots[TLSV6], prot); 996 build_proto_ops(tls_proto_ops[TLSV6], 997 sk->sk_socket->ops); 998 smp_store_release(&saved_tcpv6_prot, prot); 999 } 1000 mutex_unlock(&tcpv6_prot_mutex); 1001 } 1002 1003 if (ip_ver == TLSV4 && 1004 unlikely(prot != smp_load_acquire(&saved_tcpv4_prot))) { 1005 mutex_lock(&tcpv4_prot_mutex); 1006 if (likely(prot != saved_tcpv4_prot)) { 1007 build_protos(tls_prots[TLSV4], prot); 1008 build_proto_ops(tls_proto_ops[TLSV4], 1009 sk->sk_socket->ops); 1010 smp_store_release(&saved_tcpv4_prot, prot); 1011 } 1012 mutex_unlock(&tcpv4_prot_mutex); 1013 } 1014 } 1015 1016 static void build_protos(struct proto prot[TLS_NUM_CONFIG][TLS_NUM_CONFIG], 1017 const struct proto *base) 1018 { 1019 prot[TLS_BASE][TLS_BASE] = *base; 1020 prot[TLS_BASE][TLS_BASE].setsockopt = tls_setsockopt; 1021 prot[TLS_BASE][TLS_BASE].getsockopt = tls_getsockopt; 1022 prot[TLS_BASE][TLS_BASE].close = tls_sk_proto_close; 1023 1024 prot[TLS_SW][TLS_BASE] = prot[TLS_BASE][TLS_BASE]; 1025 prot[TLS_SW][TLS_BASE].sendmsg = tls_sw_sendmsg; 1026 prot[TLS_SW][TLS_BASE].splice_eof = tls_sw_splice_eof; 1027 1028 prot[TLS_BASE][TLS_SW] = prot[TLS_BASE][TLS_BASE]; 1029 prot[TLS_BASE][TLS_SW].recvmsg = tls_sw_recvmsg; 1030 prot[TLS_BASE][TLS_SW].sock_is_readable = tls_sw_sock_is_readable; 1031 prot[TLS_BASE][TLS_SW].close = tls_sk_proto_close; 1032 1033 prot[TLS_SW][TLS_SW] = prot[TLS_SW][TLS_BASE]; 1034 prot[TLS_SW][TLS_SW].recvmsg = tls_sw_recvmsg; 1035 prot[TLS_SW][TLS_SW].sock_is_readable = tls_sw_sock_is_readable; 1036 prot[TLS_SW][TLS_SW].close = tls_sk_proto_close; 1037 1038 #ifdef CONFIG_TLS_DEVICE 1039 prot[TLS_HW][TLS_BASE] = prot[TLS_BASE][TLS_BASE]; 1040 prot[TLS_HW][TLS_BASE].sendmsg = tls_device_sendmsg; 1041 prot[TLS_HW][TLS_BASE].splice_eof = tls_device_splice_eof; 1042 1043 prot[TLS_HW][TLS_SW] = prot[TLS_BASE][TLS_SW]; 1044 prot[TLS_HW][TLS_SW].sendmsg = tls_device_sendmsg; 1045 prot[TLS_HW][TLS_SW].splice_eof = tls_device_splice_eof; 1046 1047 prot[TLS_BASE][TLS_HW] = prot[TLS_BASE][TLS_SW]; 1048 1049 prot[TLS_SW][TLS_HW] = prot[TLS_SW][TLS_SW]; 1050 1051 prot[TLS_HW][TLS_HW] = prot[TLS_HW][TLS_SW]; 1052 #endif 1053 #ifdef CONFIG_TLS_TOE 1054 prot[TLS_HW_RECORD][TLS_HW_RECORD] = *base; 1055 prot[TLS_HW_RECORD][TLS_HW_RECORD].hash = tls_toe_hash; 1056 prot[TLS_HW_RECORD][TLS_HW_RECORD].unhash = tls_toe_unhash; 1057 #endif 1058 } 1059 1060 static int tls_init(struct sock *sk) 1061 { 1062 struct tls_context *ctx; 1063 int rc = 0; 1064 1065 tls_build_proto(sk); 1066 1067 #ifdef CONFIG_TLS_TOE 1068 if (tls_toe_bypass(sk)) 1069 return 0; 1070 #endif 1071 1072 /* The TLS ulp is currently supported only for TCP sockets 1073 * in ESTABLISHED state. 1074 * Supporting sockets in LISTEN state will require us 1075 * to modify the accept implementation to clone rather then 1076 * share the ulp context. 1077 */ 1078 if (sk->sk_state != TCP_ESTABLISHED) 1079 return -ENOTCONN; 1080 1081 /* allocate tls context */ 1082 write_lock_bh(&sk->sk_callback_lock); 1083 ctx = tls_ctx_create(sk); 1084 if (!ctx) { 1085 rc = -ENOMEM; 1086 goto out; 1087 } 1088 1089 ctx->tx_conf = TLS_BASE; 1090 ctx->rx_conf = TLS_BASE; 1091 update_sk_prot(sk, ctx); 1092 out: 1093 write_unlock_bh(&sk->sk_callback_lock); 1094 return rc; 1095 } 1096 1097 static void tls_update(struct sock *sk, struct proto *p, 1098 void (*write_space)(struct sock *sk)) 1099 { 1100 struct tls_context *ctx; 1101 1102 WARN_ON_ONCE(sk->sk_prot == p); 1103 1104 ctx = tls_get_ctx(sk); 1105 if (likely(ctx)) { 1106 ctx->sk_write_space = write_space; 1107 ctx->sk_proto = p; 1108 } else { 1109 /* Pairs with lockless read in sk_clone_lock(). */ 1110 WRITE_ONCE(sk->sk_prot, p); 1111 sk->sk_write_space = write_space; 1112 } 1113 } 1114 1115 static u16 tls_user_config(struct tls_context *ctx, bool tx) 1116 { 1117 u16 config = tx ? ctx->tx_conf : ctx->rx_conf; 1118 1119 switch (config) { 1120 case TLS_BASE: 1121 return TLS_CONF_BASE; 1122 case TLS_SW: 1123 return TLS_CONF_SW; 1124 case TLS_HW: 1125 return TLS_CONF_HW; 1126 case TLS_HW_RECORD: 1127 return TLS_CONF_HW_RECORD; 1128 } 1129 return 0; 1130 } 1131 1132 static int tls_get_info(const struct sock *sk, struct sk_buff *skb) 1133 { 1134 u16 version, cipher_type; 1135 struct tls_context *ctx; 1136 struct nlattr *start; 1137 int err; 1138 1139 start = nla_nest_start_noflag(skb, INET_ULP_INFO_TLS); 1140 if (!start) 1141 return -EMSGSIZE; 1142 1143 rcu_read_lock(); 1144 ctx = rcu_dereference(inet_csk(sk)->icsk_ulp_data); 1145 if (!ctx) { 1146 err = 0; 1147 goto nla_failure; 1148 } 1149 version = ctx->prot_info.version; 1150 if (version) { 1151 err = nla_put_u16(skb, TLS_INFO_VERSION, version); 1152 if (err) 1153 goto nla_failure; 1154 } 1155 cipher_type = ctx->prot_info.cipher_type; 1156 if (cipher_type) { 1157 err = nla_put_u16(skb, TLS_INFO_CIPHER, cipher_type); 1158 if (err) 1159 goto nla_failure; 1160 } 1161 err = nla_put_u16(skb, TLS_INFO_TXCONF, tls_user_config(ctx, true)); 1162 if (err) 1163 goto nla_failure; 1164 1165 err = nla_put_u16(skb, TLS_INFO_RXCONF, tls_user_config(ctx, false)); 1166 if (err) 1167 goto nla_failure; 1168 1169 if (ctx->tx_conf == TLS_HW && ctx->zerocopy_sendfile) { 1170 err = nla_put_flag(skb, TLS_INFO_ZC_RO_TX); 1171 if (err) 1172 goto nla_failure; 1173 } 1174 if (ctx->rx_no_pad) { 1175 err = nla_put_flag(skb, TLS_INFO_RX_NO_PAD); 1176 if (err) 1177 goto nla_failure; 1178 } 1179 1180 rcu_read_unlock(); 1181 nla_nest_end(skb, start); 1182 return 0; 1183 1184 nla_failure: 1185 rcu_read_unlock(); 1186 nla_nest_cancel(skb, start); 1187 return err; 1188 } 1189 1190 static size_t tls_get_info_size(const struct sock *sk) 1191 { 1192 size_t size = 0; 1193 1194 size += nla_total_size(0) + /* INET_ULP_INFO_TLS */ 1195 nla_total_size(sizeof(u16)) + /* TLS_INFO_VERSION */ 1196 nla_total_size(sizeof(u16)) + /* TLS_INFO_CIPHER */ 1197 nla_total_size(sizeof(u16)) + /* TLS_INFO_RXCONF */ 1198 nla_total_size(sizeof(u16)) + /* TLS_INFO_TXCONF */ 1199 nla_total_size(0) + /* TLS_INFO_ZC_RO_TX */ 1200 nla_total_size(0) + /* TLS_INFO_RX_NO_PAD */ 1201 0; 1202 1203 return size; 1204 } 1205 1206 static int __net_init tls_init_net(struct net *net) 1207 { 1208 int err; 1209 1210 net->mib.tls_statistics = alloc_percpu(struct linux_tls_mib); 1211 if (!net->mib.tls_statistics) 1212 return -ENOMEM; 1213 1214 err = tls_proc_init(net); 1215 if (err) 1216 goto err_free_stats; 1217 1218 return 0; 1219 err_free_stats: 1220 free_percpu(net->mib.tls_statistics); 1221 return err; 1222 } 1223 1224 static void __net_exit tls_exit_net(struct net *net) 1225 { 1226 tls_proc_fini(net); 1227 free_percpu(net->mib.tls_statistics); 1228 } 1229 1230 static struct pernet_operations tls_proc_ops = { 1231 .init = tls_init_net, 1232 .exit = tls_exit_net, 1233 }; 1234 1235 static struct tcp_ulp_ops tcp_tls_ulp_ops __read_mostly = { 1236 .name = "tls", 1237 .owner = THIS_MODULE, 1238 .init = tls_init, 1239 .update = tls_update, 1240 .get_info = tls_get_info, 1241 .get_info_size = tls_get_info_size, 1242 }; 1243 1244 static int __init tls_register(void) 1245 { 1246 int err; 1247 1248 err = register_pernet_subsys(&tls_proc_ops); 1249 if (err) 1250 return err; 1251 1252 err = tls_strp_dev_init(); 1253 if (err) 1254 goto err_pernet; 1255 1256 err = tls_device_init(); 1257 if (err) 1258 goto err_strp; 1259 1260 tcp_register_ulp(&tcp_tls_ulp_ops); 1261 1262 return 0; 1263 err_strp: 1264 tls_strp_dev_exit(); 1265 err_pernet: 1266 unregister_pernet_subsys(&tls_proc_ops); 1267 return err; 1268 } 1269 1270 static void __exit tls_unregister(void) 1271 { 1272 tcp_unregister_ulp(&tcp_tls_ulp_ops); 1273 tls_strp_dev_exit(); 1274 tls_device_cleanup(); 1275 unregister_pernet_subsys(&tls_proc_ops); 1276 } 1277 1278 module_init(tls_register); 1279 module_exit(tls_unregister); 1280