1 // SPDX-License-Identifier: GPL-2.0 2 /* Copyright (c) 2017 - 2018 Covalent IO, Inc. http://covalent.io */ 3 4 #include <linux/skmsg.h> 5 #include <linux/filter.h> 6 #include <linux/bpf.h> 7 #include <linux/init.h> 8 #include <linux/wait.h> 9 #include <linux/util_macros.h> 10 11 #include <net/inet_common.h> 12 #include <net/tls.h> 13 14 static int bpf_tcp_ingress(struct sock *sk, struct sk_psock *psock, 15 struct sk_msg *msg, u32 apply_bytes, int flags) 16 { 17 bool apply = apply_bytes; 18 struct scatterlist *sge; 19 u32 size, copied = 0; 20 struct sk_msg *tmp; 21 int i, ret = 0; 22 23 tmp = kzalloc(sizeof(*tmp), __GFP_NOWARN | GFP_KERNEL); 24 if (unlikely(!tmp)) 25 return -ENOMEM; 26 27 lock_sock(sk); 28 tmp->sg.start = msg->sg.start; 29 i = msg->sg.start; 30 do { 31 sge = sk_msg_elem(msg, i); 32 size = (apply && apply_bytes < sge->length) ? 33 apply_bytes : sge->length; 34 if (!sk_wmem_schedule(sk, size)) { 35 if (!copied) 36 ret = -ENOMEM; 37 break; 38 } 39 40 sk_mem_charge(sk, size); 41 sk_msg_xfer(tmp, msg, i, size); 42 copied += size; 43 if (sge->length) 44 get_page(sk_msg_page(tmp, i)); 45 sk_msg_iter_var_next(i); 46 tmp->sg.end = i; 47 if (apply) { 48 apply_bytes -= size; 49 if (!apply_bytes) { 50 if (sge->length) 51 sk_msg_iter_var_prev(i); 52 break; 53 } 54 } 55 } while (i != msg->sg.end); 56 57 if (!ret) { 58 msg->sg.start = i; 59 sk_psock_queue_msg(psock, tmp); 60 sk_psock_data_ready(sk, psock); 61 } else { 62 sk_msg_free(sk, tmp); 63 kfree(tmp); 64 } 65 66 release_sock(sk); 67 return ret; 68 } 69 70 static int tcp_bpf_push(struct sock *sk, struct sk_msg *msg, u32 apply_bytes, 71 int flags, bool uncharge) 72 { 73 bool apply = apply_bytes; 74 struct scatterlist *sge; 75 struct page *page; 76 int size, ret = 0; 77 u32 off; 78 79 while (1) { 80 bool has_tx_ulp; 81 82 sge = sk_msg_elem(msg, msg->sg.start); 83 size = (apply && apply_bytes < sge->length) ? 84 apply_bytes : sge->length; 85 off = sge->offset; 86 page = sg_page(sge); 87 88 tcp_rate_check_app_limited(sk); 89 retry: 90 has_tx_ulp = tls_sw_has_ctx_tx(sk); 91 if (has_tx_ulp) { 92 flags |= MSG_SENDPAGE_NOPOLICY; 93 ret = kernel_sendpage_locked(sk, 94 page, off, size, flags); 95 } else { 96 ret = do_tcp_sendpages(sk, page, off, size, flags); 97 } 98 99 if (ret <= 0) 100 return ret; 101 if (apply) 102 apply_bytes -= ret; 103 msg->sg.size -= ret; 104 sge->offset += ret; 105 sge->length -= ret; 106 if (uncharge) 107 sk_mem_uncharge(sk, ret); 108 if (ret != size) { 109 size -= ret; 110 off += ret; 111 goto retry; 112 } 113 if (!sge->length) { 114 put_page(page); 115 sk_msg_iter_next(msg, start); 116 sg_init_table(sge, 1); 117 if (msg->sg.start == msg->sg.end) 118 break; 119 } 120 if (apply && !apply_bytes) 121 break; 122 } 123 124 return 0; 125 } 126 127 static int tcp_bpf_push_locked(struct sock *sk, struct sk_msg *msg, 128 u32 apply_bytes, int flags, bool uncharge) 129 { 130 int ret; 131 132 lock_sock(sk); 133 ret = tcp_bpf_push(sk, msg, apply_bytes, flags, uncharge); 134 release_sock(sk); 135 return ret; 136 } 137 138 int tcp_bpf_sendmsg_redir(struct sock *sk, bool ingress, 139 struct sk_msg *msg, u32 bytes, int flags) 140 { 141 struct sk_psock *psock = sk_psock_get(sk); 142 int ret; 143 144 if (unlikely(!psock)) 145 return -EPIPE; 146 147 ret = ingress ? bpf_tcp_ingress(sk, psock, msg, bytes, flags) : 148 tcp_bpf_push_locked(sk, msg, bytes, flags, false); 149 sk_psock_put(sk, psock); 150 return ret; 151 } 152 EXPORT_SYMBOL_GPL(tcp_bpf_sendmsg_redir); 153 154 #ifdef CONFIG_BPF_SYSCALL 155 static int tcp_msg_wait_data(struct sock *sk, struct sk_psock *psock, 156 long timeo) 157 { 158 DEFINE_WAIT_FUNC(wait, woken_wake_function); 159 int ret = 0; 160 161 if (sk->sk_shutdown & RCV_SHUTDOWN) 162 return 1; 163 164 if (!timeo) 165 return ret; 166 167 add_wait_queue(sk_sleep(sk), &wait); 168 sk_set_bit(SOCKWQ_ASYNC_WAITDATA, sk); 169 ret = sk_wait_event(sk, &timeo, 170 !list_empty(&psock->ingress_msg) || 171 !skb_queue_empty(&sk->sk_receive_queue), &wait); 172 sk_clear_bit(SOCKWQ_ASYNC_WAITDATA, sk); 173 remove_wait_queue(sk_sleep(sk), &wait); 174 return ret; 175 } 176 177 static int tcp_bpf_recvmsg_parser(struct sock *sk, 178 struct msghdr *msg, 179 size_t len, 180 int flags, 181 int *addr_len) 182 { 183 struct sk_psock *psock; 184 int copied; 185 186 if (unlikely(flags & MSG_ERRQUEUE)) 187 return inet_recv_error(sk, msg, len, addr_len); 188 189 psock = sk_psock_get(sk); 190 if (unlikely(!psock)) 191 return tcp_recvmsg(sk, msg, len, flags, addr_len); 192 193 lock_sock(sk); 194 msg_bytes_ready: 195 copied = sk_msg_recvmsg(sk, psock, msg, len, flags); 196 if (!copied) { 197 long timeo; 198 int data; 199 200 if (sock_flag(sk, SOCK_DONE)) 201 goto out; 202 203 if (sk->sk_err) { 204 copied = sock_error(sk); 205 goto out; 206 } 207 208 if (sk->sk_shutdown & RCV_SHUTDOWN) 209 goto out; 210 211 if (sk->sk_state == TCP_CLOSE) { 212 copied = -ENOTCONN; 213 goto out; 214 } 215 216 timeo = sock_rcvtimeo(sk, flags & MSG_DONTWAIT); 217 if (!timeo) { 218 copied = -EAGAIN; 219 goto out; 220 } 221 222 if (signal_pending(current)) { 223 copied = sock_intr_errno(timeo); 224 goto out; 225 } 226 227 data = tcp_msg_wait_data(sk, psock, timeo); 228 if (data && !sk_psock_queue_empty(psock)) 229 goto msg_bytes_ready; 230 copied = -EAGAIN; 231 } 232 out: 233 release_sock(sk); 234 sk_psock_put(sk, psock); 235 return copied; 236 } 237 238 static int tcp_bpf_recvmsg(struct sock *sk, struct msghdr *msg, size_t len, 239 int flags, int *addr_len) 240 { 241 struct sk_psock *psock; 242 int copied, ret; 243 244 if (unlikely(flags & MSG_ERRQUEUE)) 245 return inet_recv_error(sk, msg, len, addr_len); 246 247 psock = sk_psock_get(sk); 248 if (unlikely(!psock)) 249 return tcp_recvmsg(sk, msg, len, flags, addr_len); 250 if (!skb_queue_empty(&sk->sk_receive_queue) && 251 sk_psock_queue_empty(psock)) { 252 sk_psock_put(sk, psock); 253 return tcp_recvmsg(sk, msg, len, flags, addr_len); 254 } 255 lock_sock(sk); 256 msg_bytes_ready: 257 copied = sk_msg_recvmsg(sk, psock, msg, len, flags); 258 if (!copied) { 259 long timeo; 260 int data; 261 262 timeo = sock_rcvtimeo(sk, flags & MSG_DONTWAIT); 263 data = tcp_msg_wait_data(sk, psock, timeo); 264 if (data) { 265 if (!sk_psock_queue_empty(psock)) 266 goto msg_bytes_ready; 267 release_sock(sk); 268 sk_psock_put(sk, psock); 269 return tcp_recvmsg(sk, msg, len, flags, addr_len); 270 } 271 copied = -EAGAIN; 272 } 273 ret = copied; 274 release_sock(sk); 275 sk_psock_put(sk, psock); 276 return ret; 277 } 278 279 static int tcp_bpf_send_verdict(struct sock *sk, struct sk_psock *psock, 280 struct sk_msg *msg, int *copied, int flags) 281 { 282 bool cork = false, enospc = sk_msg_full(msg), redir_ingress; 283 struct sock *sk_redir; 284 u32 tosend, origsize, sent, delta = 0; 285 u32 eval; 286 int ret; 287 288 more_data: 289 if (psock->eval == __SK_NONE) { 290 /* Track delta in msg size to add/subtract it on SK_DROP from 291 * returned to user copied size. This ensures user doesn't 292 * get a positive return code with msg_cut_data and SK_DROP 293 * verdict. 294 */ 295 delta = msg->sg.size; 296 psock->eval = sk_psock_msg_verdict(sk, psock, msg); 297 delta -= msg->sg.size; 298 } 299 300 if (msg->cork_bytes && 301 msg->cork_bytes > msg->sg.size && !enospc) { 302 psock->cork_bytes = msg->cork_bytes - msg->sg.size; 303 if (!psock->cork) { 304 psock->cork = kzalloc(sizeof(*psock->cork), 305 GFP_ATOMIC | __GFP_NOWARN); 306 if (!psock->cork) 307 return -ENOMEM; 308 } 309 memcpy(psock->cork, msg, sizeof(*msg)); 310 return 0; 311 } 312 313 tosend = msg->sg.size; 314 if (psock->apply_bytes && psock->apply_bytes < tosend) 315 tosend = psock->apply_bytes; 316 eval = __SK_NONE; 317 318 switch (psock->eval) { 319 case __SK_PASS: 320 ret = tcp_bpf_push(sk, msg, tosend, flags, true); 321 if (unlikely(ret)) { 322 *copied -= sk_msg_free(sk, msg); 323 break; 324 } 325 sk_msg_apply_bytes(psock, tosend); 326 break; 327 case __SK_REDIRECT: 328 redir_ingress = psock->redir_ingress; 329 sk_redir = psock->sk_redir; 330 sk_msg_apply_bytes(psock, tosend); 331 if (!psock->apply_bytes) { 332 /* Clean up before releasing the sock lock. */ 333 eval = psock->eval; 334 psock->eval = __SK_NONE; 335 psock->sk_redir = NULL; 336 } 337 if (psock->cork) { 338 cork = true; 339 psock->cork = NULL; 340 } 341 sk_msg_return(sk, msg, tosend); 342 release_sock(sk); 343 344 origsize = msg->sg.size; 345 ret = tcp_bpf_sendmsg_redir(sk_redir, redir_ingress, 346 msg, tosend, flags); 347 sent = origsize - msg->sg.size; 348 349 if (eval == __SK_REDIRECT) 350 sock_put(sk_redir); 351 352 lock_sock(sk); 353 if (unlikely(ret < 0)) { 354 int free = sk_msg_free_nocharge(sk, msg); 355 356 if (!cork) 357 *copied -= free; 358 } 359 if (cork) { 360 sk_msg_free(sk, msg); 361 kfree(msg); 362 msg = NULL; 363 ret = 0; 364 } 365 break; 366 case __SK_DROP: 367 default: 368 sk_msg_free_partial(sk, msg, tosend); 369 sk_msg_apply_bytes(psock, tosend); 370 *copied -= (tosend + delta); 371 return -EACCES; 372 } 373 374 if (likely(!ret)) { 375 if (!psock->apply_bytes) { 376 psock->eval = __SK_NONE; 377 if (psock->sk_redir) { 378 sock_put(psock->sk_redir); 379 psock->sk_redir = NULL; 380 } 381 } 382 if (msg && 383 msg->sg.data[msg->sg.start].page_link && 384 msg->sg.data[msg->sg.start].length) { 385 if (eval == __SK_REDIRECT) 386 sk_mem_charge(sk, tosend - sent); 387 goto more_data; 388 } 389 } 390 return ret; 391 } 392 393 static int tcp_bpf_sendmsg(struct sock *sk, struct msghdr *msg, size_t size) 394 { 395 struct sk_msg tmp, *msg_tx = NULL; 396 int copied = 0, err = 0; 397 struct sk_psock *psock; 398 long timeo; 399 int flags; 400 401 /* Don't let internal do_tcp_sendpages() flags through */ 402 flags = (msg->msg_flags & ~MSG_SENDPAGE_DECRYPTED); 403 flags |= MSG_NO_SHARED_FRAGS; 404 405 psock = sk_psock_get(sk); 406 if (unlikely(!psock)) 407 return tcp_sendmsg(sk, msg, size); 408 409 lock_sock(sk); 410 timeo = sock_sndtimeo(sk, msg->msg_flags & MSG_DONTWAIT); 411 while (msg_data_left(msg)) { 412 bool enospc = false; 413 u32 copy, osize; 414 415 if (sk->sk_err) { 416 err = -sk->sk_err; 417 goto out_err; 418 } 419 420 copy = msg_data_left(msg); 421 if (!sk_stream_memory_free(sk)) 422 goto wait_for_sndbuf; 423 if (psock->cork) { 424 msg_tx = psock->cork; 425 } else { 426 msg_tx = &tmp; 427 sk_msg_init(msg_tx); 428 } 429 430 osize = msg_tx->sg.size; 431 err = sk_msg_alloc(sk, msg_tx, msg_tx->sg.size + copy, msg_tx->sg.end - 1); 432 if (err) { 433 if (err != -ENOSPC) 434 goto wait_for_memory; 435 enospc = true; 436 copy = msg_tx->sg.size - osize; 437 } 438 439 err = sk_msg_memcopy_from_iter(sk, &msg->msg_iter, msg_tx, 440 copy); 441 if (err < 0) { 442 sk_msg_trim(sk, msg_tx, osize); 443 goto out_err; 444 } 445 446 copied += copy; 447 if (psock->cork_bytes) { 448 if (size > psock->cork_bytes) 449 psock->cork_bytes = 0; 450 else 451 psock->cork_bytes -= size; 452 if (psock->cork_bytes && !enospc) 453 goto out_err; 454 /* All cork bytes are accounted, rerun the prog. */ 455 psock->eval = __SK_NONE; 456 psock->cork_bytes = 0; 457 } 458 459 err = tcp_bpf_send_verdict(sk, psock, msg_tx, &copied, flags); 460 if (unlikely(err < 0)) 461 goto out_err; 462 continue; 463 wait_for_sndbuf: 464 set_bit(SOCK_NOSPACE, &sk->sk_socket->flags); 465 wait_for_memory: 466 err = sk_stream_wait_memory(sk, &timeo); 467 if (err) { 468 if (msg_tx && msg_tx != psock->cork) 469 sk_msg_free(sk, msg_tx); 470 goto out_err; 471 } 472 } 473 out_err: 474 if (err < 0) 475 err = sk_stream_error(sk, msg->msg_flags, err); 476 release_sock(sk); 477 sk_psock_put(sk, psock); 478 return copied ? copied : err; 479 } 480 481 static int tcp_bpf_sendpage(struct sock *sk, struct page *page, int offset, 482 size_t size, int flags) 483 { 484 struct sk_msg tmp, *msg = NULL; 485 int err = 0, copied = 0; 486 struct sk_psock *psock; 487 bool enospc = false; 488 489 psock = sk_psock_get(sk); 490 if (unlikely(!psock)) 491 return tcp_sendpage(sk, page, offset, size, flags); 492 493 lock_sock(sk); 494 if (psock->cork) { 495 msg = psock->cork; 496 } else { 497 msg = &tmp; 498 sk_msg_init(msg); 499 } 500 501 /* Catch case where ring is full and sendpage is stalled. */ 502 if (unlikely(sk_msg_full(msg))) 503 goto out_err; 504 505 sk_msg_page_add(msg, page, size, offset); 506 sk_mem_charge(sk, size); 507 copied = size; 508 if (sk_msg_full(msg)) 509 enospc = true; 510 if (psock->cork_bytes) { 511 if (size > psock->cork_bytes) 512 psock->cork_bytes = 0; 513 else 514 psock->cork_bytes -= size; 515 if (psock->cork_bytes && !enospc) 516 goto out_err; 517 /* All cork bytes are accounted, rerun the prog. */ 518 psock->eval = __SK_NONE; 519 psock->cork_bytes = 0; 520 } 521 522 err = tcp_bpf_send_verdict(sk, psock, msg, &copied, flags); 523 out_err: 524 release_sock(sk); 525 sk_psock_put(sk, psock); 526 return copied ? copied : err; 527 } 528 529 enum { 530 TCP_BPF_IPV4, 531 TCP_BPF_IPV6, 532 TCP_BPF_NUM_PROTS, 533 }; 534 535 enum { 536 TCP_BPF_BASE, 537 TCP_BPF_TX, 538 TCP_BPF_RX, 539 TCP_BPF_TXRX, 540 TCP_BPF_NUM_CFGS, 541 }; 542 543 static struct proto *tcpv6_prot_saved __read_mostly; 544 static DEFINE_SPINLOCK(tcpv6_prot_lock); 545 static struct proto tcp_bpf_prots[TCP_BPF_NUM_PROTS][TCP_BPF_NUM_CFGS]; 546 547 static void tcp_bpf_rebuild_protos(struct proto prot[TCP_BPF_NUM_CFGS], 548 struct proto *base) 549 { 550 prot[TCP_BPF_BASE] = *base; 551 prot[TCP_BPF_BASE].destroy = sock_map_destroy; 552 prot[TCP_BPF_BASE].close = sock_map_close; 553 prot[TCP_BPF_BASE].recvmsg = tcp_bpf_recvmsg; 554 prot[TCP_BPF_BASE].sock_is_readable = sk_msg_is_readable; 555 556 prot[TCP_BPF_TX] = prot[TCP_BPF_BASE]; 557 prot[TCP_BPF_TX].sendmsg = tcp_bpf_sendmsg; 558 prot[TCP_BPF_TX].sendpage = tcp_bpf_sendpage; 559 560 prot[TCP_BPF_RX] = prot[TCP_BPF_BASE]; 561 prot[TCP_BPF_RX].recvmsg = tcp_bpf_recvmsg_parser; 562 563 prot[TCP_BPF_TXRX] = prot[TCP_BPF_TX]; 564 prot[TCP_BPF_TXRX].recvmsg = tcp_bpf_recvmsg_parser; 565 } 566 567 static void tcp_bpf_check_v6_needs_rebuild(struct proto *ops) 568 { 569 if (unlikely(ops != smp_load_acquire(&tcpv6_prot_saved))) { 570 spin_lock_bh(&tcpv6_prot_lock); 571 if (likely(ops != tcpv6_prot_saved)) { 572 tcp_bpf_rebuild_protos(tcp_bpf_prots[TCP_BPF_IPV6], ops); 573 smp_store_release(&tcpv6_prot_saved, ops); 574 } 575 spin_unlock_bh(&tcpv6_prot_lock); 576 } 577 } 578 579 static int __init tcp_bpf_v4_build_proto(void) 580 { 581 tcp_bpf_rebuild_protos(tcp_bpf_prots[TCP_BPF_IPV4], &tcp_prot); 582 return 0; 583 } 584 late_initcall(tcp_bpf_v4_build_proto); 585 586 static int tcp_bpf_assert_proto_ops(struct proto *ops) 587 { 588 /* In order to avoid retpoline, we make assumptions when we call 589 * into ops if e.g. a psock is not present. Make sure they are 590 * indeed valid assumptions. 591 */ 592 return ops->recvmsg == tcp_recvmsg && 593 ops->sendmsg == tcp_sendmsg && 594 ops->sendpage == tcp_sendpage ? 0 : -ENOTSUPP; 595 } 596 597 int tcp_bpf_update_proto(struct sock *sk, struct sk_psock *psock, bool restore) 598 { 599 int family = sk->sk_family == AF_INET6 ? TCP_BPF_IPV6 : TCP_BPF_IPV4; 600 int config = psock->progs.msg_parser ? TCP_BPF_TX : TCP_BPF_BASE; 601 602 if (psock->progs.stream_verdict || psock->progs.skb_verdict) { 603 config = (config == TCP_BPF_TX) ? TCP_BPF_TXRX : TCP_BPF_RX; 604 } 605 606 if (restore) { 607 if (inet_csk_has_ulp(sk)) { 608 /* TLS does not have an unhash proto in SW cases, 609 * but we need to ensure we stop using the sock_map 610 * unhash routine because the associated psock is being 611 * removed. So use the original unhash handler. 612 */ 613 WRITE_ONCE(sk->sk_prot->unhash, psock->saved_unhash); 614 tcp_update_ulp(sk, psock->sk_proto, psock->saved_write_space); 615 } else { 616 sk->sk_write_space = psock->saved_write_space; 617 /* Pairs with lockless read in sk_clone_lock() */ 618 sock_replace_proto(sk, psock->sk_proto); 619 } 620 return 0; 621 } 622 623 if (sk->sk_family == AF_INET6) { 624 if (tcp_bpf_assert_proto_ops(psock->sk_proto)) 625 return -EINVAL; 626 627 tcp_bpf_check_v6_needs_rebuild(psock->sk_proto); 628 } 629 630 /* Pairs with lockless read in sk_clone_lock() */ 631 sock_replace_proto(sk, &tcp_bpf_prots[family][config]); 632 return 0; 633 } 634 EXPORT_SYMBOL_GPL(tcp_bpf_update_proto); 635 636 /* If a child got cloned from a listening socket that had tcp_bpf 637 * protocol callbacks installed, we need to restore the callbacks to 638 * the default ones because the child does not inherit the psock state 639 * that tcp_bpf callbacks expect. 640 */ 641 void tcp_bpf_clone(const struct sock *sk, struct sock *newsk) 642 { 643 struct proto *prot = newsk->sk_prot; 644 645 if (is_insidevar(prot, tcp_bpf_prots)) 646 newsk->sk_prot = sk->sk_prot_creator; 647 } 648 #endif /* CONFIG_BPF_SYSCALL */ 649