1 /* 2 * Copyright (c) 2016-2017, Mellanox Technologies. All rights reserved. 3 * Copyright (c) 2016-2017, Dave Watson <davejwatson@fb.com>. All rights reserved. 4 * Copyright (c) 2016-2017, Lance Chao <lancerchao@fb.com>. All rights reserved. 5 * Copyright (c) 2016, Fridolin Pokorny <fridolin.pokorny@gmail.com>. All rights reserved. 6 * Copyright (c) 2016, Nikos Mavrogiannopoulos <nmav@gnutls.org>. All rights reserved. 7 * Copyright (c) 2018, Covalent IO, Inc. http://covalent.io 8 * 9 * This software is available to you under a choice of one of two 10 * licenses. You may choose to be licensed under the terms of the GNU 11 * General Public License (GPL) Version 2, available from the file 12 * COPYING in the main directory of this source tree, or the 13 * OpenIB.org BSD license below: 14 * 15 * Redistribution and use in source and binary forms, with or 16 * without modification, are permitted provided that the following 17 * conditions are met: 18 * 19 * - Redistributions of source code must retain the above 20 * copyright notice, this list of conditions and the following 21 * disclaimer. 22 * 23 * - Redistributions in binary form must reproduce the above 24 * copyright notice, this list of conditions and the following 25 * disclaimer in the documentation and/or other materials 26 * provided with the distribution. 27 * 28 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, 29 * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF 30 * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND 31 * NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS 32 * BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN 33 * ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN 34 * CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 35 * SOFTWARE. 36 */ 37 38 #include <linux/bug.h> 39 #include <linux/sched/signal.h> 40 #include <linux/module.h> 41 #include <linux/splice.h> 42 #include <crypto/aead.h> 43 44 #include <net/strparser.h> 45 #include <net/tls.h> 46 47 struct tls_decrypt_arg { 48 bool zc; 49 bool async; 50 u8 tail; 51 }; 52 53 noinline void tls_err_abort(struct sock *sk, int err) 54 { 55 WARN_ON_ONCE(err >= 0); 56 /* sk->sk_err should contain a positive error code. */ 57 sk->sk_err = -err; 58 sk_error_report(sk); 59 } 60 61 static int __skb_nsg(struct sk_buff *skb, int offset, int len, 62 unsigned int recursion_level) 63 { 64 int start = skb_headlen(skb); 65 int i, chunk = start - offset; 66 struct sk_buff *frag_iter; 67 int elt = 0; 68 69 if (unlikely(recursion_level >= 24)) 70 return -EMSGSIZE; 71 72 if (chunk > 0) { 73 if (chunk > len) 74 chunk = len; 75 elt++; 76 len -= chunk; 77 if (len == 0) 78 return elt; 79 offset += chunk; 80 } 81 82 for (i = 0; i < skb_shinfo(skb)->nr_frags; i++) { 83 int end; 84 85 WARN_ON(start > offset + len); 86 87 end = start + skb_frag_size(&skb_shinfo(skb)->frags[i]); 88 chunk = end - offset; 89 if (chunk > 0) { 90 if (chunk > len) 91 chunk = len; 92 elt++; 93 len -= chunk; 94 if (len == 0) 95 return elt; 96 offset += chunk; 97 } 98 start = end; 99 } 100 101 if (unlikely(skb_has_frag_list(skb))) { 102 skb_walk_frags(skb, frag_iter) { 103 int end, ret; 104 105 WARN_ON(start > offset + len); 106 107 end = start + frag_iter->len; 108 chunk = end - offset; 109 if (chunk > 0) { 110 if (chunk > len) 111 chunk = len; 112 ret = __skb_nsg(frag_iter, offset - start, chunk, 113 recursion_level + 1); 114 if (unlikely(ret < 0)) 115 return ret; 116 elt += ret; 117 len -= chunk; 118 if (len == 0) 119 return elt; 120 offset += chunk; 121 } 122 start = end; 123 } 124 } 125 BUG_ON(len); 126 return elt; 127 } 128 129 /* Return the number of scatterlist elements required to completely map the 130 * skb, or -EMSGSIZE if the recursion depth is exceeded. 131 */ 132 static int skb_nsg(struct sk_buff *skb, int offset, int len) 133 { 134 return __skb_nsg(skb, offset, len, 0); 135 } 136 137 static int tls_padding_length(struct tls_prot_info *prot, struct sk_buff *skb, 138 struct tls_decrypt_arg *darg) 139 { 140 struct strp_msg *rxm = strp_msg(skb); 141 struct tls_msg *tlm = tls_msg(skb); 142 int sub = 0; 143 144 /* Determine zero-padding length */ 145 if (prot->version == TLS_1_3_VERSION) { 146 int offset = rxm->full_len - TLS_TAG_SIZE - 1; 147 char content_type = darg->zc ? darg->tail : 0; 148 int err; 149 150 while (content_type == 0) { 151 if (offset < prot->prepend_size) 152 return -EBADMSG; 153 err = skb_copy_bits(skb, rxm->offset + offset, 154 &content_type, 1); 155 if (err) 156 return err; 157 if (content_type) 158 break; 159 sub++; 160 offset--; 161 } 162 tlm->control = content_type; 163 } 164 return sub; 165 } 166 167 static void tls_decrypt_done(struct crypto_async_request *req, int err) 168 { 169 struct aead_request *aead_req = (struct aead_request *)req; 170 struct scatterlist *sgout = aead_req->dst; 171 struct scatterlist *sgin = aead_req->src; 172 struct tls_sw_context_rx *ctx; 173 struct tls_context *tls_ctx; 174 struct tls_prot_info *prot; 175 struct scatterlist *sg; 176 struct sk_buff *skb; 177 unsigned int pages; 178 179 skb = (struct sk_buff *)req->data; 180 tls_ctx = tls_get_ctx(skb->sk); 181 ctx = tls_sw_ctx_rx(tls_ctx); 182 prot = &tls_ctx->prot_info; 183 184 /* Propagate if there was an err */ 185 if (err) { 186 if (err == -EBADMSG) 187 TLS_INC_STATS(sock_net(skb->sk), 188 LINUX_MIB_TLSDECRYPTERROR); 189 ctx->async_wait.err = err; 190 tls_err_abort(skb->sk, err); 191 } else { 192 struct strp_msg *rxm = strp_msg(skb); 193 194 /* No TLS 1.3 support with async crypto */ 195 WARN_ON(prot->tail_size); 196 197 rxm->offset += prot->prepend_size; 198 rxm->full_len -= prot->overhead_size; 199 } 200 201 /* After using skb->sk to propagate sk through crypto async callback 202 * we need to NULL it again. 203 */ 204 skb->sk = NULL; 205 206 207 /* Free the destination pages if skb was not decrypted inplace */ 208 if (sgout != sgin) { 209 /* Skip the first S/G entry as it points to AAD */ 210 for_each_sg(sg_next(sgout), sg, UINT_MAX, pages) { 211 if (!sg) 212 break; 213 put_page(sg_page(sg)); 214 } 215 } 216 217 kfree(aead_req); 218 219 spin_lock_bh(&ctx->decrypt_compl_lock); 220 if (!atomic_dec_return(&ctx->decrypt_pending)) 221 complete(&ctx->async_wait.completion); 222 spin_unlock_bh(&ctx->decrypt_compl_lock); 223 } 224 225 static int tls_do_decryption(struct sock *sk, 226 struct sk_buff *skb, 227 struct scatterlist *sgin, 228 struct scatterlist *sgout, 229 char *iv_recv, 230 size_t data_len, 231 struct aead_request *aead_req, 232 struct tls_decrypt_arg *darg) 233 { 234 struct tls_context *tls_ctx = tls_get_ctx(sk); 235 struct tls_prot_info *prot = &tls_ctx->prot_info; 236 struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx); 237 int ret; 238 239 aead_request_set_tfm(aead_req, ctx->aead_recv); 240 aead_request_set_ad(aead_req, prot->aad_size); 241 aead_request_set_crypt(aead_req, sgin, sgout, 242 data_len + prot->tag_size, 243 (u8 *)iv_recv); 244 245 if (darg->async) { 246 /* Using skb->sk to push sk through to crypto async callback 247 * handler. This allows propagating errors up to the socket 248 * if needed. It _must_ be cleared in the async handler 249 * before consume_skb is called. We _know_ skb->sk is NULL 250 * because it is a clone from strparser. 251 */ 252 skb->sk = sk; 253 aead_request_set_callback(aead_req, 254 CRYPTO_TFM_REQ_MAY_BACKLOG, 255 tls_decrypt_done, skb); 256 atomic_inc(&ctx->decrypt_pending); 257 } else { 258 aead_request_set_callback(aead_req, 259 CRYPTO_TFM_REQ_MAY_BACKLOG, 260 crypto_req_done, &ctx->async_wait); 261 } 262 263 ret = crypto_aead_decrypt(aead_req); 264 if (ret == -EINPROGRESS) { 265 if (darg->async) 266 return 0; 267 268 ret = crypto_wait_req(ret, &ctx->async_wait); 269 } 270 darg->async = false; 271 272 if (ret == -EBADMSG) 273 TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSDECRYPTERROR); 274 275 return ret; 276 } 277 278 static void tls_trim_both_msgs(struct sock *sk, int target_size) 279 { 280 struct tls_context *tls_ctx = tls_get_ctx(sk); 281 struct tls_prot_info *prot = &tls_ctx->prot_info; 282 struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx); 283 struct tls_rec *rec = ctx->open_rec; 284 285 sk_msg_trim(sk, &rec->msg_plaintext, target_size); 286 if (target_size > 0) 287 target_size += prot->overhead_size; 288 sk_msg_trim(sk, &rec->msg_encrypted, target_size); 289 } 290 291 static int tls_alloc_encrypted_msg(struct sock *sk, int len) 292 { 293 struct tls_context *tls_ctx = tls_get_ctx(sk); 294 struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx); 295 struct tls_rec *rec = ctx->open_rec; 296 struct sk_msg *msg_en = &rec->msg_encrypted; 297 298 return sk_msg_alloc(sk, msg_en, len, 0); 299 } 300 301 static int tls_clone_plaintext_msg(struct sock *sk, int required) 302 { 303 struct tls_context *tls_ctx = tls_get_ctx(sk); 304 struct tls_prot_info *prot = &tls_ctx->prot_info; 305 struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx); 306 struct tls_rec *rec = ctx->open_rec; 307 struct sk_msg *msg_pl = &rec->msg_plaintext; 308 struct sk_msg *msg_en = &rec->msg_encrypted; 309 int skip, len; 310 311 /* We add page references worth len bytes from encrypted sg 312 * at the end of plaintext sg. It is guaranteed that msg_en 313 * has enough required room (ensured by caller). 314 */ 315 len = required - msg_pl->sg.size; 316 317 /* Skip initial bytes in msg_en's data to be able to use 318 * same offset of both plain and encrypted data. 319 */ 320 skip = prot->prepend_size + msg_pl->sg.size; 321 322 return sk_msg_clone(sk, msg_pl, msg_en, skip, len); 323 } 324 325 static struct tls_rec *tls_get_rec(struct sock *sk) 326 { 327 struct tls_context *tls_ctx = tls_get_ctx(sk); 328 struct tls_prot_info *prot = &tls_ctx->prot_info; 329 struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx); 330 struct sk_msg *msg_pl, *msg_en; 331 struct tls_rec *rec; 332 int mem_size; 333 334 mem_size = sizeof(struct tls_rec) + crypto_aead_reqsize(ctx->aead_send); 335 336 rec = kzalloc(mem_size, sk->sk_allocation); 337 if (!rec) 338 return NULL; 339 340 msg_pl = &rec->msg_plaintext; 341 msg_en = &rec->msg_encrypted; 342 343 sk_msg_init(msg_pl); 344 sk_msg_init(msg_en); 345 346 sg_init_table(rec->sg_aead_in, 2); 347 sg_set_buf(&rec->sg_aead_in[0], rec->aad_space, prot->aad_size); 348 sg_unmark_end(&rec->sg_aead_in[1]); 349 350 sg_init_table(rec->sg_aead_out, 2); 351 sg_set_buf(&rec->sg_aead_out[0], rec->aad_space, prot->aad_size); 352 sg_unmark_end(&rec->sg_aead_out[1]); 353 354 return rec; 355 } 356 357 static void tls_free_rec(struct sock *sk, struct tls_rec *rec) 358 { 359 sk_msg_free(sk, &rec->msg_encrypted); 360 sk_msg_free(sk, &rec->msg_plaintext); 361 kfree(rec); 362 } 363 364 static void tls_free_open_rec(struct sock *sk) 365 { 366 struct tls_context *tls_ctx = tls_get_ctx(sk); 367 struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx); 368 struct tls_rec *rec = ctx->open_rec; 369 370 if (rec) { 371 tls_free_rec(sk, rec); 372 ctx->open_rec = NULL; 373 } 374 } 375 376 int tls_tx_records(struct sock *sk, int flags) 377 { 378 struct tls_context *tls_ctx = tls_get_ctx(sk); 379 struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx); 380 struct tls_rec *rec, *tmp; 381 struct sk_msg *msg_en; 382 int tx_flags, rc = 0; 383 384 if (tls_is_partially_sent_record(tls_ctx)) { 385 rec = list_first_entry(&ctx->tx_list, 386 struct tls_rec, list); 387 388 if (flags == -1) 389 tx_flags = rec->tx_flags; 390 else 391 tx_flags = flags; 392 393 rc = tls_push_partial_record(sk, tls_ctx, tx_flags); 394 if (rc) 395 goto tx_err; 396 397 /* Full record has been transmitted. 398 * Remove the head of tx_list 399 */ 400 list_del(&rec->list); 401 sk_msg_free(sk, &rec->msg_plaintext); 402 kfree(rec); 403 } 404 405 /* Tx all ready records */ 406 list_for_each_entry_safe(rec, tmp, &ctx->tx_list, list) { 407 if (READ_ONCE(rec->tx_ready)) { 408 if (flags == -1) 409 tx_flags = rec->tx_flags; 410 else 411 tx_flags = flags; 412 413 msg_en = &rec->msg_encrypted; 414 rc = tls_push_sg(sk, tls_ctx, 415 &msg_en->sg.data[msg_en->sg.curr], 416 0, tx_flags); 417 if (rc) 418 goto tx_err; 419 420 list_del(&rec->list); 421 sk_msg_free(sk, &rec->msg_plaintext); 422 kfree(rec); 423 } else { 424 break; 425 } 426 } 427 428 tx_err: 429 if (rc < 0 && rc != -EAGAIN) 430 tls_err_abort(sk, -EBADMSG); 431 432 return rc; 433 } 434 435 static void tls_encrypt_done(struct crypto_async_request *req, int err) 436 { 437 struct aead_request *aead_req = (struct aead_request *)req; 438 struct sock *sk = req->data; 439 struct tls_context *tls_ctx = tls_get_ctx(sk); 440 struct tls_prot_info *prot = &tls_ctx->prot_info; 441 struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx); 442 struct scatterlist *sge; 443 struct sk_msg *msg_en; 444 struct tls_rec *rec; 445 bool ready = false; 446 int pending; 447 448 rec = container_of(aead_req, struct tls_rec, aead_req); 449 msg_en = &rec->msg_encrypted; 450 451 sge = sk_msg_elem(msg_en, msg_en->sg.curr); 452 sge->offset -= prot->prepend_size; 453 sge->length += prot->prepend_size; 454 455 /* Check if error is previously set on socket */ 456 if (err || sk->sk_err) { 457 rec = NULL; 458 459 /* If err is already set on socket, return the same code */ 460 if (sk->sk_err) { 461 ctx->async_wait.err = -sk->sk_err; 462 } else { 463 ctx->async_wait.err = err; 464 tls_err_abort(sk, err); 465 } 466 } 467 468 if (rec) { 469 struct tls_rec *first_rec; 470 471 /* Mark the record as ready for transmission */ 472 smp_store_mb(rec->tx_ready, true); 473 474 /* If received record is at head of tx_list, schedule tx */ 475 first_rec = list_first_entry(&ctx->tx_list, 476 struct tls_rec, list); 477 if (rec == first_rec) 478 ready = true; 479 } 480 481 spin_lock_bh(&ctx->encrypt_compl_lock); 482 pending = atomic_dec_return(&ctx->encrypt_pending); 483 484 if (!pending && ctx->async_notify) 485 complete(&ctx->async_wait.completion); 486 spin_unlock_bh(&ctx->encrypt_compl_lock); 487 488 if (!ready) 489 return; 490 491 /* Schedule the transmission */ 492 if (!test_and_set_bit(BIT_TX_SCHEDULED, &ctx->tx_bitmask)) 493 schedule_delayed_work(&ctx->tx_work.work, 1); 494 } 495 496 static int tls_do_encryption(struct sock *sk, 497 struct tls_context *tls_ctx, 498 struct tls_sw_context_tx *ctx, 499 struct aead_request *aead_req, 500 size_t data_len, u32 start) 501 { 502 struct tls_prot_info *prot = &tls_ctx->prot_info; 503 struct tls_rec *rec = ctx->open_rec; 504 struct sk_msg *msg_en = &rec->msg_encrypted; 505 struct scatterlist *sge = sk_msg_elem(msg_en, start); 506 int rc, iv_offset = 0; 507 508 /* For CCM based ciphers, first byte of IV is a constant */ 509 switch (prot->cipher_type) { 510 case TLS_CIPHER_AES_CCM_128: 511 rec->iv_data[0] = TLS_AES_CCM_IV_B0_BYTE; 512 iv_offset = 1; 513 break; 514 case TLS_CIPHER_SM4_CCM: 515 rec->iv_data[0] = TLS_SM4_CCM_IV_B0_BYTE; 516 iv_offset = 1; 517 break; 518 } 519 520 memcpy(&rec->iv_data[iv_offset], tls_ctx->tx.iv, 521 prot->iv_size + prot->salt_size); 522 523 xor_iv_with_seq(prot, rec->iv_data + iv_offset, tls_ctx->tx.rec_seq); 524 525 sge->offset += prot->prepend_size; 526 sge->length -= prot->prepend_size; 527 528 msg_en->sg.curr = start; 529 530 aead_request_set_tfm(aead_req, ctx->aead_send); 531 aead_request_set_ad(aead_req, prot->aad_size); 532 aead_request_set_crypt(aead_req, rec->sg_aead_in, 533 rec->sg_aead_out, 534 data_len, rec->iv_data); 535 536 aead_request_set_callback(aead_req, CRYPTO_TFM_REQ_MAY_BACKLOG, 537 tls_encrypt_done, sk); 538 539 /* Add the record in tx_list */ 540 list_add_tail((struct list_head *)&rec->list, &ctx->tx_list); 541 atomic_inc(&ctx->encrypt_pending); 542 543 rc = crypto_aead_encrypt(aead_req); 544 if (!rc || rc != -EINPROGRESS) { 545 atomic_dec(&ctx->encrypt_pending); 546 sge->offset -= prot->prepend_size; 547 sge->length += prot->prepend_size; 548 } 549 550 if (!rc) { 551 WRITE_ONCE(rec->tx_ready, true); 552 } else if (rc != -EINPROGRESS) { 553 list_del(&rec->list); 554 return rc; 555 } 556 557 /* Unhook the record from context if encryption is not failure */ 558 ctx->open_rec = NULL; 559 tls_advance_record_sn(sk, prot, &tls_ctx->tx); 560 return rc; 561 } 562 563 static int tls_split_open_record(struct sock *sk, struct tls_rec *from, 564 struct tls_rec **to, struct sk_msg *msg_opl, 565 struct sk_msg *msg_oen, u32 split_point, 566 u32 tx_overhead_size, u32 *orig_end) 567 { 568 u32 i, j, bytes = 0, apply = msg_opl->apply_bytes; 569 struct scatterlist *sge, *osge, *nsge; 570 u32 orig_size = msg_opl->sg.size; 571 struct scatterlist tmp = { }; 572 struct sk_msg *msg_npl; 573 struct tls_rec *new; 574 int ret; 575 576 new = tls_get_rec(sk); 577 if (!new) 578 return -ENOMEM; 579 ret = sk_msg_alloc(sk, &new->msg_encrypted, msg_opl->sg.size + 580 tx_overhead_size, 0); 581 if (ret < 0) { 582 tls_free_rec(sk, new); 583 return ret; 584 } 585 586 *orig_end = msg_opl->sg.end; 587 i = msg_opl->sg.start; 588 sge = sk_msg_elem(msg_opl, i); 589 while (apply && sge->length) { 590 if (sge->length > apply) { 591 u32 len = sge->length - apply; 592 593 get_page(sg_page(sge)); 594 sg_set_page(&tmp, sg_page(sge), len, 595 sge->offset + apply); 596 sge->length = apply; 597 bytes += apply; 598 apply = 0; 599 } else { 600 apply -= sge->length; 601 bytes += sge->length; 602 } 603 604 sk_msg_iter_var_next(i); 605 if (i == msg_opl->sg.end) 606 break; 607 sge = sk_msg_elem(msg_opl, i); 608 } 609 610 msg_opl->sg.end = i; 611 msg_opl->sg.curr = i; 612 msg_opl->sg.copybreak = 0; 613 msg_opl->apply_bytes = 0; 614 msg_opl->sg.size = bytes; 615 616 msg_npl = &new->msg_plaintext; 617 msg_npl->apply_bytes = apply; 618 msg_npl->sg.size = orig_size - bytes; 619 620 j = msg_npl->sg.start; 621 nsge = sk_msg_elem(msg_npl, j); 622 if (tmp.length) { 623 memcpy(nsge, &tmp, sizeof(*nsge)); 624 sk_msg_iter_var_next(j); 625 nsge = sk_msg_elem(msg_npl, j); 626 } 627 628 osge = sk_msg_elem(msg_opl, i); 629 while (osge->length) { 630 memcpy(nsge, osge, sizeof(*nsge)); 631 sg_unmark_end(nsge); 632 sk_msg_iter_var_next(i); 633 sk_msg_iter_var_next(j); 634 if (i == *orig_end) 635 break; 636 osge = sk_msg_elem(msg_opl, i); 637 nsge = sk_msg_elem(msg_npl, j); 638 } 639 640 msg_npl->sg.end = j; 641 msg_npl->sg.curr = j; 642 msg_npl->sg.copybreak = 0; 643 644 *to = new; 645 return 0; 646 } 647 648 static void tls_merge_open_record(struct sock *sk, struct tls_rec *to, 649 struct tls_rec *from, u32 orig_end) 650 { 651 struct sk_msg *msg_npl = &from->msg_plaintext; 652 struct sk_msg *msg_opl = &to->msg_plaintext; 653 struct scatterlist *osge, *nsge; 654 u32 i, j; 655 656 i = msg_opl->sg.end; 657 sk_msg_iter_var_prev(i); 658 j = msg_npl->sg.start; 659 660 osge = sk_msg_elem(msg_opl, i); 661 nsge = sk_msg_elem(msg_npl, j); 662 663 if (sg_page(osge) == sg_page(nsge) && 664 osge->offset + osge->length == nsge->offset) { 665 osge->length += nsge->length; 666 put_page(sg_page(nsge)); 667 } 668 669 msg_opl->sg.end = orig_end; 670 msg_opl->sg.curr = orig_end; 671 msg_opl->sg.copybreak = 0; 672 msg_opl->apply_bytes = msg_opl->sg.size + msg_npl->sg.size; 673 msg_opl->sg.size += msg_npl->sg.size; 674 675 sk_msg_free(sk, &to->msg_encrypted); 676 sk_msg_xfer_full(&to->msg_encrypted, &from->msg_encrypted); 677 678 kfree(from); 679 } 680 681 static int tls_push_record(struct sock *sk, int flags, 682 unsigned char record_type) 683 { 684 struct tls_context *tls_ctx = tls_get_ctx(sk); 685 struct tls_prot_info *prot = &tls_ctx->prot_info; 686 struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx); 687 struct tls_rec *rec = ctx->open_rec, *tmp = NULL; 688 u32 i, split_point, orig_end; 689 struct sk_msg *msg_pl, *msg_en; 690 struct aead_request *req; 691 bool split; 692 int rc; 693 694 if (!rec) 695 return 0; 696 697 msg_pl = &rec->msg_plaintext; 698 msg_en = &rec->msg_encrypted; 699 700 split_point = msg_pl->apply_bytes; 701 split = split_point && split_point < msg_pl->sg.size; 702 if (unlikely((!split && 703 msg_pl->sg.size + 704 prot->overhead_size > msg_en->sg.size) || 705 (split && 706 split_point + 707 prot->overhead_size > msg_en->sg.size))) { 708 split = true; 709 split_point = msg_en->sg.size; 710 } 711 if (split) { 712 rc = tls_split_open_record(sk, rec, &tmp, msg_pl, msg_en, 713 split_point, prot->overhead_size, 714 &orig_end); 715 if (rc < 0) 716 return rc; 717 /* This can happen if above tls_split_open_record allocates 718 * a single large encryption buffer instead of two smaller 719 * ones. In this case adjust pointers and continue without 720 * split. 721 */ 722 if (!msg_pl->sg.size) { 723 tls_merge_open_record(sk, rec, tmp, orig_end); 724 msg_pl = &rec->msg_plaintext; 725 msg_en = &rec->msg_encrypted; 726 split = false; 727 } 728 sk_msg_trim(sk, msg_en, msg_pl->sg.size + 729 prot->overhead_size); 730 } 731 732 rec->tx_flags = flags; 733 req = &rec->aead_req; 734 735 i = msg_pl->sg.end; 736 sk_msg_iter_var_prev(i); 737 738 rec->content_type = record_type; 739 if (prot->version == TLS_1_3_VERSION) { 740 /* Add content type to end of message. No padding added */ 741 sg_set_buf(&rec->sg_content_type, &rec->content_type, 1); 742 sg_mark_end(&rec->sg_content_type); 743 sg_chain(msg_pl->sg.data, msg_pl->sg.end + 1, 744 &rec->sg_content_type); 745 } else { 746 sg_mark_end(sk_msg_elem(msg_pl, i)); 747 } 748 749 if (msg_pl->sg.end < msg_pl->sg.start) { 750 sg_chain(&msg_pl->sg.data[msg_pl->sg.start], 751 MAX_SKB_FRAGS - msg_pl->sg.start + 1, 752 msg_pl->sg.data); 753 } 754 755 i = msg_pl->sg.start; 756 sg_chain(rec->sg_aead_in, 2, &msg_pl->sg.data[i]); 757 758 i = msg_en->sg.end; 759 sk_msg_iter_var_prev(i); 760 sg_mark_end(sk_msg_elem(msg_en, i)); 761 762 i = msg_en->sg.start; 763 sg_chain(rec->sg_aead_out, 2, &msg_en->sg.data[i]); 764 765 tls_make_aad(rec->aad_space, msg_pl->sg.size + prot->tail_size, 766 tls_ctx->tx.rec_seq, record_type, prot); 767 768 tls_fill_prepend(tls_ctx, 769 page_address(sg_page(&msg_en->sg.data[i])) + 770 msg_en->sg.data[i].offset, 771 msg_pl->sg.size + prot->tail_size, 772 record_type); 773 774 tls_ctx->pending_open_record_frags = false; 775 776 rc = tls_do_encryption(sk, tls_ctx, ctx, req, 777 msg_pl->sg.size + prot->tail_size, i); 778 if (rc < 0) { 779 if (rc != -EINPROGRESS) { 780 tls_err_abort(sk, -EBADMSG); 781 if (split) { 782 tls_ctx->pending_open_record_frags = true; 783 tls_merge_open_record(sk, rec, tmp, orig_end); 784 } 785 } 786 ctx->async_capable = 1; 787 return rc; 788 } else if (split) { 789 msg_pl = &tmp->msg_plaintext; 790 msg_en = &tmp->msg_encrypted; 791 sk_msg_trim(sk, msg_en, msg_pl->sg.size + prot->overhead_size); 792 tls_ctx->pending_open_record_frags = true; 793 ctx->open_rec = tmp; 794 } 795 796 return tls_tx_records(sk, flags); 797 } 798 799 static int bpf_exec_tx_verdict(struct sk_msg *msg, struct sock *sk, 800 bool full_record, u8 record_type, 801 ssize_t *copied, int flags) 802 { 803 struct tls_context *tls_ctx = tls_get_ctx(sk); 804 struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx); 805 struct sk_msg msg_redir = { }; 806 struct sk_psock *psock; 807 struct sock *sk_redir; 808 struct tls_rec *rec; 809 bool enospc, policy; 810 int err = 0, send; 811 u32 delta = 0; 812 813 policy = !(flags & MSG_SENDPAGE_NOPOLICY); 814 psock = sk_psock_get(sk); 815 if (!psock || !policy) { 816 err = tls_push_record(sk, flags, record_type); 817 if (err && sk->sk_err == EBADMSG) { 818 *copied -= sk_msg_free(sk, msg); 819 tls_free_open_rec(sk); 820 err = -sk->sk_err; 821 } 822 if (psock) 823 sk_psock_put(sk, psock); 824 return err; 825 } 826 more_data: 827 enospc = sk_msg_full(msg); 828 if (psock->eval == __SK_NONE) { 829 delta = msg->sg.size; 830 psock->eval = sk_psock_msg_verdict(sk, psock, msg); 831 delta -= msg->sg.size; 832 } 833 if (msg->cork_bytes && msg->cork_bytes > msg->sg.size && 834 !enospc && !full_record) { 835 err = -ENOSPC; 836 goto out_err; 837 } 838 msg->cork_bytes = 0; 839 send = msg->sg.size; 840 if (msg->apply_bytes && msg->apply_bytes < send) 841 send = msg->apply_bytes; 842 843 switch (psock->eval) { 844 case __SK_PASS: 845 err = tls_push_record(sk, flags, record_type); 846 if (err && sk->sk_err == EBADMSG) { 847 *copied -= sk_msg_free(sk, msg); 848 tls_free_open_rec(sk); 849 err = -sk->sk_err; 850 goto out_err; 851 } 852 break; 853 case __SK_REDIRECT: 854 sk_redir = psock->sk_redir; 855 memcpy(&msg_redir, msg, sizeof(*msg)); 856 if (msg->apply_bytes < send) 857 msg->apply_bytes = 0; 858 else 859 msg->apply_bytes -= send; 860 sk_msg_return_zero(sk, msg, send); 861 msg->sg.size -= send; 862 release_sock(sk); 863 err = tcp_bpf_sendmsg_redir(sk_redir, &msg_redir, send, flags); 864 lock_sock(sk); 865 if (err < 0) { 866 *copied -= sk_msg_free_nocharge(sk, &msg_redir); 867 msg->sg.size = 0; 868 } 869 if (msg->sg.size == 0) 870 tls_free_open_rec(sk); 871 break; 872 case __SK_DROP: 873 default: 874 sk_msg_free_partial(sk, msg, send); 875 if (msg->apply_bytes < send) 876 msg->apply_bytes = 0; 877 else 878 msg->apply_bytes -= send; 879 if (msg->sg.size == 0) 880 tls_free_open_rec(sk); 881 *copied -= (send + delta); 882 err = -EACCES; 883 } 884 885 if (likely(!err)) { 886 bool reset_eval = !ctx->open_rec; 887 888 rec = ctx->open_rec; 889 if (rec) { 890 msg = &rec->msg_plaintext; 891 if (!msg->apply_bytes) 892 reset_eval = true; 893 } 894 if (reset_eval) { 895 psock->eval = __SK_NONE; 896 if (psock->sk_redir) { 897 sock_put(psock->sk_redir); 898 psock->sk_redir = NULL; 899 } 900 } 901 if (rec) 902 goto more_data; 903 } 904 out_err: 905 sk_psock_put(sk, psock); 906 return err; 907 } 908 909 static int tls_sw_push_pending_record(struct sock *sk, int flags) 910 { 911 struct tls_context *tls_ctx = tls_get_ctx(sk); 912 struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx); 913 struct tls_rec *rec = ctx->open_rec; 914 struct sk_msg *msg_pl; 915 size_t copied; 916 917 if (!rec) 918 return 0; 919 920 msg_pl = &rec->msg_plaintext; 921 copied = msg_pl->sg.size; 922 if (!copied) 923 return 0; 924 925 return bpf_exec_tx_verdict(msg_pl, sk, true, TLS_RECORD_TYPE_DATA, 926 &copied, flags); 927 } 928 929 int tls_sw_sendmsg(struct sock *sk, struct msghdr *msg, size_t size) 930 { 931 long timeo = sock_sndtimeo(sk, msg->msg_flags & MSG_DONTWAIT); 932 struct tls_context *tls_ctx = tls_get_ctx(sk); 933 struct tls_prot_info *prot = &tls_ctx->prot_info; 934 struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx); 935 bool async_capable = ctx->async_capable; 936 unsigned char record_type = TLS_RECORD_TYPE_DATA; 937 bool is_kvec = iov_iter_is_kvec(&msg->msg_iter); 938 bool eor = !(msg->msg_flags & MSG_MORE); 939 size_t try_to_copy; 940 ssize_t copied = 0; 941 struct sk_msg *msg_pl, *msg_en; 942 struct tls_rec *rec; 943 int required_size; 944 int num_async = 0; 945 bool full_record; 946 int record_room; 947 int num_zc = 0; 948 int orig_size; 949 int ret = 0; 950 int pending; 951 952 if (msg->msg_flags & ~(MSG_MORE | MSG_DONTWAIT | MSG_NOSIGNAL | 953 MSG_CMSG_COMPAT)) 954 return -EOPNOTSUPP; 955 956 mutex_lock(&tls_ctx->tx_lock); 957 lock_sock(sk); 958 959 if (unlikely(msg->msg_controllen)) { 960 ret = tls_proccess_cmsg(sk, msg, &record_type); 961 if (ret) { 962 if (ret == -EINPROGRESS) 963 num_async++; 964 else if (ret != -EAGAIN) 965 goto send_end; 966 } 967 } 968 969 while (msg_data_left(msg)) { 970 if (sk->sk_err) { 971 ret = -sk->sk_err; 972 goto send_end; 973 } 974 975 if (ctx->open_rec) 976 rec = ctx->open_rec; 977 else 978 rec = ctx->open_rec = tls_get_rec(sk); 979 if (!rec) { 980 ret = -ENOMEM; 981 goto send_end; 982 } 983 984 msg_pl = &rec->msg_plaintext; 985 msg_en = &rec->msg_encrypted; 986 987 orig_size = msg_pl->sg.size; 988 full_record = false; 989 try_to_copy = msg_data_left(msg); 990 record_room = TLS_MAX_PAYLOAD_SIZE - msg_pl->sg.size; 991 if (try_to_copy >= record_room) { 992 try_to_copy = record_room; 993 full_record = true; 994 } 995 996 required_size = msg_pl->sg.size + try_to_copy + 997 prot->overhead_size; 998 999 if (!sk_stream_memory_free(sk)) 1000 goto wait_for_sndbuf; 1001 1002 alloc_encrypted: 1003 ret = tls_alloc_encrypted_msg(sk, required_size); 1004 if (ret) { 1005 if (ret != -ENOSPC) 1006 goto wait_for_memory; 1007 1008 /* Adjust try_to_copy according to the amount that was 1009 * actually allocated. The difference is due 1010 * to max sg elements limit 1011 */ 1012 try_to_copy -= required_size - msg_en->sg.size; 1013 full_record = true; 1014 } 1015 1016 if (!is_kvec && (full_record || eor) && !async_capable) { 1017 u32 first = msg_pl->sg.end; 1018 1019 ret = sk_msg_zerocopy_from_iter(sk, &msg->msg_iter, 1020 msg_pl, try_to_copy); 1021 if (ret) 1022 goto fallback_to_reg_send; 1023 1024 num_zc++; 1025 copied += try_to_copy; 1026 1027 sk_msg_sg_copy_set(msg_pl, first); 1028 ret = bpf_exec_tx_verdict(msg_pl, sk, full_record, 1029 record_type, &copied, 1030 msg->msg_flags); 1031 if (ret) { 1032 if (ret == -EINPROGRESS) 1033 num_async++; 1034 else if (ret == -ENOMEM) 1035 goto wait_for_memory; 1036 else if (ctx->open_rec && ret == -ENOSPC) 1037 goto rollback_iter; 1038 else if (ret != -EAGAIN) 1039 goto send_end; 1040 } 1041 continue; 1042 rollback_iter: 1043 copied -= try_to_copy; 1044 sk_msg_sg_copy_clear(msg_pl, first); 1045 iov_iter_revert(&msg->msg_iter, 1046 msg_pl->sg.size - orig_size); 1047 fallback_to_reg_send: 1048 sk_msg_trim(sk, msg_pl, orig_size); 1049 } 1050 1051 required_size = msg_pl->sg.size + try_to_copy; 1052 1053 ret = tls_clone_plaintext_msg(sk, required_size); 1054 if (ret) { 1055 if (ret != -ENOSPC) 1056 goto send_end; 1057 1058 /* Adjust try_to_copy according to the amount that was 1059 * actually allocated. The difference is due 1060 * to max sg elements limit 1061 */ 1062 try_to_copy -= required_size - msg_pl->sg.size; 1063 full_record = true; 1064 sk_msg_trim(sk, msg_en, 1065 msg_pl->sg.size + prot->overhead_size); 1066 } 1067 1068 if (try_to_copy) { 1069 ret = sk_msg_memcopy_from_iter(sk, &msg->msg_iter, 1070 msg_pl, try_to_copy); 1071 if (ret < 0) 1072 goto trim_sgl; 1073 } 1074 1075 /* Open records defined only if successfully copied, otherwise 1076 * we would trim the sg but not reset the open record frags. 1077 */ 1078 tls_ctx->pending_open_record_frags = true; 1079 copied += try_to_copy; 1080 if (full_record || eor) { 1081 ret = bpf_exec_tx_verdict(msg_pl, sk, full_record, 1082 record_type, &copied, 1083 msg->msg_flags); 1084 if (ret) { 1085 if (ret == -EINPROGRESS) 1086 num_async++; 1087 else if (ret == -ENOMEM) 1088 goto wait_for_memory; 1089 else if (ret != -EAGAIN) { 1090 if (ret == -ENOSPC) 1091 ret = 0; 1092 goto send_end; 1093 } 1094 } 1095 } 1096 1097 continue; 1098 1099 wait_for_sndbuf: 1100 set_bit(SOCK_NOSPACE, &sk->sk_socket->flags); 1101 wait_for_memory: 1102 ret = sk_stream_wait_memory(sk, &timeo); 1103 if (ret) { 1104 trim_sgl: 1105 if (ctx->open_rec) 1106 tls_trim_both_msgs(sk, orig_size); 1107 goto send_end; 1108 } 1109 1110 if (ctx->open_rec && msg_en->sg.size < required_size) 1111 goto alloc_encrypted; 1112 } 1113 1114 if (!num_async) { 1115 goto send_end; 1116 } else if (num_zc) { 1117 /* Wait for pending encryptions to get completed */ 1118 spin_lock_bh(&ctx->encrypt_compl_lock); 1119 ctx->async_notify = true; 1120 1121 pending = atomic_read(&ctx->encrypt_pending); 1122 spin_unlock_bh(&ctx->encrypt_compl_lock); 1123 if (pending) 1124 crypto_wait_req(-EINPROGRESS, &ctx->async_wait); 1125 else 1126 reinit_completion(&ctx->async_wait.completion); 1127 1128 /* There can be no concurrent accesses, since we have no 1129 * pending encrypt operations 1130 */ 1131 WRITE_ONCE(ctx->async_notify, false); 1132 1133 if (ctx->async_wait.err) { 1134 ret = ctx->async_wait.err; 1135 copied = 0; 1136 } 1137 } 1138 1139 /* Transmit if any encryptions have completed */ 1140 if (test_and_clear_bit(BIT_TX_SCHEDULED, &ctx->tx_bitmask)) { 1141 cancel_delayed_work(&ctx->tx_work.work); 1142 tls_tx_records(sk, msg->msg_flags); 1143 } 1144 1145 send_end: 1146 ret = sk_stream_error(sk, msg->msg_flags, ret); 1147 1148 release_sock(sk); 1149 mutex_unlock(&tls_ctx->tx_lock); 1150 return copied > 0 ? copied : ret; 1151 } 1152 1153 static int tls_sw_do_sendpage(struct sock *sk, struct page *page, 1154 int offset, size_t size, int flags) 1155 { 1156 long timeo = sock_sndtimeo(sk, flags & MSG_DONTWAIT); 1157 struct tls_context *tls_ctx = tls_get_ctx(sk); 1158 struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx); 1159 struct tls_prot_info *prot = &tls_ctx->prot_info; 1160 unsigned char record_type = TLS_RECORD_TYPE_DATA; 1161 struct sk_msg *msg_pl; 1162 struct tls_rec *rec; 1163 int num_async = 0; 1164 ssize_t copied = 0; 1165 bool full_record; 1166 int record_room; 1167 int ret = 0; 1168 bool eor; 1169 1170 eor = !(flags & MSG_SENDPAGE_NOTLAST); 1171 sk_clear_bit(SOCKWQ_ASYNC_NOSPACE, sk); 1172 1173 /* Call the sk_stream functions to manage the sndbuf mem. */ 1174 while (size > 0) { 1175 size_t copy, required_size; 1176 1177 if (sk->sk_err) { 1178 ret = -sk->sk_err; 1179 goto sendpage_end; 1180 } 1181 1182 if (ctx->open_rec) 1183 rec = ctx->open_rec; 1184 else 1185 rec = ctx->open_rec = tls_get_rec(sk); 1186 if (!rec) { 1187 ret = -ENOMEM; 1188 goto sendpage_end; 1189 } 1190 1191 msg_pl = &rec->msg_plaintext; 1192 1193 full_record = false; 1194 record_room = TLS_MAX_PAYLOAD_SIZE - msg_pl->sg.size; 1195 copy = size; 1196 if (copy >= record_room) { 1197 copy = record_room; 1198 full_record = true; 1199 } 1200 1201 required_size = msg_pl->sg.size + copy + prot->overhead_size; 1202 1203 if (!sk_stream_memory_free(sk)) 1204 goto wait_for_sndbuf; 1205 alloc_payload: 1206 ret = tls_alloc_encrypted_msg(sk, required_size); 1207 if (ret) { 1208 if (ret != -ENOSPC) 1209 goto wait_for_memory; 1210 1211 /* Adjust copy according to the amount that was 1212 * actually allocated. The difference is due 1213 * to max sg elements limit 1214 */ 1215 copy -= required_size - msg_pl->sg.size; 1216 full_record = true; 1217 } 1218 1219 sk_msg_page_add(msg_pl, page, copy, offset); 1220 sk_mem_charge(sk, copy); 1221 1222 offset += copy; 1223 size -= copy; 1224 copied += copy; 1225 1226 tls_ctx->pending_open_record_frags = true; 1227 if (full_record || eor || sk_msg_full(msg_pl)) { 1228 ret = bpf_exec_tx_verdict(msg_pl, sk, full_record, 1229 record_type, &copied, flags); 1230 if (ret) { 1231 if (ret == -EINPROGRESS) 1232 num_async++; 1233 else if (ret == -ENOMEM) 1234 goto wait_for_memory; 1235 else if (ret != -EAGAIN) { 1236 if (ret == -ENOSPC) 1237 ret = 0; 1238 goto sendpage_end; 1239 } 1240 } 1241 } 1242 continue; 1243 wait_for_sndbuf: 1244 set_bit(SOCK_NOSPACE, &sk->sk_socket->flags); 1245 wait_for_memory: 1246 ret = sk_stream_wait_memory(sk, &timeo); 1247 if (ret) { 1248 if (ctx->open_rec) 1249 tls_trim_both_msgs(sk, msg_pl->sg.size); 1250 goto sendpage_end; 1251 } 1252 1253 if (ctx->open_rec) 1254 goto alloc_payload; 1255 } 1256 1257 if (num_async) { 1258 /* Transmit if any encryptions have completed */ 1259 if (test_and_clear_bit(BIT_TX_SCHEDULED, &ctx->tx_bitmask)) { 1260 cancel_delayed_work(&ctx->tx_work.work); 1261 tls_tx_records(sk, flags); 1262 } 1263 } 1264 sendpage_end: 1265 ret = sk_stream_error(sk, flags, ret); 1266 return copied > 0 ? copied : ret; 1267 } 1268 1269 int tls_sw_sendpage_locked(struct sock *sk, struct page *page, 1270 int offset, size_t size, int flags) 1271 { 1272 if (flags & ~(MSG_MORE | MSG_DONTWAIT | MSG_NOSIGNAL | 1273 MSG_SENDPAGE_NOTLAST | MSG_SENDPAGE_NOPOLICY | 1274 MSG_NO_SHARED_FRAGS)) 1275 return -EOPNOTSUPP; 1276 1277 return tls_sw_do_sendpage(sk, page, offset, size, flags); 1278 } 1279 1280 int tls_sw_sendpage(struct sock *sk, struct page *page, 1281 int offset, size_t size, int flags) 1282 { 1283 struct tls_context *tls_ctx = tls_get_ctx(sk); 1284 int ret; 1285 1286 if (flags & ~(MSG_MORE | MSG_DONTWAIT | MSG_NOSIGNAL | 1287 MSG_SENDPAGE_NOTLAST | MSG_SENDPAGE_NOPOLICY)) 1288 return -EOPNOTSUPP; 1289 1290 mutex_lock(&tls_ctx->tx_lock); 1291 lock_sock(sk); 1292 ret = tls_sw_do_sendpage(sk, page, offset, size, flags); 1293 release_sock(sk); 1294 mutex_unlock(&tls_ctx->tx_lock); 1295 return ret; 1296 } 1297 1298 static struct sk_buff *tls_wait_data(struct sock *sk, struct sk_psock *psock, 1299 bool nonblock, long timeo, int *err) 1300 { 1301 struct tls_context *tls_ctx = tls_get_ctx(sk); 1302 struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx); 1303 struct sk_buff *skb; 1304 DEFINE_WAIT_FUNC(wait, woken_wake_function); 1305 1306 while (!(skb = ctx->recv_pkt) && sk_psock_queue_empty(psock)) { 1307 if (sk->sk_err) { 1308 *err = sock_error(sk); 1309 return NULL; 1310 } 1311 1312 if (!skb_queue_empty(&sk->sk_receive_queue)) { 1313 __strp_unpause(&ctx->strp); 1314 if (ctx->recv_pkt) 1315 return ctx->recv_pkt; 1316 } 1317 1318 if (sk->sk_shutdown & RCV_SHUTDOWN) 1319 return NULL; 1320 1321 if (sock_flag(sk, SOCK_DONE)) 1322 return NULL; 1323 1324 if (nonblock || !timeo) { 1325 *err = -EAGAIN; 1326 return NULL; 1327 } 1328 1329 add_wait_queue(sk_sleep(sk), &wait); 1330 sk_set_bit(SOCKWQ_ASYNC_WAITDATA, sk); 1331 sk_wait_event(sk, &timeo, 1332 ctx->recv_pkt != skb || 1333 !sk_psock_queue_empty(psock), 1334 &wait); 1335 sk_clear_bit(SOCKWQ_ASYNC_WAITDATA, sk); 1336 remove_wait_queue(sk_sleep(sk), &wait); 1337 1338 /* Handle signals */ 1339 if (signal_pending(current)) { 1340 *err = sock_intr_errno(timeo); 1341 return NULL; 1342 } 1343 } 1344 1345 return skb; 1346 } 1347 1348 static int tls_setup_from_iter(struct iov_iter *from, 1349 int length, int *pages_used, 1350 struct scatterlist *to, 1351 int to_max_pages) 1352 { 1353 int rc = 0, i = 0, num_elem = *pages_used, maxpages; 1354 struct page *pages[MAX_SKB_FRAGS]; 1355 unsigned int size = 0; 1356 ssize_t copied, use; 1357 size_t offset; 1358 1359 while (length > 0) { 1360 i = 0; 1361 maxpages = to_max_pages - num_elem; 1362 if (maxpages == 0) { 1363 rc = -EFAULT; 1364 goto out; 1365 } 1366 copied = iov_iter_get_pages(from, pages, 1367 length, 1368 maxpages, &offset); 1369 if (copied <= 0) { 1370 rc = -EFAULT; 1371 goto out; 1372 } 1373 1374 iov_iter_advance(from, copied); 1375 1376 length -= copied; 1377 size += copied; 1378 while (copied) { 1379 use = min_t(int, copied, PAGE_SIZE - offset); 1380 1381 sg_set_page(&to[num_elem], 1382 pages[i], use, offset); 1383 sg_unmark_end(&to[num_elem]); 1384 /* We do not uncharge memory from this API */ 1385 1386 offset = 0; 1387 copied -= use; 1388 1389 i++; 1390 num_elem++; 1391 } 1392 } 1393 /* Mark the end in the last sg entry if newly added */ 1394 if (num_elem > *pages_used) 1395 sg_mark_end(&to[num_elem - 1]); 1396 out: 1397 if (rc) 1398 iov_iter_revert(from, size); 1399 *pages_used = num_elem; 1400 1401 return rc; 1402 } 1403 1404 /* This function decrypts the input skb into either out_iov or in out_sg 1405 * or in skb buffers itself. The input parameter 'zc' indicates if 1406 * zero-copy mode needs to be tried or not. With zero-copy mode, either 1407 * out_iov or out_sg must be non-NULL. In case both out_iov and out_sg are 1408 * NULL, then the decryption happens inside skb buffers itself, i.e. 1409 * zero-copy gets disabled and 'zc' is updated. 1410 */ 1411 1412 static int decrypt_internal(struct sock *sk, struct sk_buff *skb, 1413 struct iov_iter *out_iov, 1414 struct scatterlist *out_sg, 1415 struct tls_decrypt_arg *darg) 1416 { 1417 struct tls_context *tls_ctx = tls_get_ctx(sk); 1418 struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx); 1419 struct tls_prot_info *prot = &tls_ctx->prot_info; 1420 struct strp_msg *rxm = strp_msg(skb); 1421 struct tls_msg *tlm = tls_msg(skb); 1422 int n_sgin, n_sgout, nsg, mem_size, aead_size, err, pages = 0; 1423 u8 *aad, *iv, *tail, *mem = NULL; 1424 struct aead_request *aead_req; 1425 struct sk_buff *unused; 1426 struct scatterlist *sgin = NULL; 1427 struct scatterlist *sgout = NULL; 1428 const int data_len = rxm->full_len - prot->overhead_size; 1429 int tail_pages = !!prot->tail_size; 1430 int iv_offset = 0; 1431 1432 if (darg->zc && (out_iov || out_sg)) { 1433 if (out_iov) 1434 n_sgout = 1 + tail_pages + 1435 iov_iter_npages_cap(out_iov, INT_MAX, data_len); 1436 else 1437 n_sgout = sg_nents(out_sg); 1438 n_sgin = skb_nsg(skb, rxm->offset + prot->prepend_size, 1439 rxm->full_len - prot->prepend_size); 1440 } else { 1441 n_sgout = 0; 1442 darg->zc = false; 1443 n_sgin = skb_cow_data(skb, 0, &unused); 1444 } 1445 1446 if (n_sgin < 1) 1447 return -EBADMSG; 1448 1449 /* Increment to accommodate AAD */ 1450 n_sgin = n_sgin + 1; 1451 1452 nsg = n_sgin + n_sgout; 1453 1454 aead_size = sizeof(*aead_req) + crypto_aead_reqsize(ctx->aead_recv); 1455 mem_size = aead_size + (nsg * sizeof(struct scatterlist)); 1456 mem_size = mem_size + prot->aad_size; 1457 mem_size = mem_size + MAX_IV_SIZE; 1458 mem_size = mem_size + prot->tail_size; 1459 1460 /* Allocate a single block of memory which contains 1461 * aead_req || sgin[] || sgout[] || aad || iv || tail. 1462 * This order achieves correct alignment for aead_req, sgin, sgout. 1463 */ 1464 mem = kmalloc(mem_size, sk->sk_allocation); 1465 if (!mem) 1466 return -ENOMEM; 1467 1468 /* Segment the allocated memory */ 1469 aead_req = (struct aead_request *)mem; 1470 sgin = (struct scatterlist *)(mem + aead_size); 1471 sgout = sgin + n_sgin; 1472 aad = (u8 *)(sgout + n_sgout); 1473 iv = aad + prot->aad_size; 1474 tail = iv + MAX_IV_SIZE; 1475 1476 /* For CCM based ciphers, first byte of nonce+iv is a constant */ 1477 switch (prot->cipher_type) { 1478 case TLS_CIPHER_AES_CCM_128: 1479 iv[0] = TLS_AES_CCM_IV_B0_BYTE; 1480 iv_offset = 1; 1481 break; 1482 case TLS_CIPHER_SM4_CCM: 1483 iv[0] = TLS_SM4_CCM_IV_B0_BYTE; 1484 iv_offset = 1; 1485 break; 1486 } 1487 1488 /* Prepare IV */ 1489 if (prot->version == TLS_1_3_VERSION || 1490 prot->cipher_type == TLS_CIPHER_CHACHA20_POLY1305) { 1491 memcpy(iv + iv_offset, tls_ctx->rx.iv, 1492 prot->iv_size + prot->salt_size); 1493 } else { 1494 err = skb_copy_bits(skb, rxm->offset + TLS_HEADER_SIZE, 1495 iv + iv_offset + prot->salt_size, 1496 prot->iv_size); 1497 if (err < 0) { 1498 kfree(mem); 1499 return err; 1500 } 1501 memcpy(iv + iv_offset, tls_ctx->rx.iv, prot->salt_size); 1502 } 1503 xor_iv_with_seq(prot, iv + iv_offset, tls_ctx->rx.rec_seq); 1504 1505 /* Prepare AAD */ 1506 tls_make_aad(aad, rxm->full_len - prot->overhead_size + 1507 prot->tail_size, 1508 tls_ctx->rx.rec_seq, tlm->control, prot); 1509 1510 /* Prepare sgin */ 1511 sg_init_table(sgin, n_sgin); 1512 sg_set_buf(&sgin[0], aad, prot->aad_size); 1513 err = skb_to_sgvec(skb, &sgin[1], 1514 rxm->offset + prot->prepend_size, 1515 rxm->full_len - prot->prepend_size); 1516 if (err < 0) { 1517 kfree(mem); 1518 return err; 1519 } 1520 1521 if (n_sgout) { 1522 if (out_iov) { 1523 sg_init_table(sgout, n_sgout); 1524 sg_set_buf(&sgout[0], aad, prot->aad_size); 1525 1526 err = tls_setup_from_iter(out_iov, data_len, 1527 &pages, &sgout[1], 1528 (n_sgout - 1 - tail_pages)); 1529 if (err < 0) 1530 goto fallback_to_reg_recv; 1531 1532 if (prot->tail_size) { 1533 sg_unmark_end(&sgout[pages]); 1534 sg_set_buf(&sgout[pages + 1], tail, 1535 prot->tail_size); 1536 sg_mark_end(&sgout[pages + 1]); 1537 } 1538 } else if (out_sg) { 1539 memcpy(sgout, out_sg, n_sgout * sizeof(*sgout)); 1540 } else { 1541 goto fallback_to_reg_recv; 1542 } 1543 } else { 1544 fallback_to_reg_recv: 1545 sgout = sgin; 1546 pages = 0; 1547 darg->zc = false; 1548 } 1549 1550 /* Prepare and submit AEAD request */ 1551 err = tls_do_decryption(sk, skb, sgin, sgout, iv, 1552 data_len + prot->tail_size, aead_req, darg); 1553 if (darg->async) 1554 return 0; 1555 1556 if (prot->tail_size) 1557 darg->tail = *tail; 1558 1559 /* Release the pages in case iov was mapped to pages */ 1560 for (; pages > 0; pages--) 1561 put_page(sg_page(&sgout[pages])); 1562 1563 kfree(mem); 1564 return err; 1565 } 1566 1567 static int decrypt_skb_update(struct sock *sk, struct sk_buff *skb, 1568 struct iov_iter *dest, 1569 struct tls_decrypt_arg *darg) 1570 { 1571 struct tls_context *tls_ctx = tls_get_ctx(sk); 1572 struct tls_prot_info *prot = &tls_ctx->prot_info; 1573 struct strp_msg *rxm = strp_msg(skb); 1574 struct tls_msg *tlm = tls_msg(skb); 1575 int pad, err; 1576 1577 if (tlm->decrypted) { 1578 darg->zc = false; 1579 darg->async = false; 1580 return 0; 1581 } 1582 1583 if (tls_ctx->rx_conf == TLS_HW) { 1584 err = tls_device_decrypted(sk, tls_ctx, skb, rxm); 1585 if (err < 0) 1586 return err; 1587 if (err > 0) { 1588 tlm->decrypted = 1; 1589 darg->zc = false; 1590 darg->async = false; 1591 goto decrypt_done; 1592 } 1593 } 1594 1595 err = decrypt_internal(sk, skb, dest, NULL, darg); 1596 if (err < 0) 1597 return err; 1598 if (darg->async) 1599 goto decrypt_next; 1600 /* If opportunistic TLS 1.3 ZC failed retry without ZC */ 1601 if (unlikely(darg->zc && prot->version == TLS_1_3_VERSION && 1602 darg->tail != TLS_RECORD_TYPE_DATA)) { 1603 darg->zc = false; 1604 TLS_INC_STATS(sock_net(sk), LINUX_MIN_TLSDECRYPTRETRY); 1605 return decrypt_skb_update(sk, skb, dest, darg); 1606 } 1607 1608 decrypt_done: 1609 pad = tls_padding_length(prot, skb, darg); 1610 if (pad < 0) 1611 return pad; 1612 1613 rxm->full_len -= pad; 1614 rxm->offset += prot->prepend_size; 1615 rxm->full_len -= prot->overhead_size; 1616 tlm->decrypted = 1; 1617 decrypt_next: 1618 tls_advance_record_sn(sk, prot, &tls_ctx->rx); 1619 1620 return 0; 1621 } 1622 1623 int decrypt_skb(struct sock *sk, struct sk_buff *skb, 1624 struct scatterlist *sgout) 1625 { 1626 struct tls_decrypt_arg darg = { .zc = true, }; 1627 1628 return decrypt_internal(sk, skb, NULL, sgout, &darg); 1629 } 1630 1631 static int tls_record_content_type(struct msghdr *msg, struct tls_msg *tlm, 1632 u8 *control) 1633 { 1634 int err; 1635 1636 if (!*control) { 1637 *control = tlm->control; 1638 if (!*control) 1639 return -EBADMSG; 1640 1641 err = put_cmsg(msg, SOL_TLS, TLS_GET_RECORD_TYPE, 1642 sizeof(*control), control); 1643 if (*control != TLS_RECORD_TYPE_DATA) { 1644 if (err || msg->msg_flags & MSG_CTRUNC) 1645 return -EIO; 1646 } 1647 } else if (*control != tlm->control) { 1648 return 0; 1649 } 1650 1651 return 1; 1652 } 1653 1654 /* This function traverses the rx_list in tls receive context to copies the 1655 * decrypted records into the buffer provided by caller zero copy is not 1656 * true. Further, the records are removed from the rx_list if it is not a peek 1657 * case and the record has been consumed completely. 1658 */ 1659 static int process_rx_list(struct tls_sw_context_rx *ctx, 1660 struct msghdr *msg, 1661 u8 *control, 1662 size_t skip, 1663 size_t len, 1664 bool zc, 1665 bool is_peek) 1666 { 1667 struct sk_buff *skb = skb_peek(&ctx->rx_list); 1668 struct tls_msg *tlm; 1669 ssize_t copied = 0; 1670 int err; 1671 1672 while (skip && skb) { 1673 struct strp_msg *rxm = strp_msg(skb); 1674 tlm = tls_msg(skb); 1675 1676 err = tls_record_content_type(msg, tlm, control); 1677 if (err <= 0) 1678 goto out; 1679 1680 if (skip < rxm->full_len) 1681 break; 1682 1683 skip = skip - rxm->full_len; 1684 skb = skb_peek_next(skb, &ctx->rx_list); 1685 } 1686 1687 while (len && skb) { 1688 struct sk_buff *next_skb; 1689 struct strp_msg *rxm = strp_msg(skb); 1690 int chunk = min_t(unsigned int, rxm->full_len - skip, len); 1691 1692 tlm = tls_msg(skb); 1693 1694 err = tls_record_content_type(msg, tlm, control); 1695 if (err <= 0) 1696 goto out; 1697 1698 if (!zc || (rxm->full_len - skip) > len) { 1699 err = skb_copy_datagram_msg(skb, rxm->offset + skip, 1700 msg, chunk); 1701 if (err < 0) 1702 goto out; 1703 } 1704 1705 len = len - chunk; 1706 copied = copied + chunk; 1707 1708 /* Consume the data from record if it is non-peek case*/ 1709 if (!is_peek) { 1710 rxm->offset = rxm->offset + chunk; 1711 rxm->full_len = rxm->full_len - chunk; 1712 1713 /* Return if there is unconsumed data in the record */ 1714 if (rxm->full_len - skip) 1715 break; 1716 } 1717 1718 /* The remaining skip-bytes must lie in 1st record in rx_list. 1719 * So from the 2nd record, 'skip' should be 0. 1720 */ 1721 skip = 0; 1722 1723 if (msg) 1724 msg->msg_flags |= MSG_EOR; 1725 1726 next_skb = skb_peek_next(skb, &ctx->rx_list); 1727 1728 if (!is_peek) { 1729 __skb_unlink(skb, &ctx->rx_list); 1730 consume_skb(skb); 1731 } 1732 1733 skb = next_skb; 1734 } 1735 err = 0; 1736 1737 out: 1738 return copied ? : err; 1739 } 1740 1741 static void 1742 tls_read_flush_backlog(struct sock *sk, struct tls_prot_info *prot, 1743 size_t len_left, size_t decrypted, ssize_t done, 1744 size_t *flushed_at) 1745 { 1746 size_t max_rec; 1747 1748 if (len_left <= decrypted) 1749 return; 1750 1751 max_rec = prot->overhead_size - prot->tail_size + TLS_MAX_PAYLOAD_SIZE; 1752 if (done - *flushed_at < SZ_128K && tcp_inq(sk) > max_rec) 1753 return; 1754 1755 *flushed_at = done; 1756 sk_flush_backlog(sk); 1757 } 1758 1759 int tls_sw_recvmsg(struct sock *sk, 1760 struct msghdr *msg, 1761 size_t len, 1762 int flags, 1763 int *addr_len) 1764 { 1765 struct tls_context *tls_ctx = tls_get_ctx(sk); 1766 struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx); 1767 struct tls_prot_info *prot = &tls_ctx->prot_info; 1768 struct sk_psock *psock; 1769 unsigned char control = 0; 1770 ssize_t decrypted = 0; 1771 size_t flushed_at = 0; 1772 struct strp_msg *rxm; 1773 struct tls_msg *tlm; 1774 struct sk_buff *skb; 1775 ssize_t copied = 0; 1776 bool async = false; 1777 int target, err = 0; 1778 long timeo; 1779 bool is_kvec = iov_iter_is_kvec(&msg->msg_iter); 1780 bool is_peek = flags & MSG_PEEK; 1781 bool bpf_strp_enabled; 1782 bool zc_capable; 1783 1784 if (unlikely(flags & MSG_ERRQUEUE)) 1785 return sock_recv_errqueue(sk, msg, len, SOL_IP, IP_RECVERR); 1786 1787 psock = sk_psock_get(sk); 1788 lock_sock(sk); 1789 bpf_strp_enabled = sk_psock_strp_enabled(psock); 1790 1791 /* If crypto failed the connection is broken */ 1792 err = ctx->async_wait.err; 1793 if (err) 1794 goto end; 1795 1796 /* Process pending decrypted records. It must be non-zero-copy */ 1797 err = process_rx_list(ctx, msg, &control, 0, len, false, is_peek); 1798 if (err < 0) 1799 goto end; 1800 1801 copied = err; 1802 if (len <= copied) 1803 goto end; 1804 1805 target = sock_rcvlowat(sk, flags & MSG_WAITALL, len); 1806 len = len - copied; 1807 timeo = sock_rcvtimeo(sk, flags & MSG_DONTWAIT); 1808 1809 zc_capable = !bpf_strp_enabled && !is_kvec && !is_peek && 1810 ctx->zc_capable; 1811 decrypted = 0; 1812 while (len && (decrypted + copied < target || ctx->recv_pkt)) { 1813 struct tls_decrypt_arg darg = {}; 1814 int to_decrypt, chunk; 1815 1816 skb = tls_wait_data(sk, psock, flags & MSG_DONTWAIT, timeo, &err); 1817 if (!skb) { 1818 if (psock) { 1819 chunk = sk_msg_recvmsg(sk, psock, msg, len, 1820 flags); 1821 if (chunk > 0) 1822 goto leave_on_list; 1823 } 1824 goto recv_end; 1825 } 1826 1827 rxm = strp_msg(skb); 1828 tlm = tls_msg(skb); 1829 1830 to_decrypt = rxm->full_len - prot->overhead_size; 1831 1832 if (zc_capable && to_decrypt <= len && 1833 tlm->control == TLS_RECORD_TYPE_DATA) 1834 darg.zc = true; 1835 1836 /* Do not use async mode if record is non-data */ 1837 if (tlm->control == TLS_RECORD_TYPE_DATA && !bpf_strp_enabled) 1838 darg.async = ctx->async_capable; 1839 else 1840 darg.async = false; 1841 1842 err = decrypt_skb_update(sk, skb, &msg->msg_iter, &darg); 1843 if (err < 0) { 1844 tls_err_abort(sk, -EBADMSG); 1845 goto recv_end; 1846 } 1847 1848 async |= darg.async; 1849 1850 /* If the type of records being processed is not known yet, 1851 * set it to record type just dequeued. If it is already known, 1852 * but does not match the record type just dequeued, go to end. 1853 * We always get record type here since for tls1.2, record type 1854 * is known just after record is dequeued from stream parser. 1855 * For tls1.3, we disable async. 1856 */ 1857 err = tls_record_content_type(msg, tlm, &control); 1858 if (err <= 0) 1859 goto recv_end; 1860 1861 /* periodically flush backlog, and feed strparser */ 1862 tls_read_flush_backlog(sk, prot, len, to_decrypt, 1863 decrypted + copied, &flushed_at); 1864 1865 ctx->recv_pkt = NULL; 1866 __strp_unpause(&ctx->strp); 1867 __skb_queue_tail(&ctx->rx_list, skb); 1868 1869 if (async) { 1870 /* TLS 1.2-only, to_decrypt must be text length */ 1871 chunk = min_t(int, to_decrypt, len); 1872 leave_on_list: 1873 decrypted += chunk; 1874 len -= chunk; 1875 continue; 1876 } 1877 /* TLS 1.3 may have updated the length by more than overhead */ 1878 chunk = rxm->full_len; 1879 1880 if (!darg.zc) { 1881 bool partially_consumed = chunk > len; 1882 1883 if (bpf_strp_enabled) { 1884 /* BPF may try to queue the skb */ 1885 __skb_unlink(skb, &ctx->rx_list); 1886 err = sk_psock_tls_strp_read(psock, skb); 1887 if (err != __SK_PASS) { 1888 rxm->offset = rxm->offset + rxm->full_len; 1889 rxm->full_len = 0; 1890 if (err == __SK_DROP) 1891 consume_skb(skb); 1892 continue; 1893 } 1894 __skb_queue_tail(&ctx->rx_list, skb); 1895 } 1896 1897 if (partially_consumed) 1898 chunk = len; 1899 1900 err = skb_copy_datagram_msg(skb, rxm->offset, 1901 msg, chunk); 1902 if (err < 0) 1903 goto recv_end; 1904 1905 if (is_peek) 1906 goto leave_on_list; 1907 1908 if (partially_consumed) { 1909 rxm->offset += chunk; 1910 rxm->full_len -= chunk; 1911 goto leave_on_list; 1912 } 1913 } 1914 1915 decrypted += chunk; 1916 len -= chunk; 1917 1918 __skb_unlink(skb, &ctx->rx_list); 1919 consume_skb(skb); 1920 1921 /* Return full control message to userspace before trying 1922 * to parse another message type 1923 */ 1924 msg->msg_flags |= MSG_EOR; 1925 if (control != TLS_RECORD_TYPE_DATA) 1926 break; 1927 } 1928 1929 recv_end: 1930 if (async) { 1931 int ret, pending; 1932 1933 /* Wait for all previously submitted records to be decrypted */ 1934 spin_lock_bh(&ctx->decrypt_compl_lock); 1935 reinit_completion(&ctx->async_wait.completion); 1936 pending = atomic_read(&ctx->decrypt_pending); 1937 spin_unlock_bh(&ctx->decrypt_compl_lock); 1938 if (pending) { 1939 ret = crypto_wait_req(-EINPROGRESS, &ctx->async_wait); 1940 if (ret) { 1941 if (err >= 0 || err == -EINPROGRESS) 1942 err = ret; 1943 decrypted = 0; 1944 goto end; 1945 } 1946 } 1947 1948 /* Drain records from the rx_list & copy if required */ 1949 if (is_peek || is_kvec) 1950 err = process_rx_list(ctx, msg, &control, copied, 1951 decrypted, false, is_peek); 1952 else 1953 err = process_rx_list(ctx, msg, &control, 0, 1954 decrypted, true, is_peek); 1955 decrypted = max(err, 0); 1956 } 1957 1958 copied += decrypted; 1959 1960 end: 1961 release_sock(sk); 1962 if (psock) 1963 sk_psock_put(sk, psock); 1964 return copied ? : err; 1965 } 1966 1967 ssize_t tls_sw_splice_read(struct socket *sock, loff_t *ppos, 1968 struct pipe_inode_info *pipe, 1969 size_t len, unsigned int flags) 1970 { 1971 struct tls_context *tls_ctx = tls_get_ctx(sock->sk); 1972 struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx); 1973 struct strp_msg *rxm = NULL; 1974 struct sock *sk = sock->sk; 1975 struct tls_msg *tlm; 1976 struct sk_buff *skb; 1977 ssize_t copied = 0; 1978 bool from_queue; 1979 int err = 0; 1980 long timeo; 1981 int chunk; 1982 1983 lock_sock(sk); 1984 1985 timeo = sock_rcvtimeo(sk, flags & SPLICE_F_NONBLOCK); 1986 1987 from_queue = !skb_queue_empty(&ctx->rx_list); 1988 if (from_queue) { 1989 skb = __skb_dequeue(&ctx->rx_list); 1990 } else { 1991 struct tls_decrypt_arg darg = {}; 1992 1993 skb = tls_wait_data(sk, NULL, flags & SPLICE_F_NONBLOCK, timeo, 1994 &err); 1995 if (!skb) 1996 goto splice_read_end; 1997 1998 err = decrypt_skb_update(sk, skb, NULL, &darg); 1999 if (err < 0) { 2000 tls_err_abort(sk, -EBADMSG); 2001 goto splice_read_end; 2002 } 2003 } 2004 2005 rxm = strp_msg(skb); 2006 tlm = tls_msg(skb); 2007 2008 /* splice does not support reading control messages */ 2009 if (tlm->control != TLS_RECORD_TYPE_DATA) { 2010 err = -EINVAL; 2011 goto splice_read_end; 2012 } 2013 2014 chunk = min_t(unsigned int, rxm->full_len, len); 2015 copied = skb_splice_bits(skb, sk, rxm->offset, pipe, chunk, flags); 2016 if (copied < 0) 2017 goto splice_read_end; 2018 2019 if (!from_queue) { 2020 ctx->recv_pkt = NULL; 2021 __strp_unpause(&ctx->strp); 2022 } 2023 if (chunk < rxm->full_len) { 2024 __skb_queue_head(&ctx->rx_list, skb); 2025 rxm->offset += len; 2026 rxm->full_len -= len; 2027 } else { 2028 consume_skb(skb); 2029 } 2030 2031 splice_read_end: 2032 release_sock(sk); 2033 return copied ? : err; 2034 } 2035 2036 bool tls_sw_sock_is_readable(struct sock *sk) 2037 { 2038 struct tls_context *tls_ctx = tls_get_ctx(sk); 2039 struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx); 2040 bool ingress_empty = true; 2041 struct sk_psock *psock; 2042 2043 rcu_read_lock(); 2044 psock = sk_psock(sk); 2045 if (psock) 2046 ingress_empty = list_empty(&psock->ingress_msg); 2047 rcu_read_unlock(); 2048 2049 return !ingress_empty || ctx->recv_pkt || 2050 !skb_queue_empty(&ctx->rx_list); 2051 } 2052 2053 static int tls_read_size(struct strparser *strp, struct sk_buff *skb) 2054 { 2055 struct tls_context *tls_ctx = tls_get_ctx(strp->sk); 2056 struct tls_prot_info *prot = &tls_ctx->prot_info; 2057 char header[TLS_HEADER_SIZE + MAX_IV_SIZE]; 2058 struct strp_msg *rxm = strp_msg(skb); 2059 struct tls_msg *tlm = tls_msg(skb); 2060 size_t cipher_overhead; 2061 size_t data_len = 0; 2062 int ret; 2063 2064 /* Verify that we have a full TLS header, or wait for more data */ 2065 if (rxm->offset + prot->prepend_size > skb->len) 2066 return 0; 2067 2068 /* Sanity-check size of on-stack buffer. */ 2069 if (WARN_ON(prot->prepend_size > sizeof(header))) { 2070 ret = -EINVAL; 2071 goto read_failure; 2072 } 2073 2074 /* Linearize header to local buffer */ 2075 ret = skb_copy_bits(skb, rxm->offset, header, prot->prepend_size); 2076 if (ret < 0) 2077 goto read_failure; 2078 2079 tlm->decrypted = 0; 2080 tlm->control = header[0]; 2081 2082 data_len = ((header[4] & 0xFF) | (header[3] << 8)); 2083 2084 cipher_overhead = prot->tag_size; 2085 if (prot->version != TLS_1_3_VERSION && 2086 prot->cipher_type != TLS_CIPHER_CHACHA20_POLY1305) 2087 cipher_overhead += prot->iv_size; 2088 2089 if (data_len > TLS_MAX_PAYLOAD_SIZE + cipher_overhead + 2090 prot->tail_size) { 2091 ret = -EMSGSIZE; 2092 goto read_failure; 2093 } 2094 if (data_len < cipher_overhead) { 2095 ret = -EBADMSG; 2096 goto read_failure; 2097 } 2098 2099 /* Note that both TLS1.3 and TLS1.2 use TLS_1_2 version here */ 2100 if (header[1] != TLS_1_2_VERSION_MINOR || 2101 header[2] != TLS_1_2_VERSION_MAJOR) { 2102 ret = -EINVAL; 2103 goto read_failure; 2104 } 2105 2106 tls_device_rx_resync_new_rec(strp->sk, data_len + TLS_HEADER_SIZE, 2107 TCP_SKB_CB(skb)->seq + rxm->offset); 2108 return data_len + TLS_HEADER_SIZE; 2109 2110 read_failure: 2111 tls_err_abort(strp->sk, ret); 2112 2113 return ret; 2114 } 2115 2116 static void tls_queue(struct strparser *strp, struct sk_buff *skb) 2117 { 2118 struct tls_context *tls_ctx = tls_get_ctx(strp->sk); 2119 struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx); 2120 2121 ctx->recv_pkt = skb; 2122 strp_pause(strp); 2123 2124 ctx->saved_data_ready(strp->sk); 2125 } 2126 2127 static void tls_data_ready(struct sock *sk) 2128 { 2129 struct tls_context *tls_ctx = tls_get_ctx(sk); 2130 struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx); 2131 struct sk_psock *psock; 2132 2133 strp_data_ready(&ctx->strp); 2134 2135 psock = sk_psock_get(sk); 2136 if (psock) { 2137 if (!list_empty(&psock->ingress_msg)) 2138 ctx->saved_data_ready(sk); 2139 sk_psock_put(sk, psock); 2140 } 2141 } 2142 2143 void tls_sw_cancel_work_tx(struct tls_context *tls_ctx) 2144 { 2145 struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx); 2146 2147 set_bit(BIT_TX_CLOSING, &ctx->tx_bitmask); 2148 set_bit(BIT_TX_SCHEDULED, &ctx->tx_bitmask); 2149 cancel_delayed_work_sync(&ctx->tx_work.work); 2150 } 2151 2152 void tls_sw_release_resources_tx(struct sock *sk) 2153 { 2154 struct tls_context *tls_ctx = tls_get_ctx(sk); 2155 struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx); 2156 struct tls_rec *rec, *tmp; 2157 int pending; 2158 2159 /* Wait for any pending async encryptions to complete */ 2160 spin_lock_bh(&ctx->encrypt_compl_lock); 2161 ctx->async_notify = true; 2162 pending = atomic_read(&ctx->encrypt_pending); 2163 spin_unlock_bh(&ctx->encrypt_compl_lock); 2164 2165 if (pending) 2166 crypto_wait_req(-EINPROGRESS, &ctx->async_wait); 2167 2168 tls_tx_records(sk, -1); 2169 2170 /* Free up un-sent records in tx_list. First, free 2171 * the partially sent record if any at head of tx_list. 2172 */ 2173 if (tls_ctx->partially_sent_record) { 2174 tls_free_partial_record(sk, tls_ctx); 2175 rec = list_first_entry(&ctx->tx_list, 2176 struct tls_rec, list); 2177 list_del(&rec->list); 2178 sk_msg_free(sk, &rec->msg_plaintext); 2179 kfree(rec); 2180 } 2181 2182 list_for_each_entry_safe(rec, tmp, &ctx->tx_list, list) { 2183 list_del(&rec->list); 2184 sk_msg_free(sk, &rec->msg_encrypted); 2185 sk_msg_free(sk, &rec->msg_plaintext); 2186 kfree(rec); 2187 } 2188 2189 crypto_free_aead(ctx->aead_send); 2190 tls_free_open_rec(sk); 2191 } 2192 2193 void tls_sw_free_ctx_tx(struct tls_context *tls_ctx) 2194 { 2195 struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx); 2196 2197 kfree(ctx); 2198 } 2199 2200 void tls_sw_release_resources_rx(struct sock *sk) 2201 { 2202 struct tls_context *tls_ctx = tls_get_ctx(sk); 2203 struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx); 2204 2205 kfree(tls_ctx->rx.rec_seq); 2206 kfree(tls_ctx->rx.iv); 2207 2208 if (ctx->aead_recv) { 2209 kfree_skb(ctx->recv_pkt); 2210 ctx->recv_pkt = NULL; 2211 __skb_queue_purge(&ctx->rx_list); 2212 crypto_free_aead(ctx->aead_recv); 2213 strp_stop(&ctx->strp); 2214 /* If tls_sw_strparser_arm() was not called (cleanup paths) 2215 * we still want to strp_stop(), but sk->sk_data_ready was 2216 * never swapped. 2217 */ 2218 if (ctx->saved_data_ready) { 2219 write_lock_bh(&sk->sk_callback_lock); 2220 sk->sk_data_ready = ctx->saved_data_ready; 2221 write_unlock_bh(&sk->sk_callback_lock); 2222 } 2223 } 2224 } 2225 2226 void tls_sw_strparser_done(struct tls_context *tls_ctx) 2227 { 2228 struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx); 2229 2230 strp_done(&ctx->strp); 2231 } 2232 2233 void tls_sw_free_ctx_rx(struct tls_context *tls_ctx) 2234 { 2235 struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx); 2236 2237 kfree(ctx); 2238 } 2239 2240 void tls_sw_free_resources_rx(struct sock *sk) 2241 { 2242 struct tls_context *tls_ctx = tls_get_ctx(sk); 2243 2244 tls_sw_release_resources_rx(sk); 2245 tls_sw_free_ctx_rx(tls_ctx); 2246 } 2247 2248 /* The work handler to transmitt the encrypted records in tx_list */ 2249 static void tx_work_handler(struct work_struct *work) 2250 { 2251 struct delayed_work *delayed_work = to_delayed_work(work); 2252 struct tx_work *tx_work = container_of(delayed_work, 2253 struct tx_work, work); 2254 struct sock *sk = tx_work->sk; 2255 struct tls_context *tls_ctx = tls_get_ctx(sk); 2256 struct tls_sw_context_tx *ctx; 2257 2258 if (unlikely(!tls_ctx)) 2259 return; 2260 2261 ctx = tls_sw_ctx_tx(tls_ctx); 2262 if (test_bit(BIT_TX_CLOSING, &ctx->tx_bitmask)) 2263 return; 2264 2265 if (!test_and_clear_bit(BIT_TX_SCHEDULED, &ctx->tx_bitmask)) 2266 return; 2267 mutex_lock(&tls_ctx->tx_lock); 2268 lock_sock(sk); 2269 tls_tx_records(sk, -1); 2270 release_sock(sk); 2271 mutex_unlock(&tls_ctx->tx_lock); 2272 } 2273 2274 void tls_sw_write_space(struct sock *sk, struct tls_context *ctx) 2275 { 2276 struct tls_sw_context_tx *tx_ctx = tls_sw_ctx_tx(ctx); 2277 2278 /* Schedule the transmission if tx list is ready */ 2279 if (is_tx_ready(tx_ctx) && 2280 !test_and_set_bit(BIT_TX_SCHEDULED, &tx_ctx->tx_bitmask)) 2281 schedule_delayed_work(&tx_ctx->tx_work.work, 0); 2282 } 2283 2284 void tls_sw_strparser_arm(struct sock *sk, struct tls_context *tls_ctx) 2285 { 2286 struct tls_sw_context_rx *rx_ctx = tls_sw_ctx_rx(tls_ctx); 2287 2288 write_lock_bh(&sk->sk_callback_lock); 2289 rx_ctx->saved_data_ready = sk->sk_data_ready; 2290 sk->sk_data_ready = tls_data_ready; 2291 write_unlock_bh(&sk->sk_callback_lock); 2292 2293 strp_check_rcv(&rx_ctx->strp); 2294 } 2295 2296 void tls_update_rx_zc_capable(struct tls_context *tls_ctx) 2297 { 2298 struct tls_sw_context_rx *rx_ctx = tls_sw_ctx_rx(tls_ctx); 2299 2300 rx_ctx->zc_capable = tls_ctx->rx_no_pad || 2301 tls_ctx->prot_info.version != TLS_1_3_VERSION; 2302 } 2303 2304 int tls_set_sw_offload(struct sock *sk, struct tls_context *ctx, int tx) 2305 { 2306 struct tls_context *tls_ctx = tls_get_ctx(sk); 2307 struct tls_prot_info *prot = &tls_ctx->prot_info; 2308 struct tls_crypto_info *crypto_info; 2309 struct tls_sw_context_tx *sw_ctx_tx = NULL; 2310 struct tls_sw_context_rx *sw_ctx_rx = NULL; 2311 struct cipher_context *cctx; 2312 struct crypto_aead **aead; 2313 struct strp_callbacks cb; 2314 u16 nonce_size, tag_size, iv_size, rec_seq_size, salt_size; 2315 struct crypto_tfm *tfm; 2316 char *iv, *rec_seq, *key, *salt, *cipher_name; 2317 size_t keysize; 2318 int rc = 0; 2319 2320 if (!ctx) { 2321 rc = -EINVAL; 2322 goto out; 2323 } 2324 2325 if (tx) { 2326 if (!ctx->priv_ctx_tx) { 2327 sw_ctx_tx = kzalloc(sizeof(*sw_ctx_tx), GFP_KERNEL); 2328 if (!sw_ctx_tx) { 2329 rc = -ENOMEM; 2330 goto out; 2331 } 2332 ctx->priv_ctx_tx = sw_ctx_tx; 2333 } else { 2334 sw_ctx_tx = 2335 (struct tls_sw_context_tx *)ctx->priv_ctx_tx; 2336 } 2337 } else { 2338 if (!ctx->priv_ctx_rx) { 2339 sw_ctx_rx = kzalloc(sizeof(*sw_ctx_rx), GFP_KERNEL); 2340 if (!sw_ctx_rx) { 2341 rc = -ENOMEM; 2342 goto out; 2343 } 2344 ctx->priv_ctx_rx = sw_ctx_rx; 2345 } else { 2346 sw_ctx_rx = 2347 (struct tls_sw_context_rx *)ctx->priv_ctx_rx; 2348 } 2349 } 2350 2351 if (tx) { 2352 crypto_init_wait(&sw_ctx_tx->async_wait); 2353 spin_lock_init(&sw_ctx_tx->encrypt_compl_lock); 2354 crypto_info = &ctx->crypto_send.info; 2355 cctx = &ctx->tx; 2356 aead = &sw_ctx_tx->aead_send; 2357 INIT_LIST_HEAD(&sw_ctx_tx->tx_list); 2358 INIT_DELAYED_WORK(&sw_ctx_tx->tx_work.work, tx_work_handler); 2359 sw_ctx_tx->tx_work.sk = sk; 2360 } else { 2361 crypto_init_wait(&sw_ctx_rx->async_wait); 2362 spin_lock_init(&sw_ctx_rx->decrypt_compl_lock); 2363 crypto_info = &ctx->crypto_recv.info; 2364 cctx = &ctx->rx; 2365 skb_queue_head_init(&sw_ctx_rx->rx_list); 2366 aead = &sw_ctx_rx->aead_recv; 2367 } 2368 2369 switch (crypto_info->cipher_type) { 2370 case TLS_CIPHER_AES_GCM_128: { 2371 struct tls12_crypto_info_aes_gcm_128 *gcm_128_info; 2372 2373 gcm_128_info = (void *)crypto_info; 2374 nonce_size = TLS_CIPHER_AES_GCM_128_IV_SIZE; 2375 tag_size = TLS_CIPHER_AES_GCM_128_TAG_SIZE; 2376 iv_size = TLS_CIPHER_AES_GCM_128_IV_SIZE; 2377 iv = gcm_128_info->iv; 2378 rec_seq_size = TLS_CIPHER_AES_GCM_128_REC_SEQ_SIZE; 2379 rec_seq = gcm_128_info->rec_seq; 2380 keysize = TLS_CIPHER_AES_GCM_128_KEY_SIZE; 2381 key = gcm_128_info->key; 2382 salt = gcm_128_info->salt; 2383 salt_size = TLS_CIPHER_AES_GCM_128_SALT_SIZE; 2384 cipher_name = "gcm(aes)"; 2385 break; 2386 } 2387 case TLS_CIPHER_AES_GCM_256: { 2388 struct tls12_crypto_info_aes_gcm_256 *gcm_256_info; 2389 2390 gcm_256_info = (void *)crypto_info; 2391 nonce_size = TLS_CIPHER_AES_GCM_256_IV_SIZE; 2392 tag_size = TLS_CIPHER_AES_GCM_256_TAG_SIZE; 2393 iv_size = TLS_CIPHER_AES_GCM_256_IV_SIZE; 2394 iv = gcm_256_info->iv; 2395 rec_seq_size = TLS_CIPHER_AES_GCM_256_REC_SEQ_SIZE; 2396 rec_seq = gcm_256_info->rec_seq; 2397 keysize = TLS_CIPHER_AES_GCM_256_KEY_SIZE; 2398 key = gcm_256_info->key; 2399 salt = gcm_256_info->salt; 2400 salt_size = TLS_CIPHER_AES_GCM_256_SALT_SIZE; 2401 cipher_name = "gcm(aes)"; 2402 break; 2403 } 2404 case TLS_CIPHER_AES_CCM_128: { 2405 struct tls12_crypto_info_aes_ccm_128 *ccm_128_info; 2406 2407 ccm_128_info = (void *)crypto_info; 2408 nonce_size = TLS_CIPHER_AES_CCM_128_IV_SIZE; 2409 tag_size = TLS_CIPHER_AES_CCM_128_TAG_SIZE; 2410 iv_size = TLS_CIPHER_AES_CCM_128_IV_SIZE; 2411 iv = ccm_128_info->iv; 2412 rec_seq_size = TLS_CIPHER_AES_CCM_128_REC_SEQ_SIZE; 2413 rec_seq = ccm_128_info->rec_seq; 2414 keysize = TLS_CIPHER_AES_CCM_128_KEY_SIZE; 2415 key = ccm_128_info->key; 2416 salt = ccm_128_info->salt; 2417 salt_size = TLS_CIPHER_AES_CCM_128_SALT_SIZE; 2418 cipher_name = "ccm(aes)"; 2419 break; 2420 } 2421 case TLS_CIPHER_CHACHA20_POLY1305: { 2422 struct tls12_crypto_info_chacha20_poly1305 *chacha20_poly1305_info; 2423 2424 chacha20_poly1305_info = (void *)crypto_info; 2425 nonce_size = 0; 2426 tag_size = TLS_CIPHER_CHACHA20_POLY1305_TAG_SIZE; 2427 iv_size = TLS_CIPHER_CHACHA20_POLY1305_IV_SIZE; 2428 iv = chacha20_poly1305_info->iv; 2429 rec_seq_size = TLS_CIPHER_CHACHA20_POLY1305_REC_SEQ_SIZE; 2430 rec_seq = chacha20_poly1305_info->rec_seq; 2431 keysize = TLS_CIPHER_CHACHA20_POLY1305_KEY_SIZE; 2432 key = chacha20_poly1305_info->key; 2433 salt = chacha20_poly1305_info->salt; 2434 salt_size = TLS_CIPHER_CHACHA20_POLY1305_SALT_SIZE; 2435 cipher_name = "rfc7539(chacha20,poly1305)"; 2436 break; 2437 } 2438 case TLS_CIPHER_SM4_GCM: { 2439 struct tls12_crypto_info_sm4_gcm *sm4_gcm_info; 2440 2441 sm4_gcm_info = (void *)crypto_info; 2442 nonce_size = TLS_CIPHER_SM4_GCM_IV_SIZE; 2443 tag_size = TLS_CIPHER_SM4_GCM_TAG_SIZE; 2444 iv_size = TLS_CIPHER_SM4_GCM_IV_SIZE; 2445 iv = sm4_gcm_info->iv; 2446 rec_seq_size = TLS_CIPHER_SM4_GCM_REC_SEQ_SIZE; 2447 rec_seq = sm4_gcm_info->rec_seq; 2448 keysize = TLS_CIPHER_SM4_GCM_KEY_SIZE; 2449 key = sm4_gcm_info->key; 2450 salt = sm4_gcm_info->salt; 2451 salt_size = TLS_CIPHER_SM4_GCM_SALT_SIZE; 2452 cipher_name = "gcm(sm4)"; 2453 break; 2454 } 2455 case TLS_CIPHER_SM4_CCM: { 2456 struct tls12_crypto_info_sm4_ccm *sm4_ccm_info; 2457 2458 sm4_ccm_info = (void *)crypto_info; 2459 nonce_size = TLS_CIPHER_SM4_CCM_IV_SIZE; 2460 tag_size = TLS_CIPHER_SM4_CCM_TAG_SIZE; 2461 iv_size = TLS_CIPHER_SM4_CCM_IV_SIZE; 2462 iv = sm4_ccm_info->iv; 2463 rec_seq_size = TLS_CIPHER_SM4_CCM_REC_SEQ_SIZE; 2464 rec_seq = sm4_ccm_info->rec_seq; 2465 keysize = TLS_CIPHER_SM4_CCM_KEY_SIZE; 2466 key = sm4_ccm_info->key; 2467 salt = sm4_ccm_info->salt; 2468 salt_size = TLS_CIPHER_SM4_CCM_SALT_SIZE; 2469 cipher_name = "ccm(sm4)"; 2470 break; 2471 } 2472 default: 2473 rc = -EINVAL; 2474 goto free_priv; 2475 } 2476 2477 /* Sanity-check the sizes for stack allocations. */ 2478 if (iv_size > MAX_IV_SIZE || nonce_size > MAX_IV_SIZE || 2479 rec_seq_size > TLS_MAX_REC_SEQ_SIZE || tag_size != TLS_TAG_SIZE) { 2480 rc = -EINVAL; 2481 goto free_priv; 2482 } 2483 2484 if (crypto_info->version == TLS_1_3_VERSION) { 2485 nonce_size = 0; 2486 prot->aad_size = TLS_HEADER_SIZE; 2487 prot->tail_size = 1; 2488 } else { 2489 prot->aad_size = TLS_AAD_SPACE_SIZE; 2490 prot->tail_size = 0; 2491 } 2492 2493 prot->version = crypto_info->version; 2494 prot->cipher_type = crypto_info->cipher_type; 2495 prot->prepend_size = TLS_HEADER_SIZE + nonce_size; 2496 prot->tag_size = tag_size; 2497 prot->overhead_size = prot->prepend_size + 2498 prot->tag_size + prot->tail_size; 2499 prot->iv_size = iv_size; 2500 prot->salt_size = salt_size; 2501 cctx->iv = kmalloc(iv_size + salt_size, GFP_KERNEL); 2502 if (!cctx->iv) { 2503 rc = -ENOMEM; 2504 goto free_priv; 2505 } 2506 /* Note: 128 & 256 bit salt are the same size */ 2507 prot->rec_seq_size = rec_seq_size; 2508 memcpy(cctx->iv, salt, salt_size); 2509 memcpy(cctx->iv + salt_size, iv, iv_size); 2510 cctx->rec_seq = kmemdup(rec_seq, rec_seq_size, GFP_KERNEL); 2511 if (!cctx->rec_seq) { 2512 rc = -ENOMEM; 2513 goto free_iv; 2514 } 2515 2516 if (!*aead) { 2517 *aead = crypto_alloc_aead(cipher_name, 0, 0); 2518 if (IS_ERR(*aead)) { 2519 rc = PTR_ERR(*aead); 2520 *aead = NULL; 2521 goto free_rec_seq; 2522 } 2523 } 2524 2525 ctx->push_pending_record = tls_sw_push_pending_record; 2526 2527 rc = crypto_aead_setkey(*aead, key, keysize); 2528 2529 if (rc) 2530 goto free_aead; 2531 2532 rc = crypto_aead_setauthsize(*aead, prot->tag_size); 2533 if (rc) 2534 goto free_aead; 2535 2536 if (sw_ctx_rx) { 2537 tfm = crypto_aead_tfm(sw_ctx_rx->aead_recv); 2538 2539 tls_update_rx_zc_capable(ctx); 2540 sw_ctx_rx->async_capable = 2541 crypto_info->version != TLS_1_3_VERSION && 2542 !!(tfm->__crt_alg->cra_flags & CRYPTO_ALG_ASYNC); 2543 2544 /* Set up strparser */ 2545 memset(&cb, 0, sizeof(cb)); 2546 cb.rcv_msg = tls_queue; 2547 cb.parse_msg = tls_read_size; 2548 2549 strp_init(&sw_ctx_rx->strp, sk, &cb); 2550 } 2551 2552 goto out; 2553 2554 free_aead: 2555 crypto_free_aead(*aead); 2556 *aead = NULL; 2557 free_rec_seq: 2558 kfree(cctx->rec_seq); 2559 cctx->rec_seq = NULL; 2560 free_iv: 2561 kfree(cctx->iv); 2562 cctx->iv = NULL; 2563 free_priv: 2564 if (tx) { 2565 kfree(ctx->priv_ctx_tx); 2566 ctx->priv_ctx_tx = NULL; 2567 } else { 2568 kfree(ctx->priv_ctx_rx); 2569 ctx->priv_ctx_rx = NULL; 2570 } 2571 out: 2572 return rc; 2573 } 2574