1 /* 2 * Copyright (c) 2016-2017, Mellanox Technologies. All rights reserved. 3 * Copyright (c) 2016-2017, Dave Watson <davejwatson@fb.com>. All rights reserved. 4 * Copyright (c) 2016-2017, Lance Chao <lancerchao@fb.com>. All rights reserved. 5 * Copyright (c) 2016, Fridolin Pokorny <fridolin.pokorny@gmail.com>. All rights reserved. 6 * Copyright (c) 2016, Nikos Mavrogiannopoulos <nmav@gnutls.org>. All rights reserved. 7 * 8 * This software is available to you under a choice of one of two 9 * licenses. You may choose to be licensed under the terms of the GNU 10 * General Public License (GPL) Version 2, available from the file 11 * COPYING in the main directory of this source tree, or the 12 * OpenIB.org BSD license below: 13 * 14 * Redistribution and use in source and binary forms, with or 15 * without modification, are permitted provided that the following 16 * conditions are met: 17 * 18 * - Redistributions of source code must retain the above 19 * copyright notice, this list of conditions and the following 20 * disclaimer. 21 * 22 * - Redistributions in binary form must reproduce the above 23 * copyright notice, this list of conditions and the following 24 * disclaimer in the documentation and/or other materials 25 * provided with the distribution. 26 * 27 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, 28 * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF 29 * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND 30 * NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS 31 * BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN 32 * ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN 33 * CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 34 * SOFTWARE. 35 */ 36 37 #include <linux/sched/signal.h> 38 #include <linux/module.h> 39 #include <crypto/aead.h> 40 41 #include <net/strparser.h> 42 #include <net/tls.h> 43 44 static int tls_do_decryption(struct sock *sk, 45 struct scatterlist *sgin, 46 struct scatterlist *sgout, 47 char *iv_recv, 48 size_t data_len, 49 struct sk_buff *skb, 50 gfp_t flags) 51 { 52 struct tls_context *tls_ctx = tls_get_ctx(sk); 53 struct tls_sw_context *ctx = tls_sw_ctx(tls_ctx); 54 struct strp_msg *rxm = strp_msg(skb); 55 struct aead_request *aead_req; 56 57 int ret; 58 unsigned int req_size = sizeof(struct aead_request) + 59 crypto_aead_reqsize(ctx->aead_recv); 60 61 aead_req = kzalloc(req_size, flags); 62 if (!aead_req) 63 return -ENOMEM; 64 65 aead_request_set_tfm(aead_req, ctx->aead_recv); 66 aead_request_set_ad(aead_req, TLS_AAD_SPACE_SIZE); 67 aead_request_set_crypt(aead_req, sgin, sgout, 68 data_len + tls_ctx->rx.tag_size, 69 (u8 *)iv_recv); 70 aead_request_set_callback(aead_req, CRYPTO_TFM_REQ_MAY_BACKLOG, 71 crypto_req_done, &ctx->async_wait); 72 73 ret = crypto_wait_req(crypto_aead_decrypt(aead_req), &ctx->async_wait); 74 75 if (ret < 0) 76 goto out; 77 78 rxm->offset += tls_ctx->rx.prepend_size; 79 rxm->full_len -= tls_ctx->rx.overhead_size; 80 tls_advance_record_sn(sk, &tls_ctx->rx); 81 82 ctx->decrypted = true; 83 84 ctx->saved_data_ready(sk); 85 86 out: 87 kfree(aead_req); 88 return ret; 89 } 90 91 static void trim_sg(struct sock *sk, struct scatterlist *sg, 92 int *sg_num_elem, unsigned int *sg_size, int target_size) 93 { 94 int i = *sg_num_elem - 1; 95 int trim = *sg_size - target_size; 96 97 if (trim <= 0) { 98 WARN_ON(trim < 0); 99 return; 100 } 101 102 *sg_size = target_size; 103 while (trim >= sg[i].length) { 104 trim -= sg[i].length; 105 sk_mem_uncharge(sk, sg[i].length); 106 put_page(sg_page(&sg[i])); 107 i--; 108 109 if (i < 0) 110 goto out; 111 } 112 113 sg[i].length -= trim; 114 sk_mem_uncharge(sk, trim); 115 116 out: 117 *sg_num_elem = i + 1; 118 } 119 120 static void trim_both_sgl(struct sock *sk, int target_size) 121 { 122 struct tls_context *tls_ctx = tls_get_ctx(sk); 123 struct tls_sw_context *ctx = tls_sw_ctx(tls_ctx); 124 125 trim_sg(sk, ctx->sg_plaintext_data, 126 &ctx->sg_plaintext_num_elem, 127 &ctx->sg_plaintext_size, 128 target_size); 129 130 if (target_size > 0) 131 target_size += tls_ctx->tx.overhead_size; 132 133 trim_sg(sk, ctx->sg_encrypted_data, 134 &ctx->sg_encrypted_num_elem, 135 &ctx->sg_encrypted_size, 136 target_size); 137 } 138 139 static int alloc_encrypted_sg(struct sock *sk, int len) 140 { 141 struct tls_context *tls_ctx = tls_get_ctx(sk); 142 struct tls_sw_context *ctx = tls_sw_ctx(tls_ctx); 143 int rc = 0; 144 145 rc = sk_alloc_sg(sk, len, 146 ctx->sg_encrypted_data, 0, 147 &ctx->sg_encrypted_num_elem, 148 &ctx->sg_encrypted_size, 0); 149 150 return rc; 151 } 152 153 static int alloc_plaintext_sg(struct sock *sk, int len) 154 { 155 struct tls_context *tls_ctx = tls_get_ctx(sk); 156 struct tls_sw_context *ctx = tls_sw_ctx(tls_ctx); 157 int rc = 0; 158 159 rc = sk_alloc_sg(sk, len, ctx->sg_plaintext_data, 0, 160 &ctx->sg_plaintext_num_elem, &ctx->sg_plaintext_size, 161 tls_ctx->pending_open_record_frags); 162 163 return rc; 164 } 165 166 static void free_sg(struct sock *sk, struct scatterlist *sg, 167 int *sg_num_elem, unsigned int *sg_size) 168 { 169 int i, n = *sg_num_elem; 170 171 for (i = 0; i < n; ++i) { 172 sk_mem_uncharge(sk, sg[i].length); 173 put_page(sg_page(&sg[i])); 174 } 175 *sg_num_elem = 0; 176 *sg_size = 0; 177 } 178 179 static void tls_free_both_sg(struct sock *sk) 180 { 181 struct tls_context *tls_ctx = tls_get_ctx(sk); 182 struct tls_sw_context *ctx = tls_sw_ctx(tls_ctx); 183 184 free_sg(sk, ctx->sg_encrypted_data, &ctx->sg_encrypted_num_elem, 185 &ctx->sg_encrypted_size); 186 187 free_sg(sk, ctx->sg_plaintext_data, &ctx->sg_plaintext_num_elem, 188 &ctx->sg_plaintext_size); 189 } 190 191 static int tls_do_encryption(struct tls_context *tls_ctx, 192 struct tls_sw_context *ctx, size_t data_len, 193 gfp_t flags) 194 { 195 unsigned int req_size = sizeof(struct aead_request) + 196 crypto_aead_reqsize(ctx->aead_send); 197 struct aead_request *aead_req; 198 int rc; 199 200 aead_req = kzalloc(req_size, flags); 201 if (!aead_req) 202 return -ENOMEM; 203 204 ctx->sg_encrypted_data[0].offset += tls_ctx->tx.prepend_size; 205 ctx->sg_encrypted_data[0].length -= tls_ctx->tx.prepend_size; 206 207 aead_request_set_tfm(aead_req, ctx->aead_send); 208 aead_request_set_ad(aead_req, TLS_AAD_SPACE_SIZE); 209 aead_request_set_crypt(aead_req, ctx->sg_aead_in, ctx->sg_aead_out, 210 data_len, tls_ctx->tx.iv); 211 212 aead_request_set_callback(aead_req, CRYPTO_TFM_REQ_MAY_BACKLOG, 213 crypto_req_done, &ctx->async_wait); 214 215 rc = crypto_wait_req(crypto_aead_encrypt(aead_req), &ctx->async_wait); 216 217 ctx->sg_encrypted_data[0].offset -= tls_ctx->tx.prepend_size; 218 ctx->sg_encrypted_data[0].length += tls_ctx->tx.prepend_size; 219 220 kfree(aead_req); 221 return rc; 222 } 223 224 static int tls_push_record(struct sock *sk, int flags, 225 unsigned char record_type) 226 { 227 struct tls_context *tls_ctx = tls_get_ctx(sk); 228 struct tls_sw_context *ctx = tls_sw_ctx(tls_ctx); 229 int rc; 230 231 sg_mark_end(ctx->sg_plaintext_data + ctx->sg_plaintext_num_elem - 1); 232 sg_mark_end(ctx->sg_encrypted_data + ctx->sg_encrypted_num_elem - 1); 233 234 tls_make_aad(ctx->aad_space, ctx->sg_plaintext_size, 235 tls_ctx->tx.rec_seq, tls_ctx->tx.rec_seq_size, 236 record_type); 237 238 tls_fill_prepend(tls_ctx, 239 page_address(sg_page(&ctx->sg_encrypted_data[0])) + 240 ctx->sg_encrypted_data[0].offset, 241 ctx->sg_plaintext_size, record_type); 242 243 tls_ctx->pending_open_record_frags = 0; 244 set_bit(TLS_PENDING_CLOSED_RECORD, &tls_ctx->flags); 245 246 rc = tls_do_encryption(tls_ctx, ctx, ctx->sg_plaintext_size, 247 sk->sk_allocation); 248 if (rc < 0) { 249 /* If we are called from write_space and 250 * we fail, we need to set this SOCK_NOSPACE 251 * to trigger another write_space in the future. 252 */ 253 set_bit(SOCK_NOSPACE, &sk->sk_socket->flags); 254 return rc; 255 } 256 257 free_sg(sk, ctx->sg_plaintext_data, &ctx->sg_plaintext_num_elem, 258 &ctx->sg_plaintext_size); 259 260 ctx->sg_encrypted_num_elem = 0; 261 ctx->sg_encrypted_size = 0; 262 263 /* Only pass through MSG_DONTWAIT and MSG_NOSIGNAL flags */ 264 rc = tls_push_sg(sk, tls_ctx, ctx->sg_encrypted_data, 0, flags); 265 if (rc < 0 && rc != -EAGAIN) 266 tls_err_abort(sk, EBADMSG); 267 268 tls_advance_record_sn(sk, &tls_ctx->tx); 269 return rc; 270 } 271 272 static int tls_sw_push_pending_record(struct sock *sk, int flags) 273 { 274 return tls_push_record(sk, flags, TLS_RECORD_TYPE_DATA); 275 } 276 277 static int zerocopy_from_iter(struct sock *sk, struct iov_iter *from, 278 int length, int *pages_used, 279 unsigned int *size_used, 280 struct scatterlist *to, int to_max_pages, 281 bool charge) 282 { 283 struct page *pages[MAX_SKB_FRAGS]; 284 285 size_t offset; 286 ssize_t copied, use; 287 int i = 0; 288 unsigned int size = *size_used; 289 int num_elem = *pages_used; 290 int rc = 0; 291 int maxpages; 292 293 while (length > 0) { 294 i = 0; 295 maxpages = to_max_pages - num_elem; 296 if (maxpages == 0) { 297 rc = -EFAULT; 298 goto out; 299 } 300 copied = iov_iter_get_pages(from, pages, 301 length, 302 maxpages, &offset); 303 if (copied <= 0) { 304 rc = -EFAULT; 305 goto out; 306 } 307 308 iov_iter_advance(from, copied); 309 310 length -= copied; 311 size += copied; 312 while (copied) { 313 use = min_t(int, copied, PAGE_SIZE - offset); 314 315 sg_set_page(&to[num_elem], 316 pages[i], use, offset); 317 sg_unmark_end(&to[num_elem]); 318 if (charge) 319 sk_mem_charge(sk, use); 320 321 offset = 0; 322 copied -= use; 323 324 ++i; 325 ++num_elem; 326 } 327 } 328 329 out: 330 *size_used = size; 331 *pages_used = num_elem; 332 333 return rc; 334 } 335 336 static int memcopy_from_iter(struct sock *sk, struct iov_iter *from, 337 int bytes) 338 { 339 struct tls_context *tls_ctx = tls_get_ctx(sk); 340 struct tls_sw_context *ctx = tls_sw_ctx(tls_ctx); 341 struct scatterlist *sg = ctx->sg_plaintext_data; 342 int copy, i, rc = 0; 343 344 for (i = tls_ctx->pending_open_record_frags; 345 i < ctx->sg_plaintext_num_elem; ++i) { 346 copy = sg[i].length; 347 if (copy_from_iter( 348 page_address(sg_page(&sg[i])) + sg[i].offset, 349 copy, from) != copy) { 350 rc = -EFAULT; 351 goto out; 352 } 353 bytes -= copy; 354 355 ++tls_ctx->pending_open_record_frags; 356 357 if (!bytes) 358 break; 359 } 360 361 out: 362 return rc; 363 } 364 365 int tls_sw_sendmsg(struct sock *sk, struct msghdr *msg, size_t size) 366 { 367 struct tls_context *tls_ctx = tls_get_ctx(sk); 368 struct tls_sw_context *ctx = tls_sw_ctx(tls_ctx); 369 int ret = 0; 370 int required_size; 371 long timeo = sock_sndtimeo(sk, msg->msg_flags & MSG_DONTWAIT); 372 bool eor = !(msg->msg_flags & MSG_MORE); 373 size_t try_to_copy, copied = 0; 374 unsigned char record_type = TLS_RECORD_TYPE_DATA; 375 int record_room; 376 bool full_record; 377 int orig_size; 378 379 if (msg->msg_flags & ~(MSG_MORE | MSG_DONTWAIT | MSG_NOSIGNAL)) 380 return -ENOTSUPP; 381 382 lock_sock(sk); 383 384 if (tls_complete_pending_work(sk, tls_ctx, msg->msg_flags, &timeo)) 385 goto send_end; 386 387 if (unlikely(msg->msg_controllen)) { 388 ret = tls_proccess_cmsg(sk, msg, &record_type); 389 if (ret) 390 goto send_end; 391 } 392 393 while (msg_data_left(msg)) { 394 if (sk->sk_err) { 395 ret = -sk->sk_err; 396 goto send_end; 397 } 398 399 orig_size = ctx->sg_plaintext_size; 400 full_record = false; 401 try_to_copy = msg_data_left(msg); 402 record_room = TLS_MAX_PAYLOAD_SIZE - ctx->sg_plaintext_size; 403 if (try_to_copy >= record_room) { 404 try_to_copy = record_room; 405 full_record = true; 406 } 407 408 required_size = ctx->sg_plaintext_size + try_to_copy + 409 tls_ctx->tx.overhead_size; 410 411 if (!sk_stream_memory_free(sk)) 412 goto wait_for_sndbuf; 413 alloc_encrypted: 414 ret = alloc_encrypted_sg(sk, required_size); 415 if (ret) { 416 if (ret != -ENOSPC) 417 goto wait_for_memory; 418 419 /* Adjust try_to_copy according to the amount that was 420 * actually allocated. The difference is due 421 * to max sg elements limit 422 */ 423 try_to_copy -= required_size - ctx->sg_encrypted_size; 424 full_record = true; 425 } 426 427 if (full_record || eor) { 428 ret = zerocopy_from_iter(sk, &msg->msg_iter, 429 try_to_copy, &ctx->sg_plaintext_num_elem, 430 &ctx->sg_plaintext_size, 431 ctx->sg_plaintext_data, 432 ARRAY_SIZE(ctx->sg_plaintext_data), 433 true); 434 if (ret) 435 goto fallback_to_reg_send; 436 437 copied += try_to_copy; 438 ret = tls_push_record(sk, msg->msg_flags, record_type); 439 if (!ret) 440 continue; 441 if (ret == -EAGAIN) 442 goto send_end; 443 444 copied -= try_to_copy; 445 fallback_to_reg_send: 446 iov_iter_revert(&msg->msg_iter, 447 ctx->sg_plaintext_size - orig_size); 448 trim_sg(sk, ctx->sg_plaintext_data, 449 &ctx->sg_plaintext_num_elem, 450 &ctx->sg_plaintext_size, 451 orig_size); 452 } 453 454 required_size = ctx->sg_plaintext_size + try_to_copy; 455 alloc_plaintext: 456 ret = alloc_plaintext_sg(sk, required_size); 457 if (ret) { 458 if (ret != -ENOSPC) 459 goto wait_for_memory; 460 461 /* Adjust try_to_copy according to the amount that was 462 * actually allocated. The difference is due 463 * to max sg elements limit 464 */ 465 try_to_copy -= required_size - ctx->sg_plaintext_size; 466 full_record = true; 467 468 trim_sg(sk, ctx->sg_encrypted_data, 469 &ctx->sg_encrypted_num_elem, 470 &ctx->sg_encrypted_size, 471 ctx->sg_plaintext_size + 472 tls_ctx->tx.overhead_size); 473 } 474 475 ret = memcopy_from_iter(sk, &msg->msg_iter, try_to_copy); 476 if (ret) 477 goto trim_sgl; 478 479 copied += try_to_copy; 480 if (full_record || eor) { 481 push_record: 482 ret = tls_push_record(sk, msg->msg_flags, record_type); 483 if (ret) { 484 if (ret == -ENOMEM) 485 goto wait_for_memory; 486 487 goto send_end; 488 } 489 } 490 491 continue; 492 493 wait_for_sndbuf: 494 set_bit(SOCK_NOSPACE, &sk->sk_socket->flags); 495 wait_for_memory: 496 ret = sk_stream_wait_memory(sk, &timeo); 497 if (ret) { 498 trim_sgl: 499 trim_both_sgl(sk, orig_size); 500 goto send_end; 501 } 502 503 if (tls_is_pending_closed_record(tls_ctx)) 504 goto push_record; 505 506 if (ctx->sg_encrypted_size < required_size) 507 goto alloc_encrypted; 508 509 goto alloc_plaintext; 510 } 511 512 send_end: 513 ret = sk_stream_error(sk, msg->msg_flags, ret); 514 515 release_sock(sk); 516 return copied ? copied : ret; 517 } 518 519 int tls_sw_sendpage(struct sock *sk, struct page *page, 520 int offset, size_t size, int flags) 521 { 522 struct tls_context *tls_ctx = tls_get_ctx(sk); 523 struct tls_sw_context *ctx = tls_sw_ctx(tls_ctx); 524 int ret = 0; 525 long timeo = sock_sndtimeo(sk, flags & MSG_DONTWAIT); 526 bool eor; 527 size_t orig_size = size; 528 unsigned char record_type = TLS_RECORD_TYPE_DATA; 529 struct scatterlist *sg; 530 bool full_record; 531 int record_room; 532 533 if (flags & ~(MSG_MORE | MSG_DONTWAIT | MSG_NOSIGNAL | 534 MSG_SENDPAGE_NOTLAST)) 535 return -ENOTSUPP; 536 537 /* No MSG_EOR from splice, only look at MSG_MORE */ 538 eor = !(flags & (MSG_MORE | MSG_SENDPAGE_NOTLAST)); 539 540 lock_sock(sk); 541 542 sk_clear_bit(SOCKWQ_ASYNC_NOSPACE, sk); 543 544 if (tls_complete_pending_work(sk, tls_ctx, flags, &timeo)) 545 goto sendpage_end; 546 547 /* Call the sk_stream functions to manage the sndbuf mem. */ 548 while (size > 0) { 549 size_t copy, required_size; 550 551 if (sk->sk_err) { 552 ret = -sk->sk_err; 553 goto sendpage_end; 554 } 555 556 full_record = false; 557 record_room = TLS_MAX_PAYLOAD_SIZE - ctx->sg_plaintext_size; 558 copy = size; 559 if (copy >= record_room) { 560 copy = record_room; 561 full_record = true; 562 } 563 required_size = ctx->sg_plaintext_size + copy + 564 tls_ctx->tx.overhead_size; 565 566 if (!sk_stream_memory_free(sk)) 567 goto wait_for_sndbuf; 568 alloc_payload: 569 ret = alloc_encrypted_sg(sk, required_size); 570 if (ret) { 571 if (ret != -ENOSPC) 572 goto wait_for_memory; 573 574 /* Adjust copy according to the amount that was 575 * actually allocated. The difference is due 576 * to max sg elements limit 577 */ 578 copy -= required_size - ctx->sg_plaintext_size; 579 full_record = true; 580 } 581 582 get_page(page); 583 sg = ctx->sg_plaintext_data + ctx->sg_plaintext_num_elem; 584 sg_set_page(sg, page, copy, offset); 585 sg_unmark_end(sg); 586 587 ctx->sg_plaintext_num_elem++; 588 589 sk_mem_charge(sk, copy); 590 offset += copy; 591 size -= copy; 592 ctx->sg_plaintext_size += copy; 593 tls_ctx->pending_open_record_frags = ctx->sg_plaintext_num_elem; 594 595 if (full_record || eor || 596 ctx->sg_plaintext_num_elem == 597 ARRAY_SIZE(ctx->sg_plaintext_data)) { 598 push_record: 599 ret = tls_push_record(sk, flags, record_type); 600 if (ret) { 601 if (ret == -ENOMEM) 602 goto wait_for_memory; 603 604 goto sendpage_end; 605 } 606 } 607 continue; 608 wait_for_sndbuf: 609 set_bit(SOCK_NOSPACE, &sk->sk_socket->flags); 610 wait_for_memory: 611 ret = sk_stream_wait_memory(sk, &timeo); 612 if (ret) { 613 trim_both_sgl(sk, ctx->sg_plaintext_size); 614 goto sendpage_end; 615 } 616 617 if (tls_is_pending_closed_record(tls_ctx)) 618 goto push_record; 619 620 goto alloc_payload; 621 } 622 623 sendpage_end: 624 if (orig_size > size) 625 ret = orig_size - size; 626 else 627 ret = sk_stream_error(sk, flags, ret); 628 629 release_sock(sk); 630 return ret; 631 } 632 633 static struct sk_buff *tls_wait_data(struct sock *sk, int flags, 634 long timeo, int *err) 635 { 636 struct tls_context *tls_ctx = tls_get_ctx(sk); 637 struct tls_sw_context *ctx = tls_sw_ctx(tls_ctx); 638 struct sk_buff *skb; 639 DEFINE_WAIT_FUNC(wait, woken_wake_function); 640 641 while (!(skb = ctx->recv_pkt)) { 642 if (sk->sk_err) { 643 *err = sock_error(sk); 644 return NULL; 645 } 646 647 if (sock_flag(sk, SOCK_DONE)) 648 return NULL; 649 650 if ((flags & MSG_DONTWAIT) || !timeo) { 651 *err = -EAGAIN; 652 return NULL; 653 } 654 655 add_wait_queue(sk_sleep(sk), &wait); 656 sk_set_bit(SOCKWQ_ASYNC_WAITDATA, sk); 657 sk_wait_event(sk, &timeo, ctx->recv_pkt != skb, &wait); 658 sk_clear_bit(SOCKWQ_ASYNC_WAITDATA, sk); 659 remove_wait_queue(sk_sleep(sk), &wait); 660 661 /* Handle signals */ 662 if (signal_pending(current)) { 663 *err = sock_intr_errno(timeo); 664 return NULL; 665 } 666 } 667 668 return skb; 669 } 670 671 static int decrypt_skb(struct sock *sk, struct sk_buff *skb, 672 struct scatterlist *sgout) 673 { 674 struct tls_context *tls_ctx = tls_get_ctx(sk); 675 struct tls_sw_context *ctx = tls_sw_ctx(tls_ctx); 676 char iv[TLS_CIPHER_AES_GCM_128_SALT_SIZE + tls_ctx->rx.iv_size]; 677 struct scatterlist sgin_arr[MAX_SKB_FRAGS + 2]; 678 struct scatterlist *sgin = &sgin_arr[0]; 679 struct strp_msg *rxm = strp_msg(skb); 680 int ret, nsg = ARRAY_SIZE(sgin_arr); 681 char aad_recv[TLS_AAD_SPACE_SIZE]; 682 struct sk_buff *unused; 683 684 ret = skb_copy_bits(skb, rxm->offset + TLS_HEADER_SIZE, 685 iv + TLS_CIPHER_AES_GCM_128_SALT_SIZE, 686 tls_ctx->rx.iv_size); 687 if (ret < 0) 688 return ret; 689 690 memcpy(iv, tls_ctx->rx.iv, TLS_CIPHER_AES_GCM_128_SALT_SIZE); 691 if (!sgout) { 692 nsg = skb_cow_data(skb, 0, &unused) + 1; 693 sgin = kmalloc_array(nsg, sizeof(*sgin), sk->sk_allocation); 694 if (!sgout) 695 sgout = sgin; 696 } 697 698 sg_init_table(sgin, nsg); 699 sg_set_buf(&sgin[0], aad_recv, sizeof(aad_recv)); 700 701 nsg = skb_to_sgvec(skb, &sgin[1], 702 rxm->offset + tls_ctx->rx.prepend_size, 703 rxm->full_len - tls_ctx->rx.prepend_size); 704 705 tls_make_aad(aad_recv, 706 rxm->full_len - tls_ctx->rx.overhead_size, 707 tls_ctx->rx.rec_seq, 708 tls_ctx->rx.rec_seq_size, 709 ctx->control); 710 711 ret = tls_do_decryption(sk, sgin, sgout, iv, 712 rxm->full_len - tls_ctx->rx.overhead_size, 713 skb, sk->sk_allocation); 714 715 if (sgin != &sgin_arr[0]) 716 kfree(sgin); 717 718 return ret; 719 } 720 721 static bool tls_sw_advance_skb(struct sock *sk, struct sk_buff *skb, 722 unsigned int len) 723 { 724 struct tls_context *tls_ctx = tls_get_ctx(sk); 725 struct tls_sw_context *ctx = tls_sw_ctx(tls_ctx); 726 struct strp_msg *rxm = strp_msg(skb); 727 728 if (len < rxm->full_len) { 729 rxm->offset += len; 730 rxm->full_len -= len; 731 732 return false; 733 } 734 735 /* Finished with message */ 736 ctx->recv_pkt = NULL; 737 kfree_skb(skb); 738 strp_unpause(&ctx->strp); 739 740 return true; 741 } 742 743 int tls_sw_recvmsg(struct sock *sk, 744 struct msghdr *msg, 745 size_t len, 746 int nonblock, 747 int flags, 748 int *addr_len) 749 { 750 struct tls_context *tls_ctx = tls_get_ctx(sk); 751 struct tls_sw_context *ctx = tls_sw_ctx(tls_ctx); 752 unsigned char control; 753 struct strp_msg *rxm; 754 struct sk_buff *skb; 755 ssize_t copied = 0; 756 bool cmsg = false; 757 int err = 0; 758 long timeo; 759 760 flags |= nonblock; 761 762 if (unlikely(flags & MSG_ERRQUEUE)) 763 return sock_recv_errqueue(sk, msg, len, SOL_IP, IP_RECVERR); 764 765 lock_sock(sk); 766 767 timeo = sock_rcvtimeo(sk, flags & MSG_DONTWAIT); 768 do { 769 bool zc = false; 770 int chunk = 0; 771 772 skb = tls_wait_data(sk, flags, timeo, &err); 773 if (!skb) 774 goto recv_end; 775 776 rxm = strp_msg(skb); 777 if (!cmsg) { 778 int cerr; 779 780 cerr = put_cmsg(msg, SOL_TLS, TLS_GET_RECORD_TYPE, 781 sizeof(ctx->control), &ctx->control); 782 cmsg = true; 783 control = ctx->control; 784 if (ctx->control != TLS_RECORD_TYPE_DATA) { 785 if (cerr || msg->msg_flags & MSG_CTRUNC) { 786 err = -EIO; 787 goto recv_end; 788 } 789 } 790 } else if (control != ctx->control) { 791 goto recv_end; 792 } 793 794 if (!ctx->decrypted) { 795 int page_count; 796 int to_copy; 797 798 page_count = iov_iter_npages(&msg->msg_iter, 799 MAX_SKB_FRAGS); 800 to_copy = rxm->full_len - tls_ctx->rx.overhead_size; 801 if (to_copy <= len && page_count < MAX_SKB_FRAGS && 802 likely(!(flags & MSG_PEEK))) { 803 struct scatterlist sgin[MAX_SKB_FRAGS + 1]; 804 char unused[21]; 805 int pages = 0; 806 807 zc = true; 808 sg_init_table(sgin, MAX_SKB_FRAGS + 1); 809 sg_set_buf(&sgin[0], unused, 13); 810 811 err = zerocopy_from_iter(sk, &msg->msg_iter, 812 to_copy, &pages, 813 &chunk, &sgin[1], 814 MAX_SKB_FRAGS, false); 815 if (err < 0) 816 goto fallback_to_reg_recv; 817 818 err = decrypt_skb(sk, skb, sgin); 819 for (; pages > 0; pages--) 820 put_page(sg_page(&sgin[pages])); 821 if (err < 0) { 822 tls_err_abort(sk, EBADMSG); 823 goto recv_end; 824 } 825 } else { 826 fallback_to_reg_recv: 827 err = decrypt_skb(sk, skb, NULL); 828 if (err < 0) { 829 tls_err_abort(sk, EBADMSG); 830 goto recv_end; 831 } 832 } 833 ctx->decrypted = true; 834 } 835 836 if (!zc) { 837 chunk = min_t(unsigned int, rxm->full_len, len); 838 err = skb_copy_datagram_msg(skb, rxm->offset, msg, 839 chunk); 840 if (err < 0) 841 goto recv_end; 842 } 843 844 copied += chunk; 845 len -= chunk; 846 if (likely(!(flags & MSG_PEEK))) { 847 u8 control = ctx->control; 848 849 if (tls_sw_advance_skb(sk, skb, chunk)) { 850 /* Return full control message to 851 * userspace before trying to parse 852 * another message type 853 */ 854 msg->msg_flags |= MSG_EOR; 855 if (control != TLS_RECORD_TYPE_DATA) 856 goto recv_end; 857 } 858 } 859 } while (len); 860 861 recv_end: 862 release_sock(sk); 863 return copied ? : err; 864 } 865 866 ssize_t tls_sw_splice_read(struct socket *sock, loff_t *ppos, 867 struct pipe_inode_info *pipe, 868 size_t len, unsigned int flags) 869 { 870 struct tls_context *tls_ctx = tls_get_ctx(sock->sk); 871 struct tls_sw_context *ctx = tls_sw_ctx(tls_ctx); 872 struct strp_msg *rxm = NULL; 873 struct sock *sk = sock->sk; 874 struct sk_buff *skb; 875 ssize_t copied = 0; 876 int err = 0; 877 long timeo; 878 int chunk; 879 880 lock_sock(sk); 881 882 timeo = sock_rcvtimeo(sk, flags & MSG_DONTWAIT); 883 884 skb = tls_wait_data(sk, flags, timeo, &err); 885 if (!skb) 886 goto splice_read_end; 887 888 /* splice does not support reading control messages */ 889 if (ctx->control != TLS_RECORD_TYPE_DATA) { 890 err = -ENOTSUPP; 891 goto splice_read_end; 892 } 893 894 if (!ctx->decrypted) { 895 err = decrypt_skb(sk, skb, NULL); 896 897 if (err < 0) { 898 tls_err_abort(sk, EBADMSG); 899 goto splice_read_end; 900 } 901 ctx->decrypted = true; 902 } 903 rxm = strp_msg(skb); 904 905 chunk = min_t(unsigned int, rxm->full_len, len); 906 copied = skb_splice_bits(skb, sk, rxm->offset, pipe, chunk, flags); 907 if (copied < 0) 908 goto splice_read_end; 909 910 if (likely(!(flags & MSG_PEEK))) 911 tls_sw_advance_skb(sk, skb, copied); 912 913 splice_read_end: 914 release_sock(sk); 915 return copied ? : err; 916 } 917 918 unsigned int tls_sw_poll(struct file *file, struct socket *sock, 919 struct poll_table_struct *wait) 920 { 921 unsigned int ret; 922 struct sock *sk = sock->sk; 923 struct tls_context *tls_ctx = tls_get_ctx(sk); 924 struct tls_sw_context *ctx = tls_sw_ctx(tls_ctx); 925 926 /* Grab POLLOUT and POLLHUP from the underlying socket */ 927 ret = ctx->sk_poll(file, sock, wait); 928 929 /* Clear POLLIN bits, and set based on recv_pkt */ 930 ret &= ~(POLLIN | POLLRDNORM); 931 if (ctx->recv_pkt) 932 ret |= POLLIN | POLLRDNORM; 933 934 return ret; 935 } 936 937 static int tls_read_size(struct strparser *strp, struct sk_buff *skb) 938 { 939 struct tls_context *tls_ctx = tls_get_ctx(strp->sk); 940 struct tls_sw_context *ctx = tls_sw_ctx(tls_ctx); 941 char header[tls_ctx->rx.prepend_size]; 942 struct strp_msg *rxm = strp_msg(skb); 943 size_t cipher_overhead; 944 size_t data_len = 0; 945 int ret; 946 947 /* Verify that we have a full TLS header, or wait for more data */ 948 if (rxm->offset + tls_ctx->rx.prepend_size > skb->len) 949 return 0; 950 951 /* Linearize header to local buffer */ 952 ret = skb_copy_bits(skb, rxm->offset, header, tls_ctx->rx.prepend_size); 953 954 if (ret < 0) 955 goto read_failure; 956 957 ctx->control = header[0]; 958 959 data_len = ((header[4] & 0xFF) | (header[3] << 8)); 960 961 cipher_overhead = tls_ctx->rx.tag_size + tls_ctx->rx.iv_size; 962 963 if (data_len > TLS_MAX_PAYLOAD_SIZE + cipher_overhead) { 964 ret = -EMSGSIZE; 965 goto read_failure; 966 } 967 if (data_len < cipher_overhead) { 968 ret = -EBADMSG; 969 goto read_failure; 970 } 971 972 if (header[1] != TLS_VERSION_MINOR(tls_ctx->crypto_recv.version) || 973 header[2] != TLS_VERSION_MAJOR(tls_ctx->crypto_recv.version)) { 974 ret = -EINVAL; 975 goto read_failure; 976 } 977 978 return data_len + TLS_HEADER_SIZE; 979 980 read_failure: 981 tls_err_abort(strp->sk, ret); 982 983 return ret; 984 } 985 986 static void tls_queue(struct strparser *strp, struct sk_buff *skb) 987 { 988 struct tls_context *tls_ctx = tls_get_ctx(strp->sk); 989 struct tls_sw_context *ctx = tls_sw_ctx(tls_ctx); 990 struct strp_msg *rxm; 991 992 rxm = strp_msg(skb); 993 994 ctx->decrypted = false; 995 996 ctx->recv_pkt = skb; 997 strp_pause(strp); 998 999 strp->sk->sk_state_change(strp->sk); 1000 } 1001 1002 static void tls_data_ready(struct sock *sk) 1003 { 1004 struct tls_context *tls_ctx = tls_get_ctx(sk); 1005 struct tls_sw_context *ctx = tls_sw_ctx(tls_ctx); 1006 1007 strp_data_ready(&ctx->strp); 1008 } 1009 1010 void tls_sw_free_resources(struct sock *sk) 1011 { 1012 struct tls_context *tls_ctx = tls_get_ctx(sk); 1013 struct tls_sw_context *ctx = tls_sw_ctx(tls_ctx); 1014 1015 if (ctx->aead_send) 1016 crypto_free_aead(ctx->aead_send); 1017 if (ctx->aead_recv) { 1018 if (ctx->recv_pkt) { 1019 kfree_skb(ctx->recv_pkt); 1020 ctx->recv_pkt = NULL; 1021 } 1022 crypto_free_aead(ctx->aead_recv); 1023 strp_stop(&ctx->strp); 1024 write_lock_bh(&sk->sk_callback_lock); 1025 sk->sk_data_ready = ctx->saved_data_ready; 1026 write_unlock_bh(&sk->sk_callback_lock); 1027 release_sock(sk); 1028 strp_done(&ctx->strp); 1029 lock_sock(sk); 1030 } 1031 1032 tls_free_both_sg(sk); 1033 1034 kfree(ctx); 1035 kfree(tls_ctx); 1036 } 1037 1038 int tls_set_sw_offload(struct sock *sk, struct tls_context *ctx, int tx) 1039 { 1040 char keyval[TLS_CIPHER_AES_GCM_128_KEY_SIZE]; 1041 struct tls_crypto_info *crypto_info; 1042 struct tls12_crypto_info_aes_gcm_128 *gcm_128_info; 1043 struct tls_sw_context *sw_ctx; 1044 struct cipher_context *cctx; 1045 struct crypto_aead **aead; 1046 struct strp_callbacks cb; 1047 u16 nonce_size, tag_size, iv_size, rec_seq_size; 1048 char *iv, *rec_seq; 1049 int rc = 0; 1050 1051 if (!ctx) { 1052 rc = -EINVAL; 1053 goto out; 1054 } 1055 1056 if (!ctx->priv_ctx) { 1057 sw_ctx = kzalloc(sizeof(*sw_ctx), GFP_KERNEL); 1058 if (!sw_ctx) { 1059 rc = -ENOMEM; 1060 goto out; 1061 } 1062 crypto_init_wait(&sw_ctx->async_wait); 1063 } else { 1064 sw_ctx = ctx->priv_ctx; 1065 } 1066 1067 ctx->priv_ctx = (struct tls_offload_context *)sw_ctx; 1068 1069 if (tx) { 1070 crypto_info = &ctx->crypto_send; 1071 cctx = &ctx->tx; 1072 aead = &sw_ctx->aead_send; 1073 } else { 1074 crypto_info = &ctx->crypto_recv; 1075 cctx = &ctx->rx; 1076 aead = &sw_ctx->aead_recv; 1077 } 1078 1079 switch (crypto_info->cipher_type) { 1080 case TLS_CIPHER_AES_GCM_128: { 1081 nonce_size = TLS_CIPHER_AES_GCM_128_IV_SIZE; 1082 tag_size = TLS_CIPHER_AES_GCM_128_TAG_SIZE; 1083 iv_size = TLS_CIPHER_AES_GCM_128_IV_SIZE; 1084 iv = ((struct tls12_crypto_info_aes_gcm_128 *)crypto_info)->iv; 1085 rec_seq_size = TLS_CIPHER_AES_GCM_128_REC_SEQ_SIZE; 1086 rec_seq = 1087 ((struct tls12_crypto_info_aes_gcm_128 *)crypto_info)->rec_seq; 1088 gcm_128_info = 1089 (struct tls12_crypto_info_aes_gcm_128 *)crypto_info; 1090 break; 1091 } 1092 default: 1093 rc = -EINVAL; 1094 goto free_priv; 1095 } 1096 1097 cctx->prepend_size = TLS_HEADER_SIZE + nonce_size; 1098 cctx->tag_size = tag_size; 1099 cctx->overhead_size = cctx->prepend_size + cctx->tag_size; 1100 cctx->iv_size = iv_size; 1101 cctx->iv = kmalloc(iv_size + TLS_CIPHER_AES_GCM_128_SALT_SIZE, 1102 GFP_KERNEL); 1103 if (!cctx->iv) { 1104 rc = -ENOMEM; 1105 goto free_priv; 1106 } 1107 memcpy(cctx->iv, gcm_128_info->salt, TLS_CIPHER_AES_GCM_128_SALT_SIZE); 1108 memcpy(cctx->iv + TLS_CIPHER_AES_GCM_128_SALT_SIZE, iv, iv_size); 1109 cctx->rec_seq_size = rec_seq_size; 1110 cctx->rec_seq = kmalloc(rec_seq_size, GFP_KERNEL); 1111 if (!cctx->rec_seq) { 1112 rc = -ENOMEM; 1113 goto free_iv; 1114 } 1115 memcpy(cctx->rec_seq, rec_seq, rec_seq_size); 1116 1117 if (tx) { 1118 sg_init_table(sw_ctx->sg_encrypted_data, 1119 ARRAY_SIZE(sw_ctx->sg_encrypted_data)); 1120 sg_init_table(sw_ctx->sg_plaintext_data, 1121 ARRAY_SIZE(sw_ctx->sg_plaintext_data)); 1122 1123 sg_init_table(sw_ctx->sg_aead_in, 2); 1124 sg_set_buf(&sw_ctx->sg_aead_in[0], sw_ctx->aad_space, 1125 sizeof(sw_ctx->aad_space)); 1126 sg_unmark_end(&sw_ctx->sg_aead_in[1]); 1127 sg_chain(sw_ctx->sg_aead_in, 2, sw_ctx->sg_plaintext_data); 1128 sg_init_table(sw_ctx->sg_aead_out, 2); 1129 sg_set_buf(&sw_ctx->sg_aead_out[0], sw_ctx->aad_space, 1130 sizeof(sw_ctx->aad_space)); 1131 sg_unmark_end(&sw_ctx->sg_aead_out[1]); 1132 sg_chain(sw_ctx->sg_aead_out, 2, sw_ctx->sg_encrypted_data); 1133 } 1134 1135 if (!*aead) { 1136 *aead = crypto_alloc_aead("gcm(aes)", 0, 0); 1137 if (IS_ERR(*aead)) { 1138 rc = PTR_ERR(*aead); 1139 *aead = NULL; 1140 goto free_rec_seq; 1141 } 1142 } 1143 1144 ctx->push_pending_record = tls_sw_push_pending_record; 1145 1146 memcpy(keyval, gcm_128_info->key, TLS_CIPHER_AES_GCM_128_KEY_SIZE); 1147 1148 rc = crypto_aead_setkey(*aead, keyval, 1149 TLS_CIPHER_AES_GCM_128_KEY_SIZE); 1150 if (rc) 1151 goto free_aead; 1152 1153 rc = crypto_aead_setauthsize(*aead, cctx->tag_size); 1154 if (rc) 1155 goto free_aead; 1156 1157 if (!tx) { 1158 /* Set up strparser */ 1159 memset(&cb, 0, sizeof(cb)); 1160 cb.rcv_msg = tls_queue; 1161 cb.parse_msg = tls_read_size; 1162 1163 strp_init(&sw_ctx->strp, sk, &cb); 1164 1165 write_lock_bh(&sk->sk_callback_lock); 1166 sw_ctx->saved_data_ready = sk->sk_data_ready; 1167 sk->sk_data_ready = tls_data_ready; 1168 write_unlock_bh(&sk->sk_callback_lock); 1169 1170 sw_ctx->sk_poll = sk->sk_socket->ops->poll; 1171 1172 strp_check_rcv(&sw_ctx->strp); 1173 } 1174 1175 goto out; 1176 1177 free_aead: 1178 crypto_free_aead(*aead); 1179 *aead = NULL; 1180 free_rec_seq: 1181 kfree(cctx->rec_seq); 1182 cctx->rec_seq = NULL; 1183 free_iv: 1184 kfree(ctx->tx.iv); 1185 ctx->tx.iv = NULL; 1186 free_priv: 1187 kfree(ctx->priv_ctx); 1188 ctx->priv_ctx = NULL; 1189 out: 1190 return rc; 1191 } 1192