1 // SPDX-License-Identifier: GPL-2.0 2 /* Copyright (c) 2017 - 2018 Covalent IO, Inc. http://covalent.io */ 3 4 #include <linux/skmsg.h> 5 #include <linux/skbuff.h> 6 #include <linux/scatterlist.h> 7 8 #include <net/sock.h> 9 #include <net/tcp.h> 10 11 static bool sk_msg_try_coalesce_ok(struct sk_msg *msg, int elem_first_coalesce) 12 { 13 if (msg->sg.end > msg->sg.start && 14 elem_first_coalesce < msg->sg.end) 15 return true; 16 17 if (msg->sg.end < msg->sg.start && 18 (elem_first_coalesce > msg->sg.start || 19 elem_first_coalesce < msg->sg.end)) 20 return true; 21 22 return false; 23 } 24 25 int sk_msg_alloc(struct sock *sk, struct sk_msg *msg, int len, 26 int elem_first_coalesce) 27 { 28 struct page_frag *pfrag = sk_page_frag(sk); 29 int ret = 0; 30 31 len -= msg->sg.size; 32 while (len > 0) { 33 struct scatterlist *sge; 34 u32 orig_offset; 35 int use, i; 36 37 if (!sk_page_frag_refill(sk, pfrag)) 38 return -ENOMEM; 39 40 orig_offset = pfrag->offset; 41 use = min_t(int, len, pfrag->size - orig_offset); 42 if (!sk_wmem_schedule(sk, use)) 43 return -ENOMEM; 44 45 i = msg->sg.end; 46 sk_msg_iter_var_prev(i); 47 sge = &msg->sg.data[i]; 48 49 if (sk_msg_try_coalesce_ok(msg, elem_first_coalesce) && 50 sg_page(sge) == pfrag->page && 51 sge->offset + sge->length == orig_offset) { 52 sge->length += use; 53 } else { 54 if (sk_msg_full(msg)) { 55 ret = -ENOSPC; 56 break; 57 } 58 59 sge = &msg->sg.data[msg->sg.end]; 60 sg_unmark_end(sge); 61 sg_set_page(sge, pfrag->page, use, orig_offset); 62 get_page(pfrag->page); 63 sk_msg_iter_next(msg, end); 64 } 65 66 sk_mem_charge(sk, use); 67 msg->sg.size += use; 68 pfrag->offset += use; 69 len -= use; 70 } 71 72 return ret; 73 } 74 EXPORT_SYMBOL_GPL(sk_msg_alloc); 75 76 int sk_msg_clone(struct sock *sk, struct sk_msg *dst, struct sk_msg *src, 77 u32 off, u32 len) 78 { 79 int i = src->sg.start; 80 struct scatterlist *sge = sk_msg_elem(src, i); 81 struct scatterlist *sgd = NULL; 82 u32 sge_len, sge_off; 83 84 while (off) { 85 if (sge->length > off) 86 break; 87 off -= sge->length; 88 sk_msg_iter_var_next(i); 89 if (i == src->sg.end && off) 90 return -ENOSPC; 91 sge = sk_msg_elem(src, i); 92 } 93 94 while (len) { 95 sge_len = sge->length - off; 96 if (sge_len > len) 97 sge_len = len; 98 99 if (dst->sg.end) 100 sgd = sk_msg_elem(dst, dst->sg.end - 1); 101 102 if (sgd && 103 (sg_page(sge) == sg_page(sgd)) && 104 (sg_virt(sge) + off == sg_virt(sgd) + sgd->length)) { 105 sgd->length += sge_len; 106 dst->sg.size += sge_len; 107 } else if (!sk_msg_full(dst)) { 108 sge_off = sge->offset + off; 109 sk_msg_page_add(dst, sg_page(sge), sge_len, sge_off); 110 } else { 111 return -ENOSPC; 112 } 113 114 off = 0; 115 len -= sge_len; 116 sk_mem_charge(sk, sge_len); 117 sk_msg_iter_var_next(i); 118 if (i == src->sg.end && len) 119 return -ENOSPC; 120 sge = sk_msg_elem(src, i); 121 } 122 123 return 0; 124 } 125 EXPORT_SYMBOL_GPL(sk_msg_clone); 126 127 void sk_msg_return_zero(struct sock *sk, struct sk_msg *msg, int bytes) 128 { 129 int i = msg->sg.start; 130 131 do { 132 struct scatterlist *sge = sk_msg_elem(msg, i); 133 134 if (bytes < sge->length) { 135 sge->length -= bytes; 136 sge->offset += bytes; 137 sk_mem_uncharge(sk, bytes); 138 break; 139 } 140 141 sk_mem_uncharge(sk, sge->length); 142 bytes -= sge->length; 143 sge->length = 0; 144 sge->offset = 0; 145 sk_msg_iter_var_next(i); 146 } while (bytes && i != msg->sg.end); 147 msg->sg.start = i; 148 } 149 EXPORT_SYMBOL_GPL(sk_msg_return_zero); 150 151 void sk_msg_return(struct sock *sk, struct sk_msg *msg, int bytes) 152 { 153 int i = msg->sg.start; 154 155 do { 156 struct scatterlist *sge = &msg->sg.data[i]; 157 int uncharge = (bytes < sge->length) ? bytes : sge->length; 158 159 sk_mem_uncharge(sk, uncharge); 160 bytes -= uncharge; 161 sk_msg_iter_var_next(i); 162 } while (i != msg->sg.end); 163 } 164 EXPORT_SYMBOL_GPL(sk_msg_return); 165 166 static int sk_msg_free_elem(struct sock *sk, struct sk_msg *msg, u32 i, 167 bool charge) 168 { 169 struct scatterlist *sge = sk_msg_elem(msg, i); 170 u32 len = sge->length; 171 172 if (charge) 173 sk_mem_uncharge(sk, len); 174 if (!msg->skb) 175 put_page(sg_page(sge)); 176 memset(sge, 0, sizeof(*sge)); 177 return len; 178 } 179 180 static int __sk_msg_free(struct sock *sk, struct sk_msg *msg, u32 i, 181 bool charge) 182 { 183 struct scatterlist *sge = sk_msg_elem(msg, i); 184 int freed = 0; 185 186 while (msg->sg.size) { 187 msg->sg.size -= sge->length; 188 freed += sk_msg_free_elem(sk, msg, i, charge); 189 sk_msg_iter_var_next(i); 190 sk_msg_check_to_free(msg, i, msg->sg.size); 191 sge = sk_msg_elem(msg, i); 192 } 193 if (msg->skb) 194 consume_skb(msg->skb); 195 sk_msg_init(msg); 196 return freed; 197 } 198 199 int sk_msg_free_nocharge(struct sock *sk, struct sk_msg *msg) 200 { 201 return __sk_msg_free(sk, msg, msg->sg.start, false); 202 } 203 EXPORT_SYMBOL_GPL(sk_msg_free_nocharge); 204 205 int sk_msg_free(struct sock *sk, struct sk_msg *msg) 206 { 207 return __sk_msg_free(sk, msg, msg->sg.start, true); 208 } 209 EXPORT_SYMBOL_GPL(sk_msg_free); 210 211 static void __sk_msg_free_partial(struct sock *sk, struct sk_msg *msg, 212 u32 bytes, bool charge) 213 { 214 struct scatterlist *sge; 215 u32 i = msg->sg.start; 216 217 while (bytes) { 218 sge = sk_msg_elem(msg, i); 219 if (!sge->length) 220 break; 221 if (bytes < sge->length) { 222 if (charge) 223 sk_mem_uncharge(sk, bytes); 224 sge->length -= bytes; 225 sge->offset += bytes; 226 msg->sg.size -= bytes; 227 break; 228 } 229 230 msg->sg.size -= sge->length; 231 bytes -= sge->length; 232 sk_msg_free_elem(sk, msg, i, charge); 233 sk_msg_iter_var_next(i); 234 sk_msg_check_to_free(msg, i, bytes); 235 } 236 msg->sg.start = i; 237 } 238 239 void sk_msg_free_partial(struct sock *sk, struct sk_msg *msg, u32 bytes) 240 { 241 __sk_msg_free_partial(sk, msg, bytes, true); 242 } 243 EXPORT_SYMBOL_GPL(sk_msg_free_partial); 244 245 void sk_msg_free_partial_nocharge(struct sock *sk, struct sk_msg *msg, 246 u32 bytes) 247 { 248 __sk_msg_free_partial(sk, msg, bytes, false); 249 } 250 251 void sk_msg_trim(struct sock *sk, struct sk_msg *msg, int len) 252 { 253 int trim = msg->sg.size - len; 254 u32 i = msg->sg.end; 255 256 if (trim <= 0) { 257 WARN_ON(trim < 0); 258 return; 259 } 260 261 sk_msg_iter_var_prev(i); 262 msg->sg.size = len; 263 while (msg->sg.data[i].length && 264 trim >= msg->sg.data[i].length) { 265 trim -= msg->sg.data[i].length; 266 sk_msg_free_elem(sk, msg, i, true); 267 sk_msg_iter_var_prev(i); 268 if (!trim) 269 goto out; 270 } 271 272 msg->sg.data[i].length -= trim; 273 sk_mem_uncharge(sk, trim); 274 out: 275 /* If we trim data before curr pointer update copybreak and current 276 * so that any future copy operations start at new copy location. 277 * However trimed data that has not yet been used in a copy op 278 * does not require an update. 279 */ 280 if (msg->sg.curr >= i) { 281 msg->sg.curr = i; 282 msg->sg.copybreak = msg->sg.data[i].length; 283 } 284 sk_msg_iter_var_next(i); 285 msg->sg.end = i; 286 } 287 EXPORT_SYMBOL_GPL(sk_msg_trim); 288 289 int sk_msg_zerocopy_from_iter(struct sock *sk, struct iov_iter *from, 290 struct sk_msg *msg, u32 bytes) 291 { 292 int i, maxpages, ret = 0, num_elems = sk_msg_elem_used(msg); 293 const int to_max_pages = MAX_MSG_FRAGS; 294 struct page *pages[MAX_MSG_FRAGS]; 295 ssize_t orig, copied, use, offset; 296 297 orig = msg->sg.size; 298 while (bytes > 0) { 299 i = 0; 300 maxpages = to_max_pages - num_elems; 301 if (maxpages == 0) { 302 ret = -EFAULT; 303 goto out; 304 } 305 306 copied = iov_iter_get_pages(from, pages, bytes, maxpages, 307 &offset); 308 if (copied <= 0) { 309 ret = -EFAULT; 310 goto out; 311 } 312 313 iov_iter_advance(from, copied); 314 bytes -= copied; 315 msg->sg.size += copied; 316 317 while (copied) { 318 use = min_t(int, copied, PAGE_SIZE - offset); 319 sg_set_page(&msg->sg.data[msg->sg.end], 320 pages[i], use, offset); 321 sg_unmark_end(&msg->sg.data[msg->sg.end]); 322 sk_mem_charge(sk, use); 323 324 offset = 0; 325 copied -= use; 326 sk_msg_iter_next(msg, end); 327 num_elems++; 328 i++; 329 } 330 /* When zerocopy is mixed with sk_msg_*copy* operations we 331 * may have a copybreak set in this case clear and prefer 332 * zerocopy remainder when possible. 333 */ 334 msg->sg.copybreak = 0; 335 msg->sg.curr = msg->sg.end; 336 } 337 out: 338 /* Revert iov_iter updates, msg will need to use 'trim' later if it 339 * also needs to be cleared. 340 */ 341 if (ret) 342 iov_iter_revert(from, msg->sg.size - orig); 343 return ret; 344 } 345 EXPORT_SYMBOL_GPL(sk_msg_zerocopy_from_iter); 346 347 int sk_msg_memcopy_from_iter(struct sock *sk, struct iov_iter *from, 348 struct sk_msg *msg, u32 bytes) 349 { 350 int ret = -ENOSPC, i = msg->sg.curr; 351 struct scatterlist *sge; 352 u32 copy, buf_size; 353 void *to; 354 355 do { 356 sge = sk_msg_elem(msg, i); 357 /* This is possible if a trim operation shrunk the buffer */ 358 if (msg->sg.copybreak >= sge->length) { 359 msg->sg.copybreak = 0; 360 sk_msg_iter_var_next(i); 361 if (i == msg->sg.end) 362 break; 363 sge = sk_msg_elem(msg, i); 364 } 365 366 buf_size = sge->length - msg->sg.copybreak; 367 copy = (buf_size > bytes) ? bytes : buf_size; 368 to = sg_virt(sge) + msg->sg.copybreak; 369 msg->sg.copybreak += copy; 370 if (sk->sk_route_caps & NETIF_F_NOCACHE_COPY) 371 ret = copy_from_iter_nocache(to, copy, from); 372 else 373 ret = copy_from_iter(to, copy, from); 374 if (ret != copy) { 375 ret = -EFAULT; 376 goto out; 377 } 378 bytes -= copy; 379 if (!bytes) 380 break; 381 msg->sg.copybreak = 0; 382 sk_msg_iter_var_next(i); 383 } while (i != msg->sg.end); 384 out: 385 msg->sg.curr = i; 386 return ret; 387 } 388 EXPORT_SYMBOL_GPL(sk_msg_memcopy_from_iter); 389 390 static int sk_psock_skb_ingress(struct sk_psock *psock, struct sk_buff *skb) 391 { 392 struct sock *sk = psock->sk; 393 int copied = 0, num_sge; 394 struct sk_msg *msg; 395 396 msg = kzalloc(sizeof(*msg), __GFP_NOWARN | GFP_ATOMIC); 397 if (unlikely(!msg)) 398 return -EAGAIN; 399 if (!sk_rmem_schedule(sk, skb, skb->len)) { 400 kfree(msg); 401 return -EAGAIN; 402 } 403 404 sk_msg_init(msg); 405 num_sge = skb_to_sgvec(skb, msg->sg.data, 0, skb->len); 406 if (unlikely(num_sge < 0)) { 407 kfree(msg); 408 return num_sge; 409 } 410 411 sk_mem_charge(sk, skb->len); 412 copied = skb->len; 413 msg->sg.start = 0; 414 msg->sg.end = num_sge == MAX_MSG_FRAGS ? 0 : num_sge; 415 msg->skb = skb; 416 417 sk_psock_queue_msg(psock, msg); 418 sk_psock_data_ready(sk, psock); 419 return copied; 420 } 421 422 static int sk_psock_handle_skb(struct sk_psock *psock, struct sk_buff *skb, 423 u32 off, u32 len, bool ingress) 424 { 425 if (ingress) 426 return sk_psock_skb_ingress(psock, skb); 427 else 428 return skb_send_sock_locked(psock->sk, skb, off, len); 429 } 430 431 static void sk_psock_backlog(struct work_struct *work) 432 { 433 struct sk_psock *psock = container_of(work, struct sk_psock, work); 434 struct sk_psock_work_state *state = &psock->work_state; 435 struct sk_buff *skb; 436 bool ingress; 437 u32 len, off; 438 int ret; 439 440 /* Lock sock to avoid losing sk_socket during loop. */ 441 lock_sock(psock->sk); 442 if (state->skb) { 443 skb = state->skb; 444 len = state->len; 445 off = state->off; 446 state->skb = NULL; 447 goto start; 448 } 449 450 while ((skb = skb_dequeue(&psock->ingress_skb))) { 451 len = skb->len; 452 off = 0; 453 start: 454 ingress = tcp_skb_bpf_ingress(skb); 455 do { 456 ret = -EIO; 457 if (likely(psock->sk->sk_socket)) 458 ret = sk_psock_handle_skb(psock, skb, off, 459 len, ingress); 460 if (ret <= 0) { 461 if (ret == -EAGAIN) { 462 state->skb = skb; 463 state->len = len; 464 state->off = off; 465 goto end; 466 } 467 /* Hard errors break pipe and stop xmit. */ 468 sk_psock_report_error(psock, ret ? -ret : EPIPE); 469 sk_psock_clear_state(psock, SK_PSOCK_TX_ENABLED); 470 kfree_skb(skb); 471 goto end; 472 } 473 off += ret; 474 len -= ret; 475 } while (len); 476 477 if (!ingress) 478 kfree_skb(skb); 479 } 480 end: 481 release_sock(psock->sk); 482 } 483 484 struct sk_psock *sk_psock_init(struct sock *sk, int node) 485 { 486 struct sk_psock *psock = kzalloc_node(sizeof(*psock), 487 GFP_ATOMIC | __GFP_NOWARN, 488 node); 489 if (!psock) 490 return NULL; 491 492 psock->sk = sk; 493 psock->eval = __SK_NONE; 494 495 INIT_LIST_HEAD(&psock->link); 496 spin_lock_init(&psock->link_lock); 497 498 INIT_WORK(&psock->work, sk_psock_backlog); 499 INIT_LIST_HEAD(&psock->ingress_msg); 500 skb_queue_head_init(&psock->ingress_skb); 501 502 sk_psock_set_state(psock, SK_PSOCK_TX_ENABLED); 503 refcount_set(&psock->refcnt, 1); 504 505 rcu_assign_sk_user_data(sk, psock); 506 sock_hold(sk); 507 508 return psock; 509 } 510 EXPORT_SYMBOL_GPL(sk_psock_init); 511 512 struct sk_psock_link *sk_psock_link_pop(struct sk_psock *psock) 513 { 514 struct sk_psock_link *link; 515 516 spin_lock_bh(&psock->link_lock); 517 link = list_first_entry_or_null(&psock->link, struct sk_psock_link, 518 list); 519 if (link) 520 list_del(&link->list); 521 spin_unlock_bh(&psock->link_lock); 522 return link; 523 } 524 525 void __sk_psock_purge_ingress_msg(struct sk_psock *psock) 526 { 527 struct sk_msg *msg, *tmp; 528 529 list_for_each_entry_safe(msg, tmp, &psock->ingress_msg, list) { 530 list_del(&msg->list); 531 sk_msg_free(psock->sk, msg); 532 kfree(msg); 533 } 534 } 535 536 static void sk_psock_zap_ingress(struct sk_psock *psock) 537 { 538 __skb_queue_purge(&psock->ingress_skb); 539 __sk_psock_purge_ingress_msg(psock); 540 } 541 542 static void sk_psock_link_destroy(struct sk_psock *psock) 543 { 544 struct sk_psock_link *link, *tmp; 545 546 list_for_each_entry_safe(link, tmp, &psock->link, list) { 547 list_del(&link->list); 548 sk_psock_free_link(link); 549 } 550 } 551 552 static void sk_psock_destroy_deferred(struct work_struct *gc) 553 { 554 struct sk_psock *psock = container_of(gc, struct sk_psock, gc); 555 556 /* No sk_callback_lock since already detached. */ 557 strp_stop(&psock->parser.strp); 558 strp_done(&psock->parser.strp); 559 560 cancel_work_sync(&psock->work); 561 562 psock_progs_drop(&psock->progs); 563 564 sk_psock_link_destroy(psock); 565 sk_psock_cork_free(psock); 566 sk_psock_zap_ingress(psock); 567 568 if (psock->sk_redir) 569 sock_put(psock->sk_redir); 570 sock_put(psock->sk); 571 kfree(psock); 572 } 573 574 void sk_psock_destroy(struct rcu_head *rcu) 575 { 576 struct sk_psock *psock = container_of(rcu, struct sk_psock, rcu); 577 578 INIT_WORK(&psock->gc, sk_psock_destroy_deferred); 579 schedule_work(&psock->gc); 580 } 581 EXPORT_SYMBOL_GPL(sk_psock_destroy); 582 583 void sk_psock_drop(struct sock *sk, struct sk_psock *psock) 584 { 585 rcu_assign_sk_user_data(sk, NULL); 586 sk_psock_cork_free(psock); 587 sk_psock_zap_ingress(psock); 588 sk_psock_restore_proto(sk, psock); 589 590 write_lock_bh(&sk->sk_callback_lock); 591 if (psock->progs.skb_parser) 592 sk_psock_stop_strp(sk, psock); 593 write_unlock_bh(&sk->sk_callback_lock); 594 sk_psock_clear_state(psock, SK_PSOCK_TX_ENABLED); 595 596 call_rcu(&psock->rcu, sk_psock_destroy); 597 } 598 EXPORT_SYMBOL_GPL(sk_psock_drop); 599 600 static int sk_psock_map_verd(int verdict, bool redir) 601 { 602 switch (verdict) { 603 case SK_PASS: 604 return redir ? __SK_REDIRECT : __SK_PASS; 605 case SK_DROP: 606 default: 607 break; 608 } 609 610 return __SK_DROP; 611 } 612 613 int sk_psock_msg_verdict(struct sock *sk, struct sk_psock *psock, 614 struct sk_msg *msg) 615 { 616 struct bpf_prog *prog; 617 int ret; 618 619 preempt_disable(); 620 rcu_read_lock(); 621 prog = READ_ONCE(psock->progs.msg_parser); 622 if (unlikely(!prog)) { 623 ret = __SK_PASS; 624 goto out; 625 } 626 627 sk_msg_compute_data_pointers(msg); 628 msg->sk = sk; 629 ret = BPF_PROG_RUN(prog, msg); 630 ret = sk_psock_map_verd(ret, msg->sk_redir); 631 psock->apply_bytes = msg->apply_bytes; 632 if (ret == __SK_REDIRECT) { 633 if (psock->sk_redir) 634 sock_put(psock->sk_redir); 635 psock->sk_redir = msg->sk_redir; 636 if (!psock->sk_redir) { 637 ret = __SK_DROP; 638 goto out; 639 } 640 sock_hold(psock->sk_redir); 641 } 642 out: 643 rcu_read_unlock(); 644 preempt_enable(); 645 return ret; 646 } 647 EXPORT_SYMBOL_GPL(sk_psock_msg_verdict); 648 649 static int sk_psock_bpf_run(struct sk_psock *psock, struct bpf_prog *prog, 650 struct sk_buff *skb) 651 { 652 int ret; 653 654 skb->sk = psock->sk; 655 bpf_compute_data_end_sk_skb(skb); 656 preempt_disable(); 657 ret = BPF_PROG_RUN(prog, skb); 658 preempt_enable(); 659 /* strparser clones the skb before handing it to a upper layer, 660 * meaning skb_orphan has been called. We NULL sk on the way out 661 * to ensure we don't trigger a BUG_ON() in skb/sk operations 662 * later and because we are not charging the memory of this skb 663 * to any socket yet. 664 */ 665 skb->sk = NULL; 666 return ret; 667 } 668 669 static struct sk_psock *sk_psock_from_strp(struct strparser *strp) 670 { 671 struct sk_psock_parser *parser; 672 673 parser = container_of(strp, struct sk_psock_parser, strp); 674 return container_of(parser, struct sk_psock, parser); 675 } 676 677 static void sk_psock_verdict_apply(struct sk_psock *psock, 678 struct sk_buff *skb, int verdict) 679 { 680 struct sk_psock *psock_other; 681 struct sock *sk_other; 682 bool ingress; 683 684 switch (verdict) { 685 case __SK_PASS: 686 sk_other = psock->sk; 687 if (sock_flag(sk_other, SOCK_DEAD) || 688 !sk_psock_test_state(psock, SK_PSOCK_TX_ENABLED)) { 689 goto out_free; 690 } 691 if (atomic_read(&sk_other->sk_rmem_alloc) <= 692 sk_other->sk_rcvbuf) { 693 struct tcp_skb_cb *tcp = TCP_SKB_CB(skb); 694 695 tcp->bpf.flags |= BPF_F_INGRESS; 696 skb_queue_tail(&psock->ingress_skb, skb); 697 schedule_work(&psock->work); 698 break; 699 } 700 goto out_free; 701 case __SK_REDIRECT: 702 sk_other = tcp_skb_bpf_redirect_fetch(skb); 703 if (unlikely(!sk_other)) 704 goto out_free; 705 psock_other = sk_psock(sk_other); 706 if (!psock_other || sock_flag(sk_other, SOCK_DEAD) || 707 !sk_psock_test_state(psock_other, SK_PSOCK_TX_ENABLED)) 708 goto out_free; 709 ingress = tcp_skb_bpf_ingress(skb); 710 if ((!ingress && sock_writeable(sk_other)) || 711 (ingress && 712 atomic_read(&sk_other->sk_rmem_alloc) <= 713 sk_other->sk_rcvbuf)) { 714 if (!ingress) 715 skb_set_owner_w(skb, sk_other); 716 skb_queue_tail(&psock_other->ingress_skb, skb); 717 schedule_work(&psock_other->work); 718 break; 719 } 720 /* fall-through */ 721 case __SK_DROP: 722 /* fall-through */ 723 default: 724 out_free: 725 kfree_skb(skb); 726 } 727 } 728 729 static void sk_psock_strp_read(struct strparser *strp, struct sk_buff *skb) 730 { 731 struct sk_psock *psock = sk_psock_from_strp(strp); 732 struct bpf_prog *prog; 733 int ret = __SK_DROP; 734 735 rcu_read_lock(); 736 prog = READ_ONCE(psock->progs.skb_verdict); 737 if (likely(prog)) { 738 skb_orphan(skb); 739 tcp_skb_bpf_redirect_clear(skb); 740 ret = sk_psock_bpf_run(psock, prog, skb); 741 ret = sk_psock_map_verd(ret, tcp_skb_bpf_redirect_fetch(skb)); 742 } 743 rcu_read_unlock(); 744 sk_psock_verdict_apply(psock, skb, ret); 745 } 746 747 static int sk_psock_strp_read_done(struct strparser *strp, int err) 748 { 749 return err; 750 } 751 752 static int sk_psock_strp_parse(struct strparser *strp, struct sk_buff *skb) 753 { 754 struct sk_psock *psock = sk_psock_from_strp(strp); 755 struct bpf_prog *prog; 756 int ret = skb->len; 757 758 rcu_read_lock(); 759 prog = READ_ONCE(psock->progs.skb_parser); 760 if (likely(prog)) 761 ret = sk_psock_bpf_run(psock, prog, skb); 762 rcu_read_unlock(); 763 return ret; 764 } 765 766 /* Called with socket lock held. */ 767 static void sk_psock_strp_data_ready(struct sock *sk) 768 { 769 struct sk_psock *psock; 770 771 rcu_read_lock(); 772 psock = sk_psock(sk); 773 if (likely(psock)) { 774 write_lock_bh(&sk->sk_callback_lock); 775 strp_data_ready(&psock->parser.strp); 776 write_unlock_bh(&sk->sk_callback_lock); 777 } 778 rcu_read_unlock(); 779 } 780 781 static void sk_psock_write_space(struct sock *sk) 782 { 783 struct sk_psock *psock; 784 void (*write_space)(struct sock *sk); 785 786 rcu_read_lock(); 787 psock = sk_psock(sk); 788 if (likely(psock && sk_psock_test_state(psock, SK_PSOCK_TX_ENABLED))) 789 schedule_work(&psock->work); 790 write_space = psock->saved_write_space; 791 rcu_read_unlock(); 792 write_space(sk); 793 } 794 795 int sk_psock_init_strp(struct sock *sk, struct sk_psock *psock) 796 { 797 static const struct strp_callbacks cb = { 798 .rcv_msg = sk_psock_strp_read, 799 .read_sock_done = sk_psock_strp_read_done, 800 .parse_msg = sk_psock_strp_parse, 801 }; 802 803 psock->parser.enabled = false; 804 return strp_init(&psock->parser.strp, sk, &cb); 805 } 806 807 void sk_psock_start_strp(struct sock *sk, struct sk_psock *psock) 808 { 809 struct sk_psock_parser *parser = &psock->parser; 810 811 if (parser->enabled) 812 return; 813 814 parser->saved_data_ready = sk->sk_data_ready; 815 sk->sk_data_ready = sk_psock_strp_data_ready; 816 sk->sk_write_space = sk_psock_write_space; 817 parser->enabled = true; 818 } 819 820 void sk_psock_stop_strp(struct sock *sk, struct sk_psock *psock) 821 { 822 struct sk_psock_parser *parser = &psock->parser; 823 824 if (!parser->enabled) 825 return; 826 827 sk->sk_data_ready = parser->saved_data_ready; 828 parser->saved_data_ready = NULL; 829 strp_stop(&parser->strp); 830 parser->enabled = false; 831 } 832