1 // SPDX-License-Identifier: GPL-2.0 2 /* Copyright (c) 2017 - 2018 Covalent IO, Inc. http://covalent.io */ 3 4 #include <linux/skmsg.h> 5 #include <linux/filter.h> 6 #include <linux/bpf.h> 7 #include <linux/init.h> 8 #include <linux/wait.h> 9 10 #include <net/inet_common.h> 11 #include <net/tls.h> 12 13 static bool tcp_bpf_stream_read(const struct sock *sk) 14 { 15 struct sk_psock *psock; 16 bool empty = true; 17 18 rcu_read_lock(); 19 psock = sk_psock(sk); 20 if (likely(psock)) 21 empty = list_empty(&psock->ingress_msg); 22 rcu_read_unlock(); 23 return !empty; 24 } 25 26 static int tcp_bpf_wait_data(struct sock *sk, struct sk_psock *psock, 27 int flags, long timeo, int *err) 28 { 29 DEFINE_WAIT_FUNC(wait, woken_wake_function); 30 int ret = 0; 31 32 if (!timeo) 33 return ret; 34 35 add_wait_queue(sk_sleep(sk), &wait); 36 sk_set_bit(SOCKWQ_ASYNC_WAITDATA, sk); 37 ret = sk_wait_event(sk, &timeo, 38 !list_empty(&psock->ingress_msg) || 39 !skb_queue_empty(&sk->sk_receive_queue), &wait); 40 sk_clear_bit(SOCKWQ_ASYNC_WAITDATA, sk); 41 remove_wait_queue(sk_sleep(sk), &wait); 42 return ret; 43 } 44 45 int __tcp_bpf_recvmsg(struct sock *sk, struct sk_psock *psock, 46 struct msghdr *msg, int len, int flags) 47 { 48 struct iov_iter *iter = &msg->msg_iter; 49 int peek = flags & MSG_PEEK; 50 int i, ret, copied = 0; 51 struct sk_msg *msg_rx; 52 53 msg_rx = list_first_entry_or_null(&psock->ingress_msg, 54 struct sk_msg, list); 55 56 while (copied != len) { 57 struct scatterlist *sge; 58 59 if (unlikely(!msg_rx)) 60 break; 61 62 i = msg_rx->sg.start; 63 do { 64 struct page *page; 65 int copy; 66 67 sge = sk_msg_elem(msg_rx, i); 68 copy = sge->length; 69 page = sg_page(sge); 70 if (copied + copy > len) 71 copy = len - copied; 72 ret = copy_page_to_iter(page, sge->offset, copy, iter); 73 if (ret != copy) { 74 msg_rx->sg.start = i; 75 return -EFAULT; 76 } 77 78 copied += copy; 79 if (likely(!peek)) { 80 sge->offset += copy; 81 sge->length -= copy; 82 sk_mem_uncharge(sk, copy); 83 msg_rx->sg.size -= copy; 84 85 if (!sge->length) { 86 sk_msg_iter_var_next(i); 87 if (!msg_rx->skb) 88 put_page(page); 89 } 90 } else { 91 sk_msg_iter_var_next(i); 92 } 93 94 if (copied == len) 95 break; 96 } while (i != msg_rx->sg.end); 97 98 if (unlikely(peek)) { 99 msg_rx = list_next_entry(msg_rx, list); 100 continue; 101 } 102 103 msg_rx->sg.start = i; 104 if (!sge->length && msg_rx->sg.start == msg_rx->sg.end) { 105 list_del(&msg_rx->list); 106 if (msg_rx->skb) 107 consume_skb(msg_rx->skb); 108 kfree(msg_rx); 109 } 110 msg_rx = list_first_entry_or_null(&psock->ingress_msg, 111 struct sk_msg, list); 112 } 113 114 return copied; 115 } 116 EXPORT_SYMBOL_GPL(__tcp_bpf_recvmsg); 117 118 int tcp_bpf_recvmsg(struct sock *sk, struct msghdr *msg, size_t len, 119 int nonblock, int flags, int *addr_len) 120 { 121 struct sk_psock *psock; 122 int copied, ret; 123 124 psock = sk_psock_get(sk); 125 if (unlikely(!psock)) 126 return tcp_recvmsg(sk, msg, len, nonblock, flags, addr_len); 127 if (unlikely(flags & MSG_ERRQUEUE)) 128 return inet_recv_error(sk, msg, len, addr_len); 129 if (!skb_queue_empty(&sk->sk_receive_queue) && 130 sk_psock_queue_empty(psock)) 131 return tcp_recvmsg(sk, msg, len, nonblock, flags, addr_len); 132 lock_sock(sk); 133 msg_bytes_ready: 134 copied = __tcp_bpf_recvmsg(sk, psock, msg, len, flags); 135 if (!copied) { 136 int data, err = 0; 137 long timeo; 138 139 timeo = sock_rcvtimeo(sk, nonblock); 140 data = tcp_bpf_wait_data(sk, psock, flags, timeo, &err); 141 if (data) { 142 if (!sk_psock_queue_empty(psock)) 143 goto msg_bytes_ready; 144 release_sock(sk); 145 sk_psock_put(sk, psock); 146 return tcp_recvmsg(sk, msg, len, nonblock, flags, addr_len); 147 } 148 if (err) { 149 ret = err; 150 goto out; 151 } 152 copied = -EAGAIN; 153 } 154 ret = copied; 155 out: 156 release_sock(sk); 157 sk_psock_put(sk, psock); 158 return ret; 159 } 160 161 static int bpf_tcp_ingress(struct sock *sk, struct sk_psock *psock, 162 struct sk_msg *msg, u32 apply_bytes, int flags) 163 { 164 bool apply = apply_bytes; 165 struct scatterlist *sge; 166 u32 size, copied = 0; 167 struct sk_msg *tmp; 168 int i, ret = 0; 169 170 tmp = kzalloc(sizeof(*tmp), __GFP_NOWARN | GFP_KERNEL); 171 if (unlikely(!tmp)) 172 return -ENOMEM; 173 174 lock_sock(sk); 175 tmp->sg.start = msg->sg.start; 176 i = msg->sg.start; 177 do { 178 sge = sk_msg_elem(msg, i); 179 size = (apply && apply_bytes < sge->length) ? 180 apply_bytes : sge->length; 181 if (!sk_wmem_schedule(sk, size)) { 182 if (!copied) 183 ret = -ENOMEM; 184 break; 185 } 186 187 sk_mem_charge(sk, size); 188 sk_msg_xfer(tmp, msg, i, size); 189 copied += size; 190 if (sge->length) 191 get_page(sk_msg_page(tmp, i)); 192 sk_msg_iter_var_next(i); 193 tmp->sg.end = i; 194 if (apply) { 195 apply_bytes -= size; 196 if (!apply_bytes) 197 break; 198 } 199 } while (i != msg->sg.end); 200 201 if (!ret) { 202 msg->sg.start = i; 203 msg->sg.size -= apply_bytes; 204 sk_psock_queue_msg(psock, tmp); 205 sk_psock_data_ready(sk, psock); 206 } else { 207 sk_msg_free(sk, tmp); 208 kfree(tmp); 209 } 210 211 release_sock(sk); 212 return ret; 213 } 214 215 static int tcp_bpf_push(struct sock *sk, struct sk_msg *msg, u32 apply_bytes, 216 int flags, bool uncharge) 217 { 218 bool apply = apply_bytes; 219 struct scatterlist *sge; 220 struct page *page; 221 int size, ret = 0; 222 u32 off; 223 224 while (1) { 225 bool has_tx_ulp; 226 227 sge = sk_msg_elem(msg, msg->sg.start); 228 size = (apply && apply_bytes < sge->length) ? 229 apply_bytes : sge->length; 230 off = sge->offset; 231 page = sg_page(sge); 232 233 tcp_rate_check_app_limited(sk); 234 retry: 235 has_tx_ulp = tls_sw_has_ctx_tx(sk); 236 if (has_tx_ulp) { 237 flags |= MSG_SENDPAGE_NOPOLICY; 238 ret = kernel_sendpage_locked(sk, 239 page, off, size, flags); 240 } else { 241 ret = do_tcp_sendpages(sk, page, off, size, flags); 242 } 243 244 if (ret <= 0) 245 return ret; 246 if (apply) 247 apply_bytes -= ret; 248 msg->sg.size -= ret; 249 sge->offset += ret; 250 sge->length -= ret; 251 if (uncharge) 252 sk_mem_uncharge(sk, ret); 253 if (ret != size) { 254 size -= ret; 255 off += ret; 256 goto retry; 257 } 258 if (!sge->length) { 259 put_page(page); 260 sk_msg_iter_next(msg, start); 261 sg_init_table(sge, 1); 262 if (msg->sg.start == msg->sg.end) 263 break; 264 } 265 if (apply && !apply_bytes) 266 break; 267 } 268 269 return 0; 270 } 271 272 static int tcp_bpf_push_locked(struct sock *sk, struct sk_msg *msg, 273 u32 apply_bytes, int flags, bool uncharge) 274 { 275 int ret; 276 277 lock_sock(sk); 278 ret = tcp_bpf_push(sk, msg, apply_bytes, flags, uncharge); 279 release_sock(sk); 280 return ret; 281 } 282 283 int tcp_bpf_sendmsg_redir(struct sock *sk, struct sk_msg *msg, 284 u32 bytes, int flags) 285 { 286 bool ingress = sk_msg_to_ingress(msg); 287 struct sk_psock *psock = sk_psock_get(sk); 288 int ret; 289 290 if (unlikely(!psock)) { 291 sk_msg_free(sk, msg); 292 return 0; 293 } 294 ret = ingress ? bpf_tcp_ingress(sk, psock, msg, bytes, flags) : 295 tcp_bpf_push_locked(sk, msg, bytes, flags, false); 296 sk_psock_put(sk, psock); 297 return ret; 298 } 299 EXPORT_SYMBOL_GPL(tcp_bpf_sendmsg_redir); 300 301 static int tcp_bpf_send_verdict(struct sock *sk, struct sk_psock *psock, 302 struct sk_msg *msg, int *copied, int flags) 303 { 304 bool cork = false, enospc = sk_msg_full(msg); 305 struct sock *sk_redir; 306 u32 tosend, delta = 0; 307 int ret; 308 309 more_data: 310 if (psock->eval == __SK_NONE) { 311 /* Track delta in msg size to add/subtract it on SK_DROP from 312 * returned to user copied size. This ensures user doesn't 313 * get a positive return code with msg_cut_data and SK_DROP 314 * verdict. 315 */ 316 delta = msg->sg.size; 317 psock->eval = sk_psock_msg_verdict(sk, psock, msg); 318 delta -= msg->sg.size; 319 } 320 321 if (msg->cork_bytes && 322 msg->cork_bytes > msg->sg.size && !enospc) { 323 psock->cork_bytes = msg->cork_bytes - msg->sg.size; 324 if (!psock->cork) { 325 psock->cork = kzalloc(sizeof(*psock->cork), 326 GFP_ATOMIC | __GFP_NOWARN); 327 if (!psock->cork) 328 return -ENOMEM; 329 } 330 memcpy(psock->cork, msg, sizeof(*msg)); 331 return 0; 332 } 333 334 tosend = msg->sg.size; 335 if (psock->apply_bytes && psock->apply_bytes < tosend) 336 tosend = psock->apply_bytes; 337 338 switch (psock->eval) { 339 case __SK_PASS: 340 ret = tcp_bpf_push(sk, msg, tosend, flags, true); 341 if (unlikely(ret)) { 342 *copied -= sk_msg_free(sk, msg); 343 break; 344 } 345 sk_msg_apply_bytes(psock, tosend); 346 break; 347 case __SK_REDIRECT: 348 sk_redir = psock->sk_redir; 349 sk_msg_apply_bytes(psock, tosend); 350 if (psock->cork) { 351 cork = true; 352 psock->cork = NULL; 353 } 354 sk_msg_return(sk, msg, tosend); 355 release_sock(sk); 356 ret = tcp_bpf_sendmsg_redir(sk_redir, msg, tosend, flags); 357 lock_sock(sk); 358 if (unlikely(ret < 0)) { 359 int free = sk_msg_free_nocharge(sk, msg); 360 361 if (!cork) 362 *copied -= free; 363 } 364 if (cork) { 365 sk_msg_free(sk, msg); 366 kfree(msg); 367 msg = NULL; 368 ret = 0; 369 } 370 break; 371 case __SK_DROP: 372 default: 373 sk_msg_free_partial(sk, msg, tosend); 374 sk_msg_apply_bytes(psock, tosend); 375 *copied -= (tosend + delta); 376 return -EACCES; 377 } 378 379 if (likely(!ret)) { 380 if (!psock->apply_bytes) { 381 psock->eval = __SK_NONE; 382 if (psock->sk_redir) { 383 sock_put(psock->sk_redir); 384 psock->sk_redir = NULL; 385 } 386 } 387 if (msg && 388 msg->sg.data[msg->sg.start].page_link && 389 msg->sg.data[msg->sg.start].length) 390 goto more_data; 391 } 392 return ret; 393 } 394 395 static int tcp_bpf_sendmsg(struct sock *sk, struct msghdr *msg, size_t size) 396 { 397 struct sk_msg tmp, *msg_tx = NULL; 398 int copied = 0, err = 0; 399 struct sk_psock *psock; 400 long timeo; 401 int flags; 402 403 /* Don't let internal do_tcp_sendpages() flags through */ 404 flags = (msg->msg_flags & ~MSG_SENDPAGE_DECRYPTED); 405 flags |= MSG_NO_SHARED_FRAGS; 406 407 psock = sk_psock_get(sk); 408 if (unlikely(!psock)) 409 return tcp_sendmsg(sk, msg, size); 410 411 lock_sock(sk); 412 timeo = sock_sndtimeo(sk, msg->msg_flags & MSG_DONTWAIT); 413 while (msg_data_left(msg)) { 414 bool enospc = false; 415 u32 copy, osize; 416 417 if (sk->sk_err) { 418 err = -sk->sk_err; 419 goto out_err; 420 } 421 422 copy = msg_data_left(msg); 423 if (!sk_stream_memory_free(sk)) 424 goto wait_for_sndbuf; 425 if (psock->cork) { 426 msg_tx = psock->cork; 427 } else { 428 msg_tx = &tmp; 429 sk_msg_init(msg_tx); 430 } 431 432 osize = msg_tx->sg.size; 433 err = sk_msg_alloc(sk, msg_tx, msg_tx->sg.size + copy, msg_tx->sg.end - 1); 434 if (err) { 435 if (err != -ENOSPC) 436 goto wait_for_memory; 437 enospc = true; 438 copy = msg_tx->sg.size - osize; 439 } 440 441 err = sk_msg_memcopy_from_iter(sk, &msg->msg_iter, msg_tx, 442 copy); 443 if (err < 0) { 444 sk_msg_trim(sk, msg_tx, osize); 445 goto out_err; 446 } 447 448 copied += copy; 449 if (psock->cork_bytes) { 450 if (size > psock->cork_bytes) 451 psock->cork_bytes = 0; 452 else 453 psock->cork_bytes -= size; 454 if (psock->cork_bytes && !enospc) 455 goto out_err; 456 /* All cork bytes are accounted, rerun the prog. */ 457 psock->eval = __SK_NONE; 458 psock->cork_bytes = 0; 459 } 460 461 err = tcp_bpf_send_verdict(sk, psock, msg_tx, &copied, flags); 462 if (unlikely(err < 0)) 463 goto out_err; 464 continue; 465 wait_for_sndbuf: 466 set_bit(SOCK_NOSPACE, &sk->sk_socket->flags); 467 wait_for_memory: 468 err = sk_stream_wait_memory(sk, &timeo); 469 if (err) { 470 if (msg_tx && msg_tx != psock->cork) 471 sk_msg_free(sk, msg_tx); 472 goto out_err; 473 } 474 } 475 out_err: 476 if (err < 0) 477 err = sk_stream_error(sk, msg->msg_flags, err); 478 release_sock(sk); 479 sk_psock_put(sk, psock); 480 return copied ? copied : err; 481 } 482 483 static int tcp_bpf_sendpage(struct sock *sk, struct page *page, int offset, 484 size_t size, int flags) 485 { 486 struct sk_msg tmp, *msg = NULL; 487 int err = 0, copied = 0; 488 struct sk_psock *psock; 489 bool enospc = false; 490 491 psock = sk_psock_get(sk); 492 if (unlikely(!psock)) 493 return tcp_sendpage(sk, page, offset, size, flags); 494 495 lock_sock(sk); 496 if (psock->cork) { 497 msg = psock->cork; 498 } else { 499 msg = &tmp; 500 sk_msg_init(msg); 501 } 502 503 /* Catch case where ring is full and sendpage is stalled. */ 504 if (unlikely(sk_msg_full(msg))) 505 goto out_err; 506 507 sk_msg_page_add(msg, page, size, offset); 508 sk_mem_charge(sk, size); 509 copied = size; 510 if (sk_msg_full(msg)) 511 enospc = true; 512 if (psock->cork_bytes) { 513 if (size > psock->cork_bytes) 514 psock->cork_bytes = 0; 515 else 516 psock->cork_bytes -= size; 517 if (psock->cork_bytes && !enospc) 518 goto out_err; 519 /* All cork bytes are accounted, rerun the prog. */ 520 psock->eval = __SK_NONE; 521 psock->cork_bytes = 0; 522 } 523 524 err = tcp_bpf_send_verdict(sk, psock, msg, &copied, flags); 525 out_err: 526 release_sock(sk); 527 sk_psock_put(sk, psock); 528 return copied ? copied : err; 529 } 530 531 static void tcp_bpf_remove(struct sock *sk, struct sk_psock *psock) 532 { 533 struct sk_psock_link *link; 534 535 while ((link = sk_psock_link_pop(psock))) { 536 sk_psock_unlink(sk, link); 537 sk_psock_free_link(link); 538 } 539 } 540 541 static void tcp_bpf_unhash(struct sock *sk) 542 { 543 void (*saved_unhash)(struct sock *sk); 544 struct sk_psock *psock; 545 546 rcu_read_lock(); 547 psock = sk_psock(sk); 548 if (unlikely(!psock)) { 549 rcu_read_unlock(); 550 if (sk->sk_prot->unhash) 551 sk->sk_prot->unhash(sk); 552 return; 553 } 554 555 saved_unhash = psock->saved_unhash; 556 tcp_bpf_remove(sk, psock); 557 rcu_read_unlock(); 558 saved_unhash(sk); 559 } 560 561 static void tcp_bpf_close(struct sock *sk, long timeout) 562 { 563 void (*saved_close)(struct sock *sk, long timeout); 564 struct sk_psock *psock; 565 566 lock_sock(sk); 567 rcu_read_lock(); 568 psock = sk_psock(sk); 569 if (unlikely(!psock)) { 570 rcu_read_unlock(); 571 release_sock(sk); 572 return sk->sk_prot->close(sk, timeout); 573 } 574 575 saved_close = psock->saved_close; 576 tcp_bpf_remove(sk, psock); 577 rcu_read_unlock(); 578 release_sock(sk); 579 saved_close(sk, timeout); 580 } 581 582 enum { 583 TCP_BPF_IPV4, 584 TCP_BPF_IPV6, 585 TCP_BPF_NUM_PROTS, 586 }; 587 588 enum { 589 TCP_BPF_BASE, 590 TCP_BPF_TX, 591 TCP_BPF_NUM_CFGS, 592 }; 593 594 static struct proto *tcpv6_prot_saved __read_mostly; 595 static DEFINE_SPINLOCK(tcpv6_prot_lock); 596 static struct proto tcp_bpf_prots[TCP_BPF_NUM_PROTS][TCP_BPF_NUM_CFGS]; 597 598 static void tcp_bpf_rebuild_protos(struct proto prot[TCP_BPF_NUM_CFGS], 599 struct proto *base) 600 { 601 prot[TCP_BPF_BASE] = *base; 602 prot[TCP_BPF_BASE].unhash = tcp_bpf_unhash; 603 prot[TCP_BPF_BASE].close = tcp_bpf_close; 604 prot[TCP_BPF_BASE].recvmsg = tcp_bpf_recvmsg; 605 prot[TCP_BPF_BASE].stream_memory_read = tcp_bpf_stream_read; 606 607 prot[TCP_BPF_TX] = prot[TCP_BPF_BASE]; 608 prot[TCP_BPF_TX].sendmsg = tcp_bpf_sendmsg; 609 prot[TCP_BPF_TX].sendpage = tcp_bpf_sendpage; 610 } 611 612 static void tcp_bpf_check_v6_needs_rebuild(struct sock *sk, struct proto *ops) 613 { 614 if (sk->sk_family == AF_INET6 && 615 unlikely(ops != smp_load_acquire(&tcpv6_prot_saved))) { 616 spin_lock_bh(&tcpv6_prot_lock); 617 if (likely(ops != tcpv6_prot_saved)) { 618 tcp_bpf_rebuild_protos(tcp_bpf_prots[TCP_BPF_IPV6], ops); 619 smp_store_release(&tcpv6_prot_saved, ops); 620 } 621 spin_unlock_bh(&tcpv6_prot_lock); 622 } 623 } 624 625 static int __init tcp_bpf_v4_build_proto(void) 626 { 627 tcp_bpf_rebuild_protos(tcp_bpf_prots[TCP_BPF_IPV4], &tcp_prot); 628 return 0; 629 } 630 core_initcall(tcp_bpf_v4_build_proto); 631 632 static void tcp_bpf_update_sk_prot(struct sock *sk, struct sk_psock *psock) 633 { 634 int family = sk->sk_family == AF_INET6 ? TCP_BPF_IPV6 : TCP_BPF_IPV4; 635 int config = psock->progs.msg_parser ? TCP_BPF_TX : TCP_BPF_BASE; 636 637 sk_psock_update_proto(sk, psock, &tcp_bpf_prots[family][config]); 638 } 639 640 static void tcp_bpf_reinit_sk_prot(struct sock *sk, struct sk_psock *psock) 641 { 642 int family = sk->sk_family == AF_INET6 ? TCP_BPF_IPV6 : TCP_BPF_IPV4; 643 int config = psock->progs.msg_parser ? TCP_BPF_TX : TCP_BPF_BASE; 644 645 /* Reinit occurs when program types change e.g. TCP_BPF_TX is removed 646 * or added requiring sk_prot hook updates. We keep original saved 647 * hooks in this case. 648 */ 649 sk->sk_prot = &tcp_bpf_prots[family][config]; 650 } 651 652 static int tcp_bpf_assert_proto_ops(struct proto *ops) 653 { 654 /* In order to avoid retpoline, we make assumptions when we call 655 * into ops if e.g. a psock is not present. Make sure they are 656 * indeed valid assumptions. 657 */ 658 return ops->recvmsg == tcp_recvmsg && 659 ops->sendmsg == tcp_sendmsg && 660 ops->sendpage == tcp_sendpage ? 0 : -ENOTSUPP; 661 } 662 663 void tcp_bpf_reinit(struct sock *sk) 664 { 665 struct sk_psock *psock; 666 667 sock_owned_by_me(sk); 668 669 rcu_read_lock(); 670 psock = sk_psock(sk); 671 tcp_bpf_reinit_sk_prot(sk, psock); 672 rcu_read_unlock(); 673 } 674 675 int tcp_bpf_init(struct sock *sk) 676 { 677 struct proto *ops = READ_ONCE(sk->sk_prot); 678 struct sk_psock *psock; 679 680 sock_owned_by_me(sk); 681 682 rcu_read_lock(); 683 psock = sk_psock(sk); 684 if (unlikely(!psock || psock->sk_proto || 685 tcp_bpf_assert_proto_ops(ops))) { 686 rcu_read_unlock(); 687 return -EINVAL; 688 } 689 tcp_bpf_check_v6_needs_rebuild(sk, ops); 690 tcp_bpf_update_sk_prot(sk, psock); 691 rcu_read_unlock(); 692 return 0; 693 } 694