1 /* 2 * Copyright (c) 2016-2017, Mellanox Technologies. All rights reserved. 3 * Copyright (c) 2016-2017, Dave Watson <davejwatson@fb.com>. All rights reserved. 4 * Copyright (c) 2016-2017, Lance Chao <lancerchao@fb.com>. All rights reserved. 5 * Copyright (c) 2016, Fridolin Pokorny <fridolin.pokorny@gmail.com>. All rights reserved. 6 * Copyright (c) 2016, Nikos Mavrogiannopoulos <nmav@gnutls.org>. All rights reserved. 7 * 8 * This software is available to you under a choice of one of two 9 * licenses. You may choose to be licensed under the terms of the GNU 10 * General Public License (GPL) Version 2, available from the file 11 * COPYING in the main directory of this source tree, or the 12 * OpenIB.org BSD license below: 13 * 14 * Redistribution and use in source and binary forms, with or 15 * without modification, are permitted provided that the following 16 * conditions are met: 17 * 18 * - Redistributions of source code must retain the above 19 * copyright notice, this list of conditions and the following 20 * disclaimer. 21 * 22 * - Redistributions in binary form must reproduce the above 23 * copyright notice, this list of conditions and the following 24 * disclaimer in the documentation and/or other materials 25 * provided with the distribution. 26 * 27 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, 28 * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF 29 * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND 30 * NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS 31 * BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN 32 * ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN 33 * CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 34 * SOFTWARE. 35 */ 36 37 #include <linux/sched/signal.h> 38 #include <linux/module.h> 39 #include <crypto/aead.h> 40 41 #include <net/strparser.h> 42 #include <net/tls.h> 43 44 #define MAX_IV_SIZE TLS_CIPHER_AES_GCM_128_IV_SIZE 45 46 static int tls_do_decryption(struct sock *sk, 47 struct scatterlist *sgin, 48 struct scatterlist *sgout, 49 char *iv_recv, 50 size_t data_len, 51 struct sk_buff *skb, 52 gfp_t flags) 53 { 54 struct tls_context *tls_ctx = tls_get_ctx(sk); 55 struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx); 56 struct aead_request *aead_req; 57 58 int ret; 59 60 aead_req = aead_request_alloc(ctx->aead_recv, flags); 61 if (!aead_req) 62 return -ENOMEM; 63 64 aead_request_set_ad(aead_req, TLS_AAD_SPACE_SIZE); 65 aead_request_set_crypt(aead_req, sgin, sgout, 66 data_len + tls_ctx->rx.tag_size, 67 (u8 *)iv_recv); 68 aead_request_set_callback(aead_req, CRYPTO_TFM_REQ_MAY_BACKLOG, 69 crypto_req_done, &ctx->async_wait); 70 71 ret = crypto_wait_req(crypto_aead_decrypt(aead_req), &ctx->async_wait); 72 73 aead_request_free(aead_req); 74 return ret; 75 } 76 77 static void trim_sg(struct sock *sk, struct scatterlist *sg, 78 int *sg_num_elem, unsigned int *sg_size, int target_size) 79 { 80 int i = *sg_num_elem - 1; 81 int trim = *sg_size - target_size; 82 83 if (trim <= 0) { 84 WARN_ON(trim < 0); 85 return; 86 } 87 88 *sg_size = target_size; 89 while (trim >= sg[i].length) { 90 trim -= sg[i].length; 91 sk_mem_uncharge(sk, sg[i].length); 92 put_page(sg_page(&sg[i])); 93 i--; 94 95 if (i < 0) 96 goto out; 97 } 98 99 sg[i].length -= trim; 100 sk_mem_uncharge(sk, trim); 101 102 out: 103 *sg_num_elem = i + 1; 104 } 105 106 static void trim_both_sgl(struct sock *sk, int target_size) 107 { 108 struct tls_context *tls_ctx = tls_get_ctx(sk); 109 struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx); 110 111 trim_sg(sk, ctx->sg_plaintext_data, 112 &ctx->sg_plaintext_num_elem, 113 &ctx->sg_plaintext_size, 114 target_size); 115 116 if (target_size > 0) 117 target_size += tls_ctx->tx.overhead_size; 118 119 trim_sg(sk, ctx->sg_encrypted_data, 120 &ctx->sg_encrypted_num_elem, 121 &ctx->sg_encrypted_size, 122 target_size); 123 } 124 125 static int alloc_encrypted_sg(struct sock *sk, int len) 126 { 127 struct tls_context *tls_ctx = tls_get_ctx(sk); 128 struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx); 129 int rc = 0; 130 131 rc = sk_alloc_sg(sk, len, 132 ctx->sg_encrypted_data, 0, 133 &ctx->sg_encrypted_num_elem, 134 &ctx->sg_encrypted_size, 0); 135 136 return rc; 137 } 138 139 static int alloc_plaintext_sg(struct sock *sk, int len) 140 { 141 struct tls_context *tls_ctx = tls_get_ctx(sk); 142 struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx); 143 int rc = 0; 144 145 rc = sk_alloc_sg(sk, len, ctx->sg_plaintext_data, 0, 146 &ctx->sg_plaintext_num_elem, &ctx->sg_plaintext_size, 147 tls_ctx->pending_open_record_frags); 148 149 return rc; 150 } 151 152 static void free_sg(struct sock *sk, struct scatterlist *sg, 153 int *sg_num_elem, unsigned int *sg_size) 154 { 155 int i, n = *sg_num_elem; 156 157 for (i = 0; i < n; ++i) { 158 sk_mem_uncharge(sk, sg[i].length); 159 put_page(sg_page(&sg[i])); 160 } 161 *sg_num_elem = 0; 162 *sg_size = 0; 163 } 164 165 static void tls_free_both_sg(struct sock *sk) 166 { 167 struct tls_context *tls_ctx = tls_get_ctx(sk); 168 struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx); 169 170 free_sg(sk, ctx->sg_encrypted_data, &ctx->sg_encrypted_num_elem, 171 &ctx->sg_encrypted_size); 172 173 free_sg(sk, ctx->sg_plaintext_data, &ctx->sg_plaintext_num_elem, 174 &ctx->sg_plaintext_size); 175 } 176 177 static int tls_do_encryption(struct tls_context *tls_ctx, 178 struct tls_sw_context_tx *ctx, 179 struct aead_request *aead_req, 180 size_t data_len) 181 { 182 int rc; 183 184 ctx->sg_encrypted_data[0].offset += tls_ctx->tx.prepend_size; 185 ctx->sg_encrypted_data[0].length -= tls_ctx->tx.prepend_size; 186 187 aead_request_set_tfm(aead_req, ctx->aead_send); 188 aead_request_set_ad(aead_req, TLS_AAD_SPACE_SIZE); 189 aead_request_set_crypt(aead_req, ctx->sg_aead_in, ctx->sg_aead_out, 190 data_len, tls_ctx->tx.iv); 191 192 aead_request_set_callback(aead_req, CRYPTO_TFM_REQ_MAY_BACKLOG, 193 crypto_req_done, &ctx->async_wait); 194 195 rc = crypto_wait_req(crypto_aead_encrypt(aead_req), &ctx->async_wait); 196 197 ctx->sg_encrypted_data[0].offset -= tls_ctx->tx.prepend_size; 198 ctx->sg_encrypted_data[0].length += tls_ctx->tx.prepend_size; 199 200 return rc; 201 } 202 203 static int tls_push_record(struct sock *sk, int flags, 204 unsigned char record_type) 205 { 206 struct tls_context *tls_ctx = tls_get_ctx(sk); 207 struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx); 208 struct aead_request *req; 209 int rc; 210 211 req = aead_request_alloc(ctx->aead_send, sk->sk_allocation); 212 if (!req) 213 return -ENOMEM; 214 215 sg_mark_end(ctx->sg_plaintext_data + ctx->sg_plaintext_num_elem - 1); 216 sg_mark_end(ctx->sg_encrypted_data + ctx->sg_encrypted_num_elem - 1); 217 218 tls_make_aad(ctx->aad_space, ctx->sg_plaintext_size, 219 tls_ctx->tx.rec_seq, tls_ctx->tx.rec_seq_size, 220 record_type); 221 222 tls_fill_prepend(tls_ctx, 223 page_address(sg_page(&ctx->sg_encrypted_data[0])) + 224 ctx->sg_encrypted_data[0].offset, 225 ctx->sg_plaintext_size, record_type); 226 227 tls_ctx->pending_open_record_frags = 0; 228 set_bit(TLS_PENDING_CLOSED_RECORD, &tls_ctx->flags); 229 230 rc = tls_do_encryption(tls_ctx, ctx, req, ctx->sg_plaintext_size); 231 if (rc < 0) { 232 /* If we are called from write_space and 233 * we fail, we need to set this SOCK_NOSPACE 234 * to trigger another write_space in the future. 235 */ 236 set_bit(SOCK_NOSPACE, &sk->sk_socket->flags); 237 goto out_req; 238 } 239 240 free_sg(sk, ctx->sg_plaintext_data, &ctx->sg_plaintext_num_elem, 241 &ctx->sg_plaintext_size); 242 243 ctx->sg_encrypted_num_elem = 0; 244 ctx->sg_encrypted_size = 0; 245 246 /* Only pass through MSG_DONTWAIT and MSG_NOSIGNAL flags */ 247 rc = tls_push_sg(sk, tls_ctx, ctx->sg_encrypted_data, 0, flags); 248 if (rc < 0 && rc != -EAGAIN) 249 tls_err_abort(sk, EBADMSG); 250 251 tls_advance_record_sn(sk, &tls_ctx->tx); 252 out_req: 253 aead_request_free(req); 254 return rc; 255 } 256 257 static int tls_sw_push_pending_record(struct sock *sk, int flags) 258 { 259 return tls_push_record(sk, flags, TLS_RECORD_TYPE_DATA); 260 } 261 262 static int zerocopy_from_iter(struct sock *sk, struct iov_iter *from, 263 int length, int *pages_used, 264 unsigned int *size_used, 265 struct scatterlist *to, int to_max_pages, 266 bool charge, bool revert) 267 { 268 struct page *pages[MAX_SKB_FRAGS]; 269 270 size_t offset; 271 ssize_t copied, use; 272 int i = 0; 273 unsigned int size = *size_used; 274 int num_elem = *pages_used; 275 int rc = 0; 276 int maxpages; 277 278 while (length > 0) { 279 i = 0; 280 maxpages = to_max_pages - num_elem; 281 if (maxpages == 0) { 282 rc = -EFAULT; 283 goto out; 284 } 285 copied = iov_iter_get_pages(from, pages, 286 length, 287 maxpages, &offset); 288 if (copied <= 0) { 289 rc = -EFAULT; 290 goto out; 291 } 292 293 iov_iter_advance(from, copied); 294 295 length -= copied; 296 size += copied; 297 while (copied) { 298 use = min_t(int, copied, PAGE_SIZE - offset); 299 300 sg_set_page(&to[num_elem], 301 pages[i], use, offset); 302 sg_unmark_end(&to[num_elem]); 303 if (charge) 304 sk_mem_charge(sk, use); 305 306 offset = 0; 307 copied -= use; 308 309 ++i; 310 ++num_elem; 311 } 312 } 313 314 out: 315 *size_used = size; 316 *pages_used = num_elem; 317 if (revert) 318 iov_iter_revert(from, size); 319 320 return rc; 321 } 322 323 static int memcopy_from_iter(struct sock *sk, struct iov_iter *from, 324 int bytes) 325 { 326 struct tls_context *tls_ctx = tls_get_ctx(sk); 327 struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx); 328 struct scatterlist *sg = ctx->sg_plaintext_data; 329 int copy, i, rc = 0; 330 331 for (i = tls_ctx->pending_open_record_frags; 332 i < ctx->sg_plaintext_num_elem; ++i) { 333 copy = sg[i].length; 334 if (copy_from_iter( 335 page_address(sg_page(&sg[i])) + sg[i].offset, 336 copy, from) != copy) { 337 rc = -EFAULT; 338 goto out; 339 } 340 bytes -= copy; 341 342 ++tls_ctx->pending_open_record_frags; 343 344 if (!bytes) 345 break; 346 } 347 348 out: 349 return rc; 350 } 351 352 int tls_sw_sendmsg(struct sock *sk, struct msghdr *msg, size_t size) 353 { 354 struct tls_context *tls_ctx = tls_get_ctx(sk); 355 struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx); 356 int ret = 0; 357 int required_size; 358 long timeo = sock_sndtimeo(sk, msg->msg_flags & MSG_DONTWAIT); 359 bool eor = !(msg->msg_flags & MSG_MORE); 360 size_t try_to_copy, copied = 0; 361 unsigned char record_type = TLS_RECORD_TYPE_DATA; 362 int record_room; 363 bool full_record; 364 int orig_size; 365 366 if (msg->msg_flags & ~(MSG_MORE | MSG_DONTWAIT | MSG_NOSIGNAL)) 367 return -ENOTSUPP; 368 369 lock_sock(sk); 370 371 if (tls_complete_pending_work(sk, tls_ctx, msg->msg_flags, &timeo)) 372 goto send_end; 373 374 if (unlikely(msg->msg_controllen)) { 375 ret = tls_proccess_cmsg(sk, msg, &record_type); 376 if (ret) 377 goto send_end; 378 } 379 380 while (msg_data_left(msg)) { 381 if (sk->sk_err) { 382 ret = -sk->sk_err; 383 goto send_end; 384 } 385 386 orig_size = ctx->sg_plaintext_size; 387 full_record = false; 388 try_to_copy = msg_data_left(msg); 389 record_room = TLS_MAX_PAYLOAD_SIZE - ctx->sg_plaintext_size; 390 if (try_to_copy >= record_room) { 391 try_to_copy = record_room; 392 full_record = true; 393 } 394 395 required_size = ctx->sg_plaintext_size + try_to_copy + 396 tls_ctx->tx.overhead_size; 397 398 if (!sk_stream_memory_free(sk)) 399 goto wait_for_sndbuf; 400 alloc_encrypted: 401 ret = alloc_encrypted_sg(sk, required_size); 402 if (ret) { 403 if (ret != -ENOSPC) 404 goto wait_for_memory; 405 406 /* Adjust try_to_copy according to the amount that was 407 * actually allocated. The difference is due 408 * to max sg elements limit 409 */ 410 try_to_copy -= required_size - ctx->sg_encrypted_size; 411 full_record = true; 412 } 413 414 if (full_record || eor) { 415 ret = zerocopy_from_iter(sk, &msg->msg_iter, 416 try_to_copy, &ctx->sg_plaintext_num_elem, 417 &ctx->sg_plaintext_size, 418 ctx->sg_plaintext_data, 419 ARRAY_SIZE(ctx->sg_plaintext_data), 420 true, false); 421 if (ret) 422 goto fallback_to_reg_send; 423 424 copied += try_to_copy; 425 ret = tls_push_record(sk, msg->msg_flags, record_type); 426 if (!ret) 427 continue; 428 if (ret < 0) 429 goto send_end; 430 431 copied -= try_to_copy; 432 fallback_to_reg_send: 433 iov_iter_revert(&msg->msg_iter, 434 ctx->sg_plaintext_size - orig_size); 435 trim_sg(sk, ctx->sg_plaintext_data, 436 &ctx->sg_plaintext_num_elem, 437 &ctx->sg_plaintext_size, 438 orig_size); 439 } 440 441 required_size = ctx->sg_plaintext_size + try_to_copy; 442 alloc_plaintext: 443 ret = alloc_plaintext_sg(sk, required_size); 444 if (ret) { 445 if (ret != -ENOSPC) 446 goto wait_for_memory; 447 448 /* Adjust try_to_copy according to the amount that was 449 * actually allocated. The difference is due 450 * to max sg elements limit 451 */ 452 try_to_copy -= required_size - ctx->sg_plaintext_size; 453 full_record = true; 454 455 trim_sg(sk, ctx->sg_encrypted_data, 456 &ctx->sg_encrypted_num_elem, 457 &ctx->sg_encrypted_size, 458 ctx->sg_plaintext_size + 459 tls_ctx->tx.overhead_size); 460 } 461 462 ret = memcopy_from_iter(sk, &msg->msg_iter, try_to_copy); 463 if (ret) 464 goto trim_sgl; 465 466 copied += try_to_copy; 467 if (full_record || eor) { 468 push_record: 469 ret = tls_push_record(sk, msg->msg_flags, record_type); 470 if (ret) { 471 if (ret == -ENOMEM) 472 goto wait_for_memory; 473 474 goto send_end; 475 } 476 } 477 478 continue; 479 480 wait_for_sndbuf: 481 set_bit(SOCK_NOSPACE, &sk->sk_socket->flags); 482 wait_for_memory: 483 ret = sk_stream_wait_memory(sk, &timeo); 484 if (ret) { 485 trim_sgl: 486 trim_both_sgl(sk, orig_size); 487 goto send_end; 488 } 489 490 if (tls_is_pending_closed_record(tls_ctx)) 491 goto push_record; 492 493 if (ctx->sg_encrypted_size < required_size) 494 goto alloc_encrypted; 495 496 goto alloc_plaintext; 497 } 498 499 send_end: 500 ret = sk_stream_error(sk, msg->msg_flags, ret); 501 502 release_sock(sk); 503 return copied ? copied : ret; 504 } 505 506 int tls_sw_sendpage(struct sock *sk, struct page *page, 507 int offset, size_t size, int flags) 508 { 509 struct tls_context *tls_ctx = tls_get_ctx(sk); 510 struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx); 511 int ret = 0; 512 long timeo = sock_sndtimeo(sk, flags & MSG_DONTWAIT); 513 bool eor; 514 size_t orig_size = size; 515 unsigned char record_type = TLS_RECORD_TYPE_DATA; 516 struct scatterlist *sg; 517 bool full_record; 518 int record_room; 519 520 if (flags & ~(MSG_MORE | MSG_DONTWAIT | MSG_NOSIGNAL | 521 MSG_SENDPAGE_NOTLAST)) 522 return -ENOTSUPP; 523 524 /* No MSG_EOR from splice, only look at MSG_MORE */ 525 eor = !(flags & (MSG_MORE | MSG_SENDPAGE_NOTLAST)); 526 527 lock_sock(sk); 528 529 sk_clear_bit(SOCKWQ_ASYNC_NOSPACE, sk); 530 531 if (tls_complete_pending_work(sk, tls_ctx, flags, &timeo)) 532 goto sendpage_end; 533 534 /* Call the sk_stream functions to manage the sndbuf mem. */ 535 while (size > 0) { 536 size_t copy, required_size; 537 538 if (sk->sk_err) { 539 ret = -sk->sk_err; 540 goto sendpage_end; 541 } 542 543 full_record = false; 544 record_room = TLS_MAX_PAYLOAD_SIZE - ctx->sg_plaintext_size; 545 copy = size; 546 if (copy >= record_room) { 547 copy = record_room; 548 full_record = true; 549 } 550 required_size = ctx->sg_plaintext_size + copy + 551 tls_ctx->tx.overhead_size; 552 553 if (!sk_stream_memory_free(sk)) 554 goto wait_for_sndbuf; 555 alloc_payload: 556 ret = alloc_encrypted_sg(sk, required_size); 557 if (ret) { 558 if (ret != -ENOSPC) 559 goto wait_for_memory; 560 561 /* Adjust copy according to the amount that was 562 * actually allocated. The difference is due 563 * to max sg elements limit 564 */ 565 copy -= required_size - ctx->sg_plaintext_size; 566 full_record = true; 567 } 568 569 get_page(page); 570 sg = ctx->sg_plaintext_data + ctx->sg_plaintext_num_elem; 571 sg_set_page(sg, page, copy, offset); 572 sg_unmark_end(sg); 573 574 ctx->sg_plaintext_num_elem++; 575 576 sk_mem_charge(sk, copy); 577 offset += copy; 578 size -= copy; 579 ctx->sg_plaintext_size += copy; 580 tls_ctx->pending_open_record_frags = ctx->sg_plaintext_num_elem; 581 582 if (full_record || eor || 583 ctx->sg_plaintext_num_elem == 584 ARRAY_SIZE(ctx->sg_plaintext_data)) { 585 push_record: 586 ret = tls_push_record(sk, flags, record_type); 587 if (ret) { 588 if (ret == -ENOMEM) 589 goto wait_for_memory; 590 591 goto sendpage_end; 592 } 593 } 594 continue; 595 wait_for_sndbuf: 596 set_bit(SOCK_NOSPACE, &sk->sk_socket->flags); 597 wait_for_memory: 598 ret = sk_stream_wait_memory(sk, &timeo); 599 if (ret) { 600 trim_both_sgl(sk, ctx->sg_plaintext_size); 601 goto sendpage_end; 602 } 603 604 if (tls_is_pending_closed_record(tls_ctx)) 605 goto push_record; 606 607 goto alloc_payload; 608 } 609 610 sendpage_end: 611 if (orig_size > size) 612 ret = orig_size - size; 613 else 614 ret = sk_stream_error(sk, flags, ret); 615 616 release_sock(sk); 617 return ret; 618 } 619 620 static struct sk_buff *tls_wait_data(struct sock *sk, int flags, 621 long timeo, int *err) 622 { 623 struct tls_context *tls_ctx = tls_get_ctx(sk); 624 struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx); 625 struct sk_buff *skb; 626 DEFINE_WAIT_FUNC(wait, woken_wake_function); 627 628 while (!(skb = ctx->recv_pkt)) { 629 if (sk->sk_err) { 630 *err = sock_error(sk); 631 return NULL; 632 } 633 634 if (sock_flag(sk, SOCK_DONE)) 635 return NULL; 636 637 if ((flags & MSG_DONTWAIT) || !timeo) { 638 *err = -EAGAIN; 639 return NULL; 640 } 641 642 add_wait_queue(sk_sleep(sk), &wait); 643 sk_set_bit(SOCKWQ_ASYNC_WAITDATA, sk); 644 sk_wait_event(sk, &timeo, ctx->recv_pkt != skb, &wait); 645 sk_clear_bit(SOCKWQ_ASYNC_WAITDATA, sk); 646 remove_wait_queue(sk_sleep(sk), &wait); 647 648 /* Handle signals */ 649 if (signal_pending(current)) { 650 *err = sock_intr_errno(timeo); 651 return NULL; 652 } 653 } 654 655 return skb; 656 } 657 658 static int decrypt_skb_update(struct sock *sk, struct sk_buff *skb, 659 struct scatterlist *sgout, bool *zc) 660 { 661 struct tls_context *tls_ctx = tls_get_ctx(sk); 662 struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx); 663 struct strp_msg *rxm = strp_msg(skb); 664 int err = 0; 665 666 #ifdef CONFIG_TLS_DEVICE 667 err = tls_device_decrypted(sk, skb); 668 if (err < 0) 669 return err; 670 #endif 671 if (!ctx->decrypted) { 672 err = decrypt_skb(sk, skb, sgout); 673 if (err < 0) 674 return err; 675 } else { 676 *zc = false; 677 } 678 679 rxm->offset += tls_ctx->rx.prepend_size; 680 rxm->full_len -= tls_ctx->rx.overhead_size; 681 tls_advance_record_sn(sk, &tls_ctx->rx); 682 ctx->decrypted = true; 683 ctx->saved_data_ready(sk); 684 685 return err; 686 } 687 688 int decrypt_skb(struct sock *sk, struct sk_buff *skb, 689 struct scatterlist *sgout) 690 { 691 struct tls_context *tls_ctx = tls_get_ctx(sk); 692 struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx); 693 char iv[TLS_CIPHER_AES_GCM_128_SALT_SIZE + MAX_IV_SIZE]; 694 struct scatterlist sgin_arr[MAX_SKB_FRAGS + 2]; 695 struct scatterlist *sgin = &sgin_arr[0]; 696 struct strp_msg *rxm = strp_msg(skb); 697 int ret, nsg = ARRAY_SIZE(sgin_arr); 698 struct sk_buff *unused; 699 700 ret = skb_copy_bits(skb, rxm->offset + TLS_HEADER_SIZE, 701 iv + TLS_CIPHER_AES_GCM_128_SALT_SIZE, 702 tls_ctx->rx.iv_size); 703 if (ret < 0) 704 return ret; 705 706 memcpy(iv, tls_ctx->rx.iv, TLS_CIPHER_AES_GCM_128_SALT_SIZE); 707 if (!sgout) { 708 nsg = skb_cow_data(skb, 0, &unused) + 1; 709 sgin = kmalloc_array(nsg, sizeof(*sgin), sk->sk_allocation); 710 sgout = sgin; 711 } 712 713 sg_init_table(sgin, nsg); 714 sg_set_buf(&sgin[0], ctx->rx_aad_ciphertext, TLS_AAD_SPACE_SIZE); 715 716 nsg = skb_to_sgvec(skb, &sgin[1], 717 rxm->offset + tls_ctx->rx.prepend_size, 718 rxm->full_len - tls_ctx->rx.prepend_size); 719 if (nsg < 0) { 720 ret = nsg; 721 goto out; 722 } 723 724 tls_make_aad(ctx->rx_aad_ciphertext, 725 rxm->full_len - tls_ctx->rx.overhead_size, 726 tls_ctx->rx.rec_seq, 727 tls_ctx->rx.rec_seq_size, 728 ctx->control); 729 730 ret = tls_do_decryption(sk, sgin, sgout, iv, 731 rxm->full_len - tls_ctx->rx.overhead_size, 732 skb, sk->sk_allocation); 733 734 out: 735 if (sgin != &sgin_arr[0]) 736 kfree(sgin); 737 738 return ret; 739 } 740 741 static bool tls_sw_advance_skb(struct sock *sk, struct sk_buff *skb, 742 unsigned int len) 743 { 744 struct tls_context *tls_ctx = tls_get_ctx(sk); 745 struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx); 746 struct strp_msg *rxm = strp_msg(skb); 747 748 if (len < rxm->full_len) { 749 rxm->offset += len; 750 rxm->full_len -= len; 751 752 return false; 753 } 754 755 /* Finished with message */ 756 ctx->recv_pkt = NULL; 757 kfree_skb(skb); 758 __strp_unpause(&ctx->strp); 759 760 return true; 761 } 762 763 int tls_sw_recvmsg(struct sock *sk, 764 struct msghdr *msg, 765 size_t len, 766 int nonblock, 767 int flags, 768 int *addr_len) 769 { 770 struct tls_context *tls_ctx = tls_get_ctx(sk); 771 struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx); 772 unsigned char control; 773 struct strp_msg *rxm; 774 struct sk_buff *skb; 775 ssize_t copied = 0; 776 bool cmsg = false; 777 int target, err = 0; 778 long timeo; 779 780 flags |= nonblock; 781 782 if (unlikely(flags & MSG_ERRQUEUE)) 783 return sock_recv_errqueue(sk, msg, len, SOL_IP, IP_RECVERR); 784 785 lock_sock(sk); 786 787 target = sock_rcvlowat(sk, flags & MSG_WAITALL, len); 788 timeo = sock_rcvtimeo(sk, flags & MSG_DONTWAIT); 789 do { 790 bool zc = false; 791 int chunk = 0; 792 793 skb = tls_wait_data(sk, flags, timeo, &err); 794 if (!skb) 795 goto recv_end; 796 797 rxm = strp_msg(skb); 798 if (!cmsg) { 799 int cerr; 800 801 cerr = put_cmsg(msg, SOL_TLS, TLS_GET_RECORD_TYPE, 802 sizeof(ctx->control), &ctx->control); 803 cmsg = true; 804 control = ctx->control; 805 if (ctx->control != TLS_RECORD_TYPE_DATA) { 806 if (cerr || msg->msg_flags & MSG_CTRUNC) { 807 err = -EIO; 808 goto recv_end; 809 } 810 } 811 } else if (control != ctx->control) { 812 goto recv_end; 813 } 814 815 if (!ctx->decrypted) { 816 int page_count; 817 int to_copy; 818 819 page_count = iov_iter_npages(&msg->msg_iter, 820 MAX_SKB_FRAGS); 821 to_copy = rxm->full_len - tls_ctx->rx.overhead_size; 822 if (to_copy <= len && page_count < MAX_SKB_FRAGS && 823 likely(!(flags & MSG_PEEK))) { 824 struct scatterlist sgin[MAX_SKB_FRAGS + 1]; 825 int pages = 0; 826 827 zc = true; 828 sg_init_table(sgin, MAX_SKB_FRAGS + 1); 829 sg_set_buf(&sgin[0], ctx->rx_aad_plaintext, 830 TLS_AAD_SPACE_SIZE); 831 832 err = zerocopy_from_iter(sk, &msg->msg_iter, 833 to_copy, &pages, 834 &chunk, &sgin[1], 835 MAX_SKB_FRAGS, false, true); 836 if (err < 0) 837 goto fallback_to_reg_recv; 838 839 err = decrypt_skb_update(sk, skb, sgin, &zc); 840 for (; pages > 0; pages--) 841 put_page(sg_page(&sgin[pages])); 842 if (err < 0) { 843 tls_err_abort(sk, EBADMSG); 844 goto recv_end; 845 } 846 } else { 847 fallback_to_reg_recv: 848 err = decrypt_skb_update(sk, skb, NULL, &zc); 849 if (err < 0) { 850 tls_err_abort(sk, EBADMSG); 851 goto recv_end; 852 } 853 } 854 ctx->decrypted = true; 855 } 856 857 if (!zc) { 858 chunk = min_t(unsigned int, rxm->full_len, len); 859 err = skb_copy_datagram_msg(skb, rxm->offset, msg, 860 chunk); 861 if (err < 0) 862 goto recv_end; 863 } 864 865 copied += chunk; 866 len -= chunk; 867 if (likely(!(flags & MSG_PEEK))) { 868 u8 control = ctx->control; 869 870 if (tls_sw_advance_skb(sk, skb, chunk)) { 871 /* Return full control message to 872 * userspace before trying to parse 873 * another message type 874 */ 875 msg->msg_flags |= MSG_EOR; 876 if (control != TLS_RECORD_TYPE_DATA) 877 goto recv_end; 878 } 879 } 880 /* If we have a new message from strparser, continue now. */ 881 if (copied >= target && !ctx->recv_pkt) 882 break; 883 } while (len); 884 885 recv_end: 886 release_sock(sk); 887 return copied ? : err; 888 } 889 890 ssize_t tls_sw_splice_read(struct socket *sock, loff_t *ppos, 891 struct pipe_inode_info *pipe, 892 size_t len, unsigned int flags) 893 { 894 struct tls_context *tls_ctx = tls_get_ctx(sock->sk); 895 struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx); 896 struct strp_msg *rxm = NULL; 897 struct sock *sk = sock->sk; 898 struct sk_buff *skb; 899 ssize_t copied = 0; 900 int err = 0; 901 long timeo; 902 int chunk; 903 bool zc; 904 905 lock_sock(sk); 906 907 timeo = sock_rcvtimeo(sk, flags & MSG_DONTWAIT); 908 909 skb = tls_wait_data(sk, flags, timeo, &err); 910 if (!skb) 911 goto splice_read_end; 912 913 /* splice does not support reading control messages */ 914 if (ctx->control != TLS_RECORD_TYPE_DATA) { 915 err = -ENOTSUPP; 916 goto splice_read_end; 917 } 918 919 if (!ctx->decrypted) { 920 err = decrypt_skb_update(sk, skb, NULL, &zc); 921 922 if (err < 0) { 923 tls_err_abort(sk, EBADMSG); 924 goto splice_read_end; 925 } 926 ctx->decrypted = true; 927 } 928 rxm = strp_msg(skb); 929 930 chunk = min_t(unsigned int, rxm->full_len, len); 931 copied = skb_splice_bits(skb, sk, rxm->offset, pipe, chunk, flags); 932 if (copied < 0) 933 goto splice_read_end; 934 935 if (likely(!(flags & MSG_PEEK))) 936 tls_sw_advance_skb(sk, skb, copied); 937 938 splice_read_end: 939 release_sock(sk); 940 return copied ? : err; 941 } 942 943 unsigned int tls_sw_poll(struct file *file, struct socket *sock, 944 struct poll_table_struct *wait) 945 { 946 unsigned int ret; 947 struct sock *sk = sock->sk; 948 struct tls_context *tls_ctx = tls_get_ctx(sk); 949 struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx); 950 951 /* Grab POLLOUT and POLLHUP from the underlying socket */ 952 ret = ctx->sk_poll(file, sock, wait); 953 954 /* Clear POLLIN bits, and set based on recv_pkt */ 955 ret &= ~(POLLIN | POLLRDNORM); 956 if (ctx->recv_pkt) 957 ret |= POLLIN | POLLRDNORM; 958 959 return ret; 960 } 961 962 static int tls_read_size(struct strparser *strp, struct sk_buff *skb) 963 { 964 struct tls_context *tls_ctx = tls_get_ctx(strp->sk); 965 struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx); 966 char header[TLS_HEADER_SIZE + MAX_IV_SIZE]; 967 struct strp_msg *rxm = strp_msg(skb); 968 size_t cipher_overhead; 969 size_t data_len = 0; 970 int ret; 971 972 /* Verify that we have a full TLS header, or wait for more data */ 973 if (rxm->offset + tls_ctx->rx.prepend_size > skb->len) 974 return 0; 975 976 /* Sanity-check size of on-stack buffer. */ 977 if (WARN_ON(tls_ctx->rx.prepend_size > sizeof(header))) { 978 ret = -EINVAL; 979 goto read_failure; 980 } 981 982 /* Linearize header to local buffer */ 983 ret = skb_copy_bits(skb, rxm->offset, header, tls_ctx->rx.prepend_size); 984 985 if (ret < 0) 986 goto read_failure; 987 988 ctx->control = header[0]; 989 990 data_len = ((header[4] & 0xFF) | (header[3] << 8)); 991 992 cipher_overhead = tls_ctx->rx.tag_size + tls_ctx->rx.iv_size; 993 994 if (data_len > TLS_MAX_PAYLOAD_SIZE + cipher_overhead) { 995 ret = -EMSGSIZE; 996 goto read_failure; 997 } 998 if (data_len < cipher_overhead) { 999 ret = -EBADMSG; 1000 goto read_failure; 1001 } 1002 1003 if (header[1] != TLS_VERSION_MINOR(tls_ctx->crypto_recv.version) || 1004 header[2] != TLS_VERSION_MAJOR(tls_ctx->crypto_recv.version)) { 1005 ret = -EINVAL; 1006 goto read_failure; 1007 } 1008 1009 #ifdef CONFIG_TLS_DEVICE 1010 handle_device_resync(strp->sk, TCP_SKB_CB(skb)->seq + rxm->offset, 1011 *(u64*)tls_ctx->rx.rec_seq); 1012 #endif 1013 return data_len + TLS_HEADER_SIZE; 1014 1015 read_failure: 1016 tls_err_abort(strp->sk, ret); 1017 1018 return ret; 1019 } 1020 1021 static void tls_queue(struct strparser *strp, struct sk_buff *skb) 1022 { 1023 struct tls_context *tls_ctx = tls_get_ctx(strp->sk); 1024 struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx); 1025 1026 ctx->decrypted = false; 1027 1028 ctx->recv_pkt = skb; 1029 strp_pause(strp); 1030 1031 strp->sk->sk_state_change(strp->sk); 1032 } 1033 1034 static void tls_data_ready(struct sock *sk) 1035 { 1036 struct tls_context *tls_ctx = tls_get_ctx(sk); 1037 struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx); 1038 1039 strp_data_ready(&ctx->strp); 1040 } 1041 1042 void tls_sw_free_resources_tx(struct sock *sk) 1043 { 1044 struct tls_context *tls_ctx = tls_get_ctx(sk); 1045 struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx); 1046 1047 if (ctx->aead_send) 1048 crypto_free_aead(ctx->aead_send); 1049 tls_free_both_sg(sk); 1050 1051 kfree(ctx); 1052 } 1053 1054 void tls_sw_release_resources_rx(struct sock *sk) 1055 { 1056 struct tls_context *tls_ctx = tls_get_ctx(sk); 1057 struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx); 1058 1059 if (ctx->aead_recv) { 1060 if (ctx->recv_pkt) { 1061 kfree_skb(ctx->recv_pkt); 1062 ctx->recv_pkt = NULL; 1063 } 1064 crypto_free_aead(ctx->aead_recv); 1065 strp_stop(&ctx->strp); 1066 write_lock_bh(&sk->sk_callback_lock); 1067 sk->sk_data_ready = ctx->saved_data_ready; 1068 write_unlock_bh(&sk->sk_callback_lock); 1069 release_sock(sk); 1070 strp_done(&ctx->strp); 1071 lock_sock(sk); 1072 } 1073 } 1074 1075 void tls_sw_free_resources_rx(struct sock *sk) 1076 { 1077 struct tls_context *tls_ctx = tls_get_ctx(sk); 1078 struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx); 1079 1080 tls_sw_release_resources_rx(sk); 1081 1082 kfree(ctx); 1083 } 1084 1085 int tls_set_sw_offload(struct sock *sk, struct tls_context *ctx, int tx) 1086 { 1087 char keyval[TLS_CIPHER_AES_GCM_128_KEY_SIZE]; 1088 struct tls_crypto_info *crypto_info; 1089 struct tls12_crypto_info_aes_gcm_128 *gcm_128_info; 1090 struct tls_sw_context_tx *sw_ctx_tx = NULL; 1091 struct tls_sw_context_rx *sw_ctx_rx = NULL; 1092 struct cipher_context *cctx; 1093 struct crypto_aead **aead; 1094 struct strp_callbacks cb; 1095 u16 nonce_size, tag_size, iv_size, rec_seq_size; 1096 char *iv, *rec_seq; 1097 int rc = 0; 1098 1099 if (!ctx) { 1100 rc = -EINVAL; 1101 goto out; 1102 } 1103 1104 if (tx) { 1105 if (!ctx->priv_ctx_tx) { 1106 sw_ctx_tx = kzalloc(sizeof(*sw_ctx_tx), GFP_KERNEL); 1107 if (!sw_ctx_tx) { 1108 rc = -ENOMEM; 1109 goto out; 1110 } 1111 ctx->priv_ctx_tx = sw_ctx_tx; 1112 } else { 1113 sw_ctx_tx = 1114 (struct tls_sw_context_tx *)ctx->priv_ctx_tx; 1115 } 1116 } else { 1117 if (!ctx->priv_ctx_rx) { 1118 sw_ctx_rx = kzalloc(sizeof(*sw_ctx_rx), GFP_KERNEL); 1119 if (!sw_ctx_rx) { 1120 rc = -ENOMEM; 1121 goto out; 1122 } 1123 ctx->priv_ctx_rx = sw_ctx_rx; 1124 } else { 1125 sw_ctx_rx = 1126 (struct tls_sw_context_rx *)ctx->priv_ctx_rx; 1127 } 1128 } 1129 1130 if (tx) { 1131 crypto_init_wait(&sw_ctx_tx->async_wait); 1132 crypto_info = &ctx->crypto_send; 1133 cctx = &ctx->tx; 1134 aead = &sw_ctx_tx->aead_send; 1135 } else { 1136 crypto_init_wait(&sw_ctx_rx->async_wait); 1137 crypto_info = &ctx->crypto_recv; 1138 cctx = &ctx->rx; 1139 aead = &sw_ctx_rx->aead_recv; 1140 } 1141 1142 switch (crypto_info->cipher_type) { 1143 case TLS_CIPHER_AES_GCM_128: { 1144 nonce_size = TLS_CIPHER_AES_GCM_128_IV_SIZE; 1145 tag_size = TLS_CIPHER_AES_GCM_128_TAG_SIZE; 1146 iv_size = TLS_CIPHER_AES_GCM_128_IV_SIZE; 1147 iv = ((struct tls12_crypto_info_aes_gcm_128 *)crypto_info)->iv; 1148 rec_seq_size = TLS_CIPHER_AES_GCM_128_REC_SEQ_SIZE; 1149 rec_seq = 1150 ((struct tls12_crypto_info_aes_gcm_128 *)crypto_info)->rec_seq; 1151 gcm_128_info = 1152 (struct tls12_crypto_info_aes_gcm_128 *)crypto_info; 1153 break; 1154 } 1155 default: 1156 rc = -EINVAL; 1157 goto free_priv; 1158 } 1159 1160 /* Sanity-check the IV size for stack allocations. */ 1161 if (iv_size > MAX_IV_SIZE || nonce_size > MAX_IV_SIZE) { 1162 rc = -EINVAL; 1163 goto free_priv; 1164 } 1165 1166 cctx->prepend_size = TLS_HEADER_SIZE + nonce_size; 1167 cctx->tag_size = tag_size; 1168 cctx->overhead_size = cctx->prepend_size + cctx->tag_size; 1169 cctx->iv_size = iv_size; 1170 cctx->iv = kmalloc(iv_size + TLS_CIPHER_AES_GCM_128_SALT_SIZE, 1171 GFP_KERNEL); 1172 if (!cctx->iv) { 1173 rc = -ENOMEM; 1174 goto free_priv; 1175 } 1176 memcpy(cctx->iv, gcm_128_info->salt, TLS_CIPHER_AES_GCM_128_SALT_SIZE); 1177 memcpy(cctx->iv + TLS_CIPHER_AES_GCM_128_SALT_SIZE, iv, iv_size); 1178 cctx->rec_seq_size = rec_seq_size; 1179 cctx->rec_seq = kmalloc(rec_seq_size, GFP_KERNEL); 1180 if (!cctx->rec_seq) { 1181 rc = -ENOMEM; 1182 goto free_iv; 1183 } 1184 memcpy(cctx->rec_seq, rec_seq, rec_seq_size); 1185 1186 if (sw_ctx_tx) { 1187 sg_init_table(sw_ctx_tx->sg_encrypted_data, 1188 ARRAY_SIZE(sw_ctx_tx->sg_encrypted_data)); 1189 sg_init_table(sw_ctx_tx->sg_plaintext_data, 1190 ARRAY_SIZE(sw_ctx_tx->sg_plaintext_data)); 1191 1192 sg_init_table(sw_ctx_tx->sg_aead_in, 2); 1193 sg_set_buf(&sw_ctx_tx->sg_aead_in[0], sw_ctx_tx->aad_space, 1194 sizeof(sw_ctx_tx->aad_space)); 1195 sg_unmark_end(&sw_ctx_tx->sg_aead_in[1]); 1196 sg_chain(sw_ctx_tx->sg_aead_in, 2, 1197 sw_ctx_tx->sg_plaintext_data); 1198 sg_init_table(sw_ctx_tx->sg_aead_out, 2); 1199 sg_set_buf(&sw_ctx_tx->sg_aead_out[0], sw_ctx_tx->aad_space, 1200 sizeof(sw_ctx_tx->aad_space)); 1201 sg_unmark_end(&sw_ctx_tx->sg_aead_out[1]); 1202 sg_chain(sw_ctx_tx->sg_aead_out, 2, 1203 sw_ctx_tx->sg_encrypted_data); 1204 } 1205 1206 if (!*aead) { 1207 *aead = crypto_alloc_aead("gcm(aes)", 0, 0); 1208 if (IS_ERR(*aead)) { 1209 rc = PTR_ERR(*aead); 1210 *aead = NULL; 1211 goto free_rec_seq; 1212 } 1213 } 1214 1215 ctx->push_pending_record = tls_sw_push_pending_record; 1216 1217 memcpy(keyval, gcm_128_info->key, TLS_CIPHER_AES_GCM_128_KEY_SIZE); 1218 1219 rc = crypto_aead_setkey(*aead, keyval, 1220 TLS_CIPHER_AES_GCM_128_KEY_SIZE); 1221 if (rc) 1222 goto free_aead; 1223 1224 rc = crypto_aead_setauthsize(*aead, cctx->tag_size); 1225 if (rc) 1226 goto free_aead; 1227 1228 if (sw_ctx_rx) { 1229 /* Set up strparser */ 1230 memset(&cb, 0, sizeof(cb)); 1231 cb.rcv_msg = tls_queue; 1232 cb.parse_msg = tls_read_size; 1233 1234 strp_init(&sw_ctx_rx->strp, sk, &cb); 1235 1236 write_lock_bh(&sk->sk_callback_lock); 1237 sw_ctx_rx->saved_data_ready = sk->sk_data_ready; 1238 sk->sk_data_ready = tls_data_ready; 1239 write_unlock_bh(&sk->sk_callback_lock); 1240 1241 sw_ctx_rx->sk_poll = sk->sk_socket->ops->poll; 1242 1243 strp_check_rcv(&sw_ctx_rx->strp); 1244 } 1245 1246 goto out; 1247 1248 free_aead: 1249 crypto_free_aead(*aead); 1250 *aead = NULL; 1251 free_rec_seq: 1252 kfree(cctx->rec_seq); 1253 cctx->rec_seq = NULL; 1254 free_iv: 1255 kfree(cctx->iv); 1256 cctx->iv = NULL; 1257 free_priv: 1258 if (tx) { 1259 kfree(ctx->priv_ctx_tx); 1260 ctx->priv_ctx_tx = NULL; 1261 } else { 1262 kfree(ctx->priv_ctx_rx); 1263 ctx->priv_ctx_rx = NULL; 1264 } 1265 out: 1266 return rc; 1267 } 1268