1 /* 2 * Copyright (c) 2016-2017, Mellanox Technologies. All rights reserved. 3 * Copyright (c) 2016-2017, Dave Watson <davejwatson@fb.com>. All rights reserved. 4 * 5 * This software is available to you under a choice of one of two 6 * licenses. You may choose to be licensed under the terms of the GNU 7 * General Public License (GPL) Version 2, available from the file 8 * COPYING in the main directory of this source tree, or the 9 * OpenIB.org BSD license below: 10 * 11 * Redistribution and use in source and binary forms, with or 12 * without modification, are permitted provided that the following 13 * conditions are met: 14 * 15 * - Redistributions of source code must retain the above 16 * copyright notice, this list of conditions and the following 17 * disclaimer. 18 * 19 * - Redistributions in binary form must reproduce the above 20 * copyright notice, this list of conditions and the following 21 * disclaimer in the documentation and/or other materials 22 * provided with the distribution. 23 * 24 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, 25 * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF 26 * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND 27 * NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS 28 * BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN 29 * ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN 30 * CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 31 * SOFTWARE. 32 */ 33 34 #include <linux/module.h> 35 36 #include <net/tcp.h> 37 #include <net/inet_common.h> 38 #include <linux/highmem.h> 39 #include <linux/netdevice.h> 40 #include <linux/sched/signal.h> 41 #include <linux/inetdevice.h> 42 43 #include <net/tls.h> 44 45 MODULE_AUTHOR("Mellanox Technologies"); 46 MODULE_DESCRIPTION("Transport Layer Security Support"); 47 MODULE_LICENSE("Dual BSD/GPL"); 48 49 enum { 50 TLSV4, 51 TLSV6, 52 TLS_NUM_PROTS, 53 }; 54 55 enum { 56 TLS_BASE, 57 TLS_SW_TX, 58 TLS_SW_RX, 59 TLS_SW_RXTX, 60 TLS_HW_RECORD, 61 TLS_NUM_CONFIG, 62 }; 63 64 static struct proto *saved_tcpv6_prot; 65 static DEFINE_MUTEX(tcpv6_prot_mutex); 66 static LIST_HEAD(device_list); 67 static DEFINE_MUTEX(device_mutex); 68 static struct proto tls_prots[TLS_NUM_PROTS][TLS_NUM_CONFIG]; 69 static struct proto_ops tls_sw_proto_ops; 70 71 static inline void update_sk_prot(struct sock *sk, struct tls_context *ctx) 72 { 73 int ip_ver = sk->sk_family == AF_INET6 ? TLSV6 : TLSV4; 74 75 sk->sk_prot = &tls_prots[ip_ver][ctx->conf]; 76 } 77 78 int wait_on_pending_writer(struct sock *sk, long *timeo) 79 { 80 int rc = 0; 81 DEFINE_WAIT_FUNC(wait, woken_wake_function); 82 83 add_wait_queue(sk_sleep(sk), &wait); 84 while (1) { 85 if (!*timeo) { 86 rc = -EAGAIN; 87 break; 88 } 89 90 if (signal_pending(current)) { 91 rc = sock_intr_errno(*timeo); 92 break; 93 } 94 95 if (sk_wait_event(sk, timeo, !sk->sk_write_pending, &wait)) 96 break; 97 } 98 remove_wait_queue(sk_sleep(sk), &wait); 99 return rc; 100 } 101 102 int tls_push_sg(struct sock *sk, 103 struct tls_context *ctx, 104 struct scatterlist *sg, 105 u16 first_offset, 106 int flags) 107 { 108 int sendpage_flags = flags | MSG_SENDPAGE_NOTLAST; 109 int ret = 0; 110 struct page *p; 111 size_t size; 112 int offset = first_offset; 113 114 size = sg->length - offset; 115 offset += sg->offset; 116 117 ctx->in_tcp_sendpages = true; 118 while (1) { 119 if (sg_is_last(sg)) 120 sendpage_flags = flags; 121 122 /* is sending application-limited? */ 123 tcp_rate_check_app_limited(sk); 124 p = sg_page(sg); 125 retry: 126 ret = do_tcp_sendpages(sk, p, offset, size, sendpage_flags); 127 128 if (ret != size) { 129 if (ret > 0) { 130 offset += ret; 131 size -= ret; 132 goto retry; 133 } 134 135 offset -= sg->offset; 136 ctx->partially_sent_offset = offset; 137 ctx->partially_sent_record = (void *)sg; 138 return ret; 139 } 140 141 put_page(p); 142 sk_mem_uncharge(sk, sg->length); 143 sg = sg_next(sg); 144 if (!sg) 145 break; 146 147 offset = sg->offset; 148 size = sg->length; 149 } 150 151 clear_bit(TLS_PENDING_CLOSED_RECORD, &ctx->flags); 152 ctx->in_tcp_sendpages = false; 153 ctx->sk_write_space(sk); 154 155 return 0; 156 } 157 158 static int tls_handle_open_record(struct sock *sk, int flags) 159 { 160 struct tls_context *ctx = tls_get_ctx(sk); 161 162 if (tls_is_pending_open_record(ctx)) 163 return ctx->push_pending_record(sk, flags); 164 165 return 0; 166 } 167 168 int tls_proccess_cmsg(struct sock *sk, struct msghdr *msg, 169 unsigned char *record_type) 170 { 171 struct cmsghdr *cmsg; 172 int rc = -EINVAL; 173 174 for_each_cmsghdr(cmsg, msg) { 175 if (!CMSG_OK(msg, cmsg)) 176 return -EINVAL; 177 if (cmsg->cmsg_level != SOL_TLS) 178 continue; 179 180 switch (cmsg->cmsg_type) { 181 case TLS_SET_RECORD_TYPE: 182 if (cmsg->cmsg_len < CMSG_LEN(sizeof(*record_type))) 183 return -EINVAL; 184 185 if (msg->msg_flags & MSG_MORE) 186 return -EINVAL; 187 188 rc = tls_handle_open_record(sk, msg->msg_flags); 189 if (rc) 190 return rc; 191 192 *record_type = *(unsigned char *)CMSG_DATA(cmsg); 193 rc = 0; 194 break; 195 default: 196 return -EINVAL; 197 } 198 } 199 200 return rc; 201 } 202 203 int tls_push_pending_closed_record(struct sock *sk, struct tls_context *ctx, 204 int flags, long *timeo) 205 { 206 struct scatterlist *sg; 207 u16 offset; 208 209 if (!tls_is_partially_sent_record(ctx)) 210 return ctx->push_pending_record(sk, flags); 211 212 sg = ctx->partially_sent_record; 213 offset = ctx->partially_sent_offset; 214 215 ctx->partially_sent_record = NULL; 216 return tls_push_sg(sk, ctx, sg, offset, flags); 217 } 218 219 static void tls_write_space(struct sock *sk) 220 { 221 struct tls_context *ctx = tls_get_ctx(sk); 222 223 /* We are already sending pages, ignore notification */ 224 if (ctx->in_tcp_sendpages) 225 return; 226 227 if (!sk->sk_write_pending && tls_is_pending_closed_record(ctx)) { 228 gfp_t sk_allocation = sk->sk_allocation; 229 int rc; 230 long timeo = 0; 231 232 sk->sk_allocation = GFP_ATOMIC; 233 rc = tls_push_pending_closed_record(sk, ctx, 234 MSG_DONTWAIT | 235 MSG_NOSIGNAL, 236 &timeo); 237 sk->sk_allocation = sk_allocation; 238 239 if (rc < 0) 240 return; 241 } 242 243 ctx->sk_write_space(sk); 244 } 245 246 static void tls_sk_proto_close(struct sock *sk, long timeout) 247 { 248 struct tls_context *ctx = tls_get_ctx(sk); 249 long timeo = sock_sndtimeo(sk, 0); 250 void (*sk_proto_close)(struct sock *sk, long timeout); 251 252 lock_sock(sk); 253 sk_proto_close = ctx->sk_proto_close; 254 255 if (ctx->conf == TLS_HW_RECORD) 256 goto skip_tx_cleanup; 257 258 if (ctx->conf == TLS_BASE) { 259 kfree(ctx); 260 ctx = NULL; 261 goto skip_tx_cleanup; 262 } 263 264 if (!tls_complete_pending_work(sk, ctx, 0, &timeo)) 265 tls_handle_open_record(sk, 0); 266 267 if (ctx->partially_sent_record) { 268 struct scatterlist *sg = ctx->partially_sent_record; 269 270 while (1) { 271 put_page(sg_page(sg)); 272 sk_mem_uncharge(sk, sg->length); 273 274 if (sg_is_last(sg)) 275 break; 276 sg++; 277 } 278 } 279 280 kfree(ctx->tx.rec_seq); 281 kfree(ctx->tx.iv); 282 kfree(ctx->rx.rec_seq); 283 kfree(ctx->rx.iv); 284 285 if (ctx->conf == TLS_SW_TX || 286 ctx->conf == TLS_SW_RX || 287 ctx->conf == TLS_SW_RXTX) { 288 tls_sw_free_resources(sk); 289 } 290 291 skip_tx_cleanup: 292 release_sock(sk); 293 sk_proto_close(sk, timeout); 294 /* free ctx for TLS_HW_RECORD, used by tcp_set_state 295 * for sk->sk_prot->unhash [tls_hw_unhash] 296 */ 297 if (ctx && ctx->conf == TLS_HW_RECORD) 298 kfree(ctx); 299 } 300 301 static int do_tls_getsockopt_tx(struct sock *sk, char __user *optval, 302 int __user *optlen) 303 { 304 int rc = 0; 305 struct tls_context *ctx = tls_get_ctx(sk); 306 struct tls_crypto_info *crypto_info; 307 int len; 308 309 if (get_user(len, optlen)) 310 return -EFAULT; 311 312 if (!optval || (len < sizeof(*crypto_info))) { 313 rc = -EINVAL; 314 goto out; 315 } 316 317 if (!ctx) { 318 rc = -EBUSY; 319 goto out; 320 } 321 322 /* get user crypto info */ 323 crypto_info = &ctx->crypto_send; 324 325 if (!TLS_CRYPTO_INFO_READY(crypto_info)) { 326 rc = -EBUSY; 327 goto out; 328 } 329 330 if (len == sizeof(*crypto_info)) { 331 if (copy_to_user(optval, crypto_info, sizeof(*crypto_info))) 332 rc = -EFAULT; 333 goto out; 334 } 335 336 switch (crypto_info->cipher_type) { 337 case TLS_CIPHER_AES_GCM_128: { 338 struct tls12_crypto_info_aes_gcm_128 * 339 crypto_info_aes_gcm_128 = 340 container_of(crypto_info, 341 struct tls12_crypto_info_aes_gcm_128, 342 info); 343 344 if (len != sizeof(*crypto_info_aes_gcm_128)) { 345 rc = -EINVAL; 346 goto out; 347 } 348 lock_sock(sk); 349 memcpy(crypto_info_aes_gcm_128->iv, 350 ctx->tx.iv + TLS_CIPHER_AES_GCM_128_SALT_SIZE, 351 TLS_CIPHER_AES_GCM_128_IV_SIZE); 352 memcpy(crypto_info_aes_gcm_128->rec_seq, ctx->tx.rec_seq, 353 TLS_CIPHER_AES_GCM_128_REC_SEQ_SIZE); 354 release_sock(sk); 355 if (copy_to_user(optval, 356 crypto_info_aes_gcm_128, 357 sizeof(*crypto_info_aes_gcm_128))) 358 rc = -EFAULT; 359 break; 360 } 361 default: 362 rc = -EINVAL; 363 } 364 365 out: 366 return rc; 367 } 368 369 static int do_tls_getsockopt(struct sock *sk, int optname, 370 char __user *optval, int __user *optlen) 371 { 372 int rc = 0; 373 374 switch (optname) { 375 case TLS_TX: 376 rc = do_tls_getsockopt_tx(sk, optval, optlen); 377 break; 378 default: 379 rc = -ENOPROTOOPT; 380 break; 381 } 382 return rc; 383 } 384 385 static int tls_getsockopt(struct sock *sk, int level, int optname, 386 char __user *optval, int __user *optlen) 387 { 388 struct tls_context *ctx = tls_get_ctx(sk); 389 390 if (level != SOL_TLS) 391 return ctx->getsockopt(sk, level, optname, optval, optlen); 392 393 return do_tls_getsockopt(sk, optname, optval, optlen); 394 } 395 396 static int do_tls_setsockopt_conf(struct sock *sk, char __user *optval, 397 unsigned int optlen, int tx) 398 { 399 struct tls_crypto_info *crypto_info; 400 struct tls_context *ctx = tls_get_ctx(sk); 401 int rc = 0; 402 int conf; 403 404 if (!optval || (optlen < sizeof(*crypto_info))) { 405 rc = -EINVAL; 406 goto out; 407 } 408 409 if (tx) 410 crypto_info = &ctx->crypto_send; 411 else 412 crypto_info = &ctx->crypto_recv; 413 414 /* Currently we don't support set crypto info more than one time */ 415 if (TLS_CRYPTO_INFO_READY(crypto_info)) { 416 rc = -EBUSY; 417 goto out; 418 } 419 420 rc = copy_from_user(crypto_info, optval, sizeof(*crypto_info)); 421 if (rc) { 422 rc = -EFAULT; 423 goto err_crypto_info; 424 } 425 426 /* check version */ 427 if (crypto_info->version != TLS_1_2_VERSION) { 428 rc = -ENOTSUPP; 429 goto err_crypto_info; 430 } 431 432 switch (crypto_info->cipher_type) { 433 case TLS_CIPHER_AES_GCM_128: { 434 if (optlen != sizeof(struct tls12_crypto_info_aes_gcm_128)) { 435 rc = -EINVAL; 436 goto err_crypto_info; 437 } 438 rc = copy_from_user(crypto_info + 1, optval + sizeof(*crypto_info), 439 optlen - sizeof(*crypto_info)); 440 if (rc) { 441 rc = -EFAULT; 442 goto err_crypto_info; 443 } 444 break; 445 } 446 default: 447 rc = -EINVAL; 448 goto err_crypto_info; 449 } 450 451 /* currently SW is default, we will have ethtool in future */ 452 if (tx) { 453 rc = tls_set_sw_offload(sk, ctx, 1); 454 if (ctx->conf == TLS_SW_RX) 455 conf = TLS_SW_RXTX; 456 else 457 conf = TLS_SW_TX; 458 } else { 459 rc = tls_set_sw_offload(sk, ctx, 0); 460 if (ctx->conf == TLS_SW_TX) 461 conf = TLS_SW_RXTX; 462 else 463 conf = TLS_SW_RX; 464 } 465 466 if (rc) 467 goto err_crypto_info; 468 469 ctx->conf = conf; 470 update_sk_prot(sk, ctx); 471 if (tx) { 472 ctx->sk_write_space = sk->sk_write_space; 473 sk->sk_write_space = tls_write_space; 474 } else { 475 sk->sk_socket->ops = &tls_sw_proto_ops; 476 } 477 goto out; 478 479 err_crypto_info: 480 memset(crypto_info, 0, sizeof(*crypto_info)); 481 out: 482 return rc; 483 } 484 485 static int do_tls_setsockopt(struct sock *sk, int optname, 486 char __user *optval, unsigned int optlen) 487 { 488 int rc = 0; 489 490 switch (optname) { 491 case TLS_TX: 492 case TLS_RX: 493 lock_sock(sk); 494 rc = do_tls_setsockopt_conf(sk, optval, optlen, 495 optname == TLS_TX); 496 release_sock(sk); 497 break; 498 default: 499 rc = -ENOPROTOOPT; 500 break; 501 } 502 return rc; 503 } 504 505 static int tls_setsockopt(struct sock *sk, int level, int optname, 506 char __user *optval, unsigned int optlen) 507 { 508 struct tls_context *ctx = tls_get_ctx(sk); 509 510 if (level != SOL_TLS) 511 return ctx->setsockopt(sk, level, optname, optval, optlen); 512 513 return do_tls_setsockopt(sk, optname, optval, optlen); 514 } 515 516 static struct tls_context *create_ctx(struct sock *sk) 517 { 518 struct inet_connection_sock *icsk = inet_csk(sk); 519 struct tls_context *ctx; 520 521 ctx = kzalloc(sizeof(*ctx), GFP_KERNEL); 522 if (!ctx) 523 return NULL; 524 525 icsk->icsk_ulp_data = ctx; 526 return ctx; 527 } 528 529 static int tls_hw_prot(struct sock *sk) 530 { 531 struct tls_context *ctx; 532 struct tls_device *dev; 533 int rc = 0; 534 535 mutex_lock(&device_mutex); 536 list_for_each_entry(dev, &device_list, dev_list) { 537 if (dev->feature && dev->feature(dev)) { 538 ctx = create_ctx(sk); 539 if (!ctx) 540 goto out; 541 542 ctx->hash = sk->sk_prot->hash; 543 ctx->unhash = sk->sk_prot->unhash; 544 ctx->sk_proto_close = sk->sk_prot->close; 545 ctx->conf = TLS_HW_RECORD; 546 update_sk_prot(sk, ctx); 547 rc = 1; 548 break; 549 } 550 } 551 out: 552 mutex_unlock(&device_mutex); 553 return rc; 554 } 555 556 static void tls_hw_unhash(struct sock *sk) 557 { 558 struct tls_context *ctx = tls_get_ctx(sk); 559 struct tls_device *dev; 560 561 mutex_lock(&device_mutex); 562 list_for_each_entry(dev, &device_list, dev_list) { 563 if (dev->unhash) 564 dev->unhash(dev, sk); 565 } 566 mutex_unlock(&device_mutex); 567 ctx->unhash(sk); 568 } 569 570 static int tls_hw_hash(struct sock *sk) 571 { 572 struct tls_context *ctx = tls_get_ctx(sk); 573 struct tls_device *dev; 574 int err; 575 576 err = ctx->hash(sk); 577 mutex_lock(&device_mutex); 578 list_for_each_entry(dev, &device_list, dev_list) { 579 if (dev->hash) 580 err |= dev->hash(dev, sk); 581 } 582 mutex_unlock(&device_mutex); 583 584 if (err) 585 tls_hw_unhash(sk); 586 return err; 587 } 588 589 static void build_protos(struct proto *prot, struct proto *base) 590 { 591 prot[TLS_BASE] = *base; 592 prot[TLS_BASE].setsockopt = tls_setsockopt; 593 prot[TLS_BASE].getsockopt = tls_getsockopt; 594 prot[TLS_BASE].close = tls_sk_proto_close; 595 596 prot[TLS_SW_TX] = prot[TLS_BASE]; 597 prot[TLS_SW_TX].sendmsg = tls_sw_sendmsg; 598 prot[TLS_SW_TX].sendpage = tls_sw_sendpage; 599 600 prot[TLS_SW_RX] = prot[TLS_BASE]; 601 prot[TLS_SW_RX].recvmsg = tls_sw_recvmsg; 602 prot[TLS_SW_RX].close = tls_sk_proto_close; 603 604 prot[TLS_SW_RXTX] = prot[TLS_SW_TX]; 605 prot[TLS_SW_RXTX].recvmsg = tls_sw_recvmsg; 606 prot[TLS_SW_RXTX].close = tls_sk_proto_close; 607 608 prot[TLS_HW_RECORD] = *base; 609 prot[TLS_HW_RECORD].hash = tls_hw_hash; 610 prot[TLS_HW_RECORD].unhash = tls_hw_unhash; 611 prot[TLS_HW_RECORD].close = tls_sk_proto_close; 612 } 613 614 static int tls_init(struct sock *sk) 615 { 616 int ip_ver = sk->sk_family == AF_INET6 ? TLSV6 : TLSV4; 617 struct tls_context *ctx; 618 int rc = 0; 619 620 if (tls_hw_prot(sk)) 621 goto out; 622 623 /* The TLS ulp is currently supported only for TCP sockets 624 * in ESTABLISHED state. 625 * Supporting sockets in LISTEN state will require us 626 * to modify the accept implementation to clone rather then 627 * share the ulp context. 628 */ 629 if (sk->sk_state != TCP_ESTABLISHED) 630 return -ENOTSUPP; 631 632 /* allocate tls context */ 633 ctx = create_ctx(sk); 634 if (!ctx) { 635 rc = -ENOMEM; 636 goto out; 637 } 638 ctx->setsockopt = sk->sk_prot->setsockopt; 639 ctx->getsockopt = sk->sk_prot->getsockopt; 640 ctx->sk_proto_close = sk->sk_prot->close; 641 642 /* Build IPv6 TLS whenever the address of tcpv6_prot changes */ 643 if (ip_ver == TLSV6 && 644 unlikely(sk->sk_prot != smp_load_acquire(&saved_tcpv6_prot))) { 645 mutex_lock(&tcpv6_prot_mutex); 646 if (likely(sk->sk_prot != saved_tcpv6_prot)) { 647 build_protos(tls_prots[TLSV6], sk->sk_prot); 648 smp_store_release(&saved_tcpv6_prot, sk->sk_prot); 649 } 650 mutex_unlock(&tcpv6_prot_mutex); 651 } 652 653 ctx->conf = TLS_BASE; 654 update_sk_prot(sk, ctx); 655 out: 656 return rc; 657 } 658 659 void tls_register_device(struct tls_device *device) 660 { 661 mutex_lock(&device_mutex); 662 list_add_tail(&device->dev_list, &device_list); 663 mutex_unlock(&device_mutex); 664 } 665 EXPORT_SYMBOL(tls_register_device); 666 667 void tls_unregister_device(struct tls_device *device) 668 { 669 mutex_lock(&device_mutex); 670 list_del(&device->dev_list); 671 mutex_unlock(&device_mutex); 672 } 673 EXPORT_SYMBOL(tls_unregister_device); 674 675 static struct tcp_ulp_ops tcp_tls_ulp_ops __read_mostly = { 676 .name = "tls", 677 .uid = TCP_ULP_TLS, 678 .user_visible = true, 679 .owner = THIS_MODULE, 680 .init = tls_init, 681 }; 682 683 static int __init tls_register(void) 684 { 685 build_protos(tls_prots[TLSV4], &tcp_prot); 686 687 tls_sw_proto_ops = inet_stream_ops; 688 tls_sw_proto_ops.poll = tls_sw_poll; 689 tls_sw_proto_ops.splice_read = tls_sw_splice_read; 690 691 tcp_register_ulp(&tcp_tls_ulp_ops); 692 693 return 0; 694 } 695 696 static void __exit tls_unregister(void) 697 { 698 tcp_unregister_ulp(&tcp_tls_ulp_ops); 699 } 700 701 module_init(tls_register); 702 module_exit(tls_unregister); 703