1 // SPDX-License-Identifier: GPL-2.0 2 /* Copyright (c) 2017 - 2018 Covalent IO, Inc. http://covalent.io */ 3 4 #include <linux/skmsg.h> 5 #include <linux/skbuff.h> 6 #include <linux/scatterlist.h> 7 8 #include <net/sock.h> 9 #include <net/tcp.h> 10 11 static bool sk_msg_try_coalesce_ok(struct sk_msg *msg, int elem_first_coalesce) 12 { 13 if (msg->sg.end > msg->sg.start && 14 elem_first_coalesce < msg->sg.end) 15 return true; 16 17 if (msg->sg.end < msg->sg.start && 18 (elem_first_coalesce > msg->sg.start || 19 elem_first_coalesce < msg->sg.end)) 20 return true; 21 22 return false; 23 } 24 25 int sk_msg_alloc(struct sock *sk, struct sk_msg *msg, int len, 26 int elem_first_coalesce) 27 { 28 struct page_frag *pfrag = sk_page_frag(sk); 29 int ret = 0; 30 31 len -= msg->sg.size; 32 while (len > 0) { 33 struct scatterlist *sge; 34 u32 orig_offset; 35 int use, i; 36 37 if (!sk_page_frag_refill(sk, pfrag)) 38 return -ENOMEM; 39 40 orig_offset = pfrag->offset; 41 use = min_t(int, len, pfrag->size - orig_offset); 42 if (!sk_wmem_schedule(sk, use)) 43 return -ENOMEM; 44 45 i = msg->sg.end; 46 sk_msg_iter_var_prev(i); 47 sge = &msg->sg.data[i]; 48 49 if (sk_msg_try_coalesce_ok(msg, elem_first_coalesce) && 50 sg_page(sge) == pfrag->page && 51 sge->offset + sge->length == orig_offset) { 52 sge->length += use; 53 } else { 54 if (sk_msg_full(msg)) { 55 ret = -ENOSPC; 56 break; 57 } 58 59 sge = &msg->sg.data[msg->sg.end]; 60 sg_unmark_end(sge); 61 sg_set_page(sge, pfrag->page, use, orig_offset); 62 get_page(pfrag->page); 63 sk_msg_iter_next(msg, end); 64 } 65 66 sk_mem_charge(sk, use); 67 msg->sg.size += use; 68 pfrag->offset += use; 69 len -= use; 70 } 71 72 return ret; 73 } 74 EXPORT_SYMBOL_GPL(sk_msg_alloc); 75 76 int sk_msg_clone(struct sock *sk, struct sk_msg *dst, struct sk_msg *src, 77 u32 off, u32 len) 78 { 79 int i = src->sg.start; 80 struct scatterlist *sge = sk_msg_elem(src, i); 81 struct scatterlist *sgd = NULL; 82 u32 sge_len, sge_off; 83 84 while (off) { 85 if (sge->length > off) 86 break; 87 off -= sge->length; 88 sk_msg_iter_var_next(i); 89 if (i == src->sg.end && off) 90 return -ENOSPC; 91 sge = sk_msg_elem(src, i); 92 } 93 94 while (len) { 95 sge_len = sge->length - off; 96 if (sge_len > len) 97 sge_len = len; 98 99 if (dst->sg.end) 100 sgd = sk_msg_elem(dst, dst->sg.end - 1); 101 102 if (sgd && 103 (sg_page(sge) == sg_page(sgd)) && 104 (sg_virt(sge) + off == sg_virt(sgd) + sgd->length)) { 105 sgd->length += sge_len; 106 dst->sg.size += sge_len; 107 } else if (!sk_msg_full(dst)) { 108 sge_off = sge->offset + off; 109 sk_msg_page_add(dst, sg_page(sge), sge_len, sge_off); 110 } else { 111 return -ENOSPC; 112 } 113 114 off = 0; 115 len -= sge_len; 116 sk_mem_charge(sk, sge_len); 117 sk_msg_iter_var_next(i); 118 if (i == src->sg.end && len) 119 return -ENOSPC; 120 sge = sk_msg_elem(src, i); 121 } 122 123 return 0; 124 } 125 EXPORT_SYMBOL_GPL(sk_msg_clone); 126 127 void sk_msg_return_zero(struct sock *sk, struct sk_msg *msg, int bytes) 128 { 129 int i = msg->sg.start; 130 131 do { 132 struct scatterlist *sge = sk_msg_elem(msg, i); 133 134 if (bytes < sge->length) { 135 sge->length -= bytes; 136 sge->offset += bytes; 137 sk_mem_uncharge(sk, bytes); 138 break; 139 } 140 141 sk_mem_uncharge(sk, sge->length); 142 bytes -= sge->length; 143 sge->length = 0; 144 sge->offset = 0; 145 sk_msg_iter_var_next(i); 146 } while (bytes && i != msg->sg.end); 147 msg->sg.start = i; 148 } 149 EXPORT_SYMBOL_GPL(sk_msg_return_zero); 150 151 void sk_msg_return(struct sock *sk, struct sk_msg *msg, int bytes) 152 { 153 int i = msg->sg.start; 154 155 do { 156 struct scatterlist *sge = &msg->sg.data[i]; 157 int uncharge = (bytes < sge->length) ? bytes : sge->length; 158 159 sk_mem_uncharge(sk, uncharge); 160 bytes -= uncharge; 161 sk_msg_iter_var_next(i); 162 } while (i != msg->sg.end); 163 } 164 EXPORT_SYMBOL_GPL(sk_msg_return); 165 166 static int sk_msg_free_elem(struct sock *sk, struct sk_msg *msg, u32 i, 167 bool charge) 168 { 169 struct scatterlist *sge = sk_msg_elem(msg, i); 170 u32 len = sge->length; 171 172 if (charge) 173 sk_mem_uncharge(sk, len); 174 if (!msg->skb) 175 put_page(sg_page(sge)); 176 memset(sge, 0, sizeof(*sge)); 177 return len; 178 } 179 180 static int __sk_msg_free(struct sock *sk, struct sk_msg *msg, u32 i, 181 bool charge) 182 { 183 struct scatterlist *sge = sk_msg_elem(msg, i); 184 int freed = 0; 185 186 while (msg->sg.size) { 187 msg->sg.size -= sge->length; 188 freed += sk_msg_free_elem(sk, msg, i, charge); 189 sk_msg_iter_var_next(i); 190 sk_msg_check_to_free(msg, i, msg->sg.size); 191 sge = sk_msg_elem(msg, i); 192 } 193 if (msg->skb) 194 consume_skb(msg->skb); 195 sk_msg_init(msg); 196 return freed; 197 } 198 199 int sk_msg_free_nocharge(struct sock *sk, struct sk_msg *msg) 200 { 201 return __sk_msg_free(sk, msg, msg->sg.start, false); 202 } 203 EXPORT_SYMBOL_GPL(sk_msg_free_nocharge); 204 205 int sk_msg_free(struct sock *sk, struct sk_msg *msg) 206 { 207 return __sk_msg_free(sk, msg, msg->sg.start, true); 208 } 209 EXPORT_SYMBOL_GPL(sk_msg_free); 210 211 static void __sk_msg_free_partial(struct sock *sk, struct sk_msg *msg, 212 u32 bytes, bool charge) 213 { 214 struct scatterlist *sge; 215 u32 i = msg->sg.start; 216 217 while (bytes) { 218 sge = sk_msg_elem(msg, i); 219 if (!sge->length) 220 break; 221 if (bytes < sge->length) { 222 if (charge) 223 sk_mem_uncharge(sk, bytes); 224 sge->length -= bytes; 225 sge->offset += bytes; 226 msg->sg.size -= bytes; 227 break; 228 } 229 230 msg->sg.size -= sge->length; 231 bytes -= sge->length; 232 sk_msg_free_elem(sk, msg, i, charge); 233 sk_msg_iter_var_next(i); 234 sk_msg_check_to_free(msg, i, bytes); 235 } 236 msg->sg.start = i; 237 } 238 239 void sk_msg_free_partial(struct sock *sk, struct sk_msg *msg, u32 bytes) 240 { 241 __sk_msg_free_partial(sk, msg, bytes, true); 242 } 243 EXPORT_SYMBOL_GPL(sk_msg_free_partial); 244 245 void sk_msg_free_partial_nocharge(struct sock *sk, struct sk_msg *msg, 246 u32 bytes) 247 { 248 __sk_msg_free_partial(sk, msg, bytes, false); 249 } 250 251 void sk_msg_trim(struct sock *sk, struct sk_msg *msg, int len) 252 { 253 int trim = msg->sg.size - len; 254 u32 i = msg->sg.end; 255 256 if (trim <= 0) { 257 WARN_ON(trim < 0); 258 return; 259 } 260 261 sk_msg_iter_var_prev(i); 262 msg->sg.size = len; 263 while (msg->sg.data[i].length && 264 trim >= msg->sg.data[i].length) { 265 trim -= msg->sg.data[i].length; 266 sk_msg_free_elem(sk, msg, i, true); 267 sk_msg_iter_var_prev(i); 268 if (!trim) 269 goto out; 270 } 271 272 msg->sg.data[i].length -= trim; 273 sk_mem_uncharge(sk, trim); 274 out: 275 /* If we trim data before curr pointer update copybreak and current 276 * so that any future copy operations start at new copy location. 277 * However trimed data that has not yet been used in a copy op 278 * does not require an update. 279 */ 280 if (msg->sg.curr >= i) { 281 msg->sg.curr = i; 282 msg->sg.copybreak = msg->sg.data[i].length; 283 } 284 sk_msg_iter_var_next(i); 285 msg->sg.end = i; 286 } 287 EXPORT_SYMBOL_GPL(sk_msg_trim); 288 289 int sk_msg_zerocopy_from_iter(struct sock *sk, struct iov_iter *from, 290 struct sk_msg *msg, u32 bytes) 291 { 292 int i, maxpages, ret = 0, num_elems = sk_msg_elem_used(msg); 293 const int to_max_pages = MAX_MSG_FRAGS; 294 struct page *pages[MAX_MSG_FRAGS]; 295 ssize_t orig, copied, use, offset; 296 297 orig = msg->sg.size; 298 while (bytes > 0) { 299 i = 0; 300 maxpages = to_max_pages - num_elems; 301 if (maxpages == 0) { 302 ret = -EFAULT; 303 goto out; 304 } 305 306 copied = iov_iter_get_pages(from, pages, bytes, maxpages, 307 &offset); 308 if (copied <= 0) { 309 ret = -EFAULT; 310 goto out; 311 } 312 313 iov_iter_advance(from, copied); 314 bytes -= copied; 315 msg->sg.size += copied; 316 317 while (copied) { 318 use = min_t(int, copied, PAGE_SIZE - offset); 319 sg_set_page(&msg->sg.data[msg->sg.end], 320 pages[i], use, offset); 321 sg_unmark_end(&msg->sg.data[msg->sg.end]); 322 sk_mem_charge(sk, use); 323 324 offset = 0; 325 copied -= use; 326 sk_msg_iter_next(msg, end); 327 num_elems++; 328 i++; 329 } 330 /* When zerocopy is mixed with sk_msg_*copy* operations we 331 * may have a copybreak set in this case clear and prefer 332 * zerocopy remainder when possible. 333 */ 334 msg->sg.copybreak = 0; 335 msg->sg.curr = msg->sg.end; 336 } 337 out: 338 /* Revert iov_iter updates, msg will need to use 'trim' later if it 339 * also needs to be cleared. 340 */ 341 if (ret) 342 iov_iter_revert(from, msg->sg.size - orig); 343 return ret; 344 } 345 EXPORT_SYMBOL_GPL(sk_msg_zerocopy_from_iter); 346 347 int sk_msg_memcopy_from_iter(struct sock *sk, struct iov_iter *from, 348 struct sk_msg *msg, u32 bytes) 349 { 350 int ret = -ENOSPC, i = msg->sg.curr; 351 struct scatterlist *sge; 352 u32 copy, buf_size; 353 void *to; 354 355 do { 356 sge = sk_msg_elem(msg, i); 357 /* This is possible if a trim operation shrunk the buffer */ 358 if (msg->sg.copybreak >= sge->length) { 359 msg->sg.copybreak = 0; 360 sk_msg_iter_var_next(i); 361 if (i == msg->sg.end) 362 break; 363 sge = sk_msg_elem(msg, i); 364 } 365 366 buf_size = sge->length - msg->sg.copybreak; 367 copy = (buf_size > bytes) ? bytes : buf_size; 368 to = sg_virt(sge) + msg->sg.copybreak; 369 msg->sg.copybreak += copy; 370 if (sk->sk_route_caps & NETIF_F_NOCACHE_COPY) 371 ret = copy_from_iter_nocache(to, copy, from); 372 else 373 ret = copy_from_iter(to, copy, from); 374 if (ret != copy) { 375 ret = -EFAULT; 376 goto out; 377 } 378 bytes -= copy; 379 if (!bytes) 380 break; 381 msg->sg.copybreak = 0; 382 sk_msg_iter_var_next(i); 383 } while (i != msg->sg.end); 384 out: 385 msg->sg.curr = i; 386 return ret; 387 } 388 EXPORT_SYMBOL_GPL(sk_msg_memcopy_from_iter); 389 390 static int sk_psock_skb_ingress(struct sk_psock *psock, struct sk_buff *skb) 391 { 392 struct sock *sk = psock->sk; 393 int copied = 0, num_sge; 394 struct sk_msg *msg; 395 396 msg = kzalloc(sizeof(*msg), __GFP_NOWARN | GFP_ATOMIC); 397 if (unlikely(!msg)) 398 return -EAGAIN; 399 if (!sk_rmem_schedule(sk, skb, skb->len)) { 400 kfree(msg); 401 return -EAGAIN; 402 } 403 404 sk_msg_init(msg); 405 num_sge = skb_to_sgvec(skb, msg->sg.data, 0, skb->len); 406 if (unlikely(num_sge < 0)) { 407 kfree(msg); 408 return num_sge; 409 } 410 411 sk_mem_charge(sk, skb->len); 412 copied = skb->len; 413 msg->sg.start = 0; 414 msg->sg.size = copied; 415 msg->sg.end = num_sge == MAX_MSG_FRAGS ? 0 : num_sge; 416 msg->skb = skb; 417 418 sk_psock_queue_msg(psock, msg); 419 sk_psock_data_ready(sk, psock); 420 return copied; 421 } 422 423 static int sk_psock_handle_skb(struct sk_psock *psock, struct sk_buff *skb, 424 u32 off, u32 len, bool ingress) 425 { 426 if (ingress) 427 return sk_psock_skb_ingress(psock, skb); 428 else 429 return skb_send_sock_locked(psock->sk, skb, off, len); 430 } 431 432 static void sk_psock_backlog(struct work_struct *work) 433 { 434 struct sk_psock *psock = container_of(work, struct sk_psock, work); 435 struct sk_psock_work_state *state = &psock->work_state; 436 struct sk_buff *skb; 437 bool ingress; 438 u32 len, off; 439 int ret; 440 441 /* Lock sock to avoid losing sk_socket during loop. */ 442 lock_sock(psock->sk); 443 if (state->skb) { 444 skb = state->skb; 445 len = state->len; 446 off = state->off; 447 state->skb = NULL; 448 goto start; 449 } 450 451 while ((skb = skb_dequeue(&psock->ingress_skb))) { 452 len = skb->len; 453 off = 0; 454 start: 455 ingress = tcp_skb_bpf_ingress(skb); 456 do { 457 ret = -EIO; 458 if (likely(psock->sk->sk_socket)) 459 ret = sk_psock_handle_skb(psock, skb, off, 460 len, ingress); 461 if (ret <= 0) { 462 if (ret == -EAGAIN) { 463 state->skb = skb; 464 state->len = len; 465 state->off = off; 466 goto end; 467 } 468 /* Hard errors break pipe and stop xmit. */ 469 sk_psock_report_error(psock, ret ? -ret : EPIPE); 470 sk_psock_clear_state(psock, SK_PSOCK_TX_ENABLED); 471 kfree_skb(skb); 472 goto end; 473 } 474 off += ret; 475 len -= ret; 476 } while (len); 477 478 if (!ingress) 479 kfree_skb(skb); 480 } 481 end: 482 release_sock(psock->sk); 483 } 484 485 struct sk_psock *sk_psock_init(struct sock *sk, int node) 486 { 487 struct sk_psock *psock = kzalloc_node(sizeof(*psock), 488 GFP_ATOMIC | __GFP_NOWARN, 489 node); 490 if (!psock) 491 return NULL; 492 493 psock->sk = sk; 494 psock->eval = __SK_NONE; 495 496 INIT_LIST_HEAD(&psock->link); 497 spin_lock_init(&psock->link_lock); 498 499 INIT_WORK(&psock->work, sk_psock_backlog); 500 INIT_LIST_HEAD(&psock->ingress_msg); 501 skb_queue_head_init(&psock->ingress_skb); 502 503 sk_psock_set_state(psock, SK_PSOCK_TX_ENABLED); 504 refcount_set(&psock->refcnt, 1); 505 506 rcu_assign_sk_user_data(sk, psock); 507 sock_hold(sk); 508 509 return psock; 510 } 511 EXPORT_SYMBOL_GPL(sk_psock_init); 512 513 struct sk_psock_link *sk_psock_link_pop(struct sk_psock *psock) 514 { 515 struct sk_psock_link *link; 516 517 spin_lock_bh(&psock->link_lock); 518 link = list_first_entry_or_null(&psock->link, struct sk_psock_link, 519 list); 520 if (link) 521 list_del(&link->list); 522 spin_unlock_bh(&psock->link_lock); 523 return link; 524 } 525 526 void __sk_psock_purge_ingress_msg(struct sk_psock *psock) 527 { 528 struct sk_msg *msg, *tmp; 529 530 list_for_each_entry_safe(msg, tmp, &psock->ingress_msg, list) { 531 list_del(&msg->list); 532 sk_msg_free(psock->sk, msg); 533 kfree(msg); 534 } 535 } 536 537 static void sk_psock_zap_ingress(struct sk_psock *psock) 538 { 539 __skb_queue_purge(&psock->ingress_skb); 540 __sk_psock_purge_ingress_msg(psock); 541 } 542 543 static void sk_psock_link_destroy(struct sk_psock *psock) 544 { 545 struct sk_psock_link *link, *tmp; 546 547 list_for_each_entry_safe(link, tmp, &psock->link, list) { 548 list_del(&link->list); 549 sk_psock_free_link(link); 550 } 551 } 552 553 static void sk_psock_destroy_deferred(struct work_struct *gc) 554 { 555 struct sk_psock *psock = container_of(gc, struct sk_psock, gc); 556 557 /* No sk_callback_lock since already detached. */ 558 559 /* Parser has been stopped */ 560 if (psock->progs.skb_parser) 561 strp_done(&psock->parser.strp); 562 563 cancel_work_sync(&psock->work); 564 565 psock_progs_drop(&psock->progs); 566 567 sk_psock_link_destroy(psock); 568 sk_psock_cork_free(psock); 569 sk_psock_zap_ingress(psock); 570 571 if (psock->sk_redir) 572 sock_put(psock->sk_redir); 573 sock_put(psock->sk); 574 kfree(psock); 575 } 576 577 void sk_psock_destroy(struct rcu_head *rcu) 578 { 579 struct sk_psock *psock = container_of(rcu, struct sk_psock, rcu); 580 581 INIT_WORK(&psock->gc, sk_psock_destroy_deferred); 582 schedule_work(&psock->gc); 583 } 584 EXPORT_SYMBOL_GPL(sk_psock_destroy); 585 586 void sk_psock_drop(struct sock *sk, struct sk_psock *psock) 587 { 588 sk_psock_cork_free(psock); 589 sk_psock_zap_ingress(psock); 590 591 write_lock_bh(&sk->sk_callback_lock); 592 sk_psock_restore_proto(sk, psock); 593 rcu_assign_sk_user_data(sk, NULL); 594 if (psock->progs.skb_parser) 595 sk_psock_stop_strp(sk, psock); 596 write_unlock_bh(&sk->sk_callback_lock); 597 sk_psock_clear_state(psock, SK_PSOCK_TX_ENABLED); 598 599 call_rcu(&psock->rcu, sk_psock_destroy); 600 } 601 EXPORT_SYMBOL_GPL(sk_psock_drop); 602 603 static int sk_psock_map_verd(int verdict, bool redir) 604 { 605 switch (verdict) { 606 case SK_PASS: 607 return redir ? __SK_REDIRECT : __SK_PASS; 608 case SK_DROP: 609 default: 610 break; 611 } 612 613 return __SK_DROP; 614 } 615 616 int sk_psock_msg_verdict(struct sock *sk, struct sk_psock *psock, 617 struct sk_msg *msg) 618 { 619 struct bpf_prog *prog; 620 int ret; 621 622 preempt_disable(); 623 rcu_read_lock(); 624 prog = READ_ONCE(psock->progs.msg_parser); 625 if (unlikely(!prog)) { 626 ret = __SK_PASS; 627 goto out; 628 } 629 630 sk_msg_compute_data_pointers(msg); 631 msg->sk = sk; 632 ret = BPF_PROG_RUN(prog, msg); 633 ret = sk_psock_map_verd(ret, msg->sk_redir); 634 psock->apply_bytes = msg->apply_bytes; 635 if (ret == __SK_REDIRECT) { 636 if (psock->sk_redir) 637 sock_put(psock->sk_redir); 638 psock->sk_redir = msg->sk_redir; 639 if (!psock->sk_redir) { 640 ret = __SK_DROP; 641 goto out; 642 } 643 sock_hold(psock->sk_redir); 644 } 645 out: 646 rcu_read_unlock(); 647 preempt_enable(); 648 return ret; 649 } 650 EXPORT_SYMBOL_GPL(sk_psock_msg_verdict); 651 652 static int sk_psock_bpf_run(struct sk_psock *psock, struct bpf_prog *prog, 653 struct sk_buff *skb) 654 { 655 int ret; 656 657 skb->sk = psock->sk; 658 bpf_compute_data_end_sk_skb(skb); 659 preempt_disable(); 660 ret = BPF_PROG_RUN(prog, skb); 661 preempt_enable(); 662 /* strparser clones the skb before handing it to a upper layer, 663 * meaning skb_orphan has been called. We NULL sk on the way out 664 * to ensure we don't trigger a BUG_ON() in skb/sk operations 665 * later and because we are not charging the memory of this skb 666 * to any socket yet. 667 */ 668 skb->sk = NULL; 669 return ret; 670 } 671 672 static struct sk_psock *sk_psock_from_strp(struct strparser *strp) 673 { 674 struct sk_psock_parser *parser; 675 676 parser = container_of(strp, struct sk_psock_parser, strp); 677 return container_of(parser, struct sk_psock, parser); 678 } 679 680 static void sk_psock_verdict_apply(struct sk_psock *psock, 681 struct sk_buff *skb, int verdict) 682 { 683 struct sk_psock *psock_other; 684 struct sock *sk_other; 685 bool ingress; 686 687 switch (verdict) { 688 case __SK_PASS: 689 sk_other = psock->sk; 690 if (sock_flag(sk_other, SOCK_DEAD) || 691 !sk_psock_test_state(psock, SK_PSOCK_TX_ENABLED)) { 692 goto out_free; 693 } 694 if (atomic_read(&sk_other->sk_rmem_alloc) <= 695 sk_other->sk_rcvbuf) { 696 struct tcp_skb_cb *tcp = TCP_SKB_CB(skb); 697 698 tcp->bpf.flags |= BPF_F_INGRESS; 699 skb_queue_tail(&psock->ingress_skb, skb); 700 schedule_work(&psock->work); 701 break; 702 } 703 goto out_free; 704 case __SK_REDIRECT: 705 sk_other = tcp_skb_bpf_redirect_fetch(skb); 706 if (unlikely(!sk_other)) 707 goto out_free; 708 psock_other = sk_psock(sk_other); 709 if (!psock_other || sock_flag(sk_other, SOCK_DEAD) || 710 !sk_psock_test_state(psock_other, SK_PSOCK_TX_ENABLED)) 711 goto out_free; 712 ingress = tcp_skb_bpf_ingress(skb); 713 if ((!ingress && sock_writeable(sk_other)) || 714 (ingress && 715 atomic_read(&sk_other->sk_rmem_alloc) <= 716 sk_other->sk_rcvbuf)) { 717 if (!ingress) 718 skb_set_owner_w(skb, sk_other); 719 skb_queue_tail(&psock_other->ingress_skb, skb); 720 schedule_work(&psock_other->work); 721 break; 722 } 723 /* fall-through */ 724 case __SK_DROP: 725 /* fall-through */ 726 default: 727 out_free: 728 kfree_skb(skb); 729 } 730 } 731 732 static void sk_psock_strp_read(struct strparser *strp, struct sk_buff *skb) 733 { 734 struct sk_psock *psock = sk_psock_from_strp(strp); 735 struct bpf_prog *prog; 736 int ret = __SK_DROP; 737 738 rcu_read_lock(); 739 prog = READ_ONCE(psock->progs.skb_verdict); 740 if (likely(prog)) { 741 skb_orphan(skb); 742 tcp_skb_bpf_redirect_clear(skb); 743 ret = sk_psock_bpf_run(psock, prog, skb); 744 ret = sk_psock_map_verd(ret, tcp_skb_bpf_redirect_fetch(skb)); 745 } 746 rcu_read_unlock(); 747 sk_psock_verdict_apply(psock, skb, ret); 748 } 749 750 static int sk_psock_strp_read_done(struct strparser *strp, int err) 751 { 752 return err; 753 } 754 755 static int sk_psock_strp_parse(struct strparser *strp, struct sk_buff *skb) 756 { 757 struct sk_psock *psock = sk_psock_from_strp(strp); 758 struct bpf_prog *prog; 759 int ret = skb->len; 760 761 rcu_read_lock(); 762 prog = READ_ONCE(psock->progs.skb_parser); 763 if (likely(prog)) 764 ret = sk_psock_bpf_run(psock, prog, skb); 765 rcu_read_unlock(); 766 return ret; 767 } 768 769 /* Called with socket lock held. */ 770 static void sk_psock_strp_data_ready(struct sock *sk) 771 { 772 struct sk_psock *psock; 773 774 rcu_read_lock(); 775 psock = sk_psock(sk); 776 if (likely(psock)) { 777 write_lock_bh(&sk->sk_callback_lock); 778 strp_data_ready(&psock->parser.strp); 779 write_unlock_bh(&sk->sk_callback_lock); 780 } 781 rcu_read_unlock(); 782 } 783 784 static void sk_psock_write_space(struct sock *sk) 785 { 786 struct sk_psock *psock; 787 void (*write_space)(struct sock *sk); 788 789 rcu_read_lock(); 790 psock = sk_psock(sk); 791 if (likely(psock && sk_psock_test_state(psock, SK_PSOCK_TX_ENABLED))) 792 schedule_work(&psock->work); 793 write_space = psock->saved_write_space; 794 rcu_read_unlock(); 795 write_space(sk); 796 } 797 798 int sk_psock_init_strp(struct sock *sk, struct sk_psock *psock) 799 { 800 static const struct strp_callbacks cb = { 801 .rcv_msg = sk_psock_strp_read, 802 .read_sock_done = sk_psock_strp_read_done, 803 .parse_msg = sk_psock_strp_parse, 804 }; 805 806 psock->parser.enabled = false; 807 return strp_init(&psock->parser.strp, sk, &cb); 808 } 809 810 void sk_psock_start_strp(struct sock *sk, struct sk_psock *psock) 811 { 812 struct sk_psock_parser *parser = &psock->parser; 813 814 if (parser->enabled) 815 return; 816 817 parser->saved_data_ready = sk->sk_data_ready; 818 sk->sk_data_ready = sk_psock_strp_data_ready; 819 sk->sk_write_space = sk_psock_write_space; 820 parser->enabled = true; 821 } 822 823 void sk_psock_stop_strp(struct sock *sk, struct sk_psock *psock) 824 { 825 struct sk_psock_parser *parser = &psock->parser; 826 827 if (!parser->enabled) 828 return; 829 830 sk->sk_data_ready = parser->saved_data_ready; 831 parser->saved_data_ready = NULL; 832 strp_stop(&parser->strp); 833 parser->enabled = false; 834 } 835