1 // SPDX-License-Identifier: GPL-2.0 2 /* Copyright (c) 2017 - 2018 Covalent IO, Inc. http://covalent.io */ 3 4 #include <linux/skmsg.h> 5 #include <linux/skbuff.h> 6 #include <linux/scatterlist.h> 7 8 #include <net/sock.h> 9 #include <net/tcp.h> 10 11 static bool sk_msg_try_coalesce_ok(struct sk_msg *msg, int elem_first_coalesce) 12 { 13 if (msg->sg.end > msg->sg.start && 14 elem_first_coalesce < msg->sg.end) 15 return true; 16 17 if (msg->sg.end < msg->sg.start && 18 (elem_first_coalesce > msg->sg.start || 19 elem_first_coalesce < msg->sg.end)) 20 return true; 21 22 return false; 23 } 24 25 int sk_msg_alloc(struct sock *sk, struct sk_msg *msg, int len, 26 int elem_first_coalesce) 27 { 28 struct page_frag *pfrag = sk_page_frag(sk); 29 int ret = 0; 30 31 len -= msg->sg.size; 32 while (len > 0) { 33 struct scatterlist *sge; 34 u32 orig_offset; 35 int use, i; 36 37 if (!sk_page_frag_refill(sk, pfrag)) 38 return -ENOMEM; 39 40 orig_offset = pfrag->offset; 41 use = min_t(int, len, pfrag->size - orig_offset); 42 if (!sk_wmem_schedule(sk, use)) 43 return -ENOMEM; 44 45 i = msg->sg.end; 46 sk_msg_iter_var_prev(i); 47 sge = &msg->sg.data[i]; 48 49 if (sk_msg_try_coalesce_ok(msg, elem_first_coalesce) && 50 sg_page(sge) == pfrag->page && 51 sge->offset + sge->length == orig_offset) { 52 sge->length += use; 53 } else { 54 if (sk_msg_full(msg)) { 55 ret = -ENOSPC; 56 break; 57 } 58 59 sge = &msg->sg.data[msg->sg.end]; 60 sg_unmark_end(sge); 61 sg_set_page(sge, pfrag->page, use, orig_offset); 62 get_page(pfrag->page); 63 sk_msg_iter_next(msg, end); 64 } 65 66 sk_mem_charge(sk, use); 67 msg->sg.size += use; 68 pfrag->offset += use; 69 len -= use; 70 } 71 72 return ret; 73 } 74 EXPORT_SYMBOL_GPL(sk_msg_alloc); 75 76 int sk_msg_clone(struct sock *sk, struct sk_msg *dst, struct sk_msg *src, 77 u32 off, u32 len) 78 { 79 int i = src->sg.start; 80 struct scatterlist *sge = sk_msg_elem(src, i); 81 struct scatterlist *sgd = NULL; 82 u32 sge_len, sge_off; 83 84 while (off) { 85 if (sge->length > off) 86 break; 87 off -= sge->length; 88 sk_msg_iter_var_next(i); 89 if (i == src->sg.end && off) 90 return -ENOSPC; 91 sge = sk_msg_elem(src, i); 92 } 93 94 while (len) { 95 sge_len = sge->length - off; 96 if (sge_len > len) 97 sge_len = len; 98 99 if (dst->sg.end) 100 sgd = sk_msg_elem(dst, dst->sg.end - 1); 101 102 if (sgd && 103 (sg_page(sge) == sg_page(sgd)) && 104 (sg_virt(sge) + off == sg_virt(sgd) + sgd->length)) { 105 sgd->length += sge_len; 106 dst->sg.size += sge_len; 107 } else if (!sk_msg_full(dst)) { 108 sge_off = sge->offset + off; 109 sk_msg_page_add(dst, sg_page(sge), sge_len, sge_off); 110 } else { 111 return -ENOSPC; 112 } 113 114 off = 0; 115 len -= sge_len; 116 sk_mem_charge(sk, sge_len); 117 sk_msg_iter_var_next(i); 118 if (i == src->sg.end && len) 119 return -ENOSPC; 120 sge = sk_msg_elem(src, i); 121 } 122 123 return 0; 124 } 125 EXPORT_SYMBOL_GPL(sk_msg_clone); 126 127 void sk_msg_return_zero(struct sock *sk, struct sk_msg *msg, int bytes) 128 { 129 int i = msg->sg.start; 130 131 do { 132 struct scatterlist *sge = sk_msg_elem(msg, i); 133 134 if (bytes < sge->length) { 135 sge->length -= bytes; 136 sge->offset += bytes; 137 sk_mem_uncharge(sk, bytes); 138 break; 139 } 140 141 sk_mem_uncharge(sk, sge->length); 142 bytes -= sge->length; 143 sge->length = 0; 144 sge->offset = 0; 145 sk_msg_iter_var_next(i); 146 } while (bytes && i != msg->sg.end); 147 msg->sg.start = i; 148 } 149 EXPORT_SYMBOL_GPL(sk_msg_return_zero); 150 151 void sk_msg_return(struct sock *sk, struct sk_msg *msg, int bytes) 152 { 153 int i = msg->sg.start; 154 155 do { 156 struct scatterlist *sge = &msg->sg.data[i]; 157 int uncharge = (bytes < sge->length) ? bytes : sge->length; 158 159 sk_mem_uncharge(sk, uncharge); 160 bytes -= uncharge; 161 sk_msg_iter_var_next(i); 162 } while (i != msg->sg.end); 163 } 164 EXPORT_SYMBOL_GPL(sk_msg_return); 165 166 static int sk_msg_free_elem(struct sock *sk, struct sk_msg *msg, u32 i, 167 bool charge) 168 { 169 struct scatterlist *sge = sk_msg_elem(msg, i); 170 u32 len = sge->length; 171 172 if (charge) 173 sk_mem_uncharge(sk, len); 174 if (!msg->skb) 175 put_page(sg_page(sge)); 176 memset(sge, 0, sizeof(*sge)); 177 return len; 178 } 179 180 static int __sk_msg_free(struct sock *sk, struct sk_msg *msg, u32 i, 181 bool charge) 182 { 183 struct scatterlist *sge = sk_msg_elem(msg, i); 184 int freed = 0; 185 186 while (msg->sg.size) { 187 msg->sg.size -= sge->length; 188 freed += sk_msg_free_elem(sk, msg, i, charge); 189 sk_msg_iter_var_next(i); 190 sk_msg_check_to_free(msg, i, msg->sg.size); 191 sge = sk_msg_elem(msg, i); 192 } 193 consume_skb(msg->skb); 194 sk_msg_init(msg); 195 return freed; 196 } 197 198 int sk_msg_free_nocharge(struct sock *sk, struct sk_msg *msg) 199 { 200 return __sk_msg_free(sk, msg, msg->sg.start, false); 201 } 202 EXPORT_SYMBOL_GPL(sk_msg_free_nocharge); 203 204 int sk_msg_free(struct sock *sk, struct sk_msg *msg) 205 { 206 return __sk_msg_free(sk, msg, msg->sg.start, true); 207 } 208 EXPORT_SYMBOL_GPL(sk_msg_free); 209 210 static void __sk_msg_free_partial(struct sock *sk, struct sk_msg *msg, 211 u32 bytes, bool charge) 212 { 213 struct scatterlist *sge; 214 u32 i = msg->sg.start; 215 216 while (bytes) { 217 sge = sk_msg_elem(msg, i); 218 if (!sge->length) 219 break; 220 if (bytes < sge->length) { 221 if (charge) 222 sk_mem_uncharge(sk, bytes); 223 sge->length -= bytes; 224 sge->offset += bytes; 225 msg->sg.size -= bytes; 226 break; 227 } 228 229 msg->sg.size -= sge->length; 230 bytes -= sge->length; 231 sk_msg_free_elem(sk, msg, i, charge); 232 sk_msg_iter_var_next(i); 233 sk_msg_check_to_free(msg, i, bytes); 234 } 235 msg->sg.start = i; 236 } 237 238 void sk_msg_free_partial(struct sock *sk, struct sk_msg *msg, u32 bytes) 239 { 240 __sk_msg_free_partial(sk, msg, bytes, true); 241 } 242 EXPORT_SYMBOL_GPL(sk_msg_free_partial); 243 244 void sk_msg_free_partial_nocharge(struct sock *sk, struct sk_msg *msg, 245 u32 bytes) 246 { 247 __sk_msg_free_partial(sk, msg, bytes, false); 248 } 249 250 void sk_msg_trim(struct sock *sk, struct sk_msg *msg, int len) 251 { 252 int trim = msg->sg.size - len; 253 u32 i = msg->sg.end; 254 255 if (trim <= 0) { 256 WARN_ON(trim < 0); 257 return; 258 } 259 260 sk_msg_iter_var_prev(i); 261 msg->sg.size = len; 262 while (msg->sg.data[i].length && 263 trim >= msg->sg.data[i].length) { 264 trim -= msg->sg.data[i].length; 265 sk_msg_free_elem(sk, msg, i, true); 266 sk_msg_iter_var_prev(i); 267 if (!trim) 268 goto out; 269 } 270 271 msg->sg.data[i].length -= trim; 272 sk_mem_uncharge(sk, trim); 273 /* Adjust copybreak if it falls into the trimmed part of last buf */ 274 if (msg->sg.curr == i && msg->sg.copybreak > msg->sg.data[i].length) 275 msg->sg.copybreak = msg->sg.data[i].length; 276 out: 277 sk_msg_iter_var_next(i); 278 msg->sg.end = i; 279 280 /* If we trim data a full sg elem before curr pointer update 281 * copybreak and current so that any future copy operations 282 * start at new copy location. 283 * However trimed data that has not yet been used in a copy op 284 * does not require an update. 285 */ 286 if (!msg->sg.size) { 287 msg->sg.curr = msg->sg.start; 288 msg->sg.copybreak = 0; 289 } else if (sk_msg_iter_dist(msg->sg.start, msg->sg.curr) >= 290 sk_msg_iter_dist(msg->sg.start, msg->sg.end)) { 291 sk_msg_iter_var_prev(i); 292 msg->sg.curr = i; 293 msg->sg.copybreak = msg->sg.data[i].length; 294 } 295 } 296 EXPORT_SYMBOL_GPL(sk_msg_trim); 297 298 int sk_msg_zerocopy_from_iter(struct sock *sk, struct iov_iter *from, 299 struct sk_msg *msg, u32 bytes) 300 { 301 int i, maxpages, ret = 0, num_elems = sk_msg_elem_used(msg); 302 const int to_max_pages = MAX_MSG_FRAGS; 303 struct page *pages[MAX_MSG_FRAGS]; 304 ssize_t orig, copied, use, offset; 305 306 orig = msg->sg.size; 307 while (bytes > 0) { 308 i = 0; 309 maxpages = to_max_pages - num_elems; 310 if (maxpages == 0) { 311 ret = -EFAULT; 312 goto out; 313 } 314 315 copied = iov_iter_get_pages(from, pages, bytes, maxpages, 316 &offset); 317 if (copied <= 0) { 318 ret = -EFAULT; 319 goto out; 320 } 321 322 iov_iter_advance(from, copied); 323 bytes -= copied; 324 msg->sg.size += copied; 325 326 while (copied) { 327 use = min_t(int, copied, PAGE_SIZE - offset); 328 sg_set_page(&msg->sg.data[msg->sg.end], 329 pages[i], use, offset); 330 sg_unmark_end(&msg->sg.data[msg->sg.end]); 331 sk_mem_charge(sk, use); 332 333 offset = 0; 334 copied -= use; 335 sk_msg_iter_next(msg, end); 336 num_elems++; 337 i++; 338 } 339 /* When zerocopy is mixed with sk_msg_*copy* operations we 340 * may have a copybreak set in this case clear and prefer 341 * zerocopy remainder when possible. 342 */ 343 msg->sg.copybreak = 0; 344 msg->sg.curr = msg->sg.end; 345 } 346 out: 347 /* Revert iov_iter updates, msg will need to use 'trim' later if it 348 * also needs to be cleared. 349 */ 350 if (ret) 351 iov_iter_revert(from, msg->sg.size - orig); 352 return ret; 353 } 354 EXPORT_SYMBOL_GPL(sk_msg_zerocopy_from_iter); 355 356 int sk_msg_memcopy_from_iter(struct sock *sk, struct iov_iter *from, 357 struct sk_msg *msg, u32 bytes) 358 { 359 int ret = -ENOSPC, i = msg->sg.curr; 360 struct scatterlist *sge; 361 u32 copy, buf_size; 362 void *to; 363 364 do { 365 sge = sk_msg_elem(msg, i); 366 /* This is possible if a trim operation shrunk the buffer */ 367 if (msg->sg.copybreak >= sge->length) { 368 msg->sg.copybreak = 0; 369 sk_msg_iter_var_next(i); 370 if (i == msg->sg.end) 371 break; 372 sge = sk_msg_elem(msg, i); 373 } 374 375 buf_size = sge->length - msg->sg.copybreak; 376 copy = (buf_size > bytes) ? bytes : buf_size; 377 to = sg_virt(sge) + msg->sg.copybreak; 378 msg->sg.copybreak += copy; 379 if (sk->sk_route_caps & NETIF_F_NOCACHE_COPY) 380 ret = copy_from_iter_nocache(to, copy, from); 381 else 382 ret = copy_from_iter(to, copy, from); 383 if (ret != copy) { 384 ret = -EFAULT; 385 goto out; 386 } 387 bytes -= copy; 388 if (!bytes) 389 break; 390 msg->sg.copybreak = 0; 391 sk_msg_iter_var_next(i); 392 } while (i != msg->sg.end); 393 out: 394 msg->sg.curr = i; 395 return ret; 396 } 397 EXPORT_SYMBOL_GPL(sk_msg_memcopy_from_iter); 398 399 static int sk_psock_skb_ingress(struct sk_psock *psock, struct sk_buff *skb) 400 { 401 struct sock *sk = psock->sk; 402 int copied = 0, num_sge; 403 struct sk_msg *msg; 404 405 msg = kzalloc(sizeof(*msg), __GFP_NOWARN | GFP_ATOMIC); 406 if (unlikely(!msg)) 407 return -EAGAIN; 408 if (!sk_rmem_schedule(sk, skb, skb->len)) { 409 kfree(msg); 410 return -EAGAIN; 411 } 412 413 sk_msg_init(msg); 414 num_sge = skb_to_sgvec(skb, msg->sg.data, 0, skb->len); 415 if (unlikely(num_sge < 0)) { 416 kfree(msg); 417 return num_sge; 418 } 419 420 sk_mem_charge(sk, skb->len); 421 copied = skb->len; 422 msg->sg.start = 0; 423 msg->sg.size = copied; 424 msg->sg.end = num_sge; 425 msg->skb = skb; 426 427 sk_psock_queue_msg(psock, msg); 428 sk_psock_data_ready(sk, psock); 429 return copied; 430 } 431 432 static int sk_psock_handle_skb(struct sk_psock *psock, struct sk_buff *skb, 433 u32 off, u32 len, bool ingress) 434 { 435 if (ingress) 436 return sk_psock_skb_ingress(psock, skb); 437 else 438 return skb_send_sock_locked(psock->sk, skb, off, len); 439 } 440 441 static void sk_psock_backlog(struct work_struct *work) 442 { 443 struct sk_psock *psock = container_of(work, struct sk_psock, work); 444 struct sk_psock_work_state *state = &psock->work_state; 445 struct sk_buff *skb; 446 bool ingress; 447 u32 len, off; 448 int ret; 449 450 /* Lock sock to avoid losing sk_socket during loop. */ 451 lock_sock(psock->sk); 452 if (state->skb) { 453 skb = state->skb; 454 len = state->len; 455 off = state->off; 456 state->skb = NULL; 457 goto start; 458 } 459 460 while ((skb = skb_dequeue(&psock->ingress_skb))) { 461 len = skb->len; 462 off = 0; 463 start: 464 ingress = tcp_skb_bpf_ingress(skb); 465 do { 466 ret = -EIO; 467 if (likely(psock->sk->sk_socket)) 468 ret = sk_psock_handle_skb(psock, skb, off, 469 len, ingress); 470 if (ret <= 0) { 471 if (ret == -EAGAIN) { 472 state->skb = skb; 473 state->len = len; 474 state->off = off; 475 goto end; 476 } 477 /* Hard errors break pipe and stop xmit. */ 478 sk_psock_report_error(psock, ret ? -ret : EPIPE); 479 sk_psock_clear_state(psock, SK_PSOCK_TX_ENABLED); 480 kfree_skb(skb); 481 goto end; 482 } 483 off += ret; 484 len -= ret; 485 } while (len); 486 487 if (!ingress) 488 kfree_skb(skb); 489 } 490 end: 491 release_sock(psock->sk); 492 } 493 494 struct sk_psock *sk_psock_init(struct sock *sk, int node) 495 { 496 struct sk_psock *psock = kzalloc_node(sizeof(*psock), 497 GFP_ATOMIC | __GFP_NOWARN, 498 node); 499 if (!psock) 500 return NULL; 501 502 psock->sk = sk; 503 psock->eval = __SK_NONE; 504 505 INIT_LIST_HEAD(&psock->link); 506 spin_lock_init(&psock->link_lock); 507 508 INIT_WORK(&psock->work, sk_psock_backlog); 509 INIT_LIST_HEAD(&psock->ingress_msg); 510 skb_queue_head_init(&psock->ingress_skb); 511 512 sk_psock_set_state(psock, SK_PSOCK_TX_ENABLED); 513 refcount_set(&psock->refcnt, 1); 514 515 rcu_assign_sk_user_data_nocopy(sk, psock); 516 sock_hold(sk); 517 518 return psock; 519 } 520 EXPORT_SYMBOL_GPL(sk_psock_init); 521 522 struct sk_psock_link *sk_psock_link_pop(struct sk_psock *psock) 523 { 524 struct sk_psock_link *link; 525 526 spin_lock_bh(&psock->link_lock); 527 link = list_first_entry_or_null(&psock->link, struct sk_psock_link, 528 list); 529 if (link) 530 list_del(&link->list); 531 spin_unlock_bh(&psock->link_lock); 532 return link; 533 } 534 535 void __sk_psock_purge_ingress_msg(struct sk_psock *psock) 536 { 537 struct sk_msg *msg, *tmp; 538 539 list_for_each_entry_safe(msg, tmp, &psock->ingress_msg, list) { 540 list_del(&msg->list); 541 sk_msg_free(psock->sk, msg); 542 kfree(msg); 543 } 544 } 545 546 static void sk_psock_zap_ingress(struct sk_psock *psock) 547 { 548 __skb_queue_purge(&psock->ingress_skb); 549 __sk_psock_purge_ingress_msg(psock); 550 } 551 552 static void sk_psock_link_destroy(struct sk_psock *psock) 553 { 554 struct sk_psock_link *link, *tmp; 555 556 list_for_each_entry_safe(link, tmp, &psock->link, list) { 557 list_del(&link->list); 558 sk_psock_free_link(link); 559 } 560 } 561 562 static void sk_psock_destroy_deferred(struct work_struct *gc) 563 { 564 struct sk_psock *psock = container_of(gc, struct sk_psock, gc); 565 566 /* No sk_callback_lock since already detached. */ 567 568 /* Parser has been stopped */ 569 if (psock->progs.skb_parser) 570 strp_done(&psock->parser.strp); 571 572 cancel_work_sync(&psock->work); 573 574 psock_progs_drop(&psock->progs); 575 576 sk_psock_link_destroy(psock); 577 sk_psock_cork_free(psock); 578 sk_psock_zap_ingress(psock); 579 580 if (psock->sk_redir) 581 sock_put(psock->sk_redir); 582 sock_put(psock->sk); 583 kfree(psock); 584 } 585 586 void sk_psock_destroy(struct rcu_head *rcu) 587 { 588 struct sk_psock *psock = container_of(rcu, struct sk_psock, rcu); 589 590 INIT_WORK(&psock->gc, sk_psock_destroy_deferred); 591 schedule_work(&psock->gc); 592 } 593 EXPORT_SYMBOL_GPL(sk_psock_destroy); 594 595 void sk_psock_drop(struct sock *sk, struct sk_psock *psock) 596 { 597 sk_psock_cork_free(psock); 598 sk_psock_zap_ingress(psock); 599 600 write_lock_bh(&sk->sk_callback_lock); 601 sk_psock_restore_proto(sk, psock); 602 rcu_assign_sk_user_data(sk, NULL); 603 if (psock->progs.skb_parser) 604 sk_psock_stop_strp(sk, psock); 605 write_unlock_bh(&sk->sk_callback_lock); 606 sk_psock_clear_state(psock, SK_PSOCK_TX_ENABLED); 607 608 call_rcu(&psock->rcu, sk_psock_destroy); 609 } 610 EXPORT_SYMBOL_GPL(sk_psock_drop); 611 612 static int sk_psock_map_verd(int verdict, bool redir) 613 { 614 switch (verdict) { 615 case SK_PASS: 616 return redir ? __SK_REDIRECT : __SK_PASS; 617 case SK_DROP: 618 default: 619 break; 620 } 621 622 return __SK_DROP; 623 } 624 625 int sk_psock_msg_verdict(struct sock *sk, struct sk_psock *psock, 626 struct sk_msg *msg) 627 { 628 struct bpf_prog *prog; 629 int ret; 630 631 rcu_read_lock(); 632 prog = READ_ONCE(psock->progs.msg_parser); 633 if (unlikely(!prog)) { 634 ret = __SK_PASS; 635 goto out; 636 } 637 638 sk_msg_compute_data_pointers(msg); 639 msg->sk = sk; 640 ret = bpf_prog_run_pin_on_cpu(prog, msg); 641 ret = sk_psock_map_verd(ret, msg->sk_redir); 642 psock->apply_bytes = msg->apply_bytes; 643 if (ret == __SK_REDIRECT) { 644 if (psock->sk_redir) 645 sock_put(psock->sk_redir); 646 psock->sk_redir = msg->sk_redir; 647 if (!psock->sk_redir) { 648 ret = __SK_DROP; 649 goto out; 650 } 651 sock_hold(psock->sk_redir); 652 } 653 out: 654 rcu_read_unlock(); 655 return ret; 656 } 657 EXPORT_SYMBOL_GPL(sk_psock_msg_verdict); 658 659 static int sk_psock_bpf_run(struct sk_psock *psock, struct bpf_prog *prog, 660 struct sk_buff *skb) 661 { 662 int ret; 663 664 skb->sk = psock->sk; 665 bpf_compute_data_end_sk_skb(skb); 666 ret = bpf_prog_run_pin_on_cpu(prog, skb); 667 /* strparser clones the skb before handing it to a upper layer, 668 * meaning skb_orphan has been called. We NULL sk on the way out 669 * to ensure we don't trigger a BUG_ON() in skb/sk operations 670 * later and because we are not charging the memory of this skb 671 * to any socket yet. 672 */ 673 skb->sk = NULL; 674 return ret; 675 } 676 677 static struct sk_psock *sk_psock_from_strp(struct strparser *strp) 678 { 679 struct sk_psock_parser *parser; 680 681 parser = container_of(strp, struct sk_psock_parser, strp); 682 return container_of(parser, struct sk_psock, parser); 683 } 684 685 static void sk_psock_verdict_apply(struct sk_psock *psock, 686 struct sk_buff *skb, int verdict) 687 { 688 struct sk_psock *psock_other; 689 struct sock *sk_other; 690 bool ingress; 691 692 switch (verdict) { 693 case __SK_PASS: 694 sk_other = psock->sk; 695 if (sock_flag(sk_other, SOCK_DEAD) || 696 !sk_psock_test_state(psock, SK_PSOCK_TX_ENABLED)) { 697 goto out_free; 698 } 699 if (atomic_read(&sk_other->sk_rmem_alloc) <= 700 sk_other->sk_rcvbuf) { 701 struct tcp_skb_cb *tcp = TCP_SKB_CB(skb); 702 703 tcp->bpf.flags |= BPF_F_INGRESS; 704 skb_queue_tail(&psock->ingress_skb, skb); 705 schedule_work(&psock->work); 706 break; 707 } 708 goto out_free; 709 case __SK_REDIRECT: 710 sk_other = tcp_skb_bpf_redirect_fetch(skb); 711 if (unlikely(!sk_other)) 712 goto out_free; 713 psock_other = sk_psock(sk_other); 714 if (!psock_other || sock_flag(sk_other, SOCK_DEAD) || 715 !sk_psock_test_state(psock_other, SK_PSOCK_TX_ENABLED)) 716 goto out_free; 717 ingress = tcp_skb_bpf_ingress(skb); 718 if ((!ingress && sock_writeable(sk_other)) || 719 (ingress && 720 atomic_read(&sk_other->sk_rmem_alloc) <= 721 sk_other->sk_rcvbuf)) { 722 if (!ingress) 723 skb_set_owner_w(skb, sk_other); 724 skb_queue_tail(&psock_other->ingress_skb, skb); 725 schedule_work(&psock_other->work); 726 break; 727 } 728 /* fall-through */ 729 case __SK_DROP: 730 /* fall-through */ 731 default: 732 out_free: 733 kfree_skb(skb); 734 } 735 } 736 737 static void sk_psock_strp_read(struct strparser *strp, struct sk_buff *skb) 738 { 739 struct sk_psock *psock = sk_psock_from_strp(strp); 740 struct bpf_prog *prog; 741 int ret = __SK_DROP; 742 743 rcu_read_lock(); 744 prog = READ_ONCE(psock->progs.skb_verdict); 745 if (likely(prog)) { 746 skb_orphan(skb); 747 tcp_skb_bpf_redirect_clear(skb); 748 ret = sk_psock_bpf_run(psock, prog, skb); 749 ret = sk_psock_map_verd(ret, tcp_skb_bpf_redirect_fetch(skb)); 750 } 751 rcu_read_unlock(); 752 sk_psock_verdict_apply(psock, skb, ret); 753 } 754 755 static int sk_psock_strp_read_done(struct strparser *strp, int err) 756 { 757 return err; 758 } 759 760 static int sk_psock_strp_parse(struct strparser *strp, struct sk_buff *skb) 761 { 762 struct sk_psock *psock = sk_psock_from_strp(strp); 763 struct bpf_prog *prog; 764 int ret = skb->len; 765 766 rcu_read_lock(); 767 prog = READ_ONCE(psock->progs.skb_parser); 768 if (likely(prog)) 769 ret = sk_psock_bpf_run(psock, prog, skb); 770 rcu_read_unlock(); 771 return ret; 772 } 773 774 /* Called with socket lock held. */ 775 static void sk_psock_strp_data_ready(struct sock *sk) 776 { 777 struct sk_psock *psock; 778 779 rcu_read_lock(); 780 psock = sk_psock(sk); 781 if (likely(psock)) { 782 write_lock_bh(&sk->sk_callback_lock); 783 strp_data_ready(&psock->parser.strp); 784 write_unlock_bh(&sk->sk_callback_lock); 785 } 786 rcu_read_unlock(); 787 } 788 789 static void sk_psock_write_space(struct sock *sk) 790 { 791 struct sk_psock *psock; 792 void (*write_space)(struct sock *sk) = NULL; 793 794 rcu_read_lock(); 795 psock = sk_psock(sk); 796 if (likely(psock)) { 797 if (sk_psock_test_state(psock, SK_PSOCK_TX_ENABLED)) 798 schedule_work(&psock->work); 799 write_space = psock->saved_write_space; 800 } 801 rcu_read_unlock(); 802 if (write_space) 803 write_space(sk); 804 } 805 806 int sk_psock_init_strp(struct sock *sk, struct sk_psock *psock) 807 { 808 static const struct strp_callbacks cb = { 809 .rcv_msg = sk_psock_strp_read, 810 .read_sock_done = sk_psock_strp_read_done, 811 .parse_msg = sk_psock_strp_parse, 812 }; 813 814 psock->parser.enabled = false; 815 return strp_init(&psock->parser.strp, sk, &cb); 816 } 817 818 void sk_psock_start_strp(struct sock *sk, struct sk_psock *psock) 819 { 820 struct sk_psock_parser *parser = &psock->parser; 821 822 if (parser->enabled) 823 return; 824 825 parser->saved_data_ready = sk->sk_data_ready; 826 sk->sk_data_ready = sk_psock_strp_data_ready; 827 sk->sk_write_space = sk_psock_write_space; 828 parser->enabled = true; 829 } 830 831 void sk_psock_stop_strp(struct sock *sk, struct sk_psock *psock) 832 { 833 struct sk_psock_parser *parser = &psock->parser; 834 835 if (!parser->enabled) 836 return; 837 838 sk->sk_data_ready = parser->saved_data_ready; 839 parser->saved_data_ready = NULL; 840 strp_stop(&parser->strp); 841 parser->enabled = false; 842 } 843