1 // SPDX-License-Identifier: GPL-2.0 2 /* Copyright (c) 2017 - 2018 Covalent IO, Inc. http://covalent.io */ 3 4 #include <linux/skmsg.h> 5 #include <linux/filter.h> 6 #include <linux/bpf.h> 7 #include <linux/init.h> 8 #include <linux/wait.h> 9 #include <linux/util_macros.h> 10 11 #include <net/inet_common.h> 12 #include <net/tls.h> 13 14 static int bpf_tcp_ingress(struct sock *sk, struct sk_psock *psock, 15 struct sk_msg *msg, u32 apply_bytes, int flags) 16 { 17 bool apply = apply_bytes; 18 struct scatterlist *sge; 19 u32 size, copied = 0; 20 struct sk_msg *tmp; 21 int i, ret = 0; 22 23 tmp = kzalloc(sizeof(*tmp), __GFP_NOWARN | GFP_KERNEL); 24 if (unlikely(!tmp)) 25 return -ENOMEM; 26 27 lock_sock(sk); 28 tmp->sg.start = msg->sg.start; 29 i = msg->sg.start; 30 do { 31 sge = sk_msg_elem(msg, i); 32 size = (apply && apply_bytes < sge->length) ? 33 apply_bytes : sge->length; 34 if (!sk_wmem_schedule(sk, size)) { 35 if (!copied) 36 ret = -ENOMEM; 37 break; 38 } 39 40 sk_mem_charge(sk, size); 41 sk_msg_xfer(tmp, msg, i, size); 42 copied += size; 43 if (sge->length) 44 get_page(sk_msg_page(tmp, i)); 45 sk_msg_iter_var_next(i); 46 tmp->sg.end = i; 47 if (apply) { 48 apply_bytes -= size; 49 if (!apply_bytes) { 50 if (sge->length) 51 sk_msg_iter_var_prev(i); 52 break; 53 } 54 } 55 } while (i != msg->sg.end); 56 57 if (!ret) { 58 msg->sg.start = i; 59 sk_psock_queue_msg(psock, tmp); 60 sk_psock_data_ready(sk, psock); 61 } else { 62 sk_msg_free(sk, tmp); 63 kfree(tmp); 64 } 65 66 release_sock(sk); 67 return ret; 68 } 69 70 static int tcp_bpf_push(struct sock *sk, struct sk_msg *msg, u32 apply_bytes, 71 int flags, bool uncharge) 72 { 73 bool apply = apply_bytes; 74 struct scatterlist *sge; 75 struct page *page; 76 int size, ret = 0; 77 u32 off; 78 79 while (1) { 80 bool has_tx_ulp; 81 82 sge = sk_msg_elem(msg, msg->sg.start); 83 size = (apply && apply_bytes < sge->length) ? 84 apply_bytes : sge->length; 85 off = sge->offset; 86 page = sg_page(sge); 87 88 tcp_rate_check_app_limited(sk); 89 retry: 90 has_tx_ulp = tls_sw_has_ctx_tx(sk); 91 if (has_tx_ulp) { 92 flags |= MSG_SENDPAGE_NOPOLICY; 93 ret = kernel_sendpage_locked(sk, 94 page, off, size, flags); 95 } else { 96 ret = do_tcp_sendpages(sk, page, off, size, flags); 97 } 98 99 if (ret <= 0) 100 return ret; 101 if (apply) 102 apply_bytes -= ret; 103 msg->sg.size -= ret; 104 sge->offset += ret; 105 sge->length -= ret; 106 if (uncharge) 107 sk_mem_uncharge(sk, ret); 108 if (ret != size) { 109 size -= ret; 110 off += ret; 111 goto retry; 112 } 113 if (!sge->length) { 114 put_page(page); 115 sk_msg_iter_next(msg, start); 116 sg_init_table(sge, 1); 117 if (msg->sg.start == msg->sg.end) 118 break; 119 } 120 if (apply && !apply_bytes) 121 break; 122 } 123 124 return 0; 125 } 126 127 static int tcp_bpf_push_locked(struct sock *sk, struct sk_msg *msg, 128 u32 apply_bytes, int flags, bool uncharge) 129 { 130 int ret; 131 132 lock_sock(sk); 133 ret = tcp_bpf_push(sk, msg, apply_bytes, flags, uncharge); 134 release_sock(sk); 135 return ret; 136 } 137 138 int tcp_bpf_sendmsg_redir(struct sock *sk, bool ingress, 139 struct sk_msg *msg, u32 bytes, int flags) 140 { 141 struct sk_psock *psock = sk_psock_get(sk); 142 int ret; 143 144 if (unlikely(!psock)) 145 return -EPIPE; 146 147 ret = ingress ? bpf_tcp_ingress(sk, psock, msg, bytes, flags) : 148 tcp_bpf_push_locked(sk, msg, bytes, flags, false); 149 sk_psock_put(sk, psock); 150 return ret; 151 } 152 EXPORT_SYMBOL_GPL(tcp_bpf_sendmsg_redir); 153 154 #ifdef CONFIG_BPF_SYSCALL 155 static int tcp_msg_wait_data(struct sock *sk, struct sk_psock *psock, 156 long timeo) 157 { 158 DEFINE_WAIT_FUNC(wait, woken_wake_function); 159 int ret = 0; 160 161 if (sk->sk_shutdown & RCV_SHUTDOWN) 162 return 1; 163 164 if (!timeo) 165 return ret; 166 167 add_wait_queue(sk_sleep(sk), &wait); 168 sk_set_bit(SOCKWQ_ASYNC_WAITDATA, sk); 169 ret = sk_wait_event(sk, &timeo, 170 !list_empty(&psock->ingress_msg) || 171 !skb_queue_empty_lockless(&sk->sk_receive_queue), &wait); 172 sk_clear_bit(SOCKWQ_ASYNC_WAITDATA, sk); 173 remove_wait_queue(sk_sleep(sk), &wait); 174 return ret; 175 } 176 177 static bool is_next_msg_fin(struct sk_psock *psock) 178 { 179 struct scatterlist *sge; 180 struct sk_msg *msg_rx; 181 int i; 182 183 msg_rx = sk_psock_peek_msg(psock); 184 i = msg_rx->sg.start; 185 sge = sk_msg_elem(msg_rx, i); 186 if (!sge->length) { 187 struct sk_buff *skb = msg_rx->skb; 188 189 if (skb && TCP_SKB_CB(skb)->tcp_flags & TCPHDR_FIN) 190 return true; 191 } 192 return false; 193 } 194 195 static int tcp_bpf_recvmsg_parser(struct sock *sk, 196 struct msghdr *msg, 197 size_t len, 198 int flags, 199 int *addr_len) 200 { 201 struct sk_psock *psock; 202 int copied; 203 204 if (unlikely(flags & MSG_ERRQUEUE)) 205 return inet_recv_error(sk, msg, len, addr_len); 206 207 if (!len) 208 return 0; 209 210 psock = sk_psock_get(sk); 211 if (unlikely(!psock)) 212 return tcp_recvmsg(sk, msg, len, flags, addr_len); 213 214 lock_sock(sk); 215 216 /* We may have received data on the sk_receive_queue pre-accept and 217 * then we can not use read_skb in this context because we haven't 218 * assigned a sk_socket yet so have no link to the ops. The work-around 219 * is to check the sk_receive_queue and in these cases read skbs off 220 * queue again. The read_skb hook is not running at this point because 221 * of lock_sock so we avoid having multiple runners in read_skb. 222 */ 223 if (unlikely(!skb_queue_empty(&sk->sk_receive_queue))) { 224 tcp_data_ready(sk); 225 /* This handles the ENOMEM errors if we both receive data 226 * pre accept and are already under memory pressure. At least 227 * let user know to retry. 228 */ 229 if (unlikely(!skb_queue_empty(&sk->sk_receive_queue))) { 230 copied = -EAGAIN; 231 goto out; 232 } 233 } 234 235 msg_bytes_ready: 236 copied = sk_msg_recvmsg(sk, psock, msg, len, flags); 237 /* The typical case for EFAULT is the socket was gracefully 238 * shutdown with a FIN pkt. So check here the other case is 239 * some error on copy_page_to_iter which would be unexpected. 240 * On fin return correct return code to zero. 241 */ 242 if (copied == -EFAULT) { 243 bool is_fin = is_next_msg_fin(psock); 244 245 if (is_fin) { 246 copied = 0; 247 goto out; 248 } 249 } 250 if (!copied) { 251 long timeo; 252 int data; 253 254 if (sock_flag(sk, SOCK_DONE)) 255 goto out; 256 257 if (sk->sk_err) { 258 copied = sock_error(sk); 259 goto out; 260 } 261 262 if (sk->sk_shutdown & RCV_SHUTDOWN) 263 goto out; 264 265 if (sk->sk_state == TCP_CLOSE) { 266 copied = -ENOTCONN; 267 goto out; 268 } 269 270 timeo = sock_rcvtimeo(sk, flags & MSG_DONTWAIT); 271 if (!timeo) { 272 copied = -EAGAIN; 273 goto out; 274 } 275 276 if (signal_pending(current)) { 277 copied = sock_intr_errno(timeo); 278 goto out; 279 } 280 281 data = tcp_msg_wait_data(sk, psock, timeo); 282 if (data && !sk_psock_queue_empty(psock)) 283 goto msg_bytes_ready; 284 copied = -EAGAIN; 285 } 286 out: 287 release_sock(sk); 288 sk_psock_put(sk, psock); 289 return copied; 290 } 291 292 static int tcp_bpf_recvmsg(struct sock *sk, struct msghdr *msg, size_t len, 293 int flags, int *addr_len) 294 { 295 struct sk_psock *psock; 296 int copied, ret; 297 298 if (unlikely(flags & MSG_ERRQUEUE)) 299 return inet_recv_error(sk, msg, len, addr_len); 300 301 if (!len) 302 return 0; 303 304 psock = sk_psock_get(sk); 305 if (unlikely(!psock)) 306 return tcp_recvmsg(sk, msg, len, flags, addr_len); 307 if (!skb_queue_empty(&sk->sk_receive_queue) && 308 sk_psock_queue_empty(psock)) { 309 sk_psock_put(sk, psock); 310 return tcp_recvmsg(sk, msg, len, flags, addr_len); 311 } 312 lock_sock(sk); 313 msg_bytes_ready: 314 copied = sk_msg_recvmsg(sk, psock, msg, len, flags); 315 if (!copied) { 316 long timeo; 317 int data; 318 319 timeo = sock_rcvtimeo(sk, flags & MSG_DONTWAIT); 320 data = tcp_msg_wait_data(sk, psock, timeo); 321 if (data) { 322 if (!sk_psock_queue_empty(psock)) 323 goto msg_bytes_ready; 324 release_sock(sk); 325 sk_psock_put(sk, psock); 326 return tcp_recvmsg(sk, msg, len, flags, addr_len); 327 } 328 copied = -EAGAIN; 329 } 330 ret = copied; 331 release_sock(sk); 332 sk_psock_put(sk, psock); 333 return ret; 334 } 335 336 static int tcp_bpf_send_verdict(struct sock *sk, struct sk_psock *psock, 337 struct sk_msg *msg, int *copied, int flags) 338 { 339 bool cork = false, enospc = sk_msg_full(msg), redir_ingress; 340 struct sock *sk_redir; 341 u32 tosend, origsize, sent, delta = 0; 342 u32 eval; 343 int ret; 344 345 more_data: 346 if (psock->eval == __SK_NONE) { 347 /* Track delta in msg size to add/subtract it on SK_DROP from 348 * returned to user copied size. This ensures user doesn't 349 * get a positive return code with msg_cut_data and SK_DROP 350 * verdict. 351 */ 352 delta = msg->sg.size; 353 psock->eval = sk_psock_msg_verdict(sk, psock, msg); 354 delta -= msg->sg.size; 355 } 356 357 if (msg->cork_bytes && 358 msg->cork_bytes > msg->sg.size && !enospc) { 359 psock->cork_bytes = msg->cork_bytes - msg->sg.size; 360 if (!psock->cork) { 361 psock->cork = kzalloc(sizeof(*psock->cork), 362 GFP_ATOMIC | __GFP_NOWARN); 363 if (!psock->cork) 364 return -ENOMEM; 365 } 366 memcpy(psock->cork, msg, sizeof(*msg)); 367 return 0; 368 } 369 370 tosend = msg->sg.size; 371 if (psock->apply_bytes && psock->apply_bytes < tosend) 372 tosend = psock->apply_bytes; 373 eval = __SK_NONE; 374 375 switch (psock->eval) { 376 case __SK_PASS: 377 ret = tcp_bpf_push(sk, msg, tosend, flags, true); 378 if (unlikely(ret)) { 379 *copied -= sk_msg_free(sk, msg); 380 break; 381 } 382 sk_msg_apply_bytes(psock, tosend); 383 break; 384 case __SK_REDIRECT: 385 redir_ingress = psock->redir_ingress; 386 sk_redir = psock->sk_redir; 387 sk_msg_apply_bytes(psock, tosend); 388 if (!psock->apply_bytes) { 389 /* Clean up before releasing the sock lock. */ 390 eval = psock->eval; 391 psock->eval = __SK_NONE; 392 psock->sk_redir = NULL; 393 } 394 if (psock->cork) { 395 cork = true; 396 psock->cork = NULL; 397 } 398 sk_msg_return(sk, msg, tosend); 399 release_sock(sk); 400 401 origsize = msg->sg.size; 402 ret = tcp_bpf_sendmsg_redir(sk_redir, redir_ingress, 403 msg, tosend, flags); 404 sent = origsize - msg->sg.size; 405 406 if (eval == __SK_REDIRECT) 407 sock_put(sk_redir); 408 409 lock_sock(sk); 410 if (unlikely(ret < 0)) { 411 int free = sk_msg_free_nocharge(sk, msg); 412 413 if (!cork) 414 *copied -= free; 415 } 416 if (cork) { 417 sk_msg_free(sk, msg); 418 kfree(msg); 419 msg = NULL; 420 ret = 0; 421 } 422 break; 423 case __SK_DROP: 424 default: 425 sk_msg_free_partial(sk, msg, tosend); 426 sk_msg_apply_bytes(psock, tosend); 427 *copied -= (tosend + delta); 428 return -EACCES; 429 } 430 431 if (likely(!ret)) { 432 if (!psock->apply_bytes) { 433 psock->eval = __SK_NONE; 434 if (psock->sk_redir) { 435 sock_put(psock->sk_redir); 436 psock->sk_redir = NULL; 437 } 438 } 439 if (msg && 440 msg->sg.data[msg->sg.start].page_link && 441 msg->sg.data[msg->sg.start].length) { 442 if (eval == __SK_REDIRECT) 443 sk_mem_charge(sk, tosend - sent); 444 goto more_data; 445 } 446 } 447 return ret; 448 } 449 450 static int tcp_bpf_sendmsg(struct sock *sk, struct msghdr *msg, size_t size) 451 { 452 struct sk_msg tmp, *msg_tx = NULL; 453 int copied = 0, err = 0; 454 struct sk_psock *psock; 455 long timeo; 456 int flags; 457 458 /* Don't let internal do_tcp_sendpages() flags through */ 459 flags = (msg->msg_flags & ~MSG_SENDPAGE_DECRYPTED); 460 flags |= MSG_NO_SHARED_FRAGS; 461 462 psock = sk_psock_get(sk); 463 if (unlikely(!psock)) 464 return tcp_sendmsg(sk, msg, size); 465 466 lock_sock(sk); 467 timeo = sock_sndtimeo(sk, msg->msg_flags & MSG_DONTWAIT); 468 while (msg_data_left(msg)) { 469 bool enospc = false; 470 u32 copy, osize; 471 472 if (sk->sk_err) { 473 err = -sk->sk_err; 474 goto out_err; 475 } 476 477 copy = msg_data_left(msg); 478 if (!sk_stream_memory_free(sk)) 479 goto wait_for_sndbuf; 480 if (psock->cork) { 481 msg_tx = psock->cork; 482 } else { 483 msg_tx = &tmp; 484 sk_msg_init(msg_tx); 485 } 486 487 osize = msg_tx->sg.size; 488 err = sk_msg_alloc(sk, msg_tx, msg_tx->sg.size + copy, msg_tx->sg.end - 1); 489 if (err) { 490 if (err != -ENOSPC) 491 goto wait_for_memory; 492 enospc = true; 493 copy = msg_tx->sg.size - osize; 494 } 495 496 err = sk_msg_memcopy_from_iter(sk, &msg->msg_iter, msg_tx, 497 copy); 498 if (err < 0) { 499 sk_msg_trim(sk, msg_tx, osize); 500 goto out_err; 501 } 502 503 copied += copy; 504 if (psock->cork_bytes) { 505 if (size > psock->cork_bytes) 506 psock->cork_bytes = 0; 507 else 508 psock->cork_bytes -= size; 509 if (psock->cork_bytes && !enospc) 510 goto out_err; 511 /* All cork bytes are accounted, rerun the prog. */ 512 psock->eval = __SK_NONE; 513 psock->cork_bytes = 0; 514 } 515 516 err = tcp_bpf_send_verdict(sk, psock, msg_tx, &copied, flags); 517 if (unlikely(err < 0)) 518 goto out_err; 519 continue; 520 wait_for_sndbuf: 521 set_bit(SOCK_NOSPACE, &sk->sk_socket->flags); 522 wait_for_memory: 523 err = sk_stream_wait_memory(sk, &timeo); 524 if (err) { 525 if (msg_tx && msg_tx != psock->cork) 526 sk_msg_free(sk, msg_tx); 527 goto out_err; 528 } 529 } 530 out_err: 531 if (err < 0) 532 err = sk_stream_error(sk, msg->msg_flags, err); 533 release_sock(sk); 534 sk_psock_put(sk, psock); 535 return copied ? copied : err; 536 } 537 538 static int tcp_bpf_sendpage(struct sock *sk, struct page *page, int offset, 539 size_t size, int flags) 540 { 541 struct sk_msg tmp, *msg = NULL; 542 int err = 0, copied = 0; 543 struct sk_psock *psock; 544 bool enospc = false; 545 546 psock = sk_psock_get(sk); 547 if (unlikely(!psock)) 548 return tcp_sendpage(sk, page, offset, size, flags); 549 550 lock_sock(sk); 551 if (psock->cork) { 552 msg = psock->cork; 553 } else { 554 msg = &tmp; 555 sk_msg_init(msg); 556 } 557 558 /* Catch case where ring is full and sendpage is stalled. */ 559 if (unlikely(sk_msg_full(msg))) 560 goto out_err; 561 562 sk_msg_page_add(msg, page, size, offset); 563 sk_mem_charge(sk, size); 564 copied = size; 565 if (sk_msg_full(msg)) 566 enospc = true; 567 if (psock->cork_bytes) { 568 if (size > psock->cork_bytes) 569 psock->cork_bytes = 0; 570 else 571 psock->cork_bytes -= size; 572 if (psock->cork_bytes && !enospc) 573 goto out_err; 574 /* All cork bytes are accounted, rerun the prog. */ 575 psock->eval = __SK_NONE; 576 psock->cork_bytes = 0; 577 } 578 579 err = tcp_bpf_send_verdict(sk, psock, msg, &copied, flags); 580 out_err: 581 release_sock(sk); 582 sk_psock_put(sk, psock); 583 return copied ? copied : err; 584 } 585 586 enum { 587 TCP_BPF_IPV4, 588 TCP_BPF_IPV6, 589 TCP_BPF_NUM_PROTS, 590 }; 591 592 enum { 593 TCP_BPF_BASE, 594 TCP_BPF_TX, 595 TCP_BPF_RX, 596 TCP_BPF_TXRX, 597 TCP_BPF_NUM_CFGS, 598 }; 599 600 static struct proto *tcpv6_prot_saved __read_mostly; 601 static DEFINE_SPINLOCK(tcpv6_prot_lock); 602 static struct proto tcp_bpf_prots[TCP_BPF_NUM_PROTS][TCP_BPF_NUM_CFGS]; 603 604 static void tcp_bpf_rebuild_protos(struct proto prot[TCP_BPF_NUM_CFGS], 605 struct proto *base) 606 { 607 prot[TCP_BPF_BASE] = *base; 608 prot[TCP_BPF_BASE].destroy = sock_map_destroy; 609 prot[TCP_BPF_BASE].close = sock_map_close; 610 prot[TCP_BPF_BASE].recvmsg = tcp_bpf_recvmsg; 611 prot[TCP_BPF_BASE].sock_is_readable = sk_msg_is_readable; 612 613 prot[TCP_BPF_TX] = prot[TCP_BPF_BASE]; 614 prot[TCP_BPF_TX].sendmsg = tcp_bpf_sendmsg; 615 prot[TCP_BPF_TX].sendpage = tcp_bpf_sendpage; 616 617 prot[TCP_BPF_RX] = prot[TCP_BPF_BASE]; 618 prot[TCP_BPF_RX].recvmsg = tcp_bpf_recvmsg_parser; 619 620 prot[TCP_BPF_TXRX] = prot[TCP_BPF_TX]; 621 prot[TCP_BPF_TXRX].recvmsg = tcp_bpf_recvmsg_parser; 622 } 623 624 static void tcp_bpf_check_v6_needs_rebuild(struct proto *ops) 625 { 626 if (unlikely(ops != smp_load_acquire(&tcpv6_prot_saved))) { 627 spin_lock_bh(&tcpv6_prot_lock); 628 if (likely(ops != tcpv6_prot_saved)) { 629 tcp_bpf_rebuild_protos(tcp_bpf_prots[TCP_BPF_IPV6], ops); 630 smp_store_release(&tcpv6_prot_saved, ops); 631 } 632 spin_unlock_bh(&tcpv6_prot_lock); 633 } 634 } 635 636 static int __init tcp_bpf_v4_build_proto(void) 637 { 638 tcp_bpf_rebuild_protos(tcp_bpf_prots[TCP_BPF_IPV4], &tcp_prot); 639 return 0; 640 } 641 late_initcall(tcp_bpf_v4_build_proto); 642 643 static int tcp_bpf_assert_proto_ops(struct proto *ops) 644 { 645 /* In order to avoid retpoline, we make assumptions when we call 646 * into ops if e.g. a psock is not present. Make sure they are 647 * indeed valid assumptions. 648 */ 649 return ops->recvmsg == tcp_recvmsg && 650 ops->sendmsg == tcp_sendmsg && 651 ops->sendpage == tcp_sendpage ? 0 : -ENOTSUPP; 652 } 653 654 int tcp_bpf_update_proto(struct sock *sk, struct sk_psock *psock, bool restore) 655 { 656 int family = sk->sk_family == AF_INET6 ? TCP_BPF_IPV6 : TCP_BPF_IPV4; 657 int config = psock->progs.msg_parser ? TCP_BPF_TX : TCP_BPF_BASE; 658 659 if (psock->progs.stream_verdict || psock->progs.skb_verdict) { 660 config = (config == TCP_BPF_TX) ? TCP_BPF_TXRX : TCP_BPF_RX; 661 } 662 663 if (restore) { 664 if (inet_csk_has_ulp(sk)) { 665 /* TLS does not have an unhash proto in SW cases, 666 * but we need to ensure we stop using the sock_map 667 * unhash routine because the associated psock is being 668 * removed. So use the original unhash handler. 669 */ 670 WRITE_ONCE(sk->sk_prot->unhash, psock->saved_unhash); 671 tcp_update_ulp(sk, psock->sk_proto, psock->saved_write_space); 672 } else { 673 sk->sk_write_space = psock->saved_write_space; 674 /* Pairs with lockless read in sk_clone_lock() */ 675 sock_replace_proto(sk, psock->sk_proto); 676 } 677 return 0; 678 } 679 680 if (sk->sk_family == AF_INET6) { 681 if (tcp_bpf_assert_proto_ops(psock->sk_proto)) 682 return -EINVAL; 683 684 tcp_bpf_check_v6_needs_rebuild(psock->sk_proto); 685 } 686 687 /* Pairs with lockless read in sk_clone_lock() */ 688 sock_replace_proto(sk, &tcp_bpf_prots[family][config]); 689 return 0; 690 } 691 EXPORT_SYMBOL_GPL(tcp_bpf_update_proto); 692 693 /* If a child got cloned from a listening socket that had tcp_bpf 694 * protocol callbacks installed, we need to restore the callbacks to 695 * the default ones because the child does not inherit the psock state 696 * that tcp_bpf callbacks expect. 697 */ 698 void tcp_bpf_clone(const struct sock *sk, struct sock *newsk) 699 { 700 struct proto *prot = newsk->sk_prot; 701 702 if (is_insidevar(prot, tcp_bpf_prots)) 703 newsk->sk_prot = sk->sk_prot_creator; 704 } 705 #endif /* CONFIG_BPF_SYSCALL */ 706