1 /* 2 * Copyright (c) 2016-2017, Mellanox Technologies. All rights reserved. 3 * Copyright (c) 2016-2017, Dave Watson <davejwatson@fb.com>. All rights reserved. 4 * Copyright (c) 2016-2017, Lance Chao <lancerchao@fb.com>. All rights reserved. 5 * Copyright (c) 2016, Fridolin Pokorny <fridolin.pokorny@gmail.com>. All rights reserved. 6 * Copyright (c) 2016, Nikos Mavrogiannopoulos <nmav@gnutls.org>. All rights reserved. 7 * Copyright (c) 2018, Covalent IO, Inc. http://covalent.io 8 * 9 * This software is available to you under a choice of one of two 10 * licenses. You may choose to be licensed under the terms of the GNU 11 * General Public License (GPL) Version 2, available from the file 12 * COPYING in the main directory of this source tree, or the 13 * OpenIB.org BSD license below: 14 * 15 * Redistribution and use in source and binary forms, with or 16 * without modification, are permitted provided that the following 17 * conditions are met: 18 * 19 * - Redistributions of source code must retain the above 20 * copyright notice, this list of conditions and the following 21 * disclaimer. 22 * 23 * - Redistributions in binary form must reproduce the above 24 * copyright notice, this list of conditions and the following 25 * disclaimer in the documentation and/or other materials 26 * provided with the distribution. 27 * 28 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, 29 * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF 30 * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND 31 * NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS 32 * BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN 33 * ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN 34 * CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 35 * SOFTWARE. 36 */ 37 38 #include <linux/bug.h> 39 #include <linux/sched/signal.h> 40 #include <linux/module.h> 41 #include <linux/splice.h> 42 #include <crypto/aead.h> 43 44 #include <net/strparser.h> 45 #include <net/tls.h> 46 #include <trace/events/sock.h> 47 48 #include "tls.h" 49 50 struct tls_decrypt_arg { 51 struct_group(inargs, 52 bool zc; 53 bool async; 54 u8 tail; 55 ); 56 57 struct sk_buff *skb; 58 }; 59 60 struct tls_decrypt_ctx { 61 u8 iv[MAX_IV_SIZE]; 62 u8 aad[TLS_MAX_AAD_SIZE]; 63 u8 tail; 64 struct scatterlist sg[]; 65 }; 66 67 noinline void tls_err_abort(struct sock *sk, int err) 68 { 69 WARN_ON_ONCE(err >= 0); 70 /* sk->sk_err should contain a positive error code. */ 71 sk->sk_err = -err; 72 sk_error_report(sk); 73 } 74 75 static int __skb_nsg(struct sk_buff *skb, int offset, int len, 76 unsigned int recursion_level) 77 { 78 int start = skb_headlen(skb); 79 int i, chunk = start - offset; 80 struct sk_buff *frag_iter; 81 int elt = 0; 82 83 if (unlikely(recursion_level >= 24)) 84 return -EMSGSIZE; 85 86 if (chunk > 0) { 87 if (chunk > len) 88 chunk = len; 89 elt++; 90 len -= chunk; 91 if (len == 0) 92 return elt; 93 offset += chunk; 94 } 95 96 for (i = 0; i < skb_shinfo(skb)->nr_frags; i++) { 97 int end; 98 99 WARN_ON(start > offset + len); 100 101 end = start + skb_frag_size(&skb_shinfo(skb)->frags[i]); 102 chunk = end - offset; 103 if (chunk > 0) { 104 if (chunk > len) 105 chunk = len; 106 elt++; 107 len -= chunk; 108 if (len == 0) 109 return elt; 110 offset += chunk; 111 } 112 start = end; 113 } 114 115 if (unlikely(skb_has_frag_list(skb))) { 116 skb_walk_frags(skb, frag_iter) { 117 int end, ret; 118 119 WARN_ON(start > offset + len); 120 121 end = start + frag_iter->len; 122 chunk = end - offset; 123 if (chunk > 0) { 124 if (chunk > len) 125 chunk = len; 126 ret = __skb_nsg(frag_iter, offset - start, chunk, 127 recursion_level + 1); 128 if (unlikely(ret < 0)) 129 return ret; 130 elt += ret; 131 len -= chunk; 132 if (len == 0) 133 return elt; 134 offset += chunk; 135 } 136 start = end; 137 } 138 } 139 BUG_ON(len); 140 return elt; 141 } 142 143 /* Return the number of scatterlist elements required to completely map the 144 * skb, or -EMSGSIZE if the recursion depth is exceeded. 145 */ 146 static int skb_nsg(struct sk_buff *skb, int offset, int len) 147 { 148 return __skb_nsg(skb, offset, len, 0); 149 } 150 151 static int tls_padding_length(struct tls_prot_info *prot, struct sk_buff *skb, 152 struct tls_decrypt_arg *darg) 153 { 154 struct strp_msg *rxm = strp_msg(skb); 155 struct tls_msg *tlm = tls_msg(skb); 156 int sub = 0; 157 158 /* Determine zero-padding length */ 159 if (prot->version == TLS_1_3_VERSION) { 160 int offset = rxm->full_len - TLS_TAG_SIZE - 1; 161 char content_type = darg->zc ? darg->tail : 0; 162 int err; 163 164 while (content_type == 0) { 165 if (offset < prot->prepend_size) 166 return -EBADMSG; 167 err = skb_copy_bits(skb, rxm->offset + offset, 168 &content_type, 1); 169 if (err) 170 return err; 171 if (content_type) 172 break; 173 sub++; 174 offset--; 175 } 176 tlm->control = content_type; 177 } 178 return sub; 179 } 180 181 static void tls_decrypt_done(struct crypto_async_request *req, int err) 182 { 183 struct aead_request *aead_req = (struct aead_request *)req; 184 struct scatterlist *sgout = aead_req->dst; 185 struct scatterlist *sgin = aead_req->src; 186 struct tls_sw_context_rx *ctx; 187 struct tls_context *tls_ctx; 188 struct scatterlist *sg; 189 unsigned int pages; 190 struct sock *sk; 191 192 sk = (struct sock *)req->data; 193 tls_ctx = tls_get_ctx(sk); 194 ctx = tls_sw_ctx_rx(tls_ctx); 195 196 /* Propagate if there was an err */ 197 if (err) { 198 if (err == -EBADMSG) 199 TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSDECRYPTERROR); 200 ctx->async_wait.err = err; 201 tls_err_abort(sk, err); 202 } 203 204 /* Free the destination pages if skb was not decrypted inplace */ 205 if (sgout != sgin) { 206 /* Skip the first S/G entry as it points to AAD */ 207 for_each_sg(sg_next(sgout), sg, UINT_MAX, pages) { 208 if (!sg) 209 break; 210 put_page(sg_page(sg)); 211 } 212 } 213 214 kfree(aead_req); 215 216 spin_lock_bh(&ctx->decrypt_compl_lock); 217 if (!atomic_dec_return(&ctx->decrypt_pending)) 218 complete(&ctx->async_wait.completion); 219 spin_unlock_bh(&ctx->decrypt_compl_lock); 220 } 221 222 static int tls_do_decryption(struct sock *sk, 223 struct scatterlist *sgin, 224 struct scatterlist *sgout, 225 char *iv_recv, 226 size_t data_len, 227 struct aead_request *aead_req, 228 struct tls_decrypt_arg *darg) 229 { 230 struct tls_context *tls_ctx = tls_get_ctx(sk); 231 struct tls_prot_info *prot = &tls_ctx->prot_info; 232 struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx); 233 int ret; 234 235 aead_request_set_tfm(aead_req, ctx->aead_recv); 236 aead_request_set_ad(aead_req, prot->aad_size); 237 aead_request_set_crypt(aead_req, sgin, sgout, 238 data_len + prot->tag_size, 239 (u8 *)iv_recv); 240 241 if (darg->async) { 242 aead_request_set_callback(aead_req, 243 CRYPTO_TFM_REQ_MAY_BACKLOG, 244 tls_decrypt_done, sk); 245 atomic_inc(&ctx->decrypt_pending); 246 } else { 247 aead_request_set_callback(aead_req, 248 CRYPTO_TFM_REQ_MAY_BACKLOG, 249 crypto_req_done, &ctx->async_wait); 250 } 251 252 ret = crypto_aead_decrypt(aead_req); 253 if (ret == -EINPROGRESS) { 254 if (darg->async) 255 return 0; 256 257 ret = crypto_wait_req(ret, &ctx->async_wait); 258 } 259 darg->async = false; 260 261 return ret; 262 } 263 264 static void tls_trim_both_msgs(struct sock *sk, int target_size) 265 { 266 struct tls_context *tls_ctx = tls_get_ctx(sk); 267 struct tls_prot_info *prot = &tls_ctx->prot_info; 268 struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx); 269 struct tls_rec *rec = ctx->open_rec; 270 271 sk_msg_trim(sk, &rec->msg_plaintext, target_size); 272 if (target_size > 0) 273 target_size += prot->overhead_size; 274 sk_msg_trim(sk, &rec->msg_encrypted, target_size); 275 } 276 277 static int tls_alloc_encrypted_msg(struct sock *sk, int len) 278 { 279 struct tls_context *tls_ctx = tls_get_ctx(sk); 280 struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx); 281 struct tls_rec *rec = ctx->open_rec; 282 struct sk_msg *msg_en = &rec->msg_encrypted; 283 284 return sk_msg_alloc(sk, msg_en, len, 0); 285 } 286 287 static int tls_clone_plaintext_msg(struct sock *sk, int required) 288 { 289 struct tls_context *tls_ctx = tls_get_ctx(sk); 290 struct tls_prot_info *prot = &tls_ctx->prot_info; 291 struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx); 292 struct tls_rec *rec = ctx->open_rec; 293 struct sk_msg *msg_pl = &rec->msg_plaintext; 294 struct sk_msg *msg_en = &rec->msg_encrypted; 295 int skip, len; 296 297 /* We add page references worth len bytes from encrypted sg 298 * at the end of plaintext sg. It is guaranteed that msg_en 299 * has enough required room (ensured by caller). 300 */ 301 len = required - msg_pl->sg.size; 302 303 /* Skip initial bytes in msg_en's data to be able to use 304 * same offset of both plain and encrypted data. 305 */ 306 skip = prot->prepend_size + msg_pl->sg.size; 307 308 return sk_msg_clone(sk, msg_pl, msg_en, skip, len); 309 } 310 311 static struct tls_rec *tls_get_rec(struct sock *sk) 312 { 313 struct tls_context *tls_ctx = tls_get_ctx(sk); 314 struct tls_prot_info *prot = &tls_ctx->prot_info; 315 struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx); 316 struct sk_msg *msg_pl, *msg_en; 317 struct tls_rec *rec; 318 int mem_size; 319 320 mem_size = sizeof(struct tls_rec) + crypto_aead_reqsize(ctx->aead_send); 321 322 rec = kzalloc(mem_size, sk->sk_allocation); 323 if (!rec) 324 return NULL; 325 326 msg_pl = &rec->msg_plaintext; 327 msg_en = &rec->msg_encrypted; 328 329 sk_msg_init(msg_pl); 330 sk_msg_init(msg_en); 331 332 sg_init_table(rec->sg_aead_in, 2); 333 sg_set_buf(&rec->sg_aead_in[0], rec->aad_space, prot->aad_size); 334 sg_unmark_end(&rec->sg_aead_in[1]); 335 336 sg_init_table(rec->sg_aead_out, 2); 337 sg_set_buf(&rec->sg_aead_out[0], rec->aad_space, prot->aad_size); 338 sg_unmark_end(&rec->sg_aead_out[1]); 339 340 return rec; 341 } 342 343 static void tls_free_rec(struct sock *sk, struct tls_rec *rec) 344 { 345 sk_msg_free(sk, &rec->msg_encrypted); 346 sk_msg_free(sk, &rec->msg_plaintext); 347 kfree(rec); 348 } 349 350 static void tls_free_open_rec(struct sock *sk) 351 { 352 struct tls_context *tls_ctx = tls_get_ctx(sk); 353 struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx); 354 struct tls_rec *rec = ctx->open_rec; 355 356 if (rec) { 357 tls_free_rec(sk, rec); 358 ctx->open_rec = NULL; 359 } 360 } 361 362 int tls_tx_records(struct sock *sk, int flags) 363 { 364 struct tls_context *tls_ctx = tls_get_ctx(sk); 365 struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx); 366 struct tls_rec *rec, *tmp; 367 struct sk_msg *msg_en; 368 int tx_flags, rc = 0; 369 370 if (tls_is_partially_sent_record(tls_ctx)) { 371 rec = list_first_entry(&ctx->tx_list, 372 struct tls_rec, list); 373 374 if (flags == -1) 375 tx_flags = rec->tx_flags; 376 else 377 tx_flags = flags; 378 379 rc = tls_push_partial_record(sk, tls_ctx, tx_flags); 380 if (rc) 381 goto tx_err; 382 383 /* Full record has been transmitted. 384 * Remove the head of tx_list 385 */ 386 list_del(&rec->list); 387 sk_msg_free(sk, &rec->msg_plaintext); 388 kfree(rec); 389 } 390 391 /* Tx all ready records */ 392 list_for_each_entry_safe(rec, tmp, &ctx->tx_list, list) { 393 if (READ_ONCE(rec->tx_ready)) { 394 if (flags == -1) 395 tx_flags = rec->tx_flags; 396 else 397 tx_flags = flags; 398 399 msg_en = &rec->msg_encrypted; 400 rc = tls_push_sg(sk, tls_ctx, 401 &msg_en->sg.data[msg_en->sg.curr], 402 0, tx_flags); 403 if (rc) 404 goto tx_err; 405 406 list_del(&rec->list); 407 sk_msg_free(sk, &rec->msg_plaintext); 408 kfree(rec); 409 } else { 410 break; 411 } 412 } 413 414 tx_err: 415 if (rc < 0 && rc != -EAGAIN) 416 tls_err_abort(sk, -EBADMSG); 417 418 return rc; 419 } 420 421 static void tls_encrypt_done(struct crypto_async_request *req, int err) 422 { 423 struct aead_request *aead_req = (struct aead_request *)req; 424 struct sock *sk = req->data; 425 struct tls_context *tls_ctx = tls_get_ctx(sk); 426 struct tls_prot_info *prot = &tls_ctx->prot_info; 427 struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx); 428 struct scatterlist *sge; 429 struct sk_msg *msg_en; 430 struct tls_rec *rec; 431 bool ready = false; 432 int pending; 433 434 rec = container_of(aead_req, struct tls_rec, aead_req); 435 msg_en = &rec->msg_encrypted; 436 437 sge = sk_msg_elem(msg_en, msg_en->sg.curr); 438 sge->offset -= prot->prepend_size; 439 sge->length += prot->prepend_size; 440 441 /* Check if error is previously set on socket */ 442 if (err || sk->sk_err) { 443 rec = NULL; 444 445 /* If err is already set on socket, return the same code */ 446 if (sk->sk_err) { 447 ctx->async_wait.err = -sk->sk_err; 448 } else { 449 ctx->async_wait.err = err; 450 tls_err_abort(sk, err); 451 } 452 } 453 454 if (rec) { 455 struct tls_rec *first_rec; 456 457 /* Mark the record as ready for transmission */ 458 smp_store_mb(rec->tx_ready, true); 459 460 /* If received record is at head of tx_list, schedule tx */ 461 first_rec = list_first_entry(&ctx->tx_list, 462 struct tls_rec, list); 463 if (rec == first_rec) 464 ready = true; 465 } 466 467 spin_lock_bh(&ctx->encrypt_compl_lock); 468 pending = atomic_dec_return(&ctx->encrypt_pending); 469 470 if (!pending && ctx->async_notify) 471 complete(&ctx->async_wait.completion); 472 spin_unlock_bh(&ctx->encrypt_compl_lock); 473 474 if (!ready) 475 return; 476 477 /* Schedule the transmission */ 478 if (!test_and_set_bit(BIT_TX_SCHEDULED, &ctx->tx_bitmask)) 479 schedule_delayed_work(&ctx->tx_work.work, 1); 480 } 481 482 static int tls_do_encryption(struct sock *sk, 483 struct tls_context *tls_ctx, 484 struct tls_sw_context_tx *ctx, 485 struct aead_request *aead_req, 486 size_t data_len, u32 start) 487 { 488 struct tls_prot_info *prot = &tls_ctx->prot_info; 489 struct tls_rec *rec = ctx->open_rec; 490 struct sk_msg *msg_en = &rec->msg_encrypted; 491 struct scatterlist *sge = sk_msg_elem(msg_en, start); 492 int rc, iv_offset = 0; 493 494 /* For CCM based ciphers, first byte of IV is a constant */ 495 switch (prot->cipher_type) { 496 case TLS_CIPHER_AES_CCM_128: 497 rec->iv_data[0] = TLS_AES_CCM_IV_B0_BYTE; 498 iv_offset = 1; 499 break; 500 case TLS_CIPHER_SM4_CCM: 501 rec->iv_data[0] = TLS_SM4_CCM_IV_B0_BYTE; 502 iv_offset = 1; 503 break; 504 } 505 506 memcpy(&rec->iv_data[iv_offset], tls_ctx->tx.iv, 507 prot->iv_size + prot->salt_size); 508 509 tls_xor_iv_with_seq(prot, rec->iv_data + iv_offset, 510 tls_ctx->tx.rec_seq); 511 512 sge->offset += prot->prepend_size; 513 sge->length -= prot->prepend_size; 514 515 msg_en->sg.curr = start; 516 517 aead_request_set_tfm(aead_req, ctx->aead_send); 518 aead_request_set_ad(aead_req, prot->aad_size); 519 aead_request_set_crypt(aead_req, rec->sg_aead_in, 520 rec->sg_aead_out, 521 data_len, rec->iv_data); 522 523 aead_request_set_callback(aead_req, CRYPTO_TFM_REQ_MAY_BACKLOG, 524 tls_encrypt_done, sk); 525 526 /* Add the record in tx_list */ 527 list_add_tail((struct list_head *)&rec->list, &ctx->tx_list); 528 atomic_inc(&ctx->encrypt_pending); 529 530 rc = crypto_aead_encrypt(aead_req); 531 if (!rc || rc != -EINPROGRESS) { 532 atomic_dec(&ctx->encrypt_pending); 533 sge->offset -= prot->prepend_size; 534 sge->length += prot->prepend_size; 535 } 536 537 if (!rc) { 538 WRITE_ONCE(rec->tx_ready, true); 539 } else if (rc != -EINPROGRESS) { 540 list_del(&rec->list); 541 return rc; 542 } 543 544 /* Unhook the record from context if encryption is not failure */ 545 ctx->open_rec = NULL; 546 tls_advance_record_sn(sk, prot, &tls_ctx->tx); 547 return rc; 548 } 549 550 static int tls_split_open_record(struct sock *sk, struct tls_rec *from, 551 struct tls_rec **to, struct sk_msg *msg_opl, 552 struct sk_msg *msg_oen, u32 split_point, 553 u32 tx_overhead_size, u32 *orig_end) 554 { 555 u32 i, j, bytes = 0, apply = msg_opl->apply_bytes; 556 struct scatterlist *sge, *osge, *nsge; 557 u32 orig_size = msg_opl->sg.size; 558 struct scatterlist tmp = { }; 559 struct sk_msg *msg_npl; 560 struct tls_rec *new; 561 int ret; 562 563 new = tls_get_rec(sk); 564 if (!new) 565 return -ENOMEM; 566 ret = sk_msg_alloc(sk, &new->msg_encrypted, msg_opl->sg.size + 567 tx_overhead_size, 0); 568 if (ret < 0) { 569 tls_free_rec(sk, new); 570 return ret; 571 } 572 573 *orig_end = msg_opl->sg.end; 574 i = msg_opl->sg.start; 575 sge = sk_msg_elem(msg_opl, i); 576 while (apply && sge->length) { 577 if (sge->length > apply) { 578 u32 len = sge->length - apply; 579 580 get_page(sg_page(sge)); 581 sg_set_page(&tmp, sg_page(sge), len, 582 sge->offset + apply); 583 sge->length = apply; 584 bytes += apply; 585 apply = 0; 586 } else { 587 apply -= sge->length; 588 bytes += sge->length; 589 } 590 591 sk_msg_iter_var_next(i); 592 if (i == msg_opl->sg.end) 593 break; 594 sge = sk_msg_elem(msg_opl, i); 595 } 596 597 msg_opl->sg.end = i; 598 msg_opl->sg.curr = i; 599 msg_opl->sg.copybreak = 0; 600 msg_opl->apply_bytes = 0; 601 msg_opl->sg.size = bytes; 602 603 msg_npl = &new->msg_plaintext; 604 msg_npl->apply_bytes = apply; 605 msg_npl->sg.size = orig_size - bytes; 606 607 j = msg_npl->sg.start; 608 nsge = sk_msg_elem(msg_npl, j); 609 if (tmp.length) { 610 memcpy(nsge, &tmp, sizeof(*nsge)); 611 sk_msg_iter_var_next(j); 612 nsge = sk_msg_elem(msg_npl, j); 613 } 614 615 osge = sk_msg_elem(msg_opl, i); 616 while (osge->length) { 617 memcpy(nsge, osge, sizeof(*nsge)); 618 sg_unmark_end(nsge); 619 sk_msg_iter_var_next(i); 620 sk_msg_iter_var_next(j); 621 if (i == *orig_end) 622 break; 623 osge = sk_msg_elem(msg_opl, i); 624 nsge = sk_msg_elem(msg_npl, j); 625 } 626 627 msg_npl->sg.end = j; 628 msg_npl->sg.curr = j; 629 msg_npl->sg.copybreak = 0; 630 631 *to = new; 632 return 0; 633 } 634 635 static void tls_merge_open_record(struct sock *sk, struct tls_rec *to, 636 struct tls_rec *from, u32 orig_end) 637 { 638 struct sk_msg *msg_npl = &from->msg_plaintext; 639 struct sk_msg *msg_opl = &to->msg_plaintext; 640 struct scatterlist *osge, *nsge; 641 u32 i, j; 642 643 i = msg_opl->sg.end; 644 sk_msg_iter_var_prev(i); 645 j = msg_npl->sg.start; 646 647 osge = sk_msg_elem(msg_opl, i); 648 nsge = sk_msg_elem(msg_npl, j); 649 650 if (sg_page(osge) == sg_page(nsge) && 651 osge->offset + osge->length == nsge->offset) { 652 osge->length += nsge->length; 653 put_page(sg_page(nsge)); 654 } 655 656 msg_opl->sg.end = orig_end; 657 msg_opl->sg.curr = orig_end; 658 msg_opl->sg.copybreak = 0; 659 msg_opl->apply_bytes = msg_opl->sg.size + msg_npl->sg.size; 660 msg_opl->sg.size += msg_npl->sg.size; 661 662 sk_msg_free(sk, &to->msg_encrypted); 663 sk_msg_xfer_full(&to->msg_encrypted, &from->msg_encrypted); 664 665 kfree(from); 666 } 667 668 static int tls_push_record(struct sock *sk, int flags, 669 unsigned char record_type) 670 { 671 struct tls_context *tls_ctx = tls_get_ctx(sk); 672 struct tls_prot_info *prot = &tls_ctx->prot_info; 673 struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx); 674 struct tls_rec *rec = ctx->open_rec, *tmp = NULL; 675 u32 i, split_point, orig_end; 676 struct sk_msg *msg_pl, *msg_en; 677 struct aead_request *req; 678 bool split; 679 int rc; 680 681 if (!rec) 682 return 0; 683 684 msg_pl = &rec->msg_plaintext; 685 msg_en = &rec->msg_encrypted; 686 687 split_point = msg_pl->apply_bytes; 688 split = split_point && split_point < msg_pl->sg.size; 689 if (unlikely((!split && 690 msg_pl->sg.size + 691 prot->overhead_size > msg_en->sg.size) || 692 (split && 693 split_point + 694 prot->overhead_size > msg_en->sg.size))) { 695 split = true; 696 split_point = msg_en->sg.size; 697 } 698 if (split) { 699 rc = tls_split_open_record(sk, rec, &tmp, msg_pl, msg_en, 700 split_point, prot->overhead_size, 701 &orig_end); 702 if (rc < 0) 703 return rc; 704 /* This can happen if above tls_split_open_record allocates 705 * a single large encryption buffer instead of two smaller 706 * ones. In this case adjust pointers and continue without 707 * split. 708 */ 709 if (!msg_pl->sg.size) { 710 tls_merge_open_record(sk, rec, tmp, orig_end); 711 msg_pl = &rec->msg_plaintext; 712 msg_en = &rec->msg_encrypted; 713 split = false; 714 } 715 sk_msg_trim(sk, msg_en, msg_pl->sg.size + 716 prot->overhead_size); 717 } 718 719 rec->tx_flags = flags; 720 req = &rec->aead_req; 721 722 i = msg_pl->sg.end; 723 sk_msg_iter_var_prev(i); 724 725 rec->content_type = record_type; 726 if (prot->version == TLS_1_3_VERSION) { 727 /* Add content type to end of message. No padding added */ 728 sg_set_buf(&rec->sg_content_type, &rec->content_type, 1); 729 sg_mark_end(&rec->sg_content_type); 730 sg_chain(msg_pl->sg.data, msg_pl->sg.end + 1, 731 &rec->sg_content_type); 732 } else { 733 sg_mark_end(sk_msg_elem(msg_pl, i)); 734 } 735 736 if (msg_pl->sg.end < msg_pl->sg.start) { 737 sg_chain(&msg_pl->sg.data[msg_pl->sg.start], 738 MAX_SKB_FRAGS - msg_pl->sg.start + 1, 739 msg_pl->sg.data); 740 } 741 742 i = msg_pl->sg.start; 743 sg_chain(rec->sg_aead_in, 2, &msg_pl->sg.data[i]); 744 745 i = msg_en->sg.end; 746 sk_msg_iter_var_prev(i); 747 sg_mark_end(sk_msg_elem(msg_en, i)); 748 749 i = msg_en->sg.start; 750 sg_chain(rec->sg_aead_out, 2, &msg_en->sg.data[i]); 751 752 tls_make_aad(rec->aad_space, msg_pl->sg.size + prot->tail_size, 753 tls_ctx->tx.rec_seq, record_type, prot); 754 755 tls_fill_prepend(tls_ctx, 756 page_address(sg_page(&msg_en->sg.data[i])) + 757 msg_en->sg.data[i].offset, 758 msg_pl->sg.size + prot->tail_size, 759 record_type); 760 761 tls_ctx->pending_open_record_frags = false; 762 763 rc = tls_do_encryption(sk, tls_ctx, ctx, req, 764 msg_pl->sg.size + prot->tail_size, i); 765 if (rc < 0) { 766 if (rc != -EINPROGRESS) { 767 tls_err_abort(sk, -EBADMSG); 768 if (split) { 769 tls_ctx->pending_open_record_frags = true; 770 tls_merge_open_record(sk, rec, tmp, orig_end); 771 } 772 } 773 ctx->async_capable = 1; 774 return rc; 775 } else if (split) { 776 msg_pl = &tmp->msg_plaintext; 777 msg_en = &tmp->msg_encrypted; 778 sk_msg_trim(sk, msg_en, msg_pl->sg.size + prot->overhead_size); 779 tls_ctx->pending_open_record_frags = true; 780 ctx->open_rec = tmp; 781 } 782 783 return tls_tx_records(sk, flags); 784 } 785 786 static int bpf_exec_tx_verdict(struct sk_msg *msg, struct sock *sk, 787 bool full_record, u8 record_type, 788 ssize_t *copied, int flags) 789 { 790 struct tls_context *tls_ctx = tls_get_ctx(sk); 791 struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx); 792 struct sk_msg msg_redir = { }; 793 struct sk_psock *psock; 794 struct sock *sk_redir; 795 struct tls_rec *rec; 796 bool enospc, policy, redir_ingress; 797 int err = 0, send; 798 u32 delta = 0; 799 800 policy = !(flags & MSG_SENDPAGE_NOPOLICY); 801 psock = sk_psock_get(sk); 802 if (!psock || !policy) { 803 err = tls_push_record(sk, flags, record_type); 804 if (err && sk->sk_err == EBADMSG) { 805 *copied -= sk_msg_free(sk, msg); 806 tls_free_open_rec(sk); 807 err = -sk->sk_err; 808 } 809 if (psock) 810 sk_psock_put(sk, psock); 811 return err; 812 } 813 more_data: 814 enospc = sk_msg_full(msg); 815 if (psock->eval == __SK_NONE) { 816 delta = msg->sg.size; 817 psock->eval = sk_psock_msg_verdict(sk, psock, msg); 818 delta -= msg->sg.size; 819 } 820 if (msg->cork_bytes && msg->cork_bytes > msg->sg.size && 821 !enospc && !full_record) { 822 err = -ENOSPC; 823 goto out_err; 824 } 825 msg->cork_bytes = 0; 826 send = msg->sg.size; 827 if (msg->apply_bytes && msg->apply_bytes < send) 828 send = msg->apply_bytes; 829 830 switch (psock->eval) { 831 case __SK_PASS: 832 err = tls_push_record(sk, flags, record_type); 833 if (err && sk->sk_err == EBADMSG) { 834 *copied -= sk_msg_free(sk, msg); 835 tls_free_open_rec(sk); 836 err = -sk->sk_err; 837 goto out_err; 838 } 839 break; 840 case __SK_REDIRECT: 841 redir_ingress = psock->redir_ingress; 842 sk_redir = psock->sk_redir; 843 memcpy(&msg_redir, msg, sizeof(*msg)); 844 if (msg->apply_bytes < send) 845 msg->apply_bytes = 0; 846 else 847 msg->apply_bytes -= send; 848 sk_msg_return_zero(sk, msg, send); 849 msg->sg.size -= send; 850 release_sock(sk); 851 err = tcp_bpf_sendmsg_redir(sk_redir, redir_ingress, 852 &msg_redir, send, flags); 853 lock_sock(sk); 854 if (err < 0) { 855 *copied -= sk_msg_free_nocharge(sk, &msg_redir); 856 msg->sg.size = 0; 857 } 858 if (msg->sg.size == 0) 859 tls_free_open_rec(sk); 860 break; 861 case __SK_DROP: 862 default: 863 sk_msg_free_partial(sk, msg, send); 864 if (msg->apply_bytes < send) 865 msg->apply_bytes = 0; 866 else 867 msg->apply_bytes -= send; 868 if (msg->sg.size == 0) 869 tls_free_open_rec(sk); 870 *copied -= (send + delta); 871 err = -EACCES; 872 } 873 874 if (likely(!err)) { 875 bool reset_eval = !ctx->open_rec; 876 877 rec = ctx->open_rec; 878 if (rec) { 879 msg = &rec->msg_plaintext; 880 if (!msg->apply_bytes) 881 reset_eval = true; 882 } 883 if (reset_eval) { 884 psock->eval = __SK_NONE; 885 if (psock->sk_redir) { 886 sock_put(psock->sk_redir); 887 psock->sk_redir = NULL; 888 } 889 } 890 if (rec) 891 goto more_data; 892 } 893 out_err: 894 sk_psock_put(sk, psock); 895 return err; 896 } 897 898 static int tls_sw_push_pending_record(struct sock *sk, int flags) 899 { 900 struct tls_context *tls_ctx = tls_get_ctx(sk); 901 struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx); 902 struct tls_rec *rec = ctx->open_rec; 903 struct sk_msg *msg_pl; 904 size_t copied; 905 906 if (!rec) 907 return 0; 908 909 msg_pl = &rec->msg_plaintext; 910 copied = msg_pl->sg.size; 911 if (!copied) 912 return 0; 913 914 return bpf_exec_tx_verdict(msg_pl, sk, true, TLS_RECORD_TYPE_DATA, 915 &copied, flags); 916 } 917 918 int tls_sw_sendmsg(struct sock *sk, struct msghdr *msg, size_t size) 919 { 920 long timeo = sock_sndtimeo(sk, msg->msg_flags & MSG_DONTWAIT); 921 struct tls_context *tls_ctx = tls_get_ctx(sk); 922 struct tls_prot_info *prot = &tls_ctx->prot_info; 923 struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx); 924 bool async_capable = ctx->async_capable; 925 unsigned char record_type = TLS_RECORD_TYPE_DATA; 926 bool is_kvec = iov_iter_is_kvec(&msg->msg_iter); 927 bool eor = !(msg->msg_flags & MSG_MORE); 928 size_t try_to_copy; 929 ssize_t copied = 0; 930 struct sk_msg *msg_pl, *msg_en; 931 struct tls_rec *rec; 932 int required_size; 933 int num_async = 0; 934 bool full_record; 935 int record_room; 936 int num_zc = 0; 937 int orig_size; 938 int ret = 0; 939 int pending; 940 941 if (msg->msg_flags & ~(MSG_MORE | MSG_DONTWAIT | MSG_NOSIGNAL | 942 MSG_CMSG_COMPAT)) 943 return -EOPNOTSUPP; 944 945 mutex_lock(&tls_ctx->tx_lock); 946 lock_sock(sk); 947 948 if (unlikely(msg->msg_controllen)) { 949 ret = tls_process_cmsg(sk, msg, &record_type); 950 if (ret) { 951 if (ret == -EINPROGRESS) 952 num_async++; 953 else if (ret != -EAGAIN) 954 goto send_end; 955 } 956 } 957 958 while (msg_data_left(msg)) { 959 if (sk->sk_err) { 960 ret = -sk->sk_err; 961 goto send_end; 962 } 963 964 if (ctx->open_rec) 965 rec = ctx->open_rec; 966 else 967 rec = ctx->open_rec = tls_get_rec(sk); 968 if (!rec) { 969 ret = -ENOMEM; 970 goto send_end; 971 } 972 973 msg_pl = &rec->msg_plaintext; 974 msg_en = &rec->msg_encrypted; 975 976 orig_size = msg_pl->sg.size; 977 full_record = false; 978 try_to_copy = msg_data_left(msg); 979 record_room = TLS_MAX_PAYLOAD_SIZE - msg_pl->sg.size; 980 if (try_to_copy >= record_room) { 981 try_to_copy = record_room; 982 full_record = true; 983 } 984 985 required_size = msg_pl->sg.size + try_to_copy + 986 prot->overhead_size; 987 988 if (!sk_stream_memory_free(sk)) 989 goto wait_for_sndbuf; 990 991 alloc_encrypted: 992 ret = tls_alloc_encrypted_msg(sk, required_size); 993 if (ret) { 994 if (ret != -ENOSPC) 995 goto wait_for_memory; 996 997 /* Adjust try_to_copy according to the amount that was 998 * actually allocated. The difference is due 999 * to max sg elements limit 1000 */ 1001 try_to_copy -= required_size - msg_en->sg.size; 1002 full_record = true; 1003 } 1004 1005 if (!is_kvec && (full_record || eor) && !async_capable) { 1006 u32 first = msg_pl->sg.end; 1007 1008 ret = sk_msg_zerocopy_from_iter(sk, &msg->msg_iter, 1009 msg_pl, try_to_copy); 1010 if (ret) 1011 goto fallback_to_reg_send; 1012 1013 num_zc++; 1014 copied += try_to_copy; 1015 1016 sk_msg_sg_copy_set(msg_pl, first); 1017 ret = bpf_exec_tx_verdict(msg_pl, sk, full_record, 1018 record_type, &copied, 1019 msg->msg_flags); 1020 if (ret) { 1021 if (ret == -EINPROGRESS) 1022 num_async++; 1023 else if (ret == -ENOMEM) 1024 goto wait_for_memory; 1025 else if (ctx->open_rec && ret == -ENOSPC) 1026 goto rollback_iter; 1027 else if (ret != -EAGAIN) 1028 goto send_end; 1029 } 1030 continue; 1031 rollback_iter: 1032 copied -= try_to_copy; 1033 sk_msg_sg_copy_clear(msg_pl, first); 1034 iov_iter_revert(&msg->msg_iter, 1035 msg_pl->sg.size - orig_size); 1036 fallback_to_reg_send: 1037 sk_msg_trim(sk, msg_pl, orig_size); 1038 } 1039 1040 required_size = msg_pl->sg.size + try_to_copy; 1041 1042 ret = tls_clone_plaintext_msg(sk, required_size); 1043 if (ret) { 1044 if (ret != -ENOSPC) 1045 goto send_end; 1046 1047 /* Adjust try_to_copy according to the amount that was 1048 * actually allocated. The difference is due 1049 * to max sg elements limit 1050 */ 1051 try_to_copy -= required_size - msg_pl->sg.size; 1052 full_record = true; 1053 sk_msg_trim(sk, msg_en, 1054 msg_pl->sg.size + prot->overhead_size); 1055 } 1056 1057 if (try_to_copy) { 1058 ret = sk_msg_memcopy_from_iter(sk, &msg->msg_iter, 1059 msg_pl, try_to_copy); 1060 if (ret < 0) 1061 goto trim_sgl; 1062 } 1063 1064 /* Open records defined only if successfully copied, otherwise 1065 * we would trim the sg but not reset the open record frags. 1066 */ 1067 tls_ctx->pending_open_record_frags = true; 1068 copied += try_to_copy; 1069 if (full_record || eor) { 1070 ret = bpf_exec_tx_verdict(msg_pl, sk, full_record, 1071 record_type, &copied, 1072 msg->msg_flags); 1073 if (ret) { 1074 if (ret == -EINPROGRESS) 1075 num_async++; 1076 else if (ret == -ENOMEM) 1077 goto wait_for_memory; 1078 else if (ret != -EAGAIN) { 1079 if (ret == -ENOSPC) 1080 ret = 0; 1081 goto send_end; 1082 } 1083 } 1084 } 1085 1086 continue; 1087 1088 wait_for_sndbuf: 1089 set_bit(SOCK_NOSPACE, &sk->sk_socket->flags); 1090 wait_for_memory: 1091 ret = sk_stream_wait_memory(sk, &timeo); 1092 if (ret) { 1093 trim_sgl: 1094 if (ctx->open_rec) 1095 tls_trim_both_msgs(sk, orig_size); 1096 goto send_end; 1097 } 1098 1099 if (ctx->open_rec && msg_en->sg.size < required_size) 1100 goto alloc_encrypted; 1101 } 1102 1103 if (!num_async) { 1104 goto send_end; 1105 } else if (num_zc) { 1106 /* Wait for pending encryptions to get completed */ 1107 spin_lock_bh(&ctx->encrypt_compl_lock); 1108 ctx->async_notify = true; 1109 1110 pending = atomic_read(&ctx->encrypt_pending); 1111 spin_unlock_bh(&ctx->encrypt_compl_lock); 1112 if (pending) 1113 crypto_wait_req(-EINPROGRESS, &ctx->async_wait); 1114 else 1115 reinit_completion(&ctx->async_wait.completion); 1116 1117 /* There can be no concurrent accesses, since we have no 1118 * pending encrypt operations 1119 */ 1120 WRITE_ONCE(ctx->async_notify, false); 1121 1122 if (ctx->async_wait.err) { 1123 ret = ctx->async_wait.err; 1124 copied = 0; 1125 } 1126 } 1127 1128 /* Transmit if any encryptions have completed */ 1129 if (test_and_clear_bit(BIT_TX_SCHEDULED, &ctx->tx_bitmask)) { 1130 cancel_delayed_work(&ctx->tx_work.work); 1131 tls_tx_records(sk, msg->msg_flags); 1132 } 1133 1134 send_end: 1135 ret = sk_stream_error(sk, msg->msg_flags, ret); 1136 1137 release_sock(sk); 1138 mutex_unlock(&tls_ctx->tx_lock); 1139 return copied > 0 ? copied : ret; 1140 } 1141 1142 static int tls_sw_do_sendpage(struct sock *sk, struct page *page, 1143 int offset, size_t size, int flags) 1144 { 1145 long timeo = sock_sndtimeo(sk, flags & MSG_DONTWAIT); 1146 struct tls_context *tls_ctx = tls_get_ctx(sk); 1147 struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx); 1148 struct tls_prot_info *prot = &tls_ctx->prot_info; 1149 unsigned char record_type = TLS_RECORD_TYPE_DATA; 1150 struct sk_msg *msg_pl; 1151 struct tls_rec *rec; 1152 int num_async = 0; 1153 ssize_t copied = 0; 1154 bool full_record; 1155 int record_room; 1156 int ret = 0; 1157 bool eor; 1158 1159 eor = !(flags & MSG_SENDPAGE_NOTLAST); 1160 sk_clear_bit(SOCKWQ_ASYNC_NOSPACE, sk); 1161 1162 /* Call the sk_stream functions to manage the sndbuf mem. */ 1163 while (size > 0) { 1164 size_t copy, required_size; 1165 1166 if (sk->sk_err) { 1167 ret = -sk->sk_err; 1168 goto sendpage_end; 1169 } 1170 1171 if (ctx->open_rec) 1172 rec = ctx->open_rec; 1173 else 1174 rec = ctx->open_rec = tls_get_rec(sk); 1175 if (!rec) { 1176 ret = -ENOMEM; 1177 goto sendpage_end; 1178 } 1179 1180 msg_pl = &rec->msg_plaintext; 1181 1182 full_record = false; 1183 record_room = TLS_MAX_PAYLOAD_SIZE - msg_pl->sg.size; 1184 copy = size; 1185 if (copy >= record_room) { 1186 copy = record_room; 1187 full_record = true; 1188 } 1189 1190 required_size = msg_pl->sg.size + copy + prot->overhead_size; 1191 1192 if (!sk_stream_memory_free(sk)) 1193 goto wait_for_sndbuf; 1194 alloc_payload: 1195 ret = tls_alloc_encrypted_msg(sk, required_size); 1196 if (ret) { 1197 if (ret != -ENOSPC) 1198 goto wait_for_memory; 1199 1200 /* Adjust copy according to the amount that was 1201 * actually allocated. The difference is due 1202 * to max sg elements limit 1203 */ 1204 copy -= required_size - msg_pl->sg.size; 1205 full_record = true; 1206 } 1207 1208 sk_msg_page_add(msg_pl, page, copy, offset); 1209 sk_mem_charge(sk, copy); 1210 1211 offset += copy; 1212 size -= copy; 1213 copied += copy; 1214 1215 tls_ctx->pending_open_record_frags = true; 1216 if (full_record || eor || sk_msg_full(msg_pl)) { 1217 ret = bpf_exec_tx_verdict(msg_pl, sk, full_record, 1218 record_type, &copied, flags); 1219 if (ret) { 1220 if (ret == -EINPROGRESS) 1221 num_async++; 1222 else if (ret == -ENOMEM) 1223 goto wait_for_memory; 1224 else if (ret != -EAGAIN) { 1225 if (ret == -ENOSPC) 1226 ret = 0; 1227 goto sendpage_end; 1228 } 1229 } 1230 } 1231 continue; 1232 wait_for_sndbuf: 1233 set_bit(SOCK_NOSPACE, &sk->sk_socket->flags); 1234 wait_for_memory: 1235 ret = sk_stream_wait_memory(sk, &timeo); 1236 if (ret) { 1237 if (ctx->open_rec) 1238 tls_trim_both_msgs(sk, msg_pl->sg.size); 1239 goto sendpage_end; 1240 } 1241 1242 if (ctx->open_rec) 1243 goto alloc_payload; 1244 } 1245 1246 if (num_async) { 1247 /* Transmit if any encryptions have completed */ 1248 if (test_and_clear_bit(BIT_TX_SCHEDULED, &ctx->tx_bitmask)) { 1249 cancel_delayed_work(&ctx->tx_work.work); 1250 tls_tx_records(sk, flags); 1251 } 1252 } 1253 sendpage_end: 1254 ret = sk_stream_error(sk, flags, ret); 1255 return copied > 0 ? copied : ret; 1256 } 1257 1258 int tls_sw_sendpage_locked(struct sock *sk, struct page *page, 1259 int offset, size_t size, int flags) 1260 { 1261 if (flags & ~(MSG_MORE | MSG_DONTWAIT | MSG_NOSIGNAL | 1262 MSG_SENDPAGE_NOTLAST | MSG_SENDPAGE_NOPOLICY | 1263 MSG_NO_SHARED_FRAGS)) 1264 return -EOPNOTSUPP; 1265 1266 return tls_sw_do_sendpage(sk, page, offset, size, flags); 1267 } 1268 1269 int tls_sw_sendpage(struct sock *sk, struct page *page, 1270 int offset, size_t size, int flags) 1271 { 1272 struct tls_context *tls_ctx = tls_get_ctx(sk); 1273 int ret; 1274 1275 if (flags & ~(MSG_MORE | MSG_DONTWAIT | MSG_NOSIGNAL | 1276 MSG_SENDPAGE_NOTLAST | MSG_SENDPAGE_NOPOLICY)) 1277 return -EOPNOTSUPP; 1278 1279 mutex_lock(&tls_ctx->tx_lock); 1280 lock_sock(sk); 1281 ret = tls_sw_do_sendpage(sk, page, offset, size, flags); 1282 release_sock(sk); 1283 mutex_unlock(&tls_ctx->tx_lock); 1284 return ret; 1285 } 1286 1287 static int 1288 tls_rx_rec_wait(struct sock *sk, struct sk_psock *psock, bool nonblock, 1289 bool released) 1290 { 1291 struct tls_context *tls_ctx = tls_get_ctx(sk); 1292 struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx); 1293 DEFINE_WAIT_FUNC(wait, woken_wake_function); 1294 long timeo; 1295 1296 timeo = sock_rcvtimeo(sk, nonblock); 1297 1298 while (!tls_strp_msg_ready(ctx)) { 1299 if (!sk_psock_queue_empty(psock)) 1300 return 0; 1301 1302 if (sk->sk_err) 1303 return sock_error(sk); 1304 1305 if (!skb_queue_empty(&sk->sk_receive_queue)) { 1306 tls_strp_check_rcv(&ctx->strp); 1307 if (tls_strp_msg_ready(ctx)) 1308 break; 1309 } 1310 1311 if (sk->sk_shutdown & RCV_SHUTDOWN) 1312 return 0; 1313 1314 if (sock_flag(sk, SOCK_DONE)) 1315 return 0; 1316 1317 if (!timeo) 1318 return -EAGAIN; 1319 1320 released = true; 1321 add_wait_queue(sk_sleep(sk), &wait); 1322 sk_set_bit(SOCKWQ_ASYNC_WAITDATA, sk); 1323 sk_wait_event(sk, &timeo, 1324 tls_strp_msg_ready(ctx) || 1325 !sk_psock_queue_empty(psock), 1326 &wait); 1327 sk_clear_bit(SOCKWQ_ASYNC_WAITDATA, sk); 1328 remove_wait_queue(sk_sleep(sk), &wait); 1329 1330 /* Handle signals */ 1331 if (signal_pending(current)) 1332 return sock_intr_errno(timeo); 1333 } 1334 1335 tls_strp_msg_load(&ctx->strp, released); 1336 1337 return 1; 1338 } 1339 1340 static int tls_setup_from_iter(struct iov_iter *from, 1341 int length, int *pages_used, 1342 struct scatterlist *to, 1343 int to_max_pages) 1344 { 1345 int rc = 0, i = 0, num_elem = *pages_used, maxpages; 1346 struct page *pages[MAX_SKB_FRAGS]; 1347 unsigned int size = 0; 1348 ssize_t copied, use; 1349 size_t offset; 1350 1351 while (length > 0) { 1352 i = 0; 1353 maxpages = to_max_pages - num_elem; 1354 if (maxpages == 0) { 1355 rc = -EFAULT; 1356 goto out; 1357 } 1358 copied = iov_iter_get_pages2(from, pages, 1359 length, 1360 maxpages, &offset); 1361 if (copied <= 0) { 1362 rc = -EFAULT; 1363 goto out; 1364 } 1365 1366 length -= copied; 1367 size += copied; 1368 while (copied) { 1369 use = min_t(int, copied, PAGE_SIZE - offset); 1370 1371 sg_set_page(&to[num_elem], 1372 pages[i], use, offset); 1373 sg_unmark_end(&to[num_elem]); 1374 /* We do not uncharge memory from this API */ 1375 1376 offset = 0; 1377 copied -= use; 1378 1379 i++; 1380 num_elem++; 1381 } 1382 } 1383 /* Mark the end in the last sg entry if newly added */ 1384 if (num_elem > *pages_used) 1385 sg_mark_end(&to[num_elem - 1]); 1386 out: 1387 if (rc) 1388 iov_iter_revert(from, size); 1389 *pages_used = num_elem; 1390 1391 return rc; 1392 } 1393 1394 static struct sk_buff * 1395 tls_alloc_clrtxt_skb(struct sock *sk, struct sk_buff *skb, 1396 unsigned int full_len) 1397 { 1398 struct strp_msg *clr_rxm; 1399 struct sk_buff *clr_skb; 1400 int err; 1401 1402 clr_skb = alloc_skb_with_frags(0, full_len, TLS_PAGE_ORDER, 1403 &err, sk->sk_allocation); 1404 if (!clr_skb) 1405 return NULL; 1406 1407 skb_copy_header(clr_skb, skb); 1408 clr_skb->len = full_len; 1409 clr_skb->data_len = full_len; 1410 1411 clr_rxm = strp_msg(clr_skb); 1412 clr_rxm->offset = 0; 1413 1414 return clr_skb; 1415 } 1416 1417 /* Decrypt handlers 1418 * 1419 * tls_decrypt_sw() and tls_decrypt_device() are decrypt handlers. 1420 * They must transform the darg in/out argument are as follows: 1421 * | Input | Output 1422 * ------------------------------------------------------------------- 1423 * zc | Zero-copy decrypt allowed | Zero-copy performed 1424 * async | Async decrypt allowed | Async crypto used / in progress 1425 * skb | * | Output skb 1426 * 1427 * If ZC decryption was performed darg.skb will point to the input skb. 1428 */ 1429 1430 /* This function decrypts the input skb into either out_iov or in out_sg 1431 * or in skb buffers itself. The input parameter 'darg->zc' indicates if 1432 * zero-copy mode needs to be tried or not. With zero-copy mode, either 1433 * out_iov or out_sg must be non-NULL. In case both out_iov and out_sg are 1434 * NULL, then the decryption happens inside skb buffers itself, i.e. 1435 * zero-copy gets disabled and 'darg->zc' is updated. 1436 */ 1437 static int tls_decrypt_sg(struct sock *sk, struct iov_iter *out_iov, 1438 struct scatterlist *out_sg, 1439 struct tls_decrypt_arg *darg) 1440 { 1441 struct tls_context *tls_ctx = tls_get_ctx(sk); 1442 struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx); 1443 struct tls_prot_info *prot = &tls_ctx->prot_info; 1444 int n_sgin, n_sgout, aead_size, err, pages = 0; 1445 struct sk_buff *skb = tls_strp_msg(ctx); 1446 const struct strp_msg *rxm = strp_msg(skb); 1447 const struct tls_msg *tlm = tls_msg(skb); 1448 struct aead_request *aead_req; 1449 struct scatterlist *sgin = NULL; 1450 struct scatterlist *sgout = NULL; 1451 const int data_len = rxm->full_len - prot->overhead_size; 1452 int tail_pages = !!prot->tail_size; 1453 struct tls_decrypt_ctx *dctx; 1454 struct sk_buff *clear_skb; 1455 int iv_offset = 0; 1456 u8 *mem; 1457 1458 n_sgin = skb_nsg(skb, rxm->offset + prot->prepend_size, 1459 rxm->full_len - prot->prepend_size); 1460 if (n_sgin < 1) 1461 return n_sgin ?: -EBADMSG; 1462 1463 if (darg->zc && (out_iov || out_sg)) { 1464 clear_skb = NULL; 1465 1466 if (out_iov) 1467 n_sgout = 1 + tail_pages + 1468 iov_iter_npages_cap(out_iov, INT_MAX, data_len); 1469 else 1470 n_sgout = sg_nents(out_sg); 1471 } else { 1472 darg->zc = false; 1473 1474 clear_skb = tls_alloc_clrtxt_skb(sk, skb, rxm->full_len); 1475 if (!clear_skb) 1476 return -ENOMEM; 1477 1478 n_sgout = 1 + skb_shinfo(clear_skb)->nr_frags; 1479 } 1480 1481 /* Increment to accommodate AAD */ 1482 n_sgin = n_sgin + 1; 1483 1484 /* Allocate a single block of memory which contains 1485 * aead_req || tls_decrypt_ctx. 1486 * Both structs are variable length. 1487 */ 1488 aead_size = sizeof(*aead_req) + crypto_aead_reqsize(ctx->aead_recv); 1489 mem = kmalloc(aead_size + struct_size(dctx, sg, n_sgin + n_sgout), 1490 sk->sk_allocation); 1491 if (!mem) { 1492 err = -ENOMEM; 1493 goto exit_free_skb; 1494 } 1495 1496 /* Segment the allocated memory */ 1497 aead_req = (struct aead_request *)mem; 1498 dctx = (struct tls_decrypt_ctx *)(mem + aead_size); 1499 sgin = &dctx->sg[0]; 1500 sgout = &dctx->sg[n_sgin]; 1501 1502 /* For CCM based ciphers, first byte of nonce+iv is a constant */ 1503 switch (prot->cipher_type) { 1504 case TLS_CIPHER_AES_CCM_128: 1505 dctx->iv[0] = TLS_AES_CCM_IV_B0_BYTE; 1506 iv_offset = 1; 1507 break; 1508 case TLS_CIPHER_SM4_CCM: 1509 dctx->iv[0] = TLS_SM4_CCM_IV_B0_BYTE; 1510 iv_offset = 1; 1511 break; 1512 } 1513 1514 /* Prepare IV */ 1515 if (prot->version == TLS_1_3_VERSION || 1516 prot->cipher_type == TLS_CIPHER_CHACHA20_POLY1305) { 1517 memcpy(&dctx->iv[iv_offset], tls_ctx->rx.iv, 1518 prot->iv_size + prot->salt_size); 1519 } else { 1520 err = skb_copy_bits(skb, rxm->offset + TLS_HEADER_SIZE, 1521 &dctx->iv[iv_offset] + prot->salt_size, 1522 prot->iv_size); 1523 if (err < 0) 1524 goto exit_free; 1525 memcpy(&dctx->iv[iv_offset], tls_ctx->rx.iv, prot->salt_size); 1526 } 1527 tls_xor_iv_with_seq(prot, &dctx->iv[iv_offset], tls_ctx->rx.rec_seq); 1528 1529 /* Prepare AAD */ 1530 tls_make_aad(dctx->aad, rxm->full_len - prot->overhead_size + 1531 prot->tail_size, 1532 tls_ctx->rx.rec_seq, tlm->control, prot); 1533 1534 /* Prepare sgin */ 1535 sg_init_table(sgin, n_sgin); 1536 sg_set_buf(&sgin[0], dctx->aad, prot->aad_size); 1537 err = skb_to_sgvec(skb, &sgin[1], 1538 rxm->offset + prot->prepend_size, 1539 rxm->full_len - prot->prepend_size); 1540 if (err < 0) 1541 goto exit_free; 1542 1543 if (clear_skb) { 1544 sg_init_table(sgout, n_sgout); 1545 sg_set_buf(&sgout[0], dctx->aad, prot->aad_size); 1546 1547 err = skb_to_sgvec(clear_skb, &sgout[1], prot->prepend_size, 1548 data_len + prot->tail_size); 1549 if (err < 0) 1550 goto exit_free; 1551 } else if (out_iov) { 1552 sg_init_table(sgout, n_sgout); 1553 sg_set_buf(&sgout[0], dctx->aad, prot->aad_size); 1554 1555 err = tls_setup_from_iter(out_iov, data_len, &pages, &sgout[1], 1556 (n_sgout - 1 - tail_pages)); 1557 if (err < 0) 1558 goto exit_free_pages; 1559 1560 if (prot->tail_size) { 1561 sg_unmark_end(&sgout[pages]); 1562 sg_set_buf(&sgout[pages + 1], &dctx->tail, 1563 prot->tail_size); 1564 sg_mark_end(&sgout[pages + 1]); 1565 } 1566 } else if (out_sg) { 1567 memcpy(sgout, out_sg, n_sgout * sizeof(*sgout)); 1568 } 1569 1570 /* Prepare and submit AEAD request */ 1571 err = tls_do_decryption(sk, sgin, sgout, dctx->iv, 1572 data_len + prot->tail_size, aead_req, darg); 1573 if (err) 1574 goto exit_free_pages; 1575 1576 darg->skb = clear_skb ?: tls_strp_msg(ctx); 1577 clear_skb = NULL; 1578 1579 if (unlikely(darg->async)) { 1580 err = tls_strp_msg_hold(&ctx->strp, &ctx->async_hold); 1581 if (err) 1582 __skb_queue_tail(&ctx->async_hold, darg->skb); 1583 return err; 1584 } 1585 1586 if (prot->tail_size) 1587 darg->tail = dctx->tail; 1588 1589 exit_free_pages: 1590 /* Release the pages in case iov was mapped to pages */ 1591 for (; pages > 0; pages--) 1592 put_page(sg_page(&sgout[pages])); 1593 exit_free: 1594 kfree(mem); 1595 exit_free_skb: 1596 consume_skb(clear_skb); 1597 return err; 1598 } 1599 1600 static int 1601 tls_decrypt_sw(struct sock *sk, struct tls_context *tls_ctx, 1602 struct msghdr *msg, struct tls_decrypt_arg *darg) 1603 { 1604 struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx); 1605 struct tls_prot_info *prot = &tls_ctx->prot_info; 1606 struct strp_msg *rxm; 1607 int pad, err; 1608 1609 err = tls_decrypt_sg(sk, &msg->msg_iter, NULL, darg); 1610 if (err < 0) { 1611 if (err == -EBADMSG) 1612 TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSDECRYPTERROR); 1613 return err; 1614 } 1615 /* keep going even for ->async, the code below is TLS 1.3 */ 1616 1617 /* If opportunistic TLS 1.3 ZC failed retry without ZC */ 1618 if (unlikely(darg->zc && prot->version == TLS_1_3_VERSION && 1619 darg->tail != TLS_RECORD_TYPE_DATA)) { 1620 darg->zc = false; 1621 if (!darg->tail) 1622 TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSRXNOPADVIOL); 1623 TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSDECRYPTRETRY); 1624 return tls_decrypt_sw(sk, tls_ctx, msg, darg); 1625 } 1626 1627 pad = tls_padding_length(prot, darg->skb, darg); 1628 if (pad < 0) { 1629 if (darg->skb != tls_strp_msg(ctx)) 1630 consume_skb(darg->skb); 1631 return pad; 1632 } 1633 1634 rxm = strp_msg(darg->skb); 1635 rxm->full_len -= pad; 1636 1637 return 0; 1638 } 1639 1640 static int 1641 tls_decrypt_device(struct sock *sk, struct msghdr *msg, 1642 struct tls_context *tls_ctx, struct tls_decrypt_arg *darg) 1643 { 1644 struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx); 1645 struct tls_prot_info *prot = &tls_ctx->prot_info; 1646 struct strp_msg *rxm; 1647 int pad, err; 1648 1649 if (tls_ctx->rx_conf != TLS_HW) 1650 return 0; 1651 1652 err = tls_device_decrypted(sk, tls_ctx); 1653 if (err <= 0) 1654 return err; 1655 1656 pad = tls_padding_length(prot, tls_strp_msg(ctx), darg); 1657 if (pad < 0) 1658 return pad; 1659 1660 darg->async = false; 1661 darg->skb = tls_strp_msg(ctx); 1662 /* ->zc downgrade check, in case TLS 1.3 gets here */ 1663 darg->zc &= !(prot->version == TLS_1_3_VERSION && 1664 tls_msg(darg->skb)->control != TLS_RECORD_TYPE_DATA); 1665 1666 rxm = strp_msg(darg->skb); 1667 rxm->full_len -= pad; 1668 1669 if (!darg->zc) { 1670 /* Non-ZC case needs a real skb */ 1671 darg->skb = tls_strp_msg_detach(ctx); 1672 if (!darg->skb) 1673 return -ENOMEM; 1674 } else { 1675 unsigned int off, len; 1676 1677 /* In ZC case nobody cares about the output skb. 1678 * Just copy the data here. Note the skb is not fully trimmed. 1679 */ 1680 off = rxm->offset + prot->prepend_size; 1681 len = rxm->full_len - prot->overhead_size; 1682 1683 err = skb_copy_datagram_msg(darg->skb, off, msg, len); 1684 if (err) 1685 return err; 1686 } 1687 return 1; 1688 } 1689 1690 static int tls_rx_one_record(struct sock *sk, struct msghdr *msg, 1691 struct tls_decrypt_arg *darg) 1692 { 1693 struct tls_context *tls_ctx = tls_get_ctx(sk); 1694 struct tls_prot_info *prot = &tls_ctx->prot_info; 1695 struct strp_msg *rxm; 1696 int err; 1697 1698 err = tls_decrypt_device(sk, msg, tls_ctx, darg); 1699 if (!err) 1700 err = tls_decrypt_sw(sk, tls_ctx, msg, darg); 1701 if (err < 0) 1702 return err; 1703 1704 rxm = strp_msg(darg->skb); 1705 rxm->offset += prot->prepend_size; 1706 rxm->full_len -= prot->overhead_size; 1707 tls_advance_record_sn(sk, prot, &tls_ctx->rx); 1708 1709 return 0; 1710 } 1711 1712 int decrypt_skb(struct sock *sk, struct scatterlist *sgout) 1713 { 1714 struct tls_decrypt_arg darg = { .zc = true, }; 1715 1716 return tls_decrypt_sg(sk, NULL, sgout, &darg); 1717 } 1718 1719 static int tls_record_content_type(struct msghdr *msg, struct tls_msg *tlm, 1720 u8 *control) 1721 { 1722 int err; 1723 1724 if (!*control) { 1725 *control = tlm->control; 1726 if (!*control) 1727 return -EBADMSG; 1728 1729 err = put_cmsg(msg, SOL_TLS, TLS_GET_RECORD_TYPE, 1730 sizeof(*control), control); 1731 if (*control != TLS_RECORD_TYPE_DATA) { 1732 if (err || msg->msg_flags & MSG_CTRUNC) 1733 return -EIO; 1734 } 1735 } else if (*control != tlm->control) { 1736 return 0; 1737 } 1738 1739 return 1; 1740 } 1741 1742 static void tls_rx_rec_done(struct tls_sw_context_rx *ctx) 1743 { 1744 tls_strp_msg_done(&ctx->strp); 1745 } 1746 1747 /* This function traverses the rx_list in tls receive context to copies the 1748 * decrypted records into the buffer provided by caller zero copy is not 1749 * true. Further, the records are removed from the rx_list if it is not a peek 1750 * case and the record has been consumed completely. 1751 */ 1752 static int process_rx_list(struct tls_sw_context_rx *ctx, 1753 struct msghdr *msg, 1754 u8 *control, 1755 size_t skip, 1756 size_t len, 1757 bool is_peek) 1758 { 1759 struct sk_buff *skb = skb_peek(&ctx->rx_list); 1760 struct tls_msg *tlm; 1761 ssize_t copied = 0; 1762 int err; 1763 1764 while (skip && skb) { 1765 struct strp_msg *rxm = strp_msg(skb); 1766 tlm = tls_msg(skb); 1767 1768 err = tls_record_content_type(msg, tlm, control); 1769 if (err <= 0) 1770 goto out; 1771 1772 if (skip < rxm->full_len) 1773 break; 1774 1775 skip = skip - rxm->full_len; 1776 skb = skb_peek_next(skb, &ctx->rx_list); 1777 } 1778 1779 while (len && skb) { 1780 struct sk_buff *next_skb; 1781 struct strp_msg *rxm = strp_msg(skb); 1782 int chunk = min_t(unsigned int, rxm->full_len - skip, len); 1783 1784 tlm = tls_msg(skb); 1785 1786 err = tls_record_content_type(msg, tlm, control); 1787 if (err <= 0) 1788 goto out; 1789 1790 err = skb_copy_datagram_msg(skb, rxm->offset + skip, 1791 msg, chunk); 1792 if (err < 0) 1793 goto out; 1794 1795 len = len - chunk; 1796 copied = copied + chunk; 1797 1798 /* Consume the data from record if it is non-peek case*/ 1799 if (!is_peek) { 1800 rxm->offset = rxm->offset + chunk; 1801 rxm->full_len = rxm->full_len - chunk; 1802 1803 /* Return if there is unconsumed data in the record */ 1804 if (rxm->full_len - skip) 1805 break; 1806 } 1807 1808 /* The remaining skip-bytes must lie in 1st record in rx_list. 1809 * So from the 2nd record, 'skip' should be 0. 1810 */ 1811 skip = 0; 1812 1813 if (msg) 1814 msg->msg_flags |= MSG_EOR; 1815 1816 next_skb = skb_peek_next(skb, &ctx->rx_list); 1817 1818 if (!is_peek) { 1819 __skb_unlink(skb, &ctx->rx_list); 1820 consume_skb(skb); 1821 } 1822 1823 skb = next_skb; 1824 } 1825 err = 0; 1826 1827 out: 1828 return copied ? : err; 1829 } 1830 1831 static bool 1832 tls_read_flush_backlog(struct sock *sk, struct tls_prot_info *prot, 1833 size_t len_left, size_t decrypted, ssize_t done, 1834 size_t *flushed_at) 1835 { 1836 size_t max_rec; 1837 1838 if (len_left <= decrypted) 1839 return false; 1840 1841 max_rec = prot->overhead_size - prot->tail_size + TLS_MAX_PAYLOAD_SIZE; 1842 if (done - *flushed_at < SZ_128K && tcp_inq(sk) > max_rec) 1843 return false; 1844 1845 *flushed_at = done; 1846 return sk_flush_backlog(sk); 1847 } 1848 1849 static int tls_rx_reader_lock(struct sock *sk, struct tls_sw_context_rx *ctx, 1850 bool nonblock) 1851 { 1852 long timeo; 1853 int err; 1854 1855 lock_sock(sk); 1856 1857 timeo = sock_rcvtimeo(sk, nonblock); 1858 1859 while (unlikely(ctx->reader_present)) { 1860 DEFINE_WAIT_FUNC(wait, woken_wake_function); 1861 1862 ctx->reader_contended = 1; 1863 1864 add_wait_queue(&ctx->wq, &wait); 1865 sk_wait_event(sk, &timeo, 1866 !READ_ONCE(ctx->reader_present), &wait); 1867 remove_wait_queue(&ctx->wq, &wait); 1868 1869 if (timeo <= 0) { 1870 err = -EAGAIN; 1871 goto err_unlock; 1872 } 1873 if (signal_pending(current)) { 1874 err = sock_intr_errno(timeo); 1875 goto err_unlock; 1876 } 1877 } 1878 1879 WRITE_ONCE(ctx->reader_present, 1); 1880 1881 return 0; 1882 1883 err_unlock: 1884 release_sock(sk); 1885 return err; 1886 } 1887 1888 static void tls_rx_reader_unlock(struct sock *sk, struct tls_sw_context_rx *ctx) 1889 { 1890 if (unlikely(ctx->reader_contended)) { 1891 if (wq_has_sleeper(&ctx->wq)) 1892 wake_up(&ctx->wq); 1893 else 1894 ctx->reader_contended = 0; 1895 1896 WARN_ON_ONCE(!ctx->reader_present); 1897 } 1898 1899 WRITE_ONCE(ctx->reader_present, 0); 1900 release_sock(sk); 1901 } 1902 1903 int tls_sw_recvmsg(struct sock *sk, 1904 struct msghdr *msg, 1905 size_t len, 1906 int flags, 1907 int *addr_len) 1908 { 1909 struct tls_context *tls_ctx = tls_get_ctx(sk); 1910 struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx); 1911 struct tls_prot_info *prot = &tls_ctx->prot_info; 1912 ssize_t decrypted = 0, async_copy_bytes = 0; 1913 struct sk_psock *psock; 1914 unsigned char control = 0; 1915 size_t flushed_at = 0; 1916 struct strp_msg *rxm; 1917 struct tls_msg *tlm; 1918 ssize_t copied = 0; 1919 bool async = false; 1920 int target, err; 1921 bool is_kvec = iov_iter_is_kvec(&msg->msg_iter); 1922 bool is_peek = flags & MSG_PEEK; 1923 bool released = true; 1924 bool bpf_strp_enabled; 1925 bool zc_capable; 1926 1927 if (unlikely(flags & MSG_ERRQUEUE)) 1928 return sock_recv_errqueue(sk, msg, len, SOL_IP, IP_RECVERR); 1929 1930 psock = sk_psock_get(sk); 1931 err = tls_rx_reader_lock(sk, ctx, flags & MSG_DONTWAIT); 1932 if (err < 0) 1933 return err; 1934 bpf_strp_enabled = sk_psock_strp_enabled(psock); 1935 1936 /* If crypto failed the connection is broken */ 1937 err = ctx->async_wait.err; 1938 if (err) 1939 goto end; 1940 1941 /* Process pending decrypted records. It must be non-zero-copy */ 1942 err = process_rx_list(ctx, msg, &control, 0, len, is_peek); 1943 if (err < 0) 1944 goto end; 1945 1946 copied = err; 1947 if (len <= copied) 1948 goto end; 1949 1950 target = sock_rcvlowat(sk, flags & MSG_WAITALL, len); 1951 len = len - copied; 1952 1953 zc_capable = !bpf_strp_enabled && !is_kvec && !is_peek && 1954 ctx->zc_capable; 1955 decrypted = 0; 1956 while (len && (decrypted + copied < target || tls_strp_msg_ready(ctx))) { 1957 struct tls_decrypt_arg darg; 1958 int to_decrypt, chunk; 1959 1960 err = tls_rx_rec_wait(sk, psock, flags & MSG_DONTWAIT, 1961 released); 1962 if (err <= 0) { 1963 if (psock) { 1964 chunk = sk_msg_recvmsg(sk, psock, msg, len, 1965 flags); 1966 if (chunk > 0) { 1967 decrypted += chunk; 1968 len -= chunk; 1969 continue; 1970 } 1971 } 1972 goto recv_end; 1973 } 1974 1975 memset(&darg.inargs, 0, sizeof(darg.inargs)); 1976 1977 rxm = strp_msg(tls_strp_msg(ctx)); 1978 tlm = tls_msg(tls_strp_msg(ctx)); 1979 1980 to_decrypt = rxm->full_len - prot->overhead_size; 1981 1982 if (zc_capable && to_decrypt <= len && 1983 tlm->control == TLS_RECORD_TYPE_DATA) 1984 darg.zc = true; 1985 1986 /* Do not use async mode if record is non-data */ 1987 if (tlm->control == TLS_RECORD_TYPE_DATA && !bpf_strp_enabled) 1988 darg.async = ctx->async_capable; 1989 else 1990 darg.async = false; 1991 1992 err = tls_rx_one_record(sk, msg, &darg); 1993 if (err < 0) { 1994 tls_err_abort(sk, -EBADMSG); 1995 goto recv_end; 1996 } 1997 1998 async |= darg.async; 1999 2000 /* If the type of records being processed is not known yet, 2001 * set it to record type just dequeued. If it is already known, 2002 * but does not match the record type just dequeued, go to end. 2003 * We always get record type here since for tls1.2, record type 2004 * is known just after record is dequeued from stream parser. 2005 * For tls1.3, we disable async. 2006 */ 2007 err = tls_record_content_type(msg, tls_msg(darg.skb), &control); 2008 if (err <= 0) { 2009 DEBUG_NET_WARN_ON_ONCE(darg.zc); 2010 tls_rx_rec_done(ctx); 2011 put_on_rx_list_err: 2012 __skb_queue_tail(&ctx->rx_list, darg.skb); 2013 goto recv_end; 2014 } 2015 2016 /* periodically flush backlog, and feed strparser */ 2017 released = tls_read_flush_backlog(sk, prot, len, to_decrypt, 2018 decrypted + copied, 2019 &flushed_at); 2020 2021 /* TLS 1.3 may have updated the length by more than overhead */ 2022 rxm = strp_msg(darg.skb); 2023 chunk = rxm->full_len; 2024 tls_rx_rec_done(ctx); 2025 2026 if (!darg.zc) { 2027 bool partially_consumed = chunk > len; 2028 struct sk_buff *skb = darg.skb; 2029 2030 DEBUG_NET_WARN_ON_ONCE(darg.skb == ctx->strp.anchor); 2031 2032 if (async) { 2033 /* TLS 1.2-only, to_decrypt must be text len */ 2034 chunk = min_t(int, to_decrypt, len); 2035 async_copy_bytes += chunk; 2036 put_on_rx_list: 2037 decrypted += chunk; 2038 len -= chunk; 2039 __skb_queue_tail(&ctx->rx_list, skb); 2040 continue; 2041 } 2042 2043 if (bpf_strp_enabled) { 2044 released = true; 2045 err = sk_psock_tls_strp_read(psock, skb); 2046 if (err != __SK_PASS) { 2047 rxm->offset = rxm->offset + rxm->full_len; 2048 rxm->full_len = 0; 2049 if (err == __SK_DROP) 2050 consume_skb(skb); 2051 continue; 2052 } 2053 } 2054 2055 if (partially_consumed) 2056 chunk = len; 2057 2058 err = skb_copy_datagram_msg(skb, rxm->offset, 2059 msg, chunk); 2060 if (err < 0) 2061 goto put_on_rx_list_err; 2062 2063 if (is_peek) 2064 goto put_on_rx_list; 2065 2066 if (partially_consumed) { 2067 rxm->offset += chunk; 2068 rxm->full_len -= chunk; 2069 goto put_on_rx_list; 2070 } 2071 2072 consume_skb(skb); 2073 } 2074 2075 decrypted += chunk; 2076 len -= chunk; 2077 2078 /* Return full control message to userspace before trying 2079 * to parse another message type 2080 */ 2081 msg->msg_flags |= MSG_EOR; 2082 if (control != TLS_RECORD_TYPE_DATA) 2083 break; 2084 } 2085 2086 recv_end: 2087 if (async) { 2088 int ret, pending; 2089 2090 /* Wait for all previously submitted records to be decrypted */ 2091 spin_lock_bh(&ctx->decrypt_compl_lock); 2092 reinit_completion(&ctx->async_wait.completion); 2093 pending = atomic_read(&ctx->decrypt_pending); 2094 spin_unlock_bh(&ctx->decrypt_compl_lock); 2095 ret = 0; 2096 if (pending) 2097 ret = crypto_wait_req(-EINPROGRESS, &ctx->async_wait); 2098 __skb_queue_purge(&ctx->async_hold); 2099 2100 if (ret) { 2101 if (err >= 0 || err == -EINPROGRESS) 2102 err = ret; 2103 decrypted = 0; 2104 goto end; 2105 } 2106 2107 /* Drain records from the rx_list & copy if required */ 2108 if (is_peek || is_kvec) 2109 err = process_rx_list(ctx, msg, &control, copied, 2110 decrypted, is_peek); 2111 else 2112 err = process_rx_list(ctx, msg, &control, 0, 2113 async_copy_bytes, is_peek); 2114 decrypted = max(err, 0); 2115 } 2116 2117 copied += decrypted; 2118 2119 end: 2120 tls_rx_reader_unlock(sk, ctx); 2121 if (psock) 2122 sk_psock_put(sk, psock); 2123 return copied ? : err; 2124 } 2125 2126 ssize_t tls_sw_splice_read(struct socket *sock, loff_t *ppos, 2127 struct pipe_inode_info *pipe, 2128 size_t len, unsigned int flags) 2129 { 2130 struct tls_context *tls_ctx = tls_get_ctx(sock->sk); 2131 struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx); 2132 struct strp_msg *rxm = NULL; 2133 struct sock *sk = sock->sk; 2134 struct tls_msg *tlm; 2135 struct sk_buff *skb; 2136 ssize_t copied = 0; 2137 int chunk; 2138 int err; 2139 2140 err = tls_rx_reader_lock(sk, ctx, flags & SPLICE_F_NONBLOCK); 2141 if (err < 0) 2142 return err; 2143 2144 if (!skb_queue_empty(&ctx->rx_list)) { 2145 skb = __skb_dequeue(&ctx->rx_list); 2146 } else { 2147 struct tls_decrypt_arg darg; 2148 2149 err = tls_rx_rec_wait(sk, NULL, flags & SPLICE_F_NONBLOCK, 2150 true); 2151 if (err <= 0) 2152 goto splice_read_end; 2153 2154 memset(&darg.inargs, 0, sizeof(darg.inargs)); 2155 2156 err = tls_rx_one_record(sk, NULL, &darg); 2157 if (err < 0) { 2158 tls_err_abort(sk, -EBADMSG); 2159 goto splice_read_end; 2160 } 2161 2162 tls_rx_rec_done(ctx); 2163 skb = darg.skb; 2164 } 2165 2166 rxm = strp_msg(skb); 2167 tlm = tls_msg(skb); 2168 2169 /* splice does not support reading control messages */ 2170 if (tlm->control != TLS_RECORD_TYPE_DATA) { 2171 err = -EINVAL; 2172 goto splice_requeue; 2173 } 2174 2175 chunk = min_t(unsigned int, rxm->full_len, len); 2176 copied = skb_splice_bits(skb, sk, rxm->offset, pipe, chunk, flags); 2177 if (copied < 0) 2178 goto splice_requeue; 2179 2180 if (chunk < rxm->full_len) { 2181 rxm->offset += len; 2182 rxm->full_len -= len; 2183 goto splice_requeue; 2184 } 2185 2186 consume_skb(skb); 2187 2188 splice_read_end: 2189 tls_rx_reader_unlock(sk, ctx); 2190 return copied ? : err; 2191 2192 splice_requeue: 2193 __skb_queue_head(&ctx->rx_list, skb); 2194 goto splice_read_end; 2195 } 2196 2197 bool tls_sw_sock_is_readable(struct sock *sk) 2198 { 2199 struct tls_context *tls_ctx = tls_get_ctx(sk); 2200 struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx); 2201 bool ingress_empty = true; 2202 struct sk_psock *psock; 2203 2204 rcu_read_lock(); 2205 psock = sk_psock(sk); 2206 if (psock) 2207 ingress_empty = list_empty(&psock->ingress_msg); 2208 rcu_read_unlock(); 2209 2210 return !ingress_empty || tls_strp_msg_ready(ctx) || 2211 !skb_queue_empty(&ctx->rx_list); 2212 } 2213 2214 int tls_rx_msg_size(struct tls_strparser *strp, struct sk_buff *skb) 2215 { 2216 struct tls_context *tls_ctx = tls_get_ctx(strp->sk); 2217 struct tls_prot_info *prot = &tls_ctx->prot_info; 2218 char header[TLS_HEADER_SIZE + MAX_IV_SIZE]; 2219 size_t cipher_overhead; 2220 size_t data_len = 0; 2221 int ret; 2222 2223 /* Verify that we have a full TLS header, or wait for more data */ 2224 if (strp->stm.offset + prot->prepend_size > skb->len) 2225 return 0; 2226 2227 /* Sanity-check size of on-stack buffer. */ 2228 if (WARN_ON(prot->prepend_size > sizeof(header))) { 2229 ret = -EINVAL; 2230 goto read_failure; 2231 } 2232 2233 /* Linearize header to local buffer */ 2234 ret = skb_copy_bits(skb, strp->stm.offset, header, prot->prepend_size); 2235 if (ret < 0) 2236 goto read_failure; 2237 2238 strp->mark = header[0]; 2239 2240 data_len = ((header[4] & 0xFF) | (header[3] << 8)); 2241 2242 cipher_overhead = prot->tag_size; 2243 if (prot->version != TLS_1_3_VERSION && 2244 prot->cipher_type != TLS_CIPHER_CHACHA20_POLY1305) 2245 cipher_overhead += prot->iv_size; 2246 2247 if (data_len > TLS_MAX_PAYLOAD_SIZE + cipher_overhead + 2248 prot->tail_size) { 2249 ret = -EMSGSIZE; 2250 goto read_failure; 2251 } 2252 if (data_len < cipher_overhead) { 2253 ret = -EBADMSG; 2254 goto read_failure; 2255 } 2256 2257 /* Note that both TLS1.3 and TLS1.2 use TLS_1_2 version here */ 2258 if (header[1] != TLS_1_2_VERSION_MINOR || 2259 header[2] != TLS_1_2_VERSION_MAJOR) { 2260 ret = -EINVAL; 2261 goto read_failure; 2262 } 2263 2264 tls_device_rx_resync_new_rec(strp->sk, data_len + TLS_HEADER_SIZE, 2265 TCP_SKB_CB(skb)->seq + strp->stm.offset); 2266 return data_len + TLS_HEADER_SIZE; 2267 2268 read_failure: 2269 tls_err_abort(strp->sk, ret); 2270 2271 return ret; 2272 } 2273 2274 void tls_rx_msg_ready(struct tls_strparser *strp) 2275 { 2276 struct tls_sw_context_rx *ctx; 2277 2278 ctx = container_of(strp, struct tls_sw_context_rx, strp); 2279 ctx->saved_data_ready(strp->sk); 2280 } 2281 2282 static void tls_data_ready(struct sock *sk) 2283 { 2284 struct tls_context *tls_ctx = tls_get_ctx(sk); 2285 struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx); 2286 struct sk_psock *psock; 2287 2288 trace_sk_data_ready(sk); 2289 2290 tls_strp_data_ready(&ctx->strp); 2291 2292 psock = sk_psock_get(sk); 2293 if (psock) { 2294 if (!list_empty(&psock->ingress_msg)) 2295 ctx->saved_data_ready(sk); 2296 sk_psock_put(sk, psock); 2297 } 2298 } 2299 2300 void tls_sw_cancel_work_tx(struct tls_context *tls_ctx) 2301 { 2302 struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx); 2303 2304 set_bit(BIT_TX_CLOSING, &ctx->tx_bitmask); 2305 set_bit(BIT_TX_SCHEDULED, &ctx->tx_bitmask); 2306 cancel_delayed_work_sync(&ctx->tx_work.work); 2307 } 2308 2309 void tls_sw_release_resources_tx(struct sock *sk) 2310 { 2311 struct tls_context *tls_ctx = tls_get_ctx(sk); 2312 struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx); 2313 struct tls_rec *rec, *tmp; 2314 int pending; 2315 2316 /* Wait for any pending async encryptions to complete */ 2317 spin_lock_bh(&ctx->encrypt_compl_lock); 2318 ctx->async_notify = true; 2319 pending = atomic_read(&ctx->encrypt_pending); 2320 spin_unlock_bh(&ctx->encrypt_compl_lock); 2321 2322 if (pending) 2323 crypto_wait_req(-EINPROGRESS, &ctx->async_wait); 2324 2325 tls_tx_records(sk, -1); 2326 2327 /* Free up un-sent records in tx_list. First, free 2328 * the partially sent record if any at head of tx_list. 2329 */ 2330 if (tls_ctx->partially_sent_record) { 2331 tls_free_partial_record(sk, tls_ctx); 2332 rec = list_first_entry(&ctx->tx_list, 2333 struct tls_rec, list); 2334 list_del(&rec->list); 2335 sk_msg_free(sk, &rec->msg_plaintext); 2336 kfree(rec); 2337 } 2338 2339 list_for_each_entry_safe(rec, tmp, &ctx->tx_list, list) { 2340 list_del(&rec->list); 2341 sk_msg_free(sk, &rec->msg_encrypted); 2342 sk_msg_free(sk, &rec->msg_plaintext); 2343 kfree(rec); 2344 } 2345 2346 crypto_free_aead(ctx->aead_send); 2347 tls_free_open_rec(sk); 2348 } 2349 2350 void tls_sw_free_ctx_tx(struct tls_context *tls_ctx) 2351 { 2352 struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx); 2353 2354 kfree(ctx); 2355 } 2356 2357 void tls_sw_release_resources_rx(struct sock *sk) 2358 { 2359 struct tls_context *tls_ctx = tls_get_ctx(sk); 2360 struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx); 2361 2362 kfree(tls_ctx->rx.rec_seq); 2363 kfree(tls_ctx->rx.iv); 2364 2365 if (ctx->aead_recv) { 2366 __skb_queue_purge(&ctx->rx_list); 2367 crypto_free_aead(ctx->aead_recv); 2368 tls_strp_stop(&ctx->strp); 2369 /* If tls_sw_strparser_arm() was not called (cleanup paths) 2370 * we still want to tls_strp_stop(), but sk->sk_data_ready was 2371 * never swapped. 2372 */ 2373 if (ctx->saved_data_ready) { 2374 write_lock_bh(&sk->sk_callback_lock); 2375 sk->sk_data_ready = ctx->saved_data_ready; 2376 write_unlock_bh(&sk->sk_callback_lock); 2377 } 2378 } 2379 } 2380 2381 void tls_sw_strparser_done(struct tls_context *tls_ctx) 2382 { 2383 struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx); 2384 2385 tls_strp_done(&ctx->strp); 2386 } 2387 2388 void tls_sw_free_ctx_rx(struct tls_context *tls_ctx) 2389 { 2390 struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx); 2391 2392 kfree(ctx); 2393 } 2394 2395 void tls_sw_free_resources_rx(struct sock *sk) 2396 { 2397 struct tls_context *tls_ctx = tls_get_ctx(sk); 2398 2399 tls_sw_release_resources_rx(sk); 2400 tls_sw_free_ctx_rx(tls_ctx); 2401 } 2402 2403 /* The work handler to transmitt the encrypted records in tx_list */ 2404 static void tx_work_handler(struct work_struct *work) 2405 { 2406 struct delayed_work *delayed_work = to_delayed_work(work); 2407 struct tx_work *tx_work = container_of(delayed_work, 2408 struct tx_work, work); 2409 struct sock *sk = tx_work->sk; 2410 struct tls_context *tls_ctx = tls_get_ctx(sk); 2411 struct tls_sw_context_tx *ctx; 2412 2413 if (unlikely(!tls_ctx)) 2414 return; 2415 2416 ctx = tls_sw_ctx_tx(tls_ctx); 2417 if (test_bit(BIT_TX_CLOSING, &ctx->tx_bitmask)) 2418 return; 2419 2420 if (!test_and_clear_bit(BIT_TX_SCHEDULED, &ctx->tx_bitmask)) 2421 return; 2422 mutex_lock(&tls_ctx->tx_lock); 2423 lock_sock(sk); 2424 tls_tx_records(sk, -1); 2425 release_sock(sk); 2426 mutex_unlock(&tls_ctx->tx_lock); 2427 } 2428 2429 static bool tls_is_tx_ready(struct tls_sw_context_tx *ctx) 2430 { 2431 struct tls_rec *rec; 2432 2433 rec = list_first_entry_or_null(&ctx->tx_list, struct tls_rec, list); 2434 if (!rec) 2435 return false; 2436 2437 return READ_ONCE(rec->tx_ready); 2438 } 2439 2440 void tls_sw_write_space(struct sock *sk, struct tls_context *ctx) 2441 { 2442 struct tls_sw_context_tx *tx_ctx = tls_sw_ctx_tx(ctx); 2443 2444 /* Schedule the transmission if tx list is ready */ 2445 if (tls_is_tx_ready(tx_ctx) && 2446 !test_and_set_bit(BIT_TX_SCHEDULED, &tx_ctx->tx_bitmask)) 2447 schedule_delayed_work(&tx_ctx->tx_work.work, 0); 2448 } 2449 2450 void tls_sw_strparser_arm(struct sock *sk, struct tls_context *tls_ctx) 2451 { 2452 struct tls_sw_context_rx *rx_ctx = tls_sw_ctx_rx(tls_ctx); 2453 2454 write_lock_bh(&sk->sk_callback_lock); 2455 rx_ctx->saved_data_ready = sk->sk_data_ready; 2456 sk->sk_data_ready = tls_data_ready; 2457 write_unlock_bh(&sk->sk_callback_lock); 2458 } 2459 2460 void tls_update_rx_zc_capable(struct tls_context *tls_ctx) 2461 { 2462 struct tls_sw_context_rx *rx_ctx = tls_sw_ctx_rx(tls_ctx); 2463 2464 rx_ctx->zc_capable = tls_ctx->rx_no_pad || 2465 tls_ctx->prot_info.version != TLS_1_3_VERSION; 2466 } 2467 2468 int tls_set_sw_offload(struct sock *sk, struct tls_context *ctx, int tx) 2469 { 2470 struct tls_context *tls_ctx = tls_get_ctx(sk); 2471 struct tls_prot_info *prot = &tls_ctx->prot_info; 2472 struct tls_crypto_info *crypto_info; 2473 struct tls_sw_context_tx *sw_ctx_tx = NULL; 2474 struct tls_sw_context_rx *sw_ctx_rx = NULL; 2475 struct cipher_context *cctx; 2476 struct crypto_aead **aead; 2477 u16 nonce_size, tag_size, iv_size, rec_seq_size, salt_size; 2478 struct crypto_tfm *tfm; 2479 char *iv, *rec_seq, *key, *salt, *cipher_name; 2480 size_t keysize; 2481 int rc = 0; 2482 2483 if (!ctx) { 2484 rc = -EINVAL; 2485 goto out; 2486 } 2487 2488 if (tx) { 2489 if (!ctx->priv_ctx_tx) { 2490 sw_ctx_tx = kzalloc(sizeof(*sw_ctx_tx), GFP_KERNEL); 2491 if (!sw_ctx_tx) { 2492 rc = -ENOMEM; 2493 goto out; 2494 } 2495 ctx->priv_ctx_tx = sw_ctx_tx; 2496 } else { 2497 sw_ctx_tx = 2498 (struct tls_sw_context_tx *)ctx->priv_ctx_tx; 2499 } 2500 } else { 2501 if (!ctx->priv_ctx_rx) { 2502 sw_ctx_rx = kzalloc(sizeof(*sw_ctx_rx), GFP_KERNEL); 2503 if (!sw_ctx_rx) { 2504 rc = -ENOMEM; 2505 goto out; 2506 } 2507 ctx->priv_ctx_rx = sw_ctx_rx; 2508 } else { 2509 sw_ctx_rx = 2510 (struct tls_sw_context_rx *)ctx->priv_ctx_rx; 2511 } 2512 } 2513 2514 if (tx) { 2515 crypto_init_wait(&sw_ctx_tx->async_wait); 2516 spin_lock_init(&sw_ctx_tx->encrypt_compl_lock); 2517 crypto_info = &ctx->crypto_send.info; 2518 cctx = &ctx->tx; 2519 aead = &sw_ctx_tx->aead_send; 2520 INIT_LIST_HEAD(&sw_ctx_tx->tx_list); 2521 INIT_DELAYED_WORK(&sw_ctx_tx->tx_work.work, tx_work_handler); 2522 sw_ctx_tx->tx_work.sk = sk; 2523 } else { 2524 crypto_init_wait(&sw_ctx_rx->async_wait); 2525 spin_lock_init(&sw_ctx_rx->decrypt_compl_lock); 2526 init_waitqueue_head(&sw_ctx_rx->wq); 2527 crypto_info = &ctx->crypto_recv.info; 2528 cctx = &ctx->rx; 2529 skb_queue_head_init(&sw_ctx_rx->rx_list); 2530 skb_queue_head_init(&sw_ctx_rx->async_hold); 2531 aead = &sw_ctx_rx->aead_recv; 2532 } 2533 2534 switch (crypto_info->cipher_type) { 2535 case TLS_CIPHER_AES_GCM_128: { 2536 struct tls12_crypto_info_aes_gcm_128 *gcm_128_info; 2537 2538 gcm_128_info = (void *)crypto_info; 2539 nonce_size = TLS_CIPHER_AES_GCM_128_IV_SIZE; 2540 tag_size = TLS_CIPHER_AES_GCM_128_TAG_SIZE; 2541 iv_size = TLS_CIPHER_AES_GCM_128_IV_SIZE; 2542 iv = gcm_128_info->iv; 2543 rec_seq_size = TLS_CIPHER_AES_GCM_128_REC_SEQ_SIZE; 2544 rec_seq = gcm_128_info->rec_seq; 2545 keysize = TLS_CIPHER_AES_GCM_128_KEY_SIZE; 2546 key = gcm_128_info->key; 2547 salt = gcm_128_info->salt; 2548 salt_size = TLS_CIPHER_AES_GCM_128_SALT_SIZE; 2549 cipher_name = "gcm(aes)"; 2550 break; 2551 } 2552 case TLS_CIPHER_AES_GCM_256: { 2553 struct tls12_crypto_info_aes_gcm_256 *gcm_256_info; 2554 2555 gcm_256_info = (void *)crypto_info; 2556 nonce_size = TLS_CIPHER_AES_GCM_256_IV_SIZE; 2557 tag_size = TLS_CIPHER_AES_GCM_256_TAG_SIZE; 2558 iv_size = TLS_CIPHER_AES_GCM_256_IV_SIZE; 2559 iv = gcm_256_info->iv; 2560 rec_seq_size = TLS_CIPHER_AES_GCM_256_REC_SEQ_SIZE; 2561 rec_seq = gcm_256_info->rec_seq; 2562 keysize = TLS_CIPHER_AES_GCM_256_KEY_SIZE; 2563 key = gcm_256_info->key; 2564 salt = gcm_256_info->salt; 2565 salt_size = TLS_CIPHER_AES_GCM_256_SALT_SIZE; 2566 cipher_name = "gcm(aes)"; 2567 break; 2568 } 2569 case TLS_CIPHER_AES_CCM_128: { 2570 struct tls12_crypto_info_aes_ccm_128 *ccm_128_info; 2571 2572 ccm_128_info = (void *)crypto_info; 2573 nonce_size = TLS_CIPHER_AES_CCM_128_IV_SIZE; 2574 tag_size = TLS_CIPHER_AES_CCM_128_TAG_SIZE; 2575 iv_size = TLS_CIPHER_AES_CCM_128_IV_SIZE; 2576 iv = ccm_128_info->iv; 2577 rec_seq_size = TLS_CIPHER_AES_CCM_128_REC_SEQ_SIZE; 2578 rec_seq = ccm_128_info->rec_seq; 2579 keysize = TLS_CIPHER_AES_CCM_128_KEY_SIZE; 2580 key = ccm_128_info->key; 2581 salt = ccm_128_info->salt; 2582 salt_size = TLS_CIPHER_AES_CCM_128_SALT_SIZE; 2583 cipher_name = "ccm(aes)"; 2584 break; 2585 } 2586 case TLS_CIPHER_CHACHA20_POLY1305: { 2587 struct tls12_crypto_info_chacha20_poly1305 *chacha20_poly1305_info; 2588 2589 chacha20_poly1305_info = (void *)crypto_info; 2590 nonce_size = 0; 2591 tag_size = TLS_CIPHER_CHACHA20_POLY1305_TAG_SIZE; 2592 iv_size = TLS_CIPHER_CHACHA20_POLY1305_IV_SIZE; 2593 iv = chacha20_poly1305_info->iv; 2594 rec_seq_size = TLS_CIPHER_CHACHA20_POLY1305_REC_SEQ_SIZE; 2595 rec_seq = chacha20_poly1305_info->rec_seq; 2596 keysize = TLS_CIPHER_CHACHA20_POLY1305_KEY_SIZE; 2597 key = chacha20_poly1305_info->key; 2598 salt = chacha20_poly1305_info->salt; 2599 salt_size = TLS_CIPHER_CHACHA20_POLY1305_SALT_SIZE; 2600 cipher_name = "rfc7539(chacha20,poly1305)"; 2601 break; 2602 } 2603 case TLS_CIPHER_SM4_GCM: { 2604 struct tls12_crypto_info_sm4_gcm *sm4_gcm_info; 2605 2606 sm4_gcm_info = (void *)crypto_info; 2607 nonce_size = TLS_CIPHER_SM4_GCM_IV_SIZE; 2608 tag_size = TLS_CIPHER_SM4_GCM_TAG_SIZE; 2609 iv_size = TLS_CIPHER_SM4_GCM_IV_SIZE; 2610 iv = sm4_gcm_info->iv; 2611 rec_seq_size = TLS_CIPHER_SM4_GCM_REC_SEQ_SIZE; 2612 rec_seq = sm4_gcm_info->rec_seq; 2613 keysize = TLS_CIPHER_SM4_GCM_KEY_SIZE; 2614 key = sm4_gcm_info->key; 2615 salt = sm4_gcm_info->salt; 2616 salt_size = TLS_CIPHER_SM4_GCM_SALT_SIZE; 2617 cipher_name = "gcm(sm4)"; 2618 break; 2619 } 2620 case TLS_CIPHER_SM4_CCM: { 2621 struct tls12_crypto_info_sm4_ccm *sm4_ccm_info; 2622 2623 sm4_ccm_info = (void *)crypto_info; 2624 nonce_size = TLS_CIPHER_SM4_CCM_IV_SIZE; 2625 tag_size = TLS_CIPHER_SM4_CCM_TAG_SIZE; 2626 iv_size = TLS_CIPHER_SM4_CCM_IV_SIZE; 2627 iv = sm4_ccm_info->iv; 2628 rec_seq_size = TLS_CIPHER_SM4_CCM_REC_SEQ_SIZE; 2629 rec_seq = sm4_ccm_info->rec_seq; 2630 keysize = TLS_CIPHER_SM4_CCM_KEY_SIZE; 2631 key = sm4_ccm_info->key; 2632 salt = sm4_ccm_info->salt; 2633 salt_size = TLS_CIPHER_SM4_CCM_SALT_SIZE; 2634 cipher_name = "ccm(sm4)"; 2635 break; 2636 } 2637 case TLS_CIPHER_ARIA_GCM_128: { 2638 struct tls12_crypto_info_aria_gcm_128 *aria_gcm_128_info; 2639 2640 aria_gcm_128_info = (void *)crypto_info; 2641 nonce_size = TLS_CIPHER_ARIA_GCM_128_IV_SIZE; 2642 tag_size = TLS_CIPHER_ARIA_GCM_128_TAG_SIZE; 2643 iv_size = TLS_CIPHER_ARIA_GCM_128_IV_SIZE; 2644 iv = aria_gcm_128_info->iv; 2645 rec_seq_size = TLS_CIPHER_ARIA_GCM_128_REC_SEQ_SIZE; 2646 rec_seq = aria_gcm_128_info->rec_seq; 2647 keysize = TLS_CIPHER_ARIA_GCM_128_KEY_SIZE; 2648 key = aria_gcm_128_info->key; 2649 salt = aria_gcm_128_info->salt; 2650 salt_size = TLS_CIPHER_ARIA_GCM_128_SALT_SIZE; 2651 cipher_name = "gcm(aria)"; 2652 break; 2653 } 2654 case TLS_CIPHER_ARIA_GCM_256: { 2655 struct tls12_crypto_info_aria_gcm_256 *gcm_256_info; 2656 2657 gcm_256_info = (void *)crypto_info; 2658 nonce_size = TLS_CIPHER_ARIA_GCM_256_IV_SIZE; 2659 tag_size = TLS_CIPHER_ARIA_GCM_256_TAG_SIZE; 2660 iv_size = TLS_CIPHER_ARIA_GCM_256_IV_SIZE; 2661 iv = gcm_256_info->iv; 2662 rec_seq_size = TLS_CIPHER_ARIA_GCM_256_REC_SEQ_SIZE; 2663 rec_seq = gcm_256_info->rec_seq; 2664 keysize = TLS_CIPHER_ARIA_GCM_256_KEY_SIZE; 2665 key = gcm_256_info->key; 2666 salt = gcm_256_info->salt; 2667 salt_size = TLS_CIPHER_ARIA_GCM_256_SALT_SIZE; 2668 cipher_name = "gcm(aria)"; 2669 break; 2670 } 2671 default: 2672 rc = -EINVAL; 2673 goto free_priv; 2674 } 2675 2676 if (crypto_info->version == TLS_1_3_VERSION) { 2677 nonce_size = 0; 2678 prot->aad_size = TLS_HEADER_SIZE; 2679 prot->tail_size = 1; 2680 } else { 2681 prot->aad_size = TLS_AAD_SPACE_SIZE; 2682 prot->tail_size = 0; 2683 } 2684 2685 /* Sanity-check the sizes for stack allocations. */ 2686 if (iv_size > MAX_IV_SIZE || nonce_size > MAX_IV_SIZE || 2687 rec_seq_size > TLS_MAX_REC_SEQ_SIZE || tag_size != TLS_TAG_SIZE || 2688 prot->aad_size > TLS_MAX_AAD_SIZE) { 2689 rc = -EINVAL; 2690 goto free_priv; 2691 } 2692 2693 prot->version = crypto_info->version; 2694 prot->cipher_type = crypto_info->cipher_type; 2695 prot->prepend_size = TLS_HEADER_SIZE + nonce_size; 2696 prot->tag_size = tag_size; 2697 prot->overhead_size = prot->prepend_size + 2698 prot->tag_size + prot->tail_size; 2699 prot->iv_size = iv_size; 2700 prot->salt_size = salt_size; 2701 cctx->iv = kmalloc(iv_size + salt_size, GFP_KERNEL); 2702 if (!cctx->iv) { 2703 rc = -ENOMEM; 2704 goto free_priv; 2705 } 2706 /* Note: 128 & 256 bit salt are the same size */ 2707 prot->rec_seq_size = rec_seq_size; 2708 memcpy(cctx->iv, salt, salt_size); 2709 memcpy(cctx->iv + salt_size, iv, iv_size); 2710 cctx->rec_seq = kmemdup(rec_seq, rec_seq_size, GFP_KERNEL); 2711 if (!cctx->rec_seq) { 2712 rc = -ENOMEM; 2713 goto free_iv; 2714 } 2715 2716 if (!*aead) { 2717 *aead = crypto_alloc_aead(cipher_name, 0, 0); 2718 if (IS_ERR(*aead)) { 2719 rc = PTR_ERR(*aead); 2720 *aead = NULL; 2721 goto free_rec_seq; 2722 } 2723 } 2724 2725 ctx->push_pending_record = tls_sw_push_pending_record; 2726 2727 rc = crypto_aead_setkey(*aead, key, keysize); 2728 2729 if (rc) 2730 goto free_aead; 2731 2732 rc = crypto_aead_setauthsize(*aead, prot->tag_size); 2733 if (rc) 2734 goto free_aead; 2735 2736 if (sw_ctx_rx) { 2737 tfm = crypto_aead_tfm(sw_ctx_rx->aead_recv); 2738 2739 tls_update_rx_zc_capable(ctx); 2740 sw_ctx_rx->async_capable = 2741 crypto_info->version != TLS_1_3_VERSION && 2742 !!(tfm->__crt_alg->cra_flags & CRYPTO_ALG_ASYNC); 2743 2744 rc = tls_strp_init(&sw_ctx_rx->strp, sk); 2745 if (rc) 2746 goto free_aead; 2747 } 2748 2749 goto out; 2750 2751 free_aead: 2752 crypto_free_aead(*aead); 2753 *aead = NULL; 2754 free_rec_seq: 2755 kfree(cctx->rec_seq); 2756 cctx->rec_seq = NULL; 2757 free_iv: 2758 kfree(cctx->iv); 2759 cctx->iv = NULL; 2760 free_priv: 2761 if (tx) { 2762 kfree(ctx->priv_ctx_tx); 2763 ctx->priv_ctx_tx = NULL; 2764 } else { 2765 kfree(ctx->priv_ctx_rx); 2766 ctx->priv_ctx_rx = NULL; 2767 } 2768 out: 2769 return rc; 2770 } 2771