1 // SPDX-License-Identifier: GPL-2.0 2 /* Copyright (c) 2017 - 2018 Covalent IO, Inc. http://covalent.io */ 3 4 #include <linux/skmsg.h> 5 #include <linux/filter.h> 6 #include <linux/bpf.h> 7 #include <linux/init.h> 8 #include <linux/wait.h> 9 10 #include <net/inet_common.h> 11 #include <net/tls.h> 12 13 static bool tcp_bpf_stream_read(const struct sock *sk) 14 { 15 struct sk_psock *psock; 16 bool empty = true; 17 18 rcu_read_lock(); 19 psock = sk_psock(sk); 20 if (likely(psock)) 21 empty = list_empty(&psock->ingress_msg); 22 rcu_read_unlock(); 23 return !empty; 24 } 25 26 static int tcp_bpf_wait_data(struct sock *sk, struct sk_psock *psock, 27 int flags, long timeo, int *err) 28 { 29 DEFINE_WAIT_FUNC(wait, woken_wake_function); 30 int ret; 31 32 add_wait_queue(sk_sleep(sk), &wait); 33 sk_set_bit(SOCKWQ_ASYNC_WAITDATA, sk); 34 ret = sk_wait_event(sk, &timeo, 35 !list_empty(&psock->ingress_msg) || 36 !skb_queue_empty(&sk->sk_receive_queue), &wait); 37 sk_clear_bit(SOCKWQ_ASYNC_WAITDATA, sk); 38 remove_wait_queue(sk_sleep(sk), &wait); 39 return ret; 40 } 41 42 int __tcp_bpf_recvmsg(struct sock *sk, struct sk_psock *psock, 43 struct msghdr *msg, int len, int flags) 44 { 45 struct iov_iter *iter = &msg->msg_iter; 46 int peek = flags & MSG_PEEK; 47 int i, ret, copied = 0; 48 struct sk_msg *msg_rx; 49 50 msg_rx = list_first_entry_or_null(&psock->ingress_msg, 51 struct sk_msg, list); 52 53 while (copied != len) { 54 struct scatterlist *sge; 55 56 if (unlikely(!msg_rx)) 57 break; 58 59 i = msg_rx->sg.start; 60 do { 61 struct page *page; 62 int copy; 63 64 sge = sk_msg_elem(msg_rx, i); 65 copy = sge->length; 66 page = sg_page(sge); 67 if (copied + copy > len) 68 copy = len - copied; 69 ret = copy_page_to_iter(page, sge->offset, copy, iter); 70 if (ret != copy) { 71 msg_rx->sg.start = i; 72 return -EFAULT; 73 } 74 75 copied += copy; 76 if (likely(!peek)) { 77 sge->offset += copy; 78 sge->length -= copy; 79 sk_mem_uncharge(sk, copy); 80 msg_rx->sg.size -= copy; 81 82 if (!sge->length) { 83 sk_msg_iter_var_next(i); 84 if (!msg_rx->skb) 85 put_page(page); 86 } 87 } else { 88 sk_msg_iter_var_next(i); 89 } 90 91 if (copied == len) 92 break; 93 } while (i != msg_rx->sg.end); 94 95 if (unlikely(peek)) { 96 msg_rx = list_next_entry(msg_rx, list); 97 continue; 98 } 99 100 msg_rx->sg.start = i; 101 if (!sge->length && msg_rx->sg.start == msg_rx->sg.end) { 102 list_del(&msg_rx->list); 103 if (msg_rx->skb) 104 consume_skb(msg_rx->skb); 105 kfree(msg_rx); 106 } 107 msg_rx = list_first_entry_or_null(&psock->ingress_msg, 108 struct sk_msg, list); 109 } 110 111 return copied; 112 } 113 EXPORT_SYMBOL_GPL(__tcp_bpf_recvmsg); 114 115 int tcp_bpf_recvmsg(struct sock *sk, struct msghdr *msg, size_t len, 116 int nonblock, int flags, int *addr_len) 117 { 118 struct sk_psock *psock; 119 int copied, ret; 120 121 if (unlikely(flags & MSG_ERRQUEUE)) 122 return inet_recv_error(sk, msg, len, addr_len); 123 if (!skb_queue_empty(&sk->sk_receive_queue)) 124 return tcp_recvmsg(sk, msg, len, nonblock, flags, addr_len); 125 126 psock = sk_psock_get(sk); 127 if (unlikely(!psock)) 128 return tcp_recvmsg(sk, msg, len, nonblock, flags, addr_len); 129 lock_sock(sk); 130 msg_bytes_ready: 131 copied = __tcp_bpf_recvmsg(sk, psock, msg, len, flags); 132 if (!copied) { 133 int data, err = 0; 134 long timeo; 135 136 timeo = sock_rcvtimeo(sk, nonblock); 137 data = tcp_bpf_wait_data(sk, psock, flags, timeo, &err); 138 if (data) { 139 if (skb_queue_empty(&sk->sk_receive_queue)) 140 goto msg_bytes_ready; 141 release_sock(sk); 142 sk_psock_put(sk, psock); 143 return tcp_recvmsg(sk, msg, len, nonblock, flags, addr_len); 144 } 145 if (err) { 146 ret = err; 147 goto out; 148 } 149 copied = -EAGAIN; 150 } 151 ret = copied; 152 out: 153 release_sock(sk); 154 sk_psock_put(sk, psock); 155 return ret; 156 } 157 158 static int bpf_tcp_ingress(struct sock *sk, struct sk_psock *psock, 159 struct sk_msg *msg, u32 apply_bytes, int flags) 160 { 161 bool apply = apply_bytes; 162 struct scatterlist *sge; 163 u32 size, copied = 0; 164 struct sk_msg *tmp; 165 int i, ret = 0; 166 167 tmp = kzalloc(sizeof(*tmp), __GFP_NOWARN | GFP_KERNEL); 168 if (unlikely(!tmp)) 169 return -ENOMEM; 170 171 lock_sock(sk); 172 tmp->sg.start = msg->sg.start; 173 i = msg->sg.start; 174 do { 175 sge = sk_msg_elem(msg, i); 176 size = (apply && apply_bytes < sge->length) ? 177 apply_bytes : sge->length; 178 if (!sk_wmem_schedule(sk, size)) { 179 if (!copied) 180 ret = -ENOMEM; 181 break; 182 } 183 184 sk_mem_charge(sk, size); 185 sk_msg_xfer(tmp, msg, i, size); 186 copied += size; 187 if (sge->length) 188 get_page(sk_msg_page(tmp, i)); 189 sk_msg_iter_var_next(i); 190 tmp->sg.end = i; 191 if (apply) { 192 apply_bytes -= size; 193 if (!apply_bytes) 194 break; 195 } 196 } while (i != msg->sg.end); 197 198 if (!ret) { 199 msg->sg.start = i; 200 msg->sg.size -= apply_bytes; 201 sk_psock_queue_msg(psock, tmp); 202 sk_psock_data_ready(sk, psock); 203 } else { 204 sk_msg_free(sk, tmp); 205 kfree(tmp); 206 } 207 208 release_sock(sk); 209 return ret; 210 } 211 212 static int tcp_bpf_push(struct sock *sk, struct sk_msg *msg, u32 apply_bytes, 213 int flags, bool uncharge) 214 { 215 bool apply = apply_bytes; 216 struct scatterlist *sge; 217 struct page *page; 218 int size, ret = 0; 219 u32 off; 220 221 while (1) { 222 bool has_tx_ulp; 223 224 sge = sk_msg_elem(msg, msg->sg.start); 225 size = (apply && apply_bytes < sge->length) ? 226 apply_bytes : sge->length; 227 off = sge->offset; 228 page = sg_page(sge); 229 230 tcp_rate_check_app_limited(sk); 231 retry: 232 has_tx_ulp = tls_sw_has_ctx_tx(sk); 233 if (has_tx_ulp) { 234 flags |= MSG_SENDPAGE_NOPOLICY; 235 ret = kernel_sendpage_locked(sk, 236 page, off, size, flags); 237 } else { 238 ret = do_tcp_sendpages(sk, page, off, size, flags); 239 } 240 241 if (ret <= 0) 242 return ret; 243 if (apply) 244 apply_bytes -= ret; 245 msg->sg.size -= ret; 246 sge->offset += ret; 247 sge->length -= ret; 248 if (uncharge) 249 sk_mem_uncharge(sk, ret); 250 if (ret != size) { 251 size -= ret; 252 off += ret; 253 goto retry; 254 } 255 if (!sge->length) { 256 put_page(page); 257 sk_msg_iter_next(msg, start); 258 sg_init_table(sge, 1); 259 if (msg->sg.start == msg->sg.end) 260 break; 261 } 262 if (apply && !apply_bytes) 263 break; 264 } 265 266 return 0; 267 } 268 269 static int tcp_bpf_push_locked(struct sock *sk, struct sk_msg *msg, 270 u32 apply_bytes, int flags, bool uncharge) 271 { 272 int ret; 273 274 lock_sock(sk); 275 ret = tcp_bpf_push(sk, msg, apply_bytes, flags, uncharge); 276 release_sock(sk); 277 return ret; 278 } 279 280 int tcp_bpf_sendmsg_redir(struct sock *sk, struct sk_msg *msg, 281 u32 bytes, int flags) 282 { 283 bool ingress = sk_msg_to_ingress(msg); 284 struct sk_psock *psock = sk_psock_get(sk); 285 int ret; 286 287 if (unlikely(!psock)) { 288 sk_msg_free(sk, msg); 289 return 0; 290 } 291 ret = ingress ? bpf_tcp_ingress(sk, psock, msg, bytes, flags) : 292 tcp_bpf_push_locked(sk, msg, bytes, flags, false); 293 sk_psock_put(sk, psock); 294 return ret; 295 } 296 EXPORT_SYMBOL_GPL(tcp_bpf_sendmsg_redir); 297 298 static int tcp_bpf_send_verdict(struct sock *sk, struct sk_psock *psock, 299 struct sk_msg *msg, int *copied, int flags) 300 { 301 bool cork = false, enospc = msg->sg.start == msg->sg.end; 302 struct sock *sk_redir; 303 u32 tosend, delta = 0; 304 int ret; 305 306 more_data: 307 if (psock->eval == __SK_NONE) { 308 /* Track delta in msg size to add/subtract it on SK_DROP from 309 * returned to user copied size. This ensures user doesn't 310 * get a positive return code with msg_cut_data and SK_DROP 311 * verdict. 312 */ 313 delta = msg->sg.size; 314 psock->eval = sk_psock_msg_verdict(sk, psock, msg); 315 if (msg->sg.size < delta) 316 delta -= msg->sg.size; 317 else 318 delta = 0; 319 } 320 321 if (msg->cork_bytes && 322 msg->cork_bytes > msg->sg.size && !enospc) { 323 psock->cork_bytes = msg->cork_bytes - msg->sg.size; 324 if (!psock->cork) { 325 psock->cork = kzalloc(sizeof(*psock->cork), 326 GFP_ATOMIC | __GFP_NOWARN); 327 if (!psock->cork) 328 return -ENOMEM; 329 } 330 memcpy(psock->cork, msg, sizeof(*msg)); 331 return 0; 332 } 333 334 tosend = msg->sg.size; 335 if (psock->apply_bytes && psock->apply_bytes < tosend) 336 tosend = psock->apply_bytes; 337 338 switch (psock->eval) { 339 case __SK_PASS: 340 ret = tcp_bpf_push(sk, msg, tosend, flags, true); 341 if (unlikely(ret)) { 342 *copied -= sk_msg_free(sk, msg); 343 break; 344 } 345 sk_msg_apply_bytes(psock, tosend); 346 break; 347 case __SK_REDIRECT: 348 sk_redir = psock->sk_redir; 349 sk_msg_apply_bytes(psock, tosend); 350 if (psock->cork) { 351 cork = true; 352 psock->cork = NULL; 353 } 354 sk_msg_return(sk, msg, tosend); 355 release_sock(sk); 356 ret = tcp_bpf_sendmsg_redir(sk_redir, msg, tosend, flags); 357 lock_sock(sk); 358 if (unlikely(ret < 0)) { 359 int free = sk_msg_free_nocharge(sk, msg); 360 361 if (!cork) 362 *copied -= free; 363 } 364 if (cork) { 365 sk_msg_free(sk, msg); 366 kfree(msg); 367 msg = NULL; 368 ret = 0; 369 } 370 break; 371 case __SK_DROP: 372 default: 373 sk_msg_free_partial(sk, msg, tosend); 374 sk_msg_apply_bytes(psock, tosend); 375 *copied -= (tosend + delta); 376 return -EACCES; 377 } 378 379 if (likely(!ret)) { 380 if (!psock->apply_bytes) { 381 psock->eval = __SK_NONE; 382 if (psock->sk_redir) { 383 sock_put(psock->sk_redir); 384 psock->sk_redir = NULL; 385 } 386 } 387 if (msg && 388 msg->sg.data[msg->sg.start].page_link && 389 msg->sg.data[msg->sg.start].length) 390 goto more_data; 391 } 392 return ret; 393 } 394 395 static int tcp_bpf_sendmsg(struct sock *sk, struct msghdr *msg, size_t size) 396 { 397 struct sk_msg tmp, *msg_tx = NULL; 398 int flags = msg->msg_flags | MSG_NO_SHARED_FRAGS; 399 int copied = 0, err = 0; 400 struct sk_psock *psock; 401 long timeo; 402 403 psock = sk_psock_get(sk); 404 if (unlikely(!psock)) 405 return tcp_sendmsg(sk, msg, size); 406 407 lock_sock(sk); 408 timeo = sock_sndtimeo(sk, msg->msg_flags & MSG_DONTWAIT); 409 while (msg_data_left(msg)) { 410 bool enospc = false; 411 u32 copy, osize; 412 413 if (sk->sk_err) { 414 err = -sk->sk_err; 415 goto out_err; 416 } 417 418 copy = msg_data_left(msg); 419 if (!sk_stream_memory_free(sk)) 420 goto wait_for_sndbuf; 421 if (psock->cork) { 422 msg_tx = psock->cork; 423 } else { 424 msg_tx = &tmp; 425 sk_msg_init(msg_tx); 426 } 427 428 osize = msg_tx->sg.size; 429 err = sk_msg_alloc(sk, msg_tx, msg_tx->sg.size + copy, msg_tx->sg.end - 1); 430 if (err) { 431 if (err != -ENOSPC) 432 goto wait_for_memory; 433 enospc = true; 434 copy = msg_tx->sg.size - osize; 435 } 436 437 err = sk_msg_memcopy_from_iter(sk, &msg->msg_iter, msg_tx, 438 copy); 439 if (err < 0) { 440 sk_msg_trim(sk, msg_tx, osize); 441 goto out_err; 442 } 443 444 copied += copy; 445 if (psock->cork_bytes) { 446 if (size > psock->cork_bytes) 447 psock->cork_bytes = 0; 448 else 449 psock->cork_bytes -= size; 450 if (psock->cork_bytes && !enospc) 451 goto out_err; 452 /* All cork bytes are accounted, rerun the prog. */ 453 psock->eval = __SK_NONE; 454 psock->cork_bytes = 0; 455 } 456 457 err = tcp_bpf_send_verdict(sk, psock, msg_tx, &copied, flags); 458 if (unlikely(err < 0)) 459 goto out_err; 460 continue; 461 wait_for_sndbuf: 462 set_bit(SOCK_NOSPACE, &sk->sk_socket->flags); 463 wait_for_memory: 464 err = sk_stream_wait_memory(sk, &timeo); 465 if (err) { 466 if (msg_tx && msg_tx != psock->cork) 467 sk_msg_free(sk, msg_tx); 468 goto out_err; 469 } 470 } 471 out_err: 472 if (err < 0) 473 err = sk_stream_error(sk, msg->msg_flags, err); 474 release_sock(sk); 475 sk_psock_put(sk, psock); 476 return copied ? copied : err; 477 } 478 479 static int tcp_bpf_sendpage(struct sock *sk, struct page *page, int offset, 480 size_t size, int flags) 481 { 482 struct sk_msg tmp, *msg = NULL; 483 int err = 0, copied = 0; 484 struct sk_psock *psock; 485 bool enospc = false; 486 487 psock = sk_psock_get(sk); 488 if (unlikely(!psock)) 489 return tcp_sendpage(sk, page, offset, size, flags); 490 491 lock_sock(sk); 492 if (psock->cork) { 493 msg = psock->cork; 494 } else { 495 msg = &tmp; 496 sk_msg_init(msg); 497 } 498 499 /* Catch case where ring is full and sendpage is stalled. */ 500 if (unlikely(sk_msg_full(msg))) 501 goto out_err; 502 503 sk_msg_page_add(msg, page, size, offset); 504 sk_mem_charge(sk, size); 505 copied = size; 506 if (sk_msg_full(msg)) 507 enospc = true; 508 if (psock->cork_bytes) { 509 if (size > psock->cork_bytes) 510 psock->cork_bytes = 0; 511 else 512 psock->cork_bytes -= size; 513 if (psock->cork_bytes && !enospc) 514 goto out_err; 515 /* All cork bytes are accounted, rerun the prog. */ 516 psock->eval = __SK_NONE; 517 psock->cork_bytes = 0; 518 } 519 520 err = tcp_bpf_send_verdict(sk, psock, msg, &copied, flags); 521 out_err: 522 release_sock(sk); 523 sk_psock_put(sk, psock); 524 return copied ? copied : err; 525 } 526 527 static void tcp_bpf_remove(struct sock *sk, struct sk_psock *psock) 528 { 529 struct sk_psock_link *link; 530 531 sk_psock_cork_free(psock); 532 __sk_psock_purge_ingress_msg(psock); 533 while ((link = sk_psock_link_pop(psock))) { 534 sk_psock_unlink(sk, link); 535 sk_psock_free_link(link); 536 } 537 } 538 539 static void tcp_bpf_unhash(struct sock *sk) 540 { 541 void (*saved_unhash)(struct sock *sk); 542 struct sk_psock *psock; 543 544 rcu_read_lock(); 545 psock = sk_psock(sk); 546 if (unlikely(!psock)) { 547 rcu_read_unlock(); 548 if (sk->sk_prot->unhash) 549 sk->sk_prot->unhash(sk); 550 return; 551 } 552 553 saved_unhash = psock->saved_unhash; 554 tcp_bpf_remove(sk, psock); 555 rcu_read_unlock(); 556 saved_unhash(sk); 557 } 558 559 static void tcp_bpf_close(struct sock *sk, long timeout) 560 { 561 void (*saved_close)(struct sock *sk, long timeout); 562 struct sk_psock *psock; 563 564 lock_sock(sk); 565 rcu_read_lock(); 566 psock = sk_psock(sk); 567 if (unlikely(!psock)) { 568 rcu_read_unlock(); 569 release_sock(sk); 570 return sk->sk_prot->close(sk, timeout); 571 } 572 573 saved_close = psock->saved_close; 574 tcp_bpf_remove(sk, psock); 575 rcu_read_unlock(); 576 release_sock(sk); 577 saved_close(sk, timeout); 578 } 579 580 enum { 581 TCP_BPF_IPV4, 582 TCP_BPF_IPV6, 583 TCP_BPF_NUM_PROTS, 584 }; 585 586 enum { 587 TCP_BPF_BASE, 588 TCP_BPF_TX, 589 TCP_BPF_NUM_CFGS, 590 }; 591 592 static struct proto *tcpv6_prot_saved __read_mostly; 593 static DEFINE_SPINLOCK(tcpv6_prot_lock); 594 static struct proto tcp_bpf_prots[TCP_BPF_NUM_PROTS][TCP_BPF_NUM_CFGS]; 595 596 static void tcp_bpf_rebuild_protos(struct proto prot[TCP_BPF_NUM_CFGS], 597 struct proto *base) 598 { 599 prot[TCP_BPF_BASE] = *base; 600 prot[TCP_BPF_BASE].unhash = tcp_bpf_unhash; 601 prot[TCP_BPF_BASE].close = tcp_bpf_close; 602 prot[TCP_BPF_BASE].recvmsg = tcp_bpf_recvmsg; 603 prot[TCP_BPF_BASE].stream_memory_read = tcp_bpf_stream_read; 604 605 prot[TCP_BPF_TX] = prot[TCP_BPF_BASE]; 606 prot[TCP_BPF_TX].sendmsg = tcp_bpf_sendmsg; 607 prot[TCP_BPF_TX].sendpage = tcp_bpf_sendpage; 608 } 609 610 static void tcp_bpf_check_v6_needs_rebuild(struct sock *sk, struct proto *ops) 611 { 612 if (sk->sk_family == AF_INET6 && 613 unlikely(ops != smp_load_acquire(&tcpv6_prot_saved))) { 614 spin_lock_bh(&tcpv6_prot_lock); 615 if (likely(ops != tcpv6_prot_saved)) { 616 tcp_bpf_rebuild_protos(tcp_bpf_prots[TCP_BPF_IPV6], ops); 617 smp_store_release(&tcpv6_prot_saved, ops); 618 } 619 spin_unlock_bh(&tcpv6_prot_lock); 620 } 621 } 622 623 static int __init tcp_bpf_v4_build_proto(void) 624 { 625 tcp_bpf_rebuild_protos(tcp_bpf_prots[TCP_BPF_IPV4], &tcp_prot); 626 return 0; 627 } 628 core_initcall(tcp_bpf_v4_build_proto); 629 630 static void tcp_bpf_update_sk_prot(struct sock *sk, struct sk_psock *psock) 631 { 632 int family = sk->sk_family == AF_INET6 ? TCP_BPF_IPV6 : TCP_BPF_IPV4; 633 int config = psock->progs.msg_parser ? TCP_BPF_TX : TCP_BPF_BASE; 634 635 sk_psock_update_proto(sk, psock, &tcp_bpf_prots[family][config]); 636 } 637 638 static void tcp_bpf_reinit_sk_prot(struct sock *sk, struct sk_psock *psock) 639 { 640 int family = sk->sk_family == AF_INET6 ? TCP_BPF_IPV6 : TCP_BPF_IPV4; 641 int config = psock->progs.msg_parser ? TCP_BPF_TX : TCP_BPF_BASE; 642 643 /* Reinit occurs when program types change e.g. TCP_BPF_TX is removed 644 * or added requiring sk_prot hook updates. We keep original saved 645 * hooks in this case. 646 */ 647 sk->sk_prot = &tcp_bpf_prots[family][config]; 648 } 649 650 static int tcp_bpf_assert_proto_ops(struct proto *ops) 651 { 652 /* In order to avoid retpoline, we make assumptions when we call 653 * into ops if e.g. a psock is not present. Make sure they are 654 * indeed valid assumptions. 655 */ 656 return ops->recvmsg == tcp_recvmsg && 657 ops->sendmsg == tcp_sendmsg && 658 ops->sendpage == tcp_sendpage ? 0 : -ENOTSUPP; 659 } 660 661 void tcp_bpf_reinit(struct sock *sk) 662 { 663 struct sk_psock *psock; 664 665 sock_owned_by_me(sk); 666 667 rcu_read_lock(); 668 psock = sk_psock(sk); 669 tcp_bpf_reinit_sk_prot(sk, psock); 670 rcu_read_unlock(); 671 } 672 673 int tcp_bpf_init(struct sock *sk) 674 { 675 struct proto *ops = READ_ONCE(sk->sk_prot); 676 struct sk_psock *psock; 677 678 sock_owned_by_me(sk); 679 680 rcu_read_lock(); 681 psock = sk_psock(sk); 682 if (unlikely(!psock || psock->sk_proto || 683 tcp_bpf_assert_proto_ops(ops))) { 684 rcu_read_unlock(); 685 return -EINVAL; 686 } 687 tcp_bpf_check_v6_needs_rebuild(sk, ops); 688 tcp_bpf_update_sk_prot(sk, psock); 689 rcu_read_unlock(); 690 return 0; 691 } 692