1 // SPDX-License-Identifier: GPL-2.0 2 /* Copyright (c) 2017 - 2018 Covalent IO, Inc. http://covalent.io */ 3 4 #include <linux/skmsg.h> 5 #include <linux/skbuff.h> 6 #include <linux/scatterlist.h> 7 8 #include <net/sock.h> 9 #include <net/tcp.h> 10 11 static bool sk_msg_try_coalesce_ok(struct sk_msg *msg, int elem_first_coalesce) 12 { 13 if (msg->sg.end > msg->sg.start && 14 elem_first_coalesce < msg->sg.end) 15 return true; 16 17 if (msg->sg.end < msg->sg.start && 18 (elem_first_coalesce > msg->sg.start || 19 elem_first_coalesce < msg->sg.end)) 20 return true; 21 22 return false; 23 } 24 25 int sk_msg_alloc(struct sock *sk, struct sk_msg *msg, int len, 26 int elem_first_coalesce) 27 { 28 struct page_frag *pfrag = sk_page_frag(sk); 29 int ret = 0; 30 31 len -= msg->sg.size; 32 while (len > 0) { 33 struct scatterlist *sge; 34 u32 orig_offset; 35 int use, i; 36 37 if (!sk_page_frag_refill(sk, pfrag)) 38 return -ENOMEM; 39 40 orig_offset = pfrag->offset; 41 use = min_t(int, len, pfrag->size - orig_offset); 42 if (!sk_wmem_schedule(sk, use)) 43 return -ENOMEM; 44 45 i = msg->sg.end; 46 sk_msg_iter_var_prev(i); 47 sge = &msg->sg.data[i]; 48 49 if (sk_msg_try_coalesce_ok(msg, elem_first_coalesce) && 50 sg_page(sge) == pfrag->page && 51 sge->offset + sge->length == orig_offset) { 52 sge->length += use; 53 } else { 54 if (sk_msg_full(msg)) { 55 ret = -ENOSPC; 56 break; 57 } 58 59 sge = &msg->sg.data[msg->sg.end]; 60 sg_unmark_end(sge); 61 sg_set_page(sge, pfrag->page, use, orig_offset); 62 get_page(pfrag->page); 63 sk_msg_iter_next(msg, end); 64 } 65 66 sk_mem_charge(sk, use); 67 msg->sg.size += use; 68 pfrag->offset += use; 69 len -= use; 70 } 71 72 return ret; 73 } 74 EXPORT_SYMBOL_GPL(sk_msg_alloc); 75 76 int sk_msg_clone(struct sock *sk, struct sk_msg *dst, struct sk_msg *src, 77 u32 off, u32 len) 78 { 79 int i = src->sg.start; 80 struct scatterlist *sge = sk_msg_elem(src, i); 81 struct scatterlist *sgd = NULL; 82 u32 sge_len, sge_off; 83 84 while (off) { 85 if (sge->length > off) 86 break; 87 off -= sge->length; 88 sk_msg_iter_var_next(i); 89 if (i == src->sg.end && off) 90 return -ENOSPC; 91 sge = sk_msg_elem(src, i); 92 } 93 94 while (len) { 95 sge_len = sge->length - off; 96 if (sge_len > len) 97 sge_len = len; 98 99 if (dst->sg.end) 100 sgd = sk_msg_elem(dst, dst->sg.end - 1); 101 102 if (sgd && 103 (sg_page(sge) == sg_page(sgd)) && 104 (sg_virt(sge) + off == sg_virt(sgd) + sgd->length)) { 105 sgd->length += sge_len; 106 dst->sg.size += sge_len; 107 } else if (!sk_msg_full(dst)) { 108 sge_off = sge->offset + off; 109 sk_msg_page_add(dst, sg_page(sge), sge_len, sge_off); 110 } else { 111 return -ENOSPC; 112 } 113 114 off = 0; 115 len -= sge_len; 116 sk_mem_charge(sk, sge_len); 117 sk_msg_iter_var_next(i); 118 if (i == src->sg.end && len) 119 return -ENOSPC; 120 sge = sk_msg_elem(src, i); 121 } 122 123 return 0; 124 } 125 EXPORT_SYMBOL_GPL(sk_msg_clone); 126 127 void sk_msg_return_zero(struct sock *sk, struct sk_msg *msg, int bytes) 128 { 129 int i = msg->sg.start; 130 131 do { 132 struct scatterlist *sge = sk_msg_elem(msg, i); 133 134 if (bytes < sge->length) { 135 sge->length -= bytes; 136 sge->offset += bytes; 137 sk_mem_uncharge(sk, bytes); 138 break; 139 } 140 141 sk_mem_uncharge(sk, sge->length); 142 bytes -= sge->length; 143 sge->length = 0; 144 sge->offset = 0; 145 sk_msg_iter_var_next(i); 146 } while (bytes && i != msg->sg.end); 147 msg->sg.start = i; 148 } 149 EXPORT_SYMBOL_GPL(sk_msg_return_zero); 150 151 void sk_msg_return(struct sock *sk, struct sk_msg *msg, int bytes) 152 { 153 int i = msg->sg.start; 154 155 do { 156 struct scatterlist *sge = &msg->sg.data[i]; 157 int uncharge = (bytes < sge->length) ? bytes : sge->length; 158 159 sk_mem_uncharge(sk, uncharge); 160 bytes -= uncharge; 161 sk_msg_iter_var_next(i); 162 } while (i != msg->sg.end); 163 } 164 EXPORT_SYMBOL_GPL(sk_msg_return); 165 166 static int sk_msg_free_elem(struct sock *sk, struct sk_msg *msg, u32 i, 167 bool charge) 168 { 169 struct scatterlist *sge = sk_msg_elem(msg, i); 170 u32 len = sge->length; 171 172 if (charge) 173 sk_mem_uncharge(sk, len); 174 if (!msg->skb) 175 put_page(sg_page(sge)); 176 memset(sge, 0, sizeof(*sge)); 177 return len; 178 } 179 180 static int __sk_msg_free(struct sock *sk, struct sk_msg *msg, u32 i, 181 bool charge) 182 { 183 struct scatterlist *sge = sk_msg_elem(msg, i); 184 int freed = 0; 185 186 while (msg->sg.size) { 187 msg->sg.size -= sge->length; 188 freed += sk_msg_free_elem(sk, msg, i, charge); 189 sk_msg_iter_var_next(i); 190 sk_msg_check_to_free(msg, i, msg->sg.size); 191 sge = sk_msg_elem(msg, i); 192 } 193 consume_skb(msg->skb); 194 sk_msg_init(msg); 195 return freed; 196 } 197 198 int sk_msg_free_nocharge(struct sock *sk, struct sk_msg *msg) 199 { 200 return __sk_msg_free(sk, msg, msg->sg.start, false); 201 } 202 EXPORT_SYMBOL_GPL(sk_msg_free_nocharge); 203 204 int sk_msg_free(struct sock *sk, struct sk_msg *msg) 205 { 206 return __sk_msg_free(sk, msg, msg->sg.start, true); 207 } 208 EXPORT_SYMBOL_GPL(sk_msg_free); 209 210 static void __sk_msg_free_partial(struct sock *sk, struct sk_msg *msg, 211 u32 bytes, bool charge) 212 { 213 struct scatterlist *sge; 214 u32 i = msg->sg.start; 215 216 while (bytes) { 217 sge = sk_msg_elem(msg, i); 218 if (!sge->length) 219 break; 220 if (bytes < sge->length) { 221 if (charge) 222 sk_mem_uncharge(sk, bytes); 223 sge->length -= bytes; 224 sge->offset += bytes; 225 msg->sg.size -= bytes; 226 break; 227 } 228 229 msg->sg.size -= sge->length; 230 bytes -= sge->length; 231 sk_msg_free_elem(sk, msg, i, charge); 232 sk_msg_iter_var_next(i); 233 sk_msg_check_to_free(msg, i, bytes); 234 } 235 msg->sg.start = i; 236 } 237 238 void sk_msg_free_partial(struct sock *sk, struct sk_msg *msg, u32 bytes) 239 { 240 __sk_msg_free_partial(sk, msg, bytes, true); 241 } 242 EXPORT_SYMBOL_GPL(sk_msg_free_partial); 243 244 void sk_msg_free_partial_nocharge(struct sock *sk, struct sk_msg *msg, 245 u32 bytes) 246 { 247 __sk_msg_free_partial(sk, msg, bytes, false); 248 } 249 250 void sk_msg_trim(struct sock *sk, struct sk_msg *msg, int len) 251 { 252 int trim = msg->sg.size - len; 253 u32 i = msg->sg.end; 254 255 if (trim <= 0) { 256 WARN_ON(trim < 0); 257 return; 258 } 259 260 sk_msg_iter_var_prev(i); 261 msg->sg.size = len; 262 while (msg->sg.data[i].length && 263 trim >= msg->sg.data[i].length) { 264 trim -= msg->sg.data[i].length; 265 sk_msg_free_elem(sk, msg, i, true); 266 sk_msg_iter_var_prev(i); 267 if (!trim) 268 goto out; 269 } 270 271 msg->sg.data[i].length -= trim; 272 sk_mem_uncharge(sk, trim); 273 out: 274 /* If we trim data before curr pointer update copybreak and current 275 * so that any future copy operations start at new copy location. 276 * However trimed data that has not yet been used in a copy op 277 * does not require an update. 278 */ 279 if (msg->sg.curr >= i) { 280 msg->sg.curr = i; 281 msg->sg.copybreak = msg->sg.data[i].length; 282 } 283 sk_msg_iter_var_next(i); 284 msg->sg.end = i; 285 } 286 EXPORT_SYMBOL_GPL(sk_msg_trim); 287 288 int sk_msg_zerocopy_from_iter(struct sock *sk, struct iov_iter *from, 289 struct sk_msg *msg, u32 bytes) 290 { 291 int i, maxpages, ret = 0, num_elems = sk_msg_elem_used(msg); 292 const int to_max_pages = MAX_MSG_FRAGS; 293 struct page *pages[MAX_MSG_FRAGS]; 294 ssize_t orig, copied, use, offset; 295 296 orig = msg->sg.size; 297 while (bytes > 0) { 298 i = 0; 299 maxpages = to_max_pages - num_elems; 300 if (maxpages == 0) { 301 ret = -EFAULT; 302 goto out; 303 } 304 305 copied = iov_iter_get_pages(from, pages, bytes, maxpages, 306 &offset); 307 if (copied <= 0) { 308 ret = -EFAULT; 309 goto out; 310 } 311 312 iov_iter_advance(from, copied); 313 bytes -= copied; 314 msg->sg.size += copied; 315 316 while (copied) { 317 use = min_t(int, copied, PAGE_SIZE - offset); 318 sg_set_page(&msg->sg.data[msg->sg.end], 319 pages[i], use, offset); 320 sg_unmark_end(&msg->sg.data[msg->sg.end]); 321 sk_mem_charge(sk, use); 322 323 offset = 0; 324 copied -= use; 325 sk_msg_iter_next(msg, end); 326 num_elems++; 327 i++; 328 } 329 /* When zerocopy is mixed with sk_msg_*copy* operations we 330 * may have a copybreak set in this case clear and prefer 331 * zerocopy remainder when possible. 332 */ 333 msg->sg.copybreak = 0; 334 msg->sg.curr = msg->sg.end; 335 } 336 out: 337 /* Revert iov_iter updates, msg will need to use 'trim' later if it 338 * also needs to be cleared. 339 */ 340 if (ret) 341 iov_iter_revert(from, msg->sg.size - orig); 342 return ret; 343 } 344 EXPORT_SYMBOL_GPL(sk_msg_zerocopy_from_iter); 345 346 int sk_msg_memcopy_from_iter(struct sock *sk, struct iov_iter *from, 347 struct sk_msg *msg, u32 bytes) 348 { 349 int ret = -ENOSPC, i = msg->sg.curr; 350 struct scatterlist *sge; 351 u32 copy, buf_size; 352 void *to; 353 354 do { 355 sge = sk_msg_elem(msg, i); 356 /* This is possible if a trim operation shrunk the buffer */ 357 if (msg->sg.copybreak >= sge->length) { 358 msg->sg.copybreak = 0; 359 sk_msg_iter_var_next(i); 360 if (i == msg->sg.end) 361 break; 362 sge = sk_msg_elem(msg, i); 363 } 364 365 buf_size = sge->length - msg->sg.copybreak; 366 copy = (buf_size > bytes) ? bytes : buf_size; 367 to = sg_virt(sge) + msg->sg.copybreak; 368 msg->sg.copybreak += copy; 369 if (sk->sk_route_caps & NETIF_F_NOCACHE_COPY) 370 ret = copy_from_iter_nocache(to, copy, from); 371 else 372 ret = copy_from_iter(to, copy, from); 373 if (ret != copy) { 374 ret = -EFAULT; 375 goto out; 376 } 377 bytes -= copy; 378 if (!bytes) 379 break; 380 msg->sg.copybreak = 0; 381 sk_msg_iter_var_next(i); 382 } while (i != msg->sg.end); 383 out: 384 msg->sg.curr = i; 385 return ret; 386 } 387 EXPORT_SYMBOL_GPL(sk_msg_memcopy_from_iter); 388 389 static int sk_psock_skb_ingress(struct sk_psock *psock, struct sk_buff *skb) 390 { 391 struct sock *sk = psock->sk; 392 int copied = 0, num_sge; 393 struct sk_msg *msg; 394 395 msg = kzalloc(sizeof(*msg), __GFP_NOWARN | GFP_ATOMIC); 396 if (unlikely(!msg)) 397 return -EAGAIN; 398 if (!sk_rmem_schedule(sk, skb, skb->len)) { 399 kfree(msg); 400 return -EAGAIN; 401 } 402 403 sk_msg_init(msg); 404 num_sge = skb_to_sgvec(skb, msg->sg.data, 0, skb->len); 405 if (unlikely(num_sge < 0)) { 406 kfree(msg); 407 return num_sge; 408 } 409 410 sk_mem_charge(sk, skb->len); 411 copied = skb->len; 412 msg->sg.start = 0; 413 msg->sg.size = copied; 414 msg->sg.end = num_sge == MAX_MSG_FRAGS ? 0 : num_sge; 415 msg->skb = skb; 416 417 sk_psock_queue_msg(psock, msg); 418 sk_psock_data_ready(sk, psock); 419 return copied; 420 } 421 422 static int sk_psock_handle_skb(struct sk_psock *psock, struct sk_buff *skb, 423 u32 off, u32 len, bool ingress) 424 { 425 if (ingress) 426 return sk_psock_skb_ingress(psock, skb); 427 else 428 return skb_send_sock_locked(psock->sk, skb, off, len); 429 } 430 431 static void sk_psock_backlog(struct work_struct *work) 432 { 433 struct sk_psock *psock = container_of(work, struct sk_psock, work); 434 struct sk_psock_work_state *state = &psock->work_state; 435 struct sk_buff *skb; 436 bool ingress; 437 u32 len, off; 438 int ret; 439 440 /* Lock sock to avoid losing sk_socket during loop. */ 441 lock_sock(psock->sk); 442 if (state->skb) { 443 skb = state->skb; 444 len = state->len; 445 off = state->off; 446 state->skb = NULL; 447 goto start; 448 } 449 450 while ((skb = skb_dequeue(&psock->ingress_skb))) { 451 len = skb->len; 452 off = 0; 453 start: 454 ingress = tcp_skb_bpf_ingress(skb); 455 do { 456 ret = -EIO; 457 if (likely(psock->sk->sk_socket)) 458 ret = sk_psock_handle_skb(psock, skb, off, 459 len, ingress); 460 if (ret <= 0) { 461 if (ret == -EAGAIN) { 462 state->skb = skb; 463 state->len = len; 464 state->off = off; 465 goto end; 466 } 467 /* Hard errors break pipe and stop xmit. */ 468 sk_psock_report_error(psock, ret ? -ret : EPIPE); 469 sk_psock_clear_state(psock, SK_PSOCK_TX_ENABLED); 470 kfree_skb(skb); 471 goto end; 472 } 473 off += ret; 474 len -= ret; 475 } while (len); 476 477 if (!ingress) 478 kfree_skb(skb); 479 } 480 end: 481 release_sock(psock->sk); 482 } 483 484 struct sk_psock *sk_psock_init(struct sock *sk, int node) 485 { 486 struct sk_psock *psock = kzalloc_node(sizeof(*psock), 487 GFP_ATOMIC | __GFP_NOWARN, 488 node); 489 if (!psock) 490 return NULL; 491 492 psock->sk = sk; 493 psock->eval = __SK_NONE; 494 495 INIT_LIST_HEAD(&psock->link); 496 spin_lock_init(&psock->link_lock); 497 498 INIT_WORK(&psock->work, sk_psock_backlog); 499 INIT_LIST_HEAD(&psock->ingress_msg); 500 skb_queue_head_init(&psock->ingress_skb); 501 502 sk_psock_set_state(psock, SK_PSOCK_TX_ENABLED); 503 refcount_set(&psock->refcnt, 1); 504 505 rcu_assign_sk_user_data(sk, psock); 506 sock_hold(sk); 507 508 return psock; 509 } 510 EXPORT_SYMBOL_GPL(sk_psock_init); 511 512 struct sk_psock_link *sk_psock_link_pop(struct sk_psock *psock) 513 { 514 struct sk_psock_link *link; 515 516 spin_lock_bh(&psock->link_lock); 517 link = list_first_entry_or_null(&psock->link, struct sk_psock_link, 518 list); 519 if (link) 520 list_del(&link->list); 521 spin_unlock_bh(&psock->link_lock); 522 return link; 523 } 524 525 void __sk_psock_purge_ingress_msg(struct sk_psock *psock) 526 { 527 struct sk_msg *msg, *tmp; 528 529 list_for_each_entry_safe(msg, tmp, &psock->ingress_msg, list) { 530 list_del(&msg->list); 531 sk_msg_free(psock->sk, msg); 532 kfree(msg); 533 } 534 } 535 536 static void sk_psock_zap_ingress(struct sk_psock *psock) 537 { 538 __skb_queue_purge(&psock->ingress_skb); 539 __sk_psock_purge_ingress_msg(psock); 540 } 541 542 static void sk_psock_link_destroy(struct sk_psock *psock) 543 { 544 struct sk_psock_link *link, *tmp; 545 546 list_for_each_entry_safe(link, tmp, &psock->link, list) { 547 list_del(&link->list); 548 sk_psock_free_link(link); 549 } 550 } 551 552 static void sk_psock_destroy_deferred(struct work_struct *gc) 553 { 554 struct sk_psock *psock = container_of(gc, struct sk_psock, gc); 555 556 /* No sk_callback_lock since already detached. */ 557 558 /* Parser has been stopped */ 559 if (psock->progs.skb_parser) 560 strp_done(&psock->parser.strp); 561 562 cancel_work_sync(&psock->work); 563 564 psock_progs_drop(&psock->progs); 565 566 sk_psock_link_destroy(psock); 567 sk_psock_cork_free(psock); 568 sk_psock_zap_ingress(psock); 569 570 if (psock->sk_redir) 571 sock_put(psock->sk_redir); 572 sock_put(psock->sk); 573 kfree(psock); 574 } 575 576 void sk_psock_destroy(struct rcu_head *rcu) 577 { 578 struct sk_psock *psock = container_of(rcu, struct sk_psock, rcu); 579 580 INIT_WORK(&psock->gc, sk_psock_destroy_deferred); 581 schedule_work(&psock->gc); 582 } 583 EXPORT_SYMBOL_GPL(sk_psock_destroy); 584 585 void sk_psock_drop(struct sock *sk, struct sk_psock *psock) 586 { 587 sk_psock_cork_free(psock); 588 sk_psock_zap_ingress(psock); 589 590 write_lock_bh(&sk->sk_callback_lock); 591 sk_psock_restore_proto(sk, psock); 592 rcu_assign_sk_user_data(sk, NULL); 593 if (psock->progs.skb_parser) 594 sk_psock_stop_strp(sk, psock); 595 write_unlock_bh(&sk->sk_callback_lock); 596 sk_psock_clear_state(psock, SK_PSOCK_TX_ENABLED); 597 598 call_rcu(&psock->rcu, sk_psock_destroy); 599 } 600 EXPORT_SYMBOL_GPL(sk_psock_drop); 601 602 static int sk_psock_map_verd(int verdict, bool redir) 603 { 604 switch (verdict) { 605 case SK_PASS: 606 return redir ? __SK_REDIRECT : __SK_PASS; 607 case SK_DROP: 608 default: 609 break; 610 } 611 612 return __SK_DROP; 613 } 614 615 int sk_psock_msg_verdict(struct sock *sk, struct sk_psock *psock, 616 struct sk_msg *msg) 617 { 618 struct bpf_prog *prog; 619 int ret; 620 621 preempt_disable(); 622 rcu_read_lock(); 623 prog = READ_ONCE(psock->progs.msg_parser); 624 if (unlikely(!prog)) { 625 ret = __SK_PASS; 626 goto out; 627 } 628 629 sk_msg_compute_data_pointers(msg); 630 msg->sk = sk; 631 ret = BPF_PROG_RUN(prog, msg); 632 ret = sk_psock_map_verd(ret, msg->sk_redir); 633 psock->apply_bytes = msg->apply_bytes; 634 if (ret == __SK_REDIRECT) { 635 if (psock->sk_redir) 636 sock_put(psock->sk_redir); 637 psock->sk_redir = msg->sk_redir; 638 if (!psock->sk_redir) { 639 ret = __SK_DROP; 640 goto out; 641 } 642 sock_hold(psock->sk_redir); 643 } 644 out: 645 rcu_read_unlock(); 646 preempt_enable(); 647 return ret; 648 } 649 EXPORT_SYMBOL_GPL(sk_psock_msg_verdict); 650 651 static int sk_psock_bpf_run(struct sk_psock *psock, struct bpf_prog *prog, 652 struct sk_buff *skb) 653 { 654 int ret; 655 656 skb->sk = psock->sk; 657 bpf_compute_data_end_sk_skb(skb); 658 preempt_disable(); 659 ret = BPF_PROG_RUN(prog, skb); 660 preempt_enable(); 661 /* strparser clones the skb before handing it to a upper layer, 662 * meaning skb_orphan has been called. We NULL sk on the way out 663 * to ensure we don't trigger a BUG_ON() in skb/sk operations 664 * later and because we are not charging the memory of this skb 665 * to any socket yet. 666 */ 667 skb->sk = NULL; 668 return ret; 669 } 670 671 static struct sk_psock *sk_psock_from_strp(struct strparser *strp) 672 { 673 struct sk_psock_parser *parser; 674 675 parser = container_of(strp, struct sk_psock_parser, strp); 676 return container_of(parser, struct sk_psock, parser); 677 } 678 679 static void sk_psock_verdict_apply(struct sk_psock *psock, 680 struct sk_buff *skb, int verdict) 681 { 682 struct sk_psock *psock_other; 683 struct sock *sk_other; 684 bool ingress; 685 686 switch (verdict) { 687 case __SK_PASS: 688 sk_other = psock->sk; 689 if (sock_flag(sk_other, SOCK_DEAD) || 690 !sk_psock_test_state(psock, SK_PSOCK_TX_ENABLED)) { 691 goto out_free; 692 } 693 if (atomic_read(&sk_other->sk_rmem_alloc) <= 694 sk_other->sk_rcvbuf) { 695 struct tcp_skb_cb *tcp = TCP_SKB_CB(skb); 696 697 tcp->bpf.flags |= BPF_F_INGRESS; 698 skb_queue_tail(&psock->ingress_skb, skb); 699 schedule_work(&psock->work); 700 break; 701 } 702 goto out_free; 703 case __SK_REDIRECT: 704 sk_other = tcp_skb_bpf_redirect_fetch(skb); 705 if (unlikely(!sk_other)) 706 goto out_free; 707 psock_other = sk_psock(sk_other); 708 if (!psock_other || sock_flag(sk_other, SOCK_DEAD) || 709 !sk_psock_test_state(psock_other, SK_PSOCK_TX_ENABLED)) 710 goto out_free; 711 ingress = tcp_skb_bpf_ingress(skb); 712 if ((!ingress && sock_writeable(sk_other)) || 713 (ingress && 714 atomic_read(&sk_other->sk_rmem_alloc) <= 715 sk_other->sk_rcvbuf)) { 716 if (!ingress) 717 skb_set_owner_w(skb, sk_other); 718 skb_queue_tail(&psock_other->ingress_skb, skb); 719 schedule_work(&psock_other->work); 720 break; 721 } 722 /* fall-through */ 723 case __SK_DROP: 724 /* fall-through */ 725 default: 726 out_free: 727 kfree_skb(skb); 728 } 729 } 730 731 static void sk_psock_strp_read(struct strparser *strp, struct sk_buff *skb) 732 { 733 struct sk_psock *psock = sk_psock_from_strp(strp); 734 struct bpf_prog *prog; 735 int ret = __SK_DROP; 736 737 rcu_read_lock(); 738 prog = READ_ONCE(psock->progs.skb_verdict); 739 if (likely(prog)) { 740 skb_orphan(skb); 741 tcp_skb_bpf_redirect_clear(skb); 742 ret = sk_psock_bpf_run(psock, prog, skb); 743 ret = sk_psock_map_verd(ret, tcp_skb_bpf_redirect_fetch(skb)); 744 } 745 rcu_read_unlock(); 746 sk_psock_verdict_apply(psock, skb, ret); 747 } 748 749 static int sk_psock_strp_read_done(struct strparser *strp, int err) 750 { 751 return err; 752 } 753 754 static int sk_psock_strp_parse(struct strparser *strp, struct sk_buff *skb) 755 { 756 struct sk_psock *psock = sk_psock_from_strp(strp); 757 struct bpf_prog *prog; 758 int ret = skb->len; 759 760 rcu_read_lock(); 761 prog = READ_ONCE(psock->progs.skb_parser); 762 if (likely(prog)) 763 ret = sk_psock_bpf_run(psock, prog, skb); 764 rcu_read_unlock(); 765 return ret; 766 } 767 768 /* Called with socket lock held. */ 769 static void sk_psock_strp_data_ready(struct sock *sk) 770 { 771 struct sk_psock *psock; 772 773 rcu_read_lock(); 774 psock = sk_psock(sk); 775 if (likely(psock)) { 776 write_lock_bh(&sk->sk_callback_lock); 777 strp_data_ready(&psock->parser.strp); 778 write_unlock_bh(&sk->sk_callback_lock); 779 } 780 rcu_read_unlock(); 781 } 782 783 static void sk_psock_write_space(struct sock *sk) 784 { 785 struct sk_psock *psock; 786 void (*write_space)(struct sock *sk); 787 788 rcu_read_lock(); 789 psock = sk_psock(sk); 790 if (likely(psock && sk_psock_test_state(psock, SK_PSOCK_TX_ENABLED))) 791 schedule_work(&psock->work); 792 write_space = psock->saved_write_space; 793 rcu_read_unlock(); 794 write_space(sk); 795 } 796 797 int sk_psock_init_strp(struct sock *sk, struct sk_psock *psock) 798 { 799 static const struct strp_callbacks cb = { 800 .rcv_msg = sk_psock_strp_read, 801 .read_sock_done = sk_psock_strp_read_done, 802 .parse_msg = sk_psock_strp_parse, 803 }; 804 805 psock->parser.enabled = false; 806 return strp_init(&psock->parser.strp, sk, &cb); 807 } 808 809 void sk_psock_start_strp(struct sock *sk, struct sk_psock *psock) 810 { 811 struct sk_psock_parser *parser = &psock->parser; 812 813 if (parser->enabled) 814 return; 815 816 parser->saved_data_ready = sk->sk_data_ready; 817 sk->sk_data_ready = sk_psock_strp_data_ready; 818 sk->sk_write_space = sk_psock_write_space; 819 parser->enabled = true; 820 } 821 822 void sk_psock_stop_strp(struct sock *sk, struct sk_psock *psock) 823 { 824 struct sk_psock_parser *parser = &psock->parser; 825 826 if (!parser->enabled) 827 return; 828 829 sk->sk_data_ready = parser->saved_data_ready; 830 parser->saved_data_ready = NULL; 831 strp_stop(&parser->strp); 832 parser->enabled = false; 833 } 834