1 // SPDX-License-Identifier: GPL-2.0 2 /* Copyright (c) 2017 - 2018 Covalent IO, Inc. http://covalent.io */ 3 4 #include <linux/skmsg.h> 5 #include <linux/skbuff.h> 6 #include <linux/scatterlist.h> 7 8 #include <net/sock.h> 9 #include <net/tcp.h> 10 11 static bool sk_msg_try_coalesce_ok(struct sk_msg *msg, int elem_first_coalesce) 12 { 13 if (msg->sg.end > msg->sg.start && 14 elem_first_coalesce < msg->sg.end) 15 return true; 16 17 if (msg->sg.end < msg->sg.start && 18 (elem_first_coalesce > msg->sg.start || 19 elem_first_coalesce < msg->sg.end)) 20 return true; 21 22 return false; 23 } 24 25 int sk_msg_alloc(struct sock *sk, struct sk_msg *msg, int len, 26 int elem_first_coalesce) 27 { 28 struct page_frag *pfrag = sk_page_frag(sk); 29 int ret = 0; 30 31 len -= msg->sg.size; 32 while (len > 0) { 33 struct scatterlist *sge; 34 u32 orig_offset; 35 int use, i; 36 37 if (!sk_page_frag_refill(sk, pfrag)) 38 return -ENOMEM; 39 40 orig_offset = pfrag->offset; 41 use = min_t(int, len, pfrag->size - orig_offset); 42 if (!sk_wmem_schedule(sk, use)) 43 return -ENOMEM; 44 45 i = msg->sg.end; 46 sk_msg_iter_var_prev(i); 47 sge = &msg->sg.data[i]; 48 49 if (sk_msg_try_coalesce_ok(msg, elem_first_coalesce) && 50 sg_page(sge) == pfrag->page && 51 sge->offset + sge->length == orig_offset) { 52 sge->length += use; 53 } else { 54 if (sk_msg_full(msg)) { 55 ret = -ENOSPC; 56 break; 57 } 58 59 sge = &msg->sg.data[msg->sg.end]; 60 sg_unmark_end(sge); 61 sg_set_page(sge, pfrag->page, use, orig_offset); 62 get_page(pfrag->page); 63 sk_msg_iter_next(msg, end); 64 } 65 66 sk_mem_charge(sk, use); 67 msg->sg.size += use; 68 pfrag->offset += use; 69 len -= use; 70 } 71 72 return ret; 73 } 74 EXPORT_SYMBOL_GPL(sk_msg_alloc); 75 76 int sk_msg_clone(struct sock *sk, struct sk_msg *dst, struct sk_msg *src, 77 u32 off, u32 len) 78 { 79 int i = src->sg.start; 80 struct scatterlist *sge = sk_msg_elem(src, i); 81 struct scatterlist *sgd = NULL; 82 u32 sge_len, sge_off; 83 84 while (off) { 85 if (sge->length > off) 86 break; 87 off -= sge->length; 88 sk_msg_iter_var_next(i); 89 if (i == src->sg.end && off) 90 return -ENOSPC; 91 sge = sk_msg_elem(src, i); 92 } 93 94 while (len) { 95 sge_len = sge->length - off; 96 if (sge_len > len) 97 sge_len = len; 98 99 if (dst->sg.end) 100 sgd = sk_msg_elem(dst, dst->sg.end - 1); 101 102 if (sgd && 103 (sg_page(sge) == sg_page(sgd)) && 104 (sg_virt(sge) + off == sg_virt(sgd) + sgd->length)) { 105 sgd->length += sge_len; 106 dst->sg.size += sge_len; 107 } else if (!sk_msg_full(dst)) { 108 sge_off = sge->offset + off; 109 sk_msg_page_add(dst, sg_page(sge), sge_len, sge_off); 110 } else { 111 return -ENOSPC; 112 } 113 114 off = 0; 115 len -= sge_len; 116 sk_mem_charge(sk, sge_len); 117 sk_msg_iter_var_next(i); 118 if (i == src->sg.end && len) 119 return -ENOSPC; 120 sge = sk_msg_elem(src, i); 121 } 122 123 return 0; 124 } 125 EXPORT_SYMBOL_GPL(sk_msg_clone); 126 127 void sk_msg_return_zero(struct sock *sk, struct sk_msg *msg, int bytes) 128 { 129 int i = msg->sg.start; 130 131 do { 132 struct scatterlist *sge = sk_msg_elem(msg, i); 133 134 if (bytes < sge->length) { 135 sge->length -= bytes; 136 sge->offset += bytes; 137 sk_mem_uncharge(sk, bytes); 138 break; 139 } 140 141 sk_mem_uncharge(sk, sge->length); 142 bytes -= sge->length; 143 sge->length = 0; 144 sge->offset = 0; 145 sk_msg_iter_var_next(i); 146 } while (bytes && i != msg->sg.end); 147 msg->sg.start = i; 148 } 149 EXPORT_SYMBOL_GPL(sk_msg_return_zero); 150 151 void sk_msg_return(struct sock *sk, struct sk_msg *msg, int bytes) 152 { 153 int i = msg->sg.start; 154 155 do { 156 struct scatterlist *sge = &msg->sg.data[i]; 157 int uncharge = (bytes < sge->length) ? bytes : sge->length; 158 159 sk_mem_uncharge(sk, uncharge); 160 bytes -= uncharge; 161 sk_msg_iter_var_next(i); 162 } while (i != msg->sg.end); 163 } 164 EXPORT_SYMBOL_GPL(sk_msg_return); 165 166 static int sk_msg_free_elem(struct sock *sk, struct sk_msg *msg, u32 i, 167 bool charge) 168 { 169 struct scatterlist *sge = sk_msg_elem(msg, i); 170 u32 len = sge->length; 171 172 if (charge) 173 sk_mem_uncharge(sk, len); 174 if (!msg->skb) 175 put_page(sg_page(sge)); 176 memset(sge, 0, sizeof(*sge)); 177 return len; 178 } 179 180 static int __sk_msg_free(struct sock *sk, struct sk_msg *msg, u32 i, 181 bool charge) 182 { 183 struct scatterlist *sge = sk_msg_elem(msg, i); 184 int freed = 0; 185 186 while (msg->sg.size) { 187 msg->sg.size -= sge->length; 188 freed += sk_msg_free_elem(sk, msg, i, charge); 189 sk_msg_iter_var_next(i); 190 sk_msg_check_to_free(msg, i, msg->sg.size); 191 sge = sk_msg_elem(msg, i); 192 } 193 if (msg->skb) 194 consume_skb(msg->skb); 195 sk_msg_init(msg); 196 return freed; 197 } 198 199 int sk_msg_free_nocharge(struct sock *sk, struct sk_msg *msg) 200 { 201 return __sk_msg_free(sk, msg, msg->sg.start, false); 202 } 203 EXPORT_SYMBOL_GPL(sk_msg_free_nocharge); 204 205 int sk_msg_free(struct sock *sk, struct sk_msg *msg) 206 { 207 return __sk_msg_free(sk, msg, msg->sg.start, true); 208 } 209 EXPORT_SYMBOL_GPL(sk_msg_free); 210 211 static void __sk_msg_free_partial(struct sock *sk, struct sk_msg *msg, 212 u32 bytes, bool charge) 213 { 214 struct scatterlist *sge; 215 u32 i = msg->sg.start; 216 217 while (bytes) { 218 sge = sk_msg_elem(msg, i); 219 if (!sge->length) 220 break; 221 if (bytes < sge->length) { 222 if (charge) 223 sk_mem_uncharge(sk, bytes); 224 sge->length -= bytes; 225 sge->offset += bytes; 226 msg->sg.size -= bytes; 227 break; 228 } 229 230 msg->sg.size -= sge->length; 231 bytes -= sge->length; 232 sk_msg_free_elem(sk, msg, i, charge); 233 sk_msg_iter_var_next(i); 234 sk_msg_check_to_free(msg, i, bytes); 235 } 236 msg->sg.start = i; 237 } 238 239 void sk_msg_free_partial(struct sock *sk, struct sk_msg *msg, u32 bytes) 240 { 241 __sk_msg_free_partial(sk, msg, bytes, true); 242 } 243 EXPORT_SYMBOL_GPL(sk_msg_free_partial); 244 245 void sk_msg_free_partial_nocharge(struct sock *sk, struct sk_msg *msg, 246 u32 bytes) 247 { 248 __sk_msg_free_partial(sk, msg, bytes, false); 249 } 250 251 void sk_msg_trim(struct sock *sk, struct sk_msg *msg, int len) 252 { 253 int trim = msg->sg.size - len; 254 u32 i = msg->sg.end; 255 256 if (trim <= 0) { 257 WARN_ON(trim < 0); 258 return; 259 } 260 261 sk_msg_iter_var_prev(i); 262 msg->sg.size = len; 263 while (msg->sg.data[i].length && 264 trim >= msg->sg.data[i].length) { 265 trim -= msg->sg.data[i].length; 266 sk_msg_free_elem(sk, msg, i, true); 267 sk_msg_iter_var_prev(i); 268 if (!trim) 269 goto out; 270 } 271 272 msg->sg.data[i].length -= trim; 273 sk_mem_uncharge(sk, trim); 274 out: 275 /* If we trim data before curr pointer update copybreak and current 276 * so that any future copy operations start at new copy location. 277 * However trimed data that has not yet been used in a copy op 278 * does not require an update. 279 */ 280 if (msg->sg.curr >= i) { 281 msg->sg.curr = i; 282 msg->sg.copybreak = msg->sg.data[i].length; 283 } 284 sk_msg_iter_var_next(i); 285 msg->sg.end = i; 286 } 287 EXPORT_SYMBOL_GPL(sk_msg_trim); 288 289 int sk_msg_zerocopy_from_iter(struct sock *sk, struct iov_iter *from, 290 struct sk_msg *msg, u32 bytes) 291 { 292 int i, maxpages, ret = 0, num_elems = sk_msg_elem_used(msg); 293 const int to_max_pages = MAX_MSG_FRAGS; 294 struct page *pages[MAX_MSG_FRAGS]; 295 ssize_t orig, copied, use, offset; 296 297 orig = msg->sg.size; 298 while (bytes > 0) { 299 i = 0; 300 maxpages = to_max_pages - num_elems; 301 if (maxpages == 0) { 302 ret = -EFAULT; 303 goto out; 304 } 305 306 copied = iov_iter_get_pages(from, pages, bytes, maxpages, 307 &offset); 308 if (copied <= 0) { 309 ret = -EFAULT; 310 goto out; 311 } 312 313 iov_iter_advance(from, copied); 314 bytes -= copied; 315 msg->sg.size += copied; 316 317 while (copied) { 318 use = min_t(int, copied, PAGE_SIZE - offset); 319 sg_set_page(&msg->sg.data[msg->sg.end], 320 pages[i], use, offset); 321 sg_unmark_end(&msg->sg.data[msg->sg.end]); 322 sk_mem_charge(sk, use); 323 324 offset = 0; 325 copied -= use; 326 sk_msg_iter_next(msg, end); 327 num_elems++; 328 i++; 329 } 330 /* When zerocopy is mixed with sk_msg_*copy* operations we 331 * may have a copybreak set in this case clear and prefer 332 * zerocopy remainder when possible. 333 */ 334 msg->sg.copybreak = 0; 335 msg->sg.curr = msg->sg.end; 336 } 337 out: 338 /* Revert iov_iter updates, msg will need to use 'trim' later if it 339 * also needs to be cleared. 340 */ 341 if (ret) 342 iov_iter_revert(from, msg->sg.size - orig); 343 return ret; 344 } 345 EXPORT_SYMBOL_GPL(sk_msg_zerocopy_from_iter); 346 347 int sk_msg_memcopy_from_iter(struct sock *sk, struct iov_iter *from, 348 struct sk_msg *msg, u32 bytes) 349 { 350 int ret = -ENOSPC, i = msg->sg.curr; 351 struct scatterlist *sge; 352 u32 copy, buf_size; 353 void *to; 354 355 do { 356 sge = sk_msg_elem(msg, i); 357 /* This is possible if a trim operation shrunk the buffer */ 358 if (msg->sg.copybreak >= sge->length) { 359 msg->sg.copybreak = 0; 360 sk_msg_iter_var_next(i); 361 if (i == msg->sg.end) 362 break; 363 sge = sk_msg_elem(msg, i); 364 } 365 366 buf_size = sge->length - msg->sg.copybreak; 367 copy = (buf_size > bytes) ? bytes : buf_size; 368 to = sg_virt(sge) + msg->sg.copybreak; 369 msg->sg.copybreak += copy; 370 if (sk->sk_route_caps & NETIF_F_NOCACHE_COPY) 371 ret = copy_from_iter_nocache(to, copy, from); 372 else 373 ret = copy_from_iter(to, copy, from); 374 if (ret != copy) { 375 ret = -EFAULT; 376 goto out; 377 } 378 bytes -= copy; 379 if (!bytes) 380 break; 381 msg->sg.copybreak = 0; 382 sk_msg_iter_var_next(i); 383 } while (i != msg->sg.end); 384 out: 385 msg->sg.curr = i; 386 return ret; 387 } 388 EXPORT_SYMBOL_GPL(sk_msg_memcopy_from_iter); 389 390 static int sk_psock_skb_ingress(struct sk_psock *psock, struct sk_buff *skb) 391 { 392 struct sock *sk = psock->sk; 393 int copied = 0, num_sge; 394 struct sk_msg *msg; 395 396 msg = kzalloc(sizeof(*msg), __GFP_NOWARN | GFP_ATOMIC); 397 if (unlikely(!msg)) 398 return -EAGAIN; 399 if (!sk_rmem_schedule(sk, skb, skb->len)) { 400 kfree(msg); 401 return -EAGAIN; 402 } 403 404 sk_msg_init(msg); 405 num_sge = skb_to_sgvec(skb, msg->sg.data, 0, skb->len); 406 if (unlikely(num_sge < 0)) { 407 kfree(msg); 408 return num_sge; 409 } 410 411 sk_mem_charge(sk, skb->len); 412 copied = skb->len; 413 msg->sg.start = 0; 414 msg->sg.end = num_sge == MAX_MSG_FRAGS ? 0 : num_sge; 415 msg->skb = skb; 416 417 sk_psock_queue_msg(psock, msg); 418 sk_psock_data_ready(sk, psock); 419 return copied; 420 } 421 422 static int sk_psock_handle_skb(struct sk_psock *psock, struct sk_buff *skb, 423 u32 off, u32 len, bool ingress) 424 { 425 if (ingress) 426 return sk_psock_skb_ingress(psock, skb); 427 else 428 return skb_send_sock_locked(psock->sk, skb, off, len); 429 } 430 431 static void sk_psock_backlog(struct work_struct *work) 432 { 433 struct sk_psock *psock = container_of(work, struct sk_psock, work); 434 struct sk_psock_work_state *state = &psock->work_state; 435 struct sk_buff *skb; 436 bool ingress; 437 u32 len, off; 438 int ret; 439 440 /* Lock sock to avoid losing sk_socket during loop. */ 441 lock_sock(psock->sk); 442 if (state->skb) { 443 skb = state->skb; 444 len = state->len; 445 off = state->off; 446 state->skb = NULL; 447 goto start; 448 } 449 450 while ((skb = skb_dequeue(&psock->ingress_skb))) { 451 len = skb->len; 452 off = 0; 453 start: 454 ingress = tcp_skb_bpf_ingress(skb); 455 do { 456 ret = -EIO; 457 if (likely(psock->sk->sk_socket)) 458 ret = sk_psock_handle_skb(psock, skb, off, 459 len, ingress); 460 if (ret <= 0) { 461 if (ret == -EAGAIN) { 462 state->skb = skb; 463 state->len = len; 464 state->off = off; 465 goto end; 466 } 467 /* Hard errors break pipe and stop xmit. */ 468 sk_psock_report_error(psock, ret ? -ret : EPIPE); 469 sk_psock_clear_state(psock, SK_PSOCK_TX_ENABLED); 470 kfree_skb(skb); 471 goto end; 472 } 473 off += ret; 474 len -= ret; 475 } while (len); 476 477 if (!ingress) 478 kfree_skb(skb); 479 } 480 end: 481 release_sock(psock->sk); 482 } 483 484 struct sk_psock *sk_psock_init(struct sock *sk, int node) 485 { 486 struct sk_psock *psock = kzalloc_node(sizeof(*psock), 487 GFP_ATOMIC | __GFP_NOWARN, 488 node); 489 if (!psock) 490 return NULL; 491 492 psock->sk = sk; 493 psock->eval = __SK_NONE; 494 495 INIT_LIST_HEAD(&psock->link); 496 spin_lock_init(&psock->link_lock); 497 498 INIT_WORK(&psock->work, sk_psock_backlog); 499 INIT_LIST_HEAD(&psock->ingress_msg); 500 skb_queue_head_init(&psock->ingress_skb); 501 502 sk_psock_set_state(psock, SK_PSOCK_TX_ENABLED); 503 refcount_set(&psock->refcnt, 1); 504 505 rcu_assign_sk_user_data(sk, psock); 506 sock_hold(sk); 507 508 return psock; 509 } 510 EXPORT_SYMBOL_GPL(sk_psock_init); 511 512 struct sk_psock_link *sk_psock_link_pop(struct sk_psock *psock) 513 { 514 struct sk_psock_link *link; 515 516 spin_lock_bh(&psock->link_lock); 517 link = list_first_entry_or_null(&psock->link, struct sk_psock_link, 518 list); 519 if (link) 520 list_del(&link->list); 521 spin_unlock_bh(&psock->link_lock); 522 return link; 523 } 524 525 void __sk_psock_purge_ingress_msg(struct sk_psock *psock) 526 { 527 struct sk_msg *msg, *tmp; 528 529 list_for_each_entry_safe(msg, tmp, &psock->ingress_msg, list) { 530 list_del(&msg->list); 531 sk_msg_free(psock->sk, msg); 532 kfree(msg); 533 } 534 } 535 536 static void sk_psock_zap_ingress(struct sk_psock *psock) 537 { 538 __skb_queue_purge(&psock->ingress_skb); 539 __sk_psock_purge_ingress_msg(psock); 540 } 541 542 static void sk_psock_link_destroy(struct sk_psock *psock) 543 { 544 struct sk_psock_link *link, *tmp; 545 546 list_for_each_entry_safe(link, tmp, &psock->link, list) { 547 list_del(&link->list); 548 sk_psock_free_link(link); 549 } 550 } 551 552 static void sk_psock_destroy_deferred(struct work_struct *gc) 553 { 554 struct sk_psock *psock = container_of(gc, struct sk_psock, gc); 555 556 /* No sk_callback_lock since already detached. */ 557 strp_done(&psock->parser.strp); 558 559 cancel_work_sync(&psock->work); 560 561 psock_progs_drop(&psock->progs); 562 563 sk_psock_link_destroy(psock); 564 sk_psock_cork_free(psock); 565 sk_psock_zap_ingress(psock); 566 567 if (psock->sk_redir) 568 sock_put(psock->sk_redir); 569 sock_put(psock->sk); 570 kfree(psock); 571 } 572 573 void sk_psock_destroy(struct rcu_head *rcu) 574 { 575 struct sk_psock *psock = container_of(rcu, struct sk_psock, rcu); 576 577 INIT_WORK(&psock->gc, sk_psock_destroy_deferred); 578 schedule_work(&psock->gc); 579 } 580 EXPORT_SYMBOL_GPL(sk_psock_destroy); 581 582 void sk_psock_drop(struct sock *sk, struct sk_psock *psock) 583 { 584 rcu_assign_sk_user_data(sk, NULL); 585 sk_psock_cork_free(psock); 586 sk_psock_zap_ingress(psock); 587 sk_psock_restore_proto(sk, psock); 588 589 write_lock_bh(&sk->sk_callback_lock); 590 if (psock->progs.skb_parser) 591 sk_psock_stop_strp(sk, psock); 592 write_unlock_bh(&sk->sk_callback_lock); 593 sk_psock_clear_state(psock, SK_PSOCK_TX_ENABLED); 594 595 call_rcu(&psock->rcu, sk_psock_destroy); 596 } 597 EXPORT_SYMBOL_GPL(sk_psock_drop); 598 599 static int sk_psock_map_verd(int verdict, bool redir) 600 { 601 switch (verdict) { 602 case SK_PASS: 603 return redir ? __SK_REDIRECT : __SK_PASS; 604 case SK_DROP: 605 default: 606 break; 607 } 608 609 return __SK_DROP; 610 } 611 612 int sk_psock_msg_verdict(struct sock *sk, struct sk_psock *psock, 613 struct sk_msg *msg) 614 { 615 struct bpf_prog *prog; 616 int ret; 617 618 preempt_disable(); 619 rcu_read_lock(); 620 prog = READ_ONCE(psock->progs.msg_parser); 621 if (unlikely(!prog)) { 622 ret = __SK_PASS; 623 goto out; 624 } 625 626 sk_msg_compute_data_pointers(msg); 627 msg->sk = sk; 628 ret = BPF_PROG_RUN(prog, msg); 629 ret = sk_psock_map_verd(ret, msg->sk_redir); 630 psock->apply_bytes = msg->apply_bytes; 631 if (ret == __SK_REDIRECT) { 632 if (psock->sk_redir) 633 sock_put(psock->sk_redir); 634 psock->sk_redir = msg->sk_redir; 635 if (!psock->sk_redir) { 636 ret = __SK_DROP; 637 goto out; 638 } 639 sock_hold(psock->sk_redir); 640 } 641 out: 642 rcu_read_unlock(); 643 preempt_enable(); 644 return ret; 645 } 646 EXPORT_SYMBOL_GPL(sk_psock_msg_verdict); 647 648 static int sk_psock_bpf_run(struct sk_psock *psock, struct bpf_prog *prog, 649 struct sk_buff *skb) 650 { 651 int ret; 652 653 skb->sk = psock->sk; 654 bpf_compute_data_end_sk_skb(skb); 655 preempt_disable(); 656 ret = BPF_PROG_RUN(prog, skb); 657 preempt_enable(); 658 /* strparser clones the skb before handing it to a upper layer, 659 * meaning skb_orphan has been called. We NULL sk on the way out 660 * to ensure we don't trigger a BUG_ON() in skb/sk operations 661 * later and because we are not charging the memory of this skb 662 * to any socket yet. 663 */ 664 skb->sk = NULL; 665 return ret; 666 } 667 668 static struct sk_psock *sk_psock_from_strp(struct strparser *strp) 669 { 670 struct sk_psock_parser *parser; 671 672 parser = container_of(strp, struct sk_psock_parser, strp); 673 return container_of(parser, struct sk_psock, parser); 674 } 675 676 static void sk_psock_verdict_apply(struct sk_psock *psock, 677 struct sk_buff *skb, int verdict) 678 { 679 struct sk_psock *psock_other; 680 struct sock *sk_other; 681 bool ingress; 682 683 switch (verdict) { 684 case __SK_PASS: 685 sk_other = psock->sk; 686 if (sock_flag(sk_other, SOCK_DEAD) || 687 !sk_psock_test_state(psock, SK_PSOCK_TX_ENABLED)) { 688 goto out_free; 689 } 690 if (atomic_read(&sk_other->sk_rmem_alloc) <= 691 sk_other->sk_rcvbuf) { 692 struct tcp_skb_cb *tcp = TCP_SKB_CB(skb); 693 694 tcp->bpf.flags |= BPF_F_INGRESS; 695 skb_queue_tail(&psock->ingress_skb, skb); 696 schedule_work(&psock->work); 697 break; 698 } 699 goto out_free; 700 case __SK_REDIRECT: 701 sk_other = tcp_skb_bpf_redirect_fetch(skb); 702 if (unlikely(!sk_other)) 703 goto out_free; 704 psock_other = sk_psock(sk_other); 705 if (!psock_other || sock_flag(sk_other, SOCK_DEAD) || 706 !sk_psock_test_state(psock_other, SK_PSOCK_TX_ENABLED)) 707 goto out_free; 708 ingress = tcp_skb_bpf_ingress(skb); 709 if ((!ingress && sock_writeable(sk_other)) || 710 (ingress && 711 atomic_read(&sk_other->sk_rmem_alloc) <= 712 sk_other->sk_rcvbuf)) { 713 if (!ingress) 714 skb_set_owner_w(skb, sk_other); 715 skb_queue_tail(&psock_other->ingress_skb, skb); 716 schedule_work(&psock_other->work); 717 break; 718 } 719 /* fall-through */ 720 case __SK_DROP: 721 /* fall-through */ 722 default: 723 out_free: 724 kfree_skb(skb); 725 } 726 } 727 728 static void sk_psock_strp_read(struct strparser *strp, struct sk_buff *skb) 729 { 730 struct sk_psock *psock = sk_psock_from_strp(strp); 731 struct bpf_prog *prog; 732 int ret = __SK_DROP; 733 734 rcu_read_lock(); 735 prog = READ_ONCE(psock->progs.skb_verdict); 736 if (likely(prog)) { 737 skb_orphan(skb); 738 tcp_skb_bpf_redirect_clear(skb); 739 ret = sk_psock_bpf_run(psock, prog, skb); 740 ret = sk_psock_map_verd(ret, tcp_skb_bpf_redirect_fetch(skb)); 741 } 742 rcu_read_unlock(); 743 sk_psock_verdict_apply(psock, skb, ret); 744 } 745 746 static int sk_psock_strp_read_done(struct strparser *strp, int err) 747 { 748 return err; 749 } 750 751 static int sk_psock_strp_parse(struct strparser *strp, struct sk_buff *skb) 752 { 753 struct sk_psock *psock = sk_psock_from_strp(strp); 754 struct bpf_prog *prog; 755 int ret = skb->len; 756 757 rcu_read_lock(); 758 prog = READ_ONCE(psock->progs.skb_parser); 759 if (likely(prog)) 760 ret = sk_psock_bpf_run(psock, prog, skb); 761 rcu_read_unlock(); 762 return ret; 763 } 764 765 /* Called with socket lock held. */ 766 static void sk_psock_strp_data_ready(struct sock *sk) 767 { 768 struct sk_psock *psock; 769 770 rcu_read_lock(); 771 psock = sk_psock(sk); 772 if (likely(psock)) { 773 write_lock_bh(&sk->sk_callback_lock); 774 strp_data_ready(&psock->parser.strp); 775 write_unlock_bh(&sk->sk_callback_lock); 776 } 777 rcu_read_unlock(); 778 } 779 780 static void sk_psock_write_space(struct sock *sk) 781 { 782 struct sk_psock *psock; 783 void (*write_space)(struct sock *sk); 784 785 rcu_read_lock(); 786 psock = sk_psock(sk); 787 if (likely(psock && sk_psock_test_state(psock, SK_PSOCK_TX_ENABLED))) 788 schedule_work(&psock->work); 789 write_space = psock->saved_write_space; 790 rcu_read_unlock(); 791 write_space(sk); 792 } 793 794 int sk_psock_init_strp(struct sock *sk, struct sk_psock *psock) 795 { 796 static const struct strp_callbacks cb = { 797 .rcv_msg = sk_psock_strp_read, 798 .read_sock_done = sk_psock_strp_read_done, 799 .parse_msg = sk_psock_strp_parse, 800 }; 801 802 psock->parser.enabled = false; 803 return strp_init(&psock->parser.strp, sk, &cb); 804 } 805 806 void sk_psock_start_strp(struct sock *sk, struct sk_psock *psock) 807 { 808 struct sk_psock_parser *parser = &psock->parser; 809 810 if (parser->enabled) 811 return; 812 813 parser->saved_data_ready = sk->sk_data_ready; 814 sk->sk_data_ready = sk_psock_strp_data_ready; 815 sk->sk_write_space = sk_psock_write_space; 816 parser->enabled = true; 817 } 818 819 void sk_psock_stop_strp(struct sock *sk, struct sk_psock *psock) 820 { 821 struct sk_psock_parser *parser = &psock->parser; 822 823 if (!parser->enabled) 824 return; 825 826 sk->sk_data_ready = parser->saved_data_ready; 827 parser->saved_data_ready = NULL; 828 strp_stop(&parser->strp); 829 parser->enabled = false; 830 } 831