1 // SPDX-License-Identifier: GPL-2.0 2 /* Copyright (c) 2017 - 2018 Covalent IO, Inc. http://covalent.io */ 3 4 #include <linux/skmsg.h> 5 #include <linux/skbuff.h> 6 #include <linux/scatterlist.h> 7 8 #include <net/sock.h> 9 #include <net/tcp.h> 10 11 static bool sk_msg_try_coalesce_ok(struct sk_msg *msg, int elem_first_coalesce) 12 { 13 if (msg->sg.end > msg->sg.start && 14 elem_first_coalesce < msg->sg.end) 15 return true; 16 17 if (msg->sg.end < msg->sg.start && 18 (elem_first_coalesce > msg->sg.start || 19 elem_first_coalesce < msg->sg.end)) 20 return true; 21 22 return false; 23 } 24 25 int sk_msg_alloc(struct sock *sk, struct sk_msg *msg, int len, 26 int elem_first_coalesce) 27 { 28 struct page_frag *pfrag = sk_page_frag(sk); 29 int ret = 0; 30 31 len -= msg->sg.size; 32 while (len > 0) { 33 struct scatterlist *sge; 34 u32 orig_offset; 35 int use, i; 36 37 if (!sk_page_frag_refill(sk, pfrag)) 38 return -ENOMEM; 39 40 orig_offset = pfrag->offset; 41 use = min_t(int, len, pfrag->size - orig_offset); 42 if (!sk_wmem_schedule(sk, use)) 43 return -ENOMEM; 44 45 i = msg->sg.end; 46 sk_msg_iter_var_prev(i); 47 sge = &msg->sg.data[i]; 48 49 if (sk_msg_try_coalesce_ok(msg, elem_first_coalesce) && 50 sg_page(sge) == pfrag->page && 51 sge->offset + sge->length == orig_offset) { 52 sge->length += use; 53 } else { 54 if (sk_msg_full(msg)) { 55 ret = -ENOSPC; 56 break; 57 } 58 59 sge = &msg->sg.data[msg->sg.end]; 60 sg_unmark_end(sge); 61 sg_set_page(sge, pfrag->page, use, orig_offset); 62 get_page(pfrag->page); 63 sk_msg_iter_next(msg, end); 64 } 65 66 sk_mem_charge(sk, use); 67 msg->sg.size += use; 68 pfrag->offset += use; 69 len -= use; 70 } 71 72 return ret; 73 } 74 EXPORT_SYMBOL_GPL(sk_msg_alloc); 75 76 int sk_msg_clone(struct sock *sk, struct sk_msg *dst, struct sk_msg *src, 77 u32 off, u32 len) 78 { 79 int i = src->sg.start; 80 struct scatterlist *sge = sk_msg_elem(src, i); 81 u32 sge_len, sge_off; 82 83 if (sk_msg_full(dst)) 84 return -ENOSPC; 85 86 while (off) { 87 if (sge->length > off) 88 break; 89 off -= sge->length; 90 sk_msg_iter_var_next(i); 91 if (i == src->sg.end && off) 92 return -ENOSPC; 93 sge = sk_msg_elem(src, i); 94 } 95 96 while (len) { 97 sge_len = sge->length - off; 98 sge_off = sge->offset + off; 99 if (sge_len > len) 100 sge_len = len; 101 off = 0; 102 len -= sge_len; 103 sk_msg_page_add(dst, sg_page(sge), sge_len, sge_off); 104 sk_mem_charge(sk, sge_len); 105 sk_msg_iter_var_next(i); 106 if (i == src->sg.end && len) 107 return -ENOSPC; 108 sge = sk_msg_elem(src, i); 109 } 110 111 return 0; 112 } 113 EXPORT_SYMBOL_GPL(sk_msg_clone); 114 115 void sk_msg_return_zero(struct sock *sk, struct sk_msg *msg, int bytes) 116 { 117 int i = msg->sg.start; 118 119 do { 120 struct scatterlist *sge = sk_msg_elem(msg, i); 121 122 if (bytes < sge->length) { 123 sge->length -= bytes; 124 sge->offset += bytes; 125 sk_mem_uncharge(sk, bytes); 126 break; 127 } 128 129 sk_mem_uncharge(sk, sge->length); 130 bytes -= sge->length; 131 sge->length = 0; 132 sge->offset = 0; 133 sk_msg_iter_var_next(i); 134 } while (bytes && i != msg->sg.end); 135 msg->sg.start = i; 136 } 137 EXPORT_SYMBOL_GPL(sk_msg_return_zero); 138 139 void sk_msg_return(struct sock *sk, struct sk_msg *msg, int bytes) 140 { 141 int i = msg->sg.start; 142 143 do { 144 struct scatterlist *sge = &msg->sg.data[i]; 145 int uncharge = (bytes < sge->length) ? bytes : sge->length; 146 147 sk_mem_uncharge(sk, uncharge); 148 bytes -= uncharge; 149 sk_msg_iter_var_next(i); 150 } while (i != msg->sg.end); 151 } 152 EXPORT_SYMBOL_GPL(sk_msg_return); 153 154 static int sk_msg_free_elem(struct sock *sk, struct sk_msg *msg, u32 i, 155 bool charge) 156 { 157 struct scatterlist *sge = sk_msg_elem(msg, i); 158 u32 len = sge->length; 159 160 if (charge) 161 sk_mem_uncharge(sk, len); 162 if (!msg->skb) 163 put_page(sg_page(sge)); 164 memset(sge, 0, sizeof(*sge)); 165 return len; 166 } 167 168 static int __sk_msg_free(struct sock *sk, struct sk_msg *msg, u32 i, 169 bool charge) 170 { 171 struct scatterlist *sge = sk_msg_elem(msg, i); 172 int freed = 0; 173 174 while (msg->sg.size) { 175 msg->sg.size -= sge->length; 176 freed += sk_msg_free_elem(sk, msg, i, charge); 177 sk_msg_iter_var_next(i); 178 sk_msg_check_to_free(msg, i, msg->sg.size); 179 sge = sk_msg_elem(msg, i); 180 } 181 if (msg->skb) 182 consume_skb(msg->skb); 183 sk_msg_init(msg); 184 return freed; 185 } 186 187 int sk_msg_free_nocharge(struct sock *sk, struct sk_msg *msg) 188 { 189 return __sk_msg_free(sk, msg, msg->sg.start, false); 190 } 191 EXPORT_SYMBOL_GPL(sk_msg_free_nocharge); 192 193 int sk_msg_free(struct sock *sk, struct sk_msg *msg) 194 { 195 return __sk_msg_free(sk, msg, msg->sg.start, true); 196 } 197 EXPORT_SYMBOL_GPL(sk_msg_free); 198 199 static void __sk_msg_free_partial(struct sock *sk, struct sk_msg *msg, 200 u32 bytes, bool charge) 201 { 202 struct scatterlist *sge; 203 u32 i = msg->sg.start; 204 205 while (bytes) { 206 sge = sk_msg_elem(msg, i); 207 if (!sge->length) 208 break; 209 if (bytes < sge->length) { 210 if (charge) 211 sk_mem_uncharge(sk, bytes); 212 sge->length -= bytes; 213 sge->offset += bytes; 214 msg->sg.size -= bytes; 215 break; 216 } 217 218 msg->sg.size -= sge->length; 219 bytes -= sge->length; 220 sk_msg_free_elem(sk, msg, i, charge); 221 sk_msg_iter_var_next(i); 222 sk_msg_check_to_free(msg, i, bytes); 223 } 224 msg->sg.start = i; 225 } 226 227 void sk_msg_free_partial(struct sock *sk, struct sk_msg *msg, u32 bytes) 228 { 229 __sk_msg_free_partial(sk, msg, bytes, true); 230 } 231 EXPORT_SYMBOL_GPL(sk_msg_free_partial); 232 233 void sk_msg_free_partial_nocharge(struct sock *sk, struct sk_msg *msg, 234 u32 bytes) 235 { 236 __sk_msg_free_partial(sk, msg, bytes, false); 237 } 238 239 void sk_msg_trim(struct sock *sk, struct sk_msg *msg, int len) 240 { 241 int trim = msg->sg.size - len; 242 u32 i = msg->sg.end; 243 244 if (trim <= 0) { 245 WARN_ON(trim < 0); 246 return; 247 } 248 249 sk_msg_iter_var_prev(i); 250 msg->sg.size = len; 251 while (msg->sg.data[i].length && 252 trim >= msg->sg.data[i].length) { 253 trim -= msg->sg.data[i].length; 254 sk_msg_free_elem(sk, msg, i, true); 255 sk_msg_iter_var_prev(i); 256 if (!trim) 257 goto out; 258 } 259 260 msg->sg.data[i].length -= trim; 261 sk_mem_uncharge(sk, trim); 262 out: 263 /* If we trim data before curr pointer update copybreak and current 264 * so that any future copy operations start at new copy location. 265 * However trimed data that has not yet been used in a copy op 266 * does not require an update. 267 */ 268 if (msg->sg.curr >= i) { 269 msg->sg.curr = i; 270 msg->sg.copybreak = msg->sg.data[i].length; 271 } 272 sk_msg_iter_var_next(i); 273 msg->sg.end = i; 274 } 275 EXPORT_SYMBOL_GPL(sk_msg_trim); 276 277 int sk_msg_zerocopy_from_iter(struct sock *sk, struct iov_iter *from, 278 struct sk_msg *msg, u32 bytes) 279 { 280 int i, maxpages, ret = 0, num_elems = sk_msg_elem_used(msg); 281 const int to_max_pages = MAX_MSG_FRAGS; 282 struct page *pages[MAX_MSG_FRAGS]; 283 ssize_t orig, copied, use, offset; 284 285 orig = msg->sg.size; 286 while (bytes > 0) { 287 i = 0; 288 maxpages = to_max_pages - num_elems; 289 if (maxpages == 0) { 290 ret = -EFAULT; 291 goto out; 292 } 293 294 copied = iov_iter_get_pages(from, pages, bytes, maxpages, 295 &offset); 296 if (copied <= 0) { 297 ret = -EFAULT; 298 goto out; 299 } 300 301 iov_iter_advance(from, copied); 302 bytes -= copied; 303 msg->sg.size += copied; 304 305 while (copied) { 306 use = min_t(int, copied, PAGE_SIZE - offset); 307 sg_set_page(&msg->sg.data[msg->sg.end], 308 pages[i], use, offset); 309 sg_unmark_end(&msg->sg.data[msg->sg.end]); 310 sk_mem_charge(sk, use); 311 312 offset = 0; 313 copied -= use; 314 sk_msg_iter_next(msg, end); 315 num_elems++; 316 i++; 317 } 318 /* When zerocopy is mixed with sk_msg_*copy* operations we 319 * may have a copybreak set in this case clear and prefer 320 * zerocopy remainder when possible. 321 */ 322 msg->sg.copybreak = 0; 323 msg->sg.curr = msg->sg.end; 324 } 325 out: 326 /* Revert iov_iter updates, msg will need to use 'trim' later if it 327 * also needs to be cleared. 328 */ 329 if (ret) 330 iov_iter_revert(from, msg->sg.size - orig); 331 return ret; 332 } 333 EXPORT_SYMBOL_GPL(sk_msg_zerocopy_from_iter); 334 335 int sk_msg_memcopy_from_iter(struct sock *sk, struct iov_iter *from, 336 struct sk_msg *msg, u32 bytes) 337 { 338 int ret = -ENOSPC, i = msg->sg.curr; 339 struct scatterlist *sge; 340 u32 copy, buf_size; 341 void *to; 342 343 do { 344 sge = sk_msg_elem(msg, i); 345 /* This is possible if a trim operation shrunk the buffer */ 346 if (msg->sg.copybreak >= sge->length) { 347 msg->sg.copybreak = 0; 348 sk_msg_iter_var_next(i); 349 if (i == msg->sg.end) 350 break; 351 sge = sk_msg_elem(msg, i); 352 } 353 354 buf_size = sge->length - msg->sg.copybreak; 355 copy = (buf_size > bytes) ? bytes : buf_size; 356 to = sg_virt(sge) + msg->sg.copybreak; 357 msg->sg.copybreak += copy; 358 if (sk->sk_route_caps & NETIF_F_NOCACHE_COPY) 359 ret = copy_from_iter_nocache(to, copy, from); 360 else 361 ret = copy_from_iter(to, copy, from); 362 if (ret != copy) { 363 ret = -EFAULT; 364 goto out; 365 } 366 bytes -= copy; 367 if (!bytes) 368 break; 369 msg->sg.copybreak = 0; 370 sk_msg_iter_var_next(i); 371 } while (i != msg->sg.end); 372 out: 373 msg->sg.curr = i; 374 return ret; 375 } 376 EXPORT_SYMBOL_GPL(sk_msg_memcopy_from_iter); 377 378 static int sk_psock_skb_ingress(struct sk_psock *psock, struct sk_buff *skb) 379 { 380 struct sock *sk = psock->sk; 381 int copied = 0, num_sge; 382 struct sk_msg *msg; 383 384 msg = kzalloc(sizeof(*msg), __GFP_NOWARN | GFP_ATOMIC); 385 if (unlikely(!msg)) 386 return -EAGAIN; 387 if (!sk_rmem_schedule(sk, skb, skb->len)) { 388 kfree(msg); 389 return -EAGAIN; 390 } 391 392 sk_msg_init(msg); 393 num_sge = skb_to_sgvec(skb, msg->sg.data, 0, skb->len); 394 if (unlikely(num_sge < 0)) { 395 kfree(msg); 396 return num_sge; 397 } 398 399 sk_mem_charge(sk, skb->len); 400 copied = skb->len; 401 msg->sg.start = 0; 402 msg->sg.end = num_sge == MAX_MSG_FRAGS ? 0 : num_sge; 403 msg->skb = skb; 404 405 sk_psock_queue_msg(psock, msg); 406 sk->sk_data_ready(sk); 407 return copied; 408 } 409 410 static int sk_psock_handle_skb(struct sk_psock *psock, struct sk_buff *skb, 411 u32 off, u32 len, bool ingress) 412 { 413 if (ingress) 414 return sk_psock_skb_ingress(psock, skb); 415 else 416 return skb_send_sock_locked(psock->sk, skb, off, len); 417 } 418 419 static void sk_psock_backlog(struct work_struct *work) 420 { 421 struct sk_psock *psock = container_of(work, struct sk_psock, work); 422 struct sk_psock_work_state *state = &psock->work_state; 423 struct sk_buff *skb; 424 bool ingress; 425 u32 len, off; 426 int ret; 427 428 /* Lock sock to avoid losing sk_socket during loop. */ 429 lock_sock(psock->sk); 430 if (state->skb) { 431 skb = state->skb; 432 len = state->len; 433 off = state->off; 434 state->skb = NULL; 435 goto start; 436 } 437 438 while ((skb = skb_dequeue(&psock->ingress_skb))) { 439 len = skb->len; 440 off = 0; 441 start: 442 ingress = tcp_skb_bpf_ingress(skb); 443 do { 444 ret = -EIO; 445 if (likely(psock->sk->sk_socket)) 446 ret = sk_psock_handle_skb(psock, skb, off, 447 len, ingress); 448 if (ret <= 0) { 449 if (ret == -EAGAIN) { 450 state->skb = skb; 451 state->len = len; 452 state->off = off; 453 goto end; 454 } 455 /* Hard errors break pipe and stop xmit. */ 456 sk_psock_report_error(psock, ret ? -ret : EPIPE); 457 sk_psock_clear_state(psock, SK_PSOCK_TX_ENABLED); 458 kfree_skb(skb); 459 goto end; 460 } 461 off += ret; 462 len -= ret; 463 } while (len); 464 465 if (!ingress) 466 kfree_skb(skb); 467 } 468 end: 469 release_sock(psock->sk); 470 } 471 472 struct sk_psock *sk_psock_init(struct sock *sk, int node) 473 { 474 struct sk_psock *psock = kzalloc_node(sizeof(*psock), 475 GFP_ATOMIC | __GFP_NOWARN, 476 node); 477 if (!psock) 478 return NULL; 479 480 psock->sk = sk; 481 psock->eval = __SK_NONE; 482 483 INIT_LIST_HEAD(&psock->link); 484 spin_lock_init(&psock->link_lock); 485 486 INIT_WORK(&psock->work, sk_psock_backlog); 487 INIT_LIST_HEAD(&psock->ingress_msg); 488 skb_queue_head_init(&psock->ingress_skb); 489 490 sk_psock_set_state(psock, SK_PSOCK_TX_ENABLED); 491 refcount_set(&psock->refcnt, 1); 492 493 rcu_assign_sk_user_data(sk, psock); 494 sock_hold(sk); 495 496 return psock; 497 } 498 EXPORT_SYMBOL_GPL(sk_psock_init); 499 500 struct sk_psock_link *sk_psock_link_pop(struct sk_psock *psock) 501 { 502 struct sk_psock_link *link; 503 504 spin_lock_bh(&psock->link_lock); 505 link = list_first_entry_or_null(&psock->link, struct sk_psock_link, 506 list); 507 if (link) 508 list_del(&link->list); 509 spin_unlock_bh(&psock->link_lock); 510 return link; 511 } 512 513 void __sk_psock_purge_ingress_msg(struct sk_psock *psock) 514 { 515 struct sk_msg *msg, *tmp; 516 517 list_for_each_entry_safe(msg, tmp, &psock->ingress_msg, list) { 518 list_del(&msg->list); 519 sk_msg_free(psock->sk, msg); 520 kfree(msg); 521 } 522 } 523 524 static void sk_psock_zap_ingress(struct sk_psock *psock) 525 { 526 __skb_queue_purge(&psock->ingress_skb); 527 __sk_psock_purge_ingress_msg(psock); 528 } 529 530 static void sk_psock_link_destroy(struct sk_psock *psock) 531 { 532 struct sk_psock_link *link, *tmp; 533 534 list_for_each_entry_safe(link, tmp, &psock->link, list) { 535 list_del(&link->list); 536 sk_psock_free_link(link); 537 } 538 } 539 540 static void sk_psock_destroy_deferred(struct work_struct *gc) 541 { 542 struct sk_psock *psock = container_of(gc, struct sk_psock, gc); 543 544 /* No sk_callback_lock since already detached. */ 545 if (psock->parser.enabled) 546 strp_done(&psock->parser.strp); 547 548 cancel_work_sync(&psock->work); 549 550 psock_progs_drop(&psock->progs); 551 552 sk_psock_link_destroy(psock); 553 sk_psock_cork_free(psock); 554 sk_psock_zap_ingress(psock); 555 556 if (psock->sk_redir) 557 sock_put(psock->sk_redir); 558 sock_put(psock->sk); 559 kfree(psock); 560 } 561 562 void sk_psock_destroy(struct rcu_head *rcu) 563 { 564 struct sk_psock *psock = container_of(rcu, struct sk_psock, rcu); 565 566 INIT_WORK(&psock->gc, sk_psock_destroy_deferred); 567 schedule_work(&psock->gc); 568 } 569 EXPORT_SYMBOL_GPL(sk_psock_destroy); 570 571 void sk_psock_drop(struct sock *sk, struct sk_psock *psock) 572 { 573 rcu_assign_sk_user_data(sk, NULL); 574 sk_psock_cork_free(psock); 575 sk_psock_restore_proto(sk, psock); 576 577 write_lock_bh(&sk->sk_callback_lock); 578 if (psock->progs.skb_parser) 579 sk_psock_stop_strp(sk, psock); 580 write_unlock_bh(&sk->sk_callback_lock); 581 sk_psock_clear_state(psock, SK_PSOCK_TX_ENABLED); 582 583 call_rcu_sched(&psock->rcu, sk_psock_destroy); 584 } 585 EXPORT_SYMBOL_GPL(sk_psock_drop); 586 587 static int sk_psock_map_verd(int verdict, bool redir) 588 { 589 switch (verdict) { 590 case SK_PASS: 591 return redir ? __SK_REDIRECT : __SK_PASS; 592 case SK_DROP: 593 default: 594 break; 595 } 596 597 return __SK_DROP; 598 } 599 600 int sk_psock_msg_verdict(struct sock *sk, struct sk_psock *psock, 601 struct sk_msg *msg) 602 { 603 struct bpf_prog *prog; 604 int ret; 605 606 preempt_disable(); 607 rcu_read_lock(); 608 prog = READ_ONCE(psock->progs.msg_parser); 609 if (unlikely(!prog)) { 610 ret = __SK_PASS; 611 goto out; 612 } 613 614 sk_msg_compute_data_pointers(msg); 615 msg->sk = sk; 616 ret = BPF_PROG_RUN(prog, msg); 617 ret = sk_psock_map_verd(ret, msg->sk_redir); 618 psock->apply_bytes = msg->apply_bytes; 619 if (ret == __SK_REDIRECT) { 620 if (psock->sk_redir) 621 sock_put(psock->sk_redir); 622 psock->sk_redir = msg->sk_redir; 623 if (!psock->sk_redir) { 624 ret = __SK_DROP; 625 goto out; 626 } 627 sock_hold(psock->sk_redir); 628 } 629 out: 630 rcu_read_unlock(); 631 preempt_enable(); 632 return ret; 633 } 634 EXPORT_SYMBOL_GPL(sk_psock_msg_verdict); 635 636 static int sk_psock_bpf_run(struct sk_psock *psock, struct bpf_prog *prog, 637 struct sk_buff *skb) 638 { 639 int ret; 640 641 skb->sk = psock->sk; 642 bpf_compute_data_end_sk_skb(skb); 643 preempt_disable(); 644 ret = BPF_PROG_RUN(prog, skb); 645 preempt_enable(); 646 /* strparser clones the skb before handing it to a upper layer, 647 * meaning skb_orphan has been called. We NULL sk on the way out 648 * to ensure we don't trigger a BUG_ON() in skb/sk operations 649 * later and because we are not charging the memory of this skb 650 * to any socket yet. 651 */ 652 skb->sk = NULL; 653 return ret; 654 } 655 656 static struct sk_psock *sk_psock_from_strp(struct strparser *strp) 657 { 658 struct sk_psock_parser *parser; 659 660 parser = container_of(strp, struct sk_psock_parser, strp); 661 return container_of(parser, struct sk_psock, parser); 662 } 663 664 static void sk_psock_verdict_apply(struct sk_psock *psock, 665 struct sk_buff *skb, int verdict) 666 { 667 struct sk_psock *psock_other; 668 struct sock *sk_other; 669 bool ingress; 670 671 switch (verdict) { 672 case __SK_REDIRECT: 673 sk_other = tcp_skb_bpf_redirect_fetch(skb); 674 if (unlikely(!sk_other)) 675 goto out_free; 676 psock_other = sk_psock(sk_other); 677 if (!psock_other || sock_flag(sk_other, SOCK_DEAD) || 678 !sk_psock_test_state(psock_other, SK_PSOCK_TX_ENABLED)) 679 goto out_free; 680 ingress = tcp_skb_bpf_ingress(skb); 681 if ((!ingress && sock_writeable(sk_other)) || 682 (ingress && 683 atomic_read(&sk_other->sk_rmem_alloc) <= 684 sk_other->sk_rcvbuf)) { 685 if (!ingress) 686 skb_set_owner_w(skb, sk_other); 687 skb_queue_tail(&psock_other->ingress_skb, skb); 688 schedule_work(&psock_other->work); 689 break; 690 } 691 /* fall-through */ 692 case __SK_DROP: 693 /* fall-through */ 694 default: 695 out_free: 696 kfree_skb(skb); 697 } 698 } 699 700 static void sk_psock_strp_read(struct strparser *strp, struct sk_buff *skb) 701 { 702 struct sk_psock *psock = sk_psock_from_strp(strp); 703 struct bpf_prog *prog; 704 int ret = __SK_DROP; 705 706 rcu_read_lock(); 707 prog = READ_ONCE(psock->progs.skb_verdict); 708 if (likely(prog)) { 709 skb_orphan(skb); 710 tcp_skb_bpf_redirect_clear(skb); 711 ret = sk_psock_bpf_run(psock, prog, skb); 712 ret = sk_psock_map_verd(ret, tcp_skb_bpf_redirect_fetch(skb)); 713 } 714 rcu_read_unlock(); 715 sk_psock_verdict_apply(psock, skb, ret); 716 } 717 718 static int sk_psock_strp_read_done(struct strparser *strp, int err) 719 { 720 return err; 721 } 722 723 static int sk_psock_strp_parse(struct strparser *strp, struct sk_buff *skb) 724 { 725 struct sk_psock *psock = sk_psock_from_strp(strp); 726 struct bpf_prog *prog; 727 int ret = skb->len; 728 729 rcu_read_lock(); 730 prog = READ_ONCE(psock->progs.skb_parser); 731 if (likely(prog)) 732 ret = sk_psock_bpf_run(psock, prog, skb); 733 rcu_read_unlock(); 734 return ret; 735 } 736 737 /* Called with socket lock held. */ 738 static void sk_psock_data_ready(struct sock *sk) 739 { 740 struct sk_psock *psock; 741 742 rcu_read_lock(); 743 psock = sk_psock(sk); 744 if (likely(psock)) { 745 write_lock_bh(&sk->sk_callback_lock); 746 strp_data_ready(&psock->parser.strp); 747 write_unlock_bh(&sk->sk_callback_lock); 748 } 749 rcu_read_unlock(); 750 } 751 752 static void sk_psock_write_space(struct sock *sk) 753 { 754 struct sk_psock *psock; 755 void (*write_space)(struct sock *sk); 756 757 rcu_read_lock(); 758 psock = sk_psock(sk); 759 if (likely(psock && sk_psock_test_state(psock, SK_PSOCK_TX_ENABLED))) 760 schedule_work(&psock->work); 761 write_space = psock->saved_write_space; 762 rcu_read_unlock(); 763 write_space(sk); 764 } 765 766 int sk_psock_init_strp(struct sock *sk, struct sk_psock *psock) 767 { 768 static const struct strp_callbacks cb = { 769 .rcv_msg = sk_psock_strp_read, 770 .read_sock_done = sk_psock_strp_read_done, 771 .parse_msg = sk_psock_strp_parse, 772 }; 773 774 psock->parser.enabled = false; 775 return strp_init(&psock->parser.strp, sk, &cb); 776 } 777 778 void sk_psock_start_strp(struct sock *sk, struct sk_psock *psock) 779 { 780 struct sk_psock_parser *parser = &psock->parser; 781 782 if (parser->enabled) 783 return; 784 785 parser->saved_data_ready = sk->sk_data_ready; 786 sk->sk_data_ready = sk_psock_data_ready; 787 sk->sk_write_space = sk_psock_write_space; 788 parser->enabled = true; 789 } 790 791 void sk_psock_stop_strp(struct sock *sk, struct sk_psock *psock) 792 { 793 struct sk_psock_parser *parser = &psock->parser; 794 795 if (!parser->enabled) 796 return; 797 798 sk->sk_data_ready = parser->saved_data_ready; 799 parser->saved_data_ready = NULL; 800 strp_stop(&parser->strp); 801 parser->enabled = false; 802 } 803