1 // SPDX-License-Identifier: GPL-2.0 2 /* Copyright (c) 2017 - 2018 Covalent IO, Inc. http://covalent.io */ 3 4 #include <linux/skmsg.h> 5 #include <linux/skbuff.h> 6 #include <linux/scatterlist.h> 7 8 #include <net/sock.h> 9 #include <net/tcp.h> 10 11 static bool sk_msg_try_coalesce_ok(struct sk_msg *msg, int elem_first_coalesce) 12 { 13 if (msg->sg.end > msg->sg.start && 14 elem_first_coalesce < msg->sg.end) 15 return true; 16 17 if (msg->sg.end < msg->sg.start && 18 (elem_first_coalesce > msg->sg.start || 19 elem_first_coalesce < msg->sg.end)) 20 return true; 21 22 return false; 23 } 24 25 int sk_msg_alloc(struct sock *sk, struct sk_msg *msg, int len, 26 int elem_first_coalesce) 27 { 28 struct page_frag *pfrag = sk_page_frag(sk); 29 int ret = 0; 30 31 len -= msg->sg.size; 32 while (len > 0) { 33 struct scatterlist *sge; 34 u32 orig_offset; 35 int use, i; 36 37 if (!sk_page_frag_refill(sk, pfrag)) 38 return -ENOMEM; 39 40 orig_offset = pfrag->offset; 41 use = min_t(int, len, pfrag->size - orig_offset); 42 if (!sk_wmem_schedule(sk, use)) 43 return -ENOMEM; 44 45 i = msg->sg.end; 46 sk_msg_iter_var_prev(i); 47 sge = &msg->sg.data[i]; 48 49 if (sk_msg_try_coalesce_ok(msg, elem_first_coalesce) && 50 sg_page(sge) == pfrag->page && 51 sge->offset + sge->length == orig_offset) { 52 sge->length += use; 53 } else { 54 if (sk_msg_full(msg)) { 55 ret = -ENOSPC; 56 break; 57 } 58 59 sge = &msg->sg.data[msg->sg.end]; 60 sg_unmark_end(sge); 61 sg_set_page(sge, pfrag->page, use, orig_offset); 62 get_page(pfrag->page); 63 sk_msg_iter_next(msg, end); 64 } 65 66 sk_mem_charge(sk, use); 67 msg->sg.size += use; 68 pfrag->offset += use; 69 len -= use; 70 } 71 72 return ret; 73 } 74 EXPORT_SYMBOL_GPL(sk_msg_alloc); 75 76 int sk_msg_clone(struct sock *sk, struct sk_msg *dst, struct sk_msg *src, 77 u32 off, u32 len) 78 { 79 int i = src->sg.start; 80 struct scatterlist *sge = sk_msg_elem(src, i); 81 u32 sge_len, sge_off; 82 83 if (sk_msg_full(dst)) 84 return -ENOSPC; 85 86 while (off) { 87 if (sge->length > off) 88 break; 89 off -= sge->length; 90 sk_msg_iter_var_next(i); 91 if (i == src->sg.end && off) 92 return -ENOSPC; 93 sge = sk_msg_elem(src, i); 94 } 95 96 while (len) { 97 if (sk_msg_full(dst)) 98 return -ENOSPC; 99 100 sge_len = sge->length - off; 101 sge_off = sge->offset + off; 102 if (sge_len > len) 103 sge_len = len; 104 off = 0; 105 len -= sge_len; 106 sk_msg_page_add(dst, sg_page(sge), sge_len, sge_off); 107 sk_mem_charge(sk, sge_len); 108 sk_msg_iter_var_next(i); 109 if (i == src->sg.end && len) 110 return -ENOSPC; 111 sge = sk_msg_elem(src, i); 112 } 113 114 return 0; 115 } 116 EXPORT_SYMBOL_GPL(sk_msg_clone); 117 118 void sk_msg_return_zero(struct sock *sk, struct sk_msg *msg, int bytes) 119 { 120 int i = msg->sg.start; 121 122 do { 123 struct scatterlist *sge = sk_msg_elem(msg, i); 124 125 if (bytes < sge->length) { 126 sge->length -= bytes; 127 sge->offset += bytes; 128 sk_mem_uncharge(sk, bytes); 129 break; 130 } 131 132 sk_mem_uncharge(sk, sge->length); 133 bytes -= sge->length; 134 sge->length = 0; 135 sge->offset = 0; 136 sk_msg_iter_var_next(i); 137 } while (bytes && i != msg->sg.end); 138 msg->sg.start = i; 139 } 140 EXPORT_SYMBOL_GPL(sk_msg_return_zero); 141 142 void sk_msg_return(struct sock *sk, struct sk_msg *msg, int bytes) 143 { 144 int i = msg->sg.start; 145 146 do { 147 struct scatterlist *sge = &msg->sg.data[i]; 148 int uncharge = (bytes < sge->length) ? bytes : sge->length; 149 150 sk_mem_uncharge(sk, uncharge); 151 bytes -= uncharge; 152 sk_msg_iter_var_next(i); 153 } while (i != msg->sg.end); 154 } 155 EXPORT_SYMBOL_GPL(sk_msg_return); 156 157 static int sk_msg_free_elem(struct sock *sk, struct sk_msg *msg, u32 i, 158 bool charge) 159 { 160 struct scatterlist *sge = sk_msg_elem(msg, i); 161 u32 len = sge->length; 162 163 if (charge) 164 sk_mem_uncharge(sk, len); 165 if (!msg->skb) 166 put_page(sg_page(sge)); 167 memset(sge, 0, sizeof(*sge)); 168 return len; 169 } 170 171 static int __sk_msg_free(struct sock *sk, struct sk_msg *msg, u32 i, 172 bool charge) 173 { 174 struct scatterlist *sge = sk_msg_elem(msg, i); 175 int freed = 0; 176 177 while (msg->sg.size) { 178 msg->sg.size -= sge->length; 179 freed += sk_msg_free_elem(sk, msg, i, charge); 180 sk_msg_iter_var_next(i); 181 sk_msg_check_to_free(msg, i, msg->sg.size); 182 sge = sk_msg_elem(msg, i); 183 } 184 if (msg->skb) 185 consume_skb(msg->skb); 186 sk_msg_init(msg); 187 return freed; 188 } 189 190 int sk_msg_free_nocharge(struct sock *sk, struct sk_msg *msg) 191 { 192 return __sk_msg_free(sk, msg, msg->sg.start, false); 193 } 194 EXPORT_SYMBOL_GPL(sk_msg_free_nocharge); 195 196 int sk_msg_free(struct sock *sk, struct sk_msg *msg) 197 { 198 return __sk_msg_free(sk, msg, msg->sg.start, true); 199 } 200 EXPORT_SYMBOL_GPL(sk_msg_free); 201 202 static void __sk_msg_free_partial(struct sock *sk, struct sk_msg *msg, 203 u32 bytes, bool charge) 204 { 205 struct scatterlist *sge; 206 u32 i = msg->sg.start; 207 208 while (bytes) { 209 sge = sk_msg_elem(msg, i); 210 if (!sge->length) 211 break; 212 if (bytes < sge->length) { 213 if (charge) 214 sk_mem_uncharge(sk, bytes); 215 sge->length -= bytes; 216 sge->offset += bytes; 217 msg->sg.size -= bytes; 218 break; 219 } 220 221 msg->sg.size -= sge->length; 222 bytes -= sge->length; 223 sk_msg_free_elem(sk, msg, i, charge); 224 sk_msg_iter_var_next(i); 225 sk_msg_check_to_free(msg, i, bytes); 226 } 227 msg->sg.start = i; 228 } 229 230 void sk_msg_free_partial(struct sock *sk, struct sk_msg *msg, u32 bytes) 231 { 232 __sk_msg_free_partial(sk, msg, bytes, true); 233 } 234 EXPORT_SYMBOL_GPL(sk_msg_free_partial); 235 236 void sk_msg_free_partial_nocharge(struct sock *sk, struct sk_msg *msg, 237 u32 bytes) 238 { 239 __sk_msg_free_partial(sk, msg, bytes, false); 240 } 241 242 void sk_msg_trim(struct sock *sk, struct sk_msg *msg, int len) 243 { 244 int trim = msg->sg.size - len; 245 u32 i = msg->sg.end; 246 247 if (trim <= 0) { 248 WARN_ON(trim < 0); 249 return; 250 } 251 252 sk_msg_iter_var_prev(i); 253 msg->sg.size = len; 254 while (msg->sg.data[i].length && 255 trim >= msg->sg.data[i].length) { 256 trim -= msg->sg.data[i].length; 257 sk_msg_free_elem(sk, msg, i, true); 258 sk_msg_iter_var_prev(i); 259 if (!trim) 260 goto out; 261 } 262 263 msg->sg.data[i].length -= trim; 264 sk_mem_uncharge(sk, trim); 265 out: 266 /* If we trim data before curr pointer update copybreak and current 267 * so that any future copy operations start at new copy location. 268 * However trimed data that has not yet been used in a copy op 269 * does not require an update. 270 */ 271 if (msg->sg.curr >= i) { 272 msg->sg.curr = i; 273 msg->sg.copybreak = msg->sg.data[i].length; 274 } 275 sk_msg_iter_var_next(i); 276 msg->sg.end = i; 277 } 278 EXPORT_SYMBOL_GPL(sk_msg_trim); 279 280 int sk_msg_zerocopy_from_iter(struct sock *sk, struct iov_iter *from, 281 struct sk_msg *msg, u32 bytes) 282 { 283 int i, maxpages, ret = 0, num_elems = sk_msg_elem_used(msg); 284 const int to_max_pages = MAX_MSG_FRAGS; 285 struct page *pages[MAX_MSG_FRAGS]; 286 ssize_t orig, copied, use, offset; 287 288 orig = msg->sg.size; 289 while (bytes > 0) { 290 i = 0; 291 maxpages = to_max_pages - num_elems; 292 if (maxpages == 0) { 293 ret = -EFAULT; 294 goto out; 295 } 296 297 copied = iov_iter_get_pages(from, pages, bytes, maxpages, 298 &offset); 299 if (copied <= 0) { 300 ret = -EFAULT; 301 goto out; 302 } 303 304 iov_iter_advance(from, copied); 305 bytes -= copied; 306 msg->sg.size += copied; 307 308 while (copied) { 309 use = min_t(int, copied, PAGE_SIZE - offset); 310 sg_set_page(&msg->sg.data[msg->sg.end], 311 pages[i], use, offset); 312 sg_unmark_end(&msg->sg.data[msg->sg.end]); 313 sk_mem_charge(sk, use); 314 315 offset = 0; 316 copied -= use; 317 sk_msg_iter_next(msg, end); 318 num_elems++; 319 i++; 320 } 321 /* When zerocopy is mixed with sk_msg_*copy* operations we 322 * may have a copybreak set in this case clear and prefer 323 * zerocopy remainder when possible. 324 */ 325 msg->sg.copybreak = 0; 326 msg->sg.curr = msg->sg.end; 327 } 328 out: 329 /* Revert iov_iter updates, msg will need to use 'trim' later if it 330 * also needs to be cleared. 331 */ 332 if (ret) 333 iov_iter_revert(from, msg->sg.size - orig); 334 return ret; 335 } 336 EXPORT_SYMBOL_GPL(sk_msg_zerocopy_from_iter); 337 338 int sk_msg_memcopy_from_iter(struct sock *sk, struct iov_iter *from, 339 struct sk_msg *msg, u32 bytes) 340 { 341 int ret = -ENOSPC, i = msg->sg.curr; 342 struct scatterlist *sge; 343 u32 copy, buf_size; 344 void *to; 345 346 do { 347 sge = sk_msg_elem(msg, i); 348 /* This is possible if a trim operation shrunk the buffer */ 349 if (msg->sg.copybreak >= sge->length) { 350 msg->sg.copybreak = 0; 351 sk_msg_iter_var_next(i); 352 if (i == msg->sg.end) 353 break; 354 sge = sk_msg_elem(msg, i); 355 } 356 357 buf_size = sge->length - msg->sg.copybreak; 358 copy = (buf_size > bytes) ? bytes : buf_size; 359 to = sg_virt(sge) + msg->sg.copybreak; 360 msg->sg.copybreak += copy; 361 if (sk->sk_route_caps & NETIF_F_NOCACHE_COPY) 362 ret = copy_from_iter_nocache(to, copy, from); 363 else 364 ret = copy_from_iter(to, copy, from); 365 if (ret != copy) { 366 ret = -EFAULT; 367 goto out; 368 } 369 bytes -= copy; 370 if (!bytes) 371 break; 372 msg->sg.copybreak = 0; 373 sk_msg_iter_var_next(i); 374 } while (i != msg->sg.end); 375 out: 376 msg->sg.curr = i; 377 return ret; 378 } 379 EXPORT_SYMBOL_GPL(sk_msg_memcopy_from_iter); 380 381 static int sk_psock_skb_ingress(struct sk_psock *psock, struct sk_buff *skb) 382 { 383 struct sock *sk = psock->sk; 384 int copied = 0, num_sge; 385 struct sk_msg *msg; 386 387 msg = kzalloc(sizeof(*msg), __GFP_NOWARN | GFP_ATOMIC); 388 if (unlikely(!msg)) 389 return -EAGAIN; 390 if (!sk_rmem_schedule(sk, skb, skb->len)) { 391 kfree(msg); 392 return -EAGAIN; 393 } 394 395 sk_msg_init(msg); 396 num_sge = skb_to_sgvec(skb, msg->sg.data, 0, skb->len); 397 if (unlikely(num_sge < 0)) { 398 kfree(msg); 399 return num_sge; 400 } 401 402 sk_mem_charge(sk, skb->len); 403 copied = skb->len; 404 msg->sg.start = 0; 405 msg->sg.end = num_sge == MAX_MSG_FRAGS ? 0 : num_sge; 406 msg->skb = skb; 407 408 sk_psock_queue_msg(psock, msg); 409 sk_psock_data_ready(sk, psock); 410 return copied; 411 } 412 413 static int sk_psock_handle_skb(struct sk_psock *psock, struct sk_buff *skb, 414 u32 off, u32 len, bool ingress) 415 { 416 if (ingress) 417 return sk_psock_skb_ingress(psock, skb); 418 else 419 return skb_send_sock_locked(psock->sk, skb, off, len); 420 } 421 422 static void sk_psock_backlog(struct work_struct *work) 423 { 424 struct sk_psock *psock = container_of(work, struct sk_psock, work); 425 struct sk_psock_work_state *state = &psock->work_state; 426 struct sk_buff *skb; 427 bool ingress; 428 u32 len, off; 429 int ret; 430 431 /* Lock sock to avoid losing sk_socket during loop. */ 432 lock_sock(psock->sk); 433 if (state->skb) { 434 skb = state->skb; 435 len = state->len; 436 off = state->off; 437 state->skb = NULL; 438 goto start; 439 } 440 441 while ((skb = skb_dequeue(&psock->ingress_skb))) { 442 len = skb->len; 443 off = 0; 444 start: 445 ingress = tcp_skb_bpf_ingress(skb); 446 do { 447 ret = -EIO; 448 if (likely(psock->sk->sk_socket)) 449 ret = sk_psock_handle_skb(psock, skb, off, 450 len, ingress); 451 if (ret <= 0) { 452 if (ret == -EAGAIN) { 453 state->skb = skb; 454 state->len = len; 455 state->off = off; 456 goto end; 457 } 458 /* Hard errors break pipe and stop xmit. */ 459 sk_psock_report_error(psock, ret ? -ret : EPIPE); 460 sk_psock_clear_state(psock, SK_PSOCK_TX_ENABLED); 461 kfree_skb(skb); 462 goto end; 463 } 464 off += ret; 465 len -= ret; 466 } while (len); 467 468 if (!ingress) 469 kfree_skb(skb); 470 } 471 end: 472 release_sock(psock->sk); 473 } 474 475 struct sk_psock *sk_psock_init(struct sock *sk, int node) 476 { 477 struct sk_psock *psock = kzalloc_node(sizeof(*psock), 478 GFP_ATOMIC | __GFP_NOWARN, 479 node); 480 if (!psock) 481 return NULL; 482 483 psock->sk = sk; 484 psock->eval = __SK_NONE; 485 486 INIT_LIST_HEAD(&psock->link); 487 spin_lock_init(&psock->link_lock); 488 489 INIT_WORK(&psock->work, sk_psock_backlog); 490 INIT_LIST_HEAD(&psock->ingress_msg); 491 skb_queue_head_init(&psock->ingress_skb); 492 493 sk_psock_set_state(psock, SK_PSOCK_TX_ENABLED); 494 refcount_set(&psock->refcnt, 1); 495 496 rcu_assign_sk_user_data(sk, psock); 497 sock_hold(sk); 498 499 return psock; 500 } 501 EXPORT_SYMBOL_GPL(sk_psock_init); 502 503 struct sk_psock_link *sk_psock_link_pop(struct sk_psock *psock) 504 { 505 struct sk_psock_link *link; 506 507 spin_lock_bh(&psock->link_lock); 508 link = list_first_entry_or_null(&psock->link, struct sk_psock_link, 509 list); 510 if (link) 511 list_del(&link->list); 512 spin_unlock_bh(&psock->link_lock); 513 return link; 514 } 515 516 void __sk_psock_purge_ingress_msg(struct sk_psock *psock) 517 { 518 struct sk_msg *msg, *tmp; 519 520 list_for_each_entry_safe(msg, tmp, &psock->ingress_msg, list) { 521 list_del(&msg->list); 522 sk_msg_free(psock->sk, msg); 523 kfree(msg); 524 } 525 } 526 527 static void sk_psock_zap_ingress(struct sk_psock *psock) 528 { 529 __skb_queue_purge(&psock->ingress_skb); 530 __sk_psock_purge_ingress_msg(psock); 531 } 532 533 static void sk_psock_link_destroy(struct sk_psock *psock) 534 { 535 struct sk_psock_link *link, *tmp; 536 537 list_for_each_entry_safe(link, tmp, &psock->link, list) { 538 list_del(&link->list); 539 sk_psock_free_link(link); 540 } 541 } 542 543 static void sk_psock_destroy_deferred(struct work_struct *gc) 544 { 545 struct sk_psock *psock = container_of(gc, struct sk_psock, gc); 546 547 /* No sk_callback_lock since already detached. */ 548 if (psock->parser.enabled) 549 strp_done(&psock->parser.strp); 550 551 cancel_work_sync(&psock->work); 552 553 psock_progs_drop(&psock->progs); 554 555 sk_psock_link_destroy(psock); 556 sk_psock_cork_free(psock); 557 sk_psock_zap_ingress(psock); 558 559 if (psock->sk_redir) 560 sock_put(psock->sk_redir); 561 sock_put(psock->sk); 562 kfree(psock); 563 } 564 565 void sk_psock_destroy(struct rcu_head *rcu) 566 { 567 struct sk_psock *psock = container_of(rcu, struct sk_psock, rcu); 568 569 INIT_WORK(&psock->gc, sk_psock_destroy_deferred); 570 schedule_work(&psock->gc); 571 } 572 EXPORT_SYMBOL_GPL(sk_psock_destroy); 573 574 void sk_psock_drop(struct sock *sk, struct sk_psock *psock) 575 { 576 rcu_assign_sk_user_data(sk, NULL); 577 sk_psock_cork_free(psock); 578 sk_psock_zap_ingress(psock); 579 sk_psock_restore_proto(sk, psock); 580 581 write_lock_bh(&sk->sk_callback_lock); 582 if (psock->progs.skb_parser) 583 sk_psock_stop_strp(sk, psock); 584 write_unlock_bh(&sk->sk_callback_lock); 585 sk_psock_clear_state(psock, SK_PSOCK_TX_ENABLED); 586 587 call_rcu(&psock->rcu, sk_psock_destroy); 588 } 589 EXPORT_SYMBOL_GPL(sk_psock_drop); 590 591 static int sk_psock_map_verd(int verdict, bool redir) 592 { 593 switch (verdict) { 594 case SK_PASS: 595 return redir ? __SK_REDIRECT : __SK_PASS; 596 case SK_DROP: 597 default: 598 break; 599 } 600 601 return __SK_DROP; 602 } 603 604 int sk_psock_msg_verdict(struct sock *sk, struct sk_psock *psock, 605 struct sk_msg *msg) 606 { 607 struct bpf_prog *prog; 608 int ret; 609 610 preempt_disable(); 611 rcu_read_lock(); 612 prog = READ_ONCE(psock->progs.msg_parser); 613 if (unlikely(!prog)) { 614 ret = __SK_PASS; 615 goto out; 616 } 617 618 sk_msg_compute_data_pointers(msg); 619 msg->sk = sk; 620 ret = BPF_PROG_RUN(prog, msg); 621 ret = sk_psock_map_verd(ret, msg->sk_redir); 622 psock->apply_bytes = msg->apply_bytes; 623 if (ret == __SK_REDIRECT) { 624 if (psock->sk_redir) 625 sock_put(psock->sk_redir); 626 psock->sk_redir = msg->sk_redir; 627 if (!psock->sk_redir) { 628 ret = __SK_DROP; 629 goto out; 630 } 631 sock_hold(psock->sk_redir); 632 } 633 out: 634 rcu_read_unlock(); 635 preempt_enable(); 636 return ret; 637 } 638 EXPORT_SYMBOL_GPL(sk_psock_msg_verdict); 639 640 static int sk_psock_bpf_run(struct sk_psock *psock, struct bpf_prog *prog, 641 struct sk_buff *skb) 642 { 643 int ret; 644 645 skb->sk = psock->sk; 646 bpf_compute_data_end_sk_skb(skb); 647 preempt_disable(); 648 ret = BPF_PROG_RUN(prog, skb); 649 preempt_enable(); 650 /* strparser clones the skb before handing it to a upper layer, 651 * meaning skb_orphan has been called. We NULL sk on the way out 652 * to ensure we don't trigger a BUG_ON() in skb/sk operations 653 * later and because we are not charging the memory of this skb 654 * to any socket yet. 655 */ 656 skb->sk = NULL; 657 return ret; 658 } 659 660 static struct sk_psock *sk_psock_from_strp(struct strparser *strp) 661 { 662 struct sk_psock_parser *parser; 663 664 parser = container_of(strp, struct sk_psock_parser, strp); 665 return container_of(parser, struct sk_psock, parser); 666 } 667 668 static void sk_psock_verdict_apply(struct sk_psock *psock, 669 struct sk_buff *skb, int verdict) 670 { 671 struct sk_psock *psock_other; 672 struct sock *sk_other; 673 bool ingress; 674 675 switch (verdict) { 676 case __SK_PASS: 677 sk_other = psock->sk; 678 if (sock_flag(sk_other, SOCK_DEAD) || 679 !sk_psock_test_state(psock, SK_PSOCK_TX_ENABLED)) { 680 goto out_free; 681 } 682 if (atomic_read(&sk_other->sk_rmem_alloc) <= 683 sk_other->sk_rcvbuf) { 684 struct tcp_skb_cb *tcp = TCP_SKB_CB(skb); 685 686 tcp->bpf.flags |= BPF_F_INGRESS; 687 skb_queue_tail(&psock->ingress_skb, skb); 688 schedule_work(&psock->work); 689 break; 690 } 691 goto out_free; 692 case __SK_REDIRECT: 693 sk_other = tcp_skb_bpf_redirect_fetch(skb); 694 if (unlikely(!sk_other)) 695 goto out_free; 696 psock_other = sk_psock(sk_other); 697 if (!psock_other || sock_flag(sk_other, SOCK_DEAD) || 698 !sk_psock_test_state(psock_other, SK_PSOCK_TX_ENABLED)) 699 goto out_free; 700 ingress = tcp_skb_bpf_ingress(skb); 701 if ((!ingress && sock_writeable(sk_other)) || 702 (ingress && 703 atomic_read(&sk_other->sk_rmem_alloc) <= 704 sk_other->sk_rcvbuf)) { 705 if (!ingress) 706 skb_set_owner_w(skb, sk_other); 707 skb_queue_tail(&psock_other->ingress_skb, skb); 708 schedule_work(&psock_other->work); 709 break; 710 } 711 /* fall-through */ 712 case __SK_DROP: 713 /* fall-through */ 714 default: 715 out_free: 716 kfree_skb(skb); 717 } 718 } 719 720 static void sk_psock_strp_read(struct strparser *strp, struct sk_buff *skb) 721 { 722 struct sk_psock *psock = sk_psock_from_strp(strp); 723 struct bpf_prog *prog; 724 int ret = __SK_DROP; 725 726 rcu_read_lock(); 727 prog = READ_ONCE(psock->progs.skb_verdict); 728 if (likely(prog)) { 729 skb_orphan(skb); 730 tcp_skb_bpf_redirect_clear(skb); 731 ret = sk_psock_bpf_run(psock, prog, skb); 732 ret = sk_psock_map_verd(ret, tcp_skb_bpf_redirect_fetch(skb)); 733 } 734 rcu_read_unlock(); 735 sk_psock_verdict_apply(psock, skb, ret); 736 } 737 738 static int sk_psock_strp_read_done(struct strparser *strp, int err) 739 { 740 return err; 741 } 742 743 static int sk_psock_strp_parse(struct strparser *strp, struct sk_buff *skb) 744 { 745 struct sk_psock *psock = sk_psock_from_strp(strp); 746 struct bpf_prog *prog; 747 int ret = skb->len; 748 749 rcu_read_lock(); 750 prog = READ_ONCE(psock->progs.skb_parser); 751 if (likely(prog)) 752 ret = sk_psock_bpf_run(psock, prog, skb); 753 rcu_read_unlock(); 754 return ret; 755 } 756 757 /* Called with socket lock held. */ 758 static void sk_psock_strp_data_ready(struct sock *sk) 759 { 760 struct sk_psock *psock; 761 762 rcu_read_lock(); 763 psock = sk_psock(sk); 764 if (likely(psock)) { 765 write_lock_bh(&sk->sk_callback_lock); 766 strp_data_ready(&psock->parser.strp); 767 write_unlock_bh(&sk->sk_callback_lock); 768 } 769 rcu_read_unlock(); 770 } 771 772 static void sk_psock_write_space(struct sock *sk) 773 { 774 struct sk_psock *psock; 775 void (*write_space)(struct sock *sk); 776 777 rcu_read_lock(); 778 psock = sk_psock(sk); 779 if (likely(psock && sk_psock_test_state(psock, SK_PSOCK_TX_ENABLED))) 780 schedule_work(&psock->work); 781 write_space = psock->saved_write_space; 782 rcu_read_unlock(); 783 write_space(sk); 784 } 785 786 int sk_psock_init_strp(struct sock *sk, struct sk_psock *psock) 787 { 788 static const struct strp_callbacks cb = { 789 .rcv_msg = sk_psock_strp_read, 790 .read_sock_done = sk_psock_strp_read_done, 791 .parse_msg = sk_psock_strp_parse, 792 }; 793 794 psock->parser.enabled = false; 795 return strp_init(&psock->parser.strp, sk, &cb); 796 } 797 798 void sk_psock_start_strp(struct sock *sk, struct sk_psock *psock) 799 { 800 struct sk_psock_parser *parser = &psock->parser; 801 802 if (parser->enabled) 803 return; 804 805 parser->saved_data_ready = sk->sk_data_ready; 806 sk->sk_data_ready = sk_psock_strp_data_ready; 807 sk->sk_write_space = sk_psock_write_space; 808 parser->enabled = true; 809 } 810 811 void sk_psock_stop_strp(struct sock *sk, struct sk_psock *psock) 812 { 813 struct sk_psock_parser *parser = &psock->parser; 814 815 if (!parser->enabled) 816 return; 817 818 sk->sk_data_ready = parser->saved_data_ready; 819 parser->saved_data_ready = NULL; 820 strp_stop(&parser->strp); 821 parser->enabled = false; 822 } 823