1 // SPDX-License-Identifier: GPL-2.0 2 /* Multipath TCP 3 * 4 * Copyright (c) 2017 - 2019, Intel Corporation. 5 */ 6 7 #define pr_fmt(fmt) "MPTCP: " fmt 8 9 #include <linux/kernel.h> 10 #include <linux/module.h> 11 #include <linux/netdevice.h> 12 #include <linux/sched/signal.h> 13 #include <linux/atomic.h> 14 #include <net/sock.h> 15 #include <net/inet_common.h> 16 #include <net/inet_hashtables.h> 17 #include <net/protocol.h> 18 #include <net/tcp.h> 19 #if IS_ENABLED(CONFIG_MPTCP_IPV6) 20 #include <net/transp_v6.h> 21 #endif 22 #include <net/mptcp.h> 23 #include "protocol.h" 24 25 #define MPTCP_SAME_STATE TCP_MAX_STATES 26 27 /* If msk has an initial subflow socket, and the MP_CAPABLE handshake has not 28 * completed yet or has failed, return the subflow socket. 29 * Otherwise return NULL. 30 */ 31 static struct socket *__mptcp_nmpc_socket(const struct mptcp_sock *msk) 32 { 33 if (!msk->subflow || READ_ONCE(msk->can_ack)) 34 return NULL; 35 36 return msk->subflow; 37 } 38 39 static bool __mptcp_needs_tcp_fallback(const struct mptcp_sock *msk) 40 { 41 return msk->first && !sk_is_mptcp(msk->first); 42 } 43 44 static struct socket *__mptcp_tcp_fallback(struct mptcp_sock *msk) 45 { 46 sock_owned_by_me((const struct sock *)msk); 47 48 if (likely(!__mptcp_needs_tcp_fallback(msk))) 49 return NULL; 50 51 if (msk->subflow) { 52 release_sock((struct sock *)msk); 53 return msk->subflow; 54 } 55 56 return NULL; 57 } 58 59 static bool __mptcp_can_create_subflow(const struct mptcp_sock *msk) 60 { 61 return !msk->first; 62 } 63 64 static struct socket *__mptcp_socket_create(struct mptcp_sock *msk, int state) 65 { 66 struct mptcp_subflow_context *subflow; 67 struct sock *sk = (struct sock *)msk; 68 struct socket *ssock; 69 int err; 70 71 ssock = __mptcp_nmpc_socket(msk); 72 if (ssock) 73 goto set_state; 74 75 if (!__mptcp_can_create_subflow(msk)) 76 return ERR_PTR(-EINVAL); 77 78 err = mptcp_subflow_create_socket(sk, &ssock); 79 if (err) 80 return ERR_PTR(err); 81 82 msk->first = ssock->sk; 83 msk->subflow = ssock; 84 subflow = mptcp_subflow_ctx(ssock->sk); 85 list_add(&subflow->node, &msk->conn_list); 86 subflow->request_mptcp = 1; 87 88 set_state: 89 if (state != MPTCP_SAME_STATE) 90 inet_sk_state_store(sk, state); 91 return ssock; 92 } 93 94 static struct sock *mptcp_subflow_get(const struct mptcp_sock *msk) 95 { 96 struct mptcp_subflow_context *subflow; 97 98 sock_owned_by_me((const struct sock *)msk); 99 100 mptcp_for_each_subflow(msk, subflow) { 101 return mptcp_subflow_tcp_sock(subflow); 102 } 103 104 return NULL; 105 } 106 107 static bool mptcp_ext_cache_refill(struct mptcp_sock *msk) 108 { 109 if (!msk->cached_ext) 110 msk->cached_ext = __skb_ext_alloc(); 111 112 return !!msk->cached_ext; 113 } 114 115 static struct sock *mptcp_subflow_recv_lookup(const struct mptcp_sock *msk) 116 { 117 struct mptcp_subflow_context *subflow; 118 struct sock *sk = (struct sock *)msk; 119 120 sock_owned_by_me(sk); 121 122 mptcp_for_each_subflow(msk, subflow) { 123 if (subflow->data_avail) 124 return mptcp_subflow_tcp_sock(subflow); 125 } 126 127 return NULL; 128 } 129 130 static inline bool mptcp_skb_can_collapse_to(const struct mptcp_sock *msk, 131 const struct sk_buff *skb, 132 const struct mptcp_ext *mpext) 133 { 134 if (!tcp_skb_can_collapse_to(skb)) 135 return false; 136 137 /* can collapse only if MPTCP level sequence is in order */ 138 return mpext && mpext->data_seq + mpext->data_len == msk->write_seq; 139 } 140 141 static int mptcp_sendmsg_frag(struct sock *sk, struct sock *ssk, 142 struct msghdr *msg, long *timeo, int *pmss_now, 143 int *ps_goal) 144 { 145 int mss_now, avail_size, size_goal, ret; 146 struct mptcp_sock *msk = mptcp_sk(sk); 147 struct mptcp_ext *mpext = NULL; 148 struct sk_buff *skb, *tail; 149 bool can_collapse = false; 150 struct page_frag *pfrag; 151 size_t psize; 152 153 /* use the mptcp page cache so that we can easily move the data 154 * from one substream to another, but do per subflow memory accounting 155 */ 156 pfrag = sk_page_frag(sk); 157 while (!sk_page_frag_refill(ssk, pfrag) || 158 !mptcp_ext_cache_refill(msk)) { 159 ret = sk_stream_wait_memory(ssk, timeo); 160 if (ret) 161 return ret; 162 if (unlikely(__mptcp_needs_tcp_fallback(msk))) 163 return 0; 164 } 165 166 /* compute copy limit */ 167 mss_now = tcp_send_mss(ssk, &size_goal, msg->msg_flags); 168 *pmss_now = mss_now; 169 *ps_goal = size_goal; 170 avail_size = size_goal; 171 skb = tcp_write_queue_tail(ssk); 172 if (skb) { 173 mpext = skb_ext_find(skb, SKB_EXT_MPTCP); 174 175 /* Limit the write to the size available in the 176 * current skb, if any, so that we create at most a new skb. 177 * Explicitly tells TCP internals to avoid collapsing on later 178 * queue management operation, to avoid breaking the ext <-> 179 * SSN association set here 180 */ 181 can_collapse = (size_goal - skb->len > 0) && 182 mptcp_skb_can_collapse_to(msk, skb, mpext); 183 if (!can_collapse) 184 TCP_SKB_CB(skb)->eor = 1; 185 else 186 avail_size = size_goal - skb->len; 187 } 188 psize = min_t(size_t, pfrag->size - pfrag->offset, avail_size); 189 190 /* Copy to page */ 191 pr_debug("left=%zu", msg_data_left(msg)); 192 psize = copy_page_from_iter(pfrag->page, pfrag->offset, 193 min_t(size_t, msg_data_left(msg), psize), 194 &msg->msg_iter); 195 pr_debug("left=%zu", msg_data_left(msg)); 196 if (!psize) 197 return -EINVAL; 198 199 /* tell the TCP stack to delay the push so that we can safely 200 * access the skb after the sendpages call 201 */ 202 ret = do_tcp_sendpages(ssk, pfrag->page, pfrag->offset, psize, 203 msg->msg_flags | MSG_SENDPAGE_NOTLAST); 204 if (ret <= 0) 205 return ret; 206 if (unlikely(ret < psize)) 207 iov_iter_revert(&msg->msg_iter, psize - ret); 208 209 /* if the tail skb extension is still the cached one, collapsing 210 * really happened. Note: we can't check for 'same skb' as the sk_buff 211 * hdr on tail can be transmitted, freed and re-allocated by the 212 * do_tcp_sendpages() call 213 */ 214 tail = tcp_write_queue_tail(ssk); 215 if (mpext && tail && mpext == skb_ext_find(tail, SKB_EXT_MPTCP)) { 216 WARN_ON_ONCE(!can_collapse); 217 mpext->data_len += ret; 218 goto out; 219 } 220 221 skb = tcp_write_queue_tail(ssk); 222 mpext = __skb_ext_set(skb, SKB_EXT_MPTCP, msk->cached_ext); 223 msk->cached_ext = NULL; 224 225 memset(mpext, 0, sizeof(*mpext)); 226 mpext->data_seq = msk->write_seq; 227 mpext->subflow_seq = mptcp_subflow_ctx(ssk)->rel_write_seq; 228 mpext->data_len = ret; 229 mpext->use_map = 1; 230 mpext->dsn64 = 1; 231 232 pr_debug("data_seq=%llu subflow_seq=%u data_len=%u dsn64=%d", 233 mpext->data_seq, mpext->subflow_seq, mpext->data_len, 234 mpext->dsn64); 235 236 out: 237 pfrag->offset += ret; 238 msk->write_seq += ret; 239 mptcp_subflow_ctx(ssk)->rel_write_seq += ret; 240 241 return ret; 242 } 243 244 static void ssk_check_wmem(struct mptcp_sock *msk, struct sock *ssk) 245 { 246 struct socket *sock; 247 248 if (likely(sk_stream_is_writeable(ssk))) 249 return; 250 251 sock = READ_ONCE(ssk->sk_socket); 252 253 if (sock) { 254 clear_bit(MPTCP_SEND_SPACE, &msk->flags); 255 smp_mb__after_atomic(); 256 /* set NOSPACE only after clearing SEND_SPACE flag */ 257 set_bit(SOCK_NOSPACE, &sock->flags); 258 } 259 } 260 261 static int mptcp_sendmsg(struct sock *sk, struct msghdr *msg, size_t len) 262 { 263 int mss_now = 0, size_goal = 0, ret = 0; 264 struct mptcp_sock *msk = mptcp_sk(sk); 265 struct socket *ssock; 266 size_t copied = 0; 267 struct sock *ssk; 268 long timeo; 269 270 if (msg->msg_flags & ~(MSG_MORE | MSG_DONTWAIT | MSG_NOSIGNAL)) 271 return -EOPNOTSUPP; 272 273 lock_sock(sk); 274 ssock = __mptcp_tcp_fallback(msk); 275 if (unlikely(ssock)) { 276 fallback: 277 pr_debug("fallback passthrough"); 278 ret = sock_sendmsg(ssock, msg); 279 return ret >= 0 ? ret + copied : (copied ? copied : ret); 280 } 281 282 timeo = sock_sndtimeo(sk, msg->msg_flags & MSG_DONTWAIT); 283 284 ssk = mptcp_subflow_get(msk); 285 if (!ssk) { 286 release_sock(sk); 287 return -ENOTCONN; 288 } 289 290 pr_debug("conn_list->subflow=%p", ssk); 291 292 lock_sock(ssk); 293 while (msg_data_left(msg)) { 294 ret = mptcp_sendmsg_frag(sk, ssk, msg, &timeo, &mss_now, 295 &size_goal); 296 if (ret < 0) 297 break; 298 if (ret == 0 && unlikely(__mptcp_needs_tcp_fallback(msk))) { 299 release_sock(ssk); 300 ssock = __mptcp_tcp_fallback(msk); 301 goto fallback; 302 } 303 304 copied += ret; 305 } 306 307 if (copied) { 308 ret = copied; 309 tcp_push(ssk, msg->msg_flags, mss_now, tcp_sk(ssk)->nonagle, 310 size_goal); 311 } 312 313 ssk_check_wmem(msk, ssk); 314 release_sock(ssk); 315 release_sock(sk); 316 return ret; 317 } 318 319 int mptcp_read_actor(read_descriptor_t *desc, struct sk_buff *skb, 320 unsigned int offset, size_t len) 321 { 322 struct mptcp_read_arg *arg = desc->arg.data; 323 size_t copy_len; 324 325 copy_len = min(desc->count, len); 326 327 if (likely(arg->msg)) { 328 int err; 329 330 err = skb_copy_datagram_msg(skb, offset, arg->msg, copy_len); 331 if (err) { 332 pr_debug("error path"); 333 desc->error = err; 334 return err; 335 } 336 } else { 337 pr_debug("Flushing skb payload"); 338 } 339 340 desc->count -= copy_len; 341 342 pr_debug("consumed %zu bytes, %zu left", copy_len, desc->count); 343 return copy_len; 344 } 345 346 static void mptcp_wait_data(struct sock *sk, long *timeo) 347 { 348 DEFINE_WAIT_FUNC(wait, woken_wake_function); 349 struct mptcp_sock *msk = mptcp_sk(sk); 350 351 add_wait_queue(sk_sleep(sk), &wait); 352 sk_set_bit(SOCKWQ_ASYNC_WAITDATA, sk); 353 354 sk_wait_event(sk, timeo, 355 test_and_clear_bit(MPTCP_DATA_READY, &msk->flags), &wait); 356 357 sk_clear_bit(SOCKWQ_ASYNC_WAITDATA, sk); 358 remove_wait_queue(sk_sleep(sk), &wait); 359 } 360 361 static int mptcp_recvmsg(struct sock *sk, struct msghdr *msg, size_t len, 362 int nonblock, int flags, int *addr_len) 363 { 364 struct mptcp_sock *msk = mptcp_sk(sk); 365 struct mptcp_subflow_context *subflow; 366 bool more_data_avail = false; 367 struct mptcp_read_arg arg; 368 read_descriptor_t desc; 369 bool wait_data = false; 370 struct socket *ssock; 371 struct tcp_sock *tp; 372 bool done = false; 373 struct sock *ssk; 374 int copied = 0; 375 int target; 376 long timeo; 377 378 if (msg->msg_flags & ~(MSG_WAITALL | MSG_DONTWAIT)) 379 return -EOPNOTSUPP; 380 381 lock_sock(sk); 382 ssock = __mptcp_tcp_fallback(msk); 383 if (unlikely(ssock)) { 384 fallback: 385 pr_debug("fallback-read subflow=%p", 386 mptcp_subflow_ctx(ssock->sk)); 387 copied = sock_recvmsg(ssock, msg, flags); 388 return copied; 389 } 390 391 arg.msg = msg; 392 desc.arg.data = &arg; 393 desc.error = 0; 394 395 timeo = sock_rcvtimeo(sk, nonblock); 396 397 len = min_t(size_t, len, INT_MAX); 398 target = sock_rcvlowat(sk, flags & MSG_WAITALL, len); 399 400 while (!done) { 401 u32 map_remaining; 402 int bytes_read; 403 404 ssk = mptcp_subflow_recv_lookup(msk); 405 pr_debug("msk=%p ssk=%p", msk, ssk); 406 if (!ssk) 407 goto wait_for_data; 408 409 subflow = mptcp_subflow_ctx(ssk); 410 tp = tcp_sk(ssk); 411 412 lock_sock(ssk); 413 do { 414 /* try to read as much data as available */ 415 map_remaining = subflow->map_data_len - 416 mptcp_subflow_get_map_offset(subflow); 417 desc.count = min_t(size_t, len - copied, map_remaining); 418 pr_debug("reading %zu bytes, copied %d", desc.count, 419 copied); 420 bytes_read = tcp_read_sock(ssk, &desc, 421 mptcp_read_actor); 422 if (bytes_read < 0) { 423 if (!copied) 424 copied = bytes_read; 425 done = true; 426 goto next; 427 } 428 429 pr_debug("msk ack_seq=%llx -> %llx", msk->ack_seq, 430 msk->ack_seq + bytes_read); 431 msk->ack_seq += bytes_read; 432 copied += bytes_read; 433 if (copied >= len) { 434 done = true; 435 goto next; 436 } 437 if (tp->urg_data && tp->urg_seq == tp->copied_seq) { 438 pr_err("Urgent data present, cannot proceed"); 439 done = true; 440 goto next; 441 } 442 next: 443 more_data_avail = mptcp_subflow_data_available(ssk); 444 } while (more_data_avail && !done); 445 release_sock(ssk); 446 continue; 447 448 wait_for_data: 449 more_data_avail = false; 450 451 /* only the master socket status is relevant here. The exit 452 * conditions mirror closely tcp_recvmsg() 453 */ 454 if (copied >= target) 455 break; 456 457 if (copied) { 458 if (sk->sk_err || 459 sk->sk_state == TCP_CLOSE || 460 (sk->sk_shutdown & RCV_SHUTDOWN) || 461 !timeo || 462 signal_pending(current)) 463 break; 464 } else { 465 if (sk->sk_err) { 466 copied = sock_error(sk); 467 break; 468 } 469 470 if (sk->sk_shutdown & RCV_SHUTDOWN) 471 break; 472 473 if (sk->sk_state == TCP_CLOSE) { 474 copied = -ENOTCONN; 475 break; 476 } 477 478 if (!timeo) { 479 copied = -EAGAIN; 480 break; 481 } 482 483 if (signal_pending(current)) { 484 copied = sock_intr_errno(timeo); 485 break; 486 } 487 } 488 489 pr_debug("block timeout %ld", timeo); 490 wait_data = true; 491 mptcp_wait_data(sk, &timeo); 492 if (unlikely(__mptcp_tcp_fallback(msk))) 493 goto fallback; 494 } 495 496 if (more_data_avail) { 497 if (!test_bit(MPTCP_DATA_READY, &msk->flags)) 498 set_bit(MPTCP_DATA_READY, &msk->flags); 499 } else if (!wait_data) { 500 clear_bit(MPTCP_DATA_READY, &msk->flags); 501 502 /* .. race-breaker: ssk might get new data after last 503 * data_available() returns false. 504 */ 505 ssk = mptcp_subflow_recv_lookup(msk); 506 if (unlikely(ssk)) 507 set_bit(MPTCP_DATA_READY, &msk->flags); 508 } 509 510 release_sock(sk); 511 return copied; 512 } 513 514 /* subflow sockets can be either outgoing (connect) or incoming 515 * (accept). 516 * 517 * Outgoing subflows use in-kernel sockets. 518 * Incoming subflows do not have their own 'struct socket' allocated, 519 * so we need to use tcp_close() after detaching them from the mptcp 520 * parent socket. 521 */ 522 static void __mptcp_close_ssk(struct sock *sk, struct sock *ssk, 523 struct mptcp_subflow_context *subflow, 524 long timeout) 525 { 526 struct socket *sock = READ_ONCE(ssk->sk_socket); 527 528 list_del(&subflow->node); 529 530 if (sock && sock != sk->sk_socket) { 531 /* outgoing subflow */ 532 sock_release(sock); 533 } else { 534 /* incoming subflow */ 535 tcp_close(ssk, timeout); 536 } 537 } 538 539 static int __mptcp_init_sock(struct sock *sk) 540 { 541 struct mptcp_sock *msk = mptcp_sk(sk); 542 543 INIT_LIST_HEAD(&msk->conn_list); 544 __set_bit(MPTCP_SEND_SPACE, &msk->flags); 545 546 msk->first = NULL; 547 548 return 0; 549 } 550 551 static int mptcp_init_sock(struct sock *sk) 552 { 553 if (!mptcp_is_enabled(sock_net(sk))) 554 return -ENOPROTOOPT; 555 556 return __mptcp_init_sock(sk); 557 } 558 559 static void mptcp_subflow_shutdown(struct sock *ssk, int how) 560 { 561 lock_sock(ssk); 562 563 switch (ssk->sk_state) { 564 case TCP_LISTEN: 565 if (!(how & RCV_SHUTDOWN)) 566 break; 567 /* fall through */ 568 case TCP_SYN_SENT: 569 tcp_disconnect(ssk, O_NONBLOCK); 570 break; 571 default: 572 ssk->sk_shutdown |= how; 573 tcp_shutdown(ssk, how); 574 break; 575 } 576 577 /* Wake up anyone sleeping in poll. */ 578 ssk->sk_state_change(ssk); 579 release_sock(ssk); 580 } 581 582 /* Called with msk lock held, releases such lock before returning */ 583 static void mptcp_close(struct sock *sk, long timeout) 584 { 585 struct mptcp_subflow_context *subflow, *tmp; 586 struct mptcp_sock *msk = mptcp_sk(sk); 587 LIST_HEAD(conn_list); 588 589 lock_sock(sk); 590 591 mptcp_token_destroy(msk->token); 592 inet_sk_state_store(sk, TCP_CLOSE); 593 594 list_splice_init(&msk->conn_list, &conn_list); 595 596 release_sock(sk); 597 598 list_for_each_entry_safe(subflow, tmp, &conn_list, node) { 599 struct sock *ssk = mptcp_subflow_tcp_sock(subflow); 600 601 __mptcp_close_ssk(sk, ssk, subflow, timeout); 602 } 603 604 sk_common_release(sk); 605 } 606 607 static void mptcp_copy_inaddrs(struct sock *msk, const struct sock *ssk) 608 { 609 #if IS_ENABLED(CONFIG_MPTCP_IPV6) 610 const struct ipv6_pinfo *ssk6 = inet6_sk(ssk); 611 struct ipv6_pinfo *msk6 = inet6_sk(msk); 612 613 msk->sk_v6_daddr = ssk->sk_v6_daddr; 614 msk->sk_v6_rcv_saddr = ssk->sk_v6_rcv_saddr; 615 616 if (msk6 && ssk6) { 617 msk6->saddr = ssk6->saddr; 618 msk6->flow_label = ssk6->flow_label; 619 } 620 #endif 621 622 inet_sk(msk)->inet_num = inet_sk(ssk)->inet_num; 623 inet_sk(msk)->inet_dport = inet_sk(ssk)->inet_dport; 624 inet_sk(msk)->inet_sport = inet_sk(ssk)->inet_sport; 625 inet_sk(msk)->inet_daddr = inet_sk(ssk)->inet_daddr; 626 inet_sk(msk)->inet_saddr = inet_sk(ssk)->inet_saddr; 627 inet_sk(msk)->inet_rcv_saddr = inet_sk(ssk)->inet_rcv_saddr; 628 } 629 630 static struct sock *mptcp_accept(struct sock *sk, int flags, int *err, 631 bool kern) 632 { 633 struct mptcp_sock *msk = mptcp_sk(sk); 634 struct socket *listener; 635 struct sock *newsk; 636 637 listener = __mptcp_nmpc_socket(msk); 638 if (WARN_ON_ONCE(!listener)) { 639 *err = -EINVAL; 640 return NULL; 641 } 642 643 pr_debug("msk=%p, listener=%p", msk, mptcp_subflow_ctx(listener->sk)); 644 newsk = inet_csk_accept(listener->sk, flags, err, kern); 645 if (!newsk) 646 return NULL; 647 648 pr_debug("msk=%p, subflow is mptcp=%d", msk, sk_is_mptcp(newsk)); 649 650 if (sk_is_mptcp(newsk)) { 651 struct mptcp_subflow_context *subflow; 652 struct sock *new_mptcp_sock; 653 struct sock *ssk = newsk; 654 u64 ack_seq; 655 656 subflow = mptcp_subflow_ctx(newsk); 657 lock_sock(sk); 658 659 local_bh_disable(); 660 new_mptcp_sock = sk_clone_lock(sk, GFP_ATOMIC); 661 if (!new_mptcp_sock) { 662 *err = -ENOBUFS; 663 local_bh_enable(); 664 release_sock(sk); 665 mptcp_subflow_shutdown(newsk, SHUT_RDWR + 1); 666 tcp_close(newsk, 0); 667 return NULL; 668 } 669 670 __mptcp_init_sock(new_mptcp_sock); 671 672 msk = mptcp_sk(new_mptcp_sock); 673 msk->local_key = subflow->local_key; 674 msk->token = subflow->token; 675 msk->subflow = NULL; 676 msk->first = newsk; 677 678 mptcp_token_update_accept(newsk, new_mptcp_sock); 679 680 msk->write_seq = subflow->idsn + 1; 681 if (subflow->can_ack) { 682 msk->can_ack = true; 683 msk->remote_key = subflow->remote_key; 684 mptcp_crypto_key_sha(msk->remote_key, NULL, &ack_seq); 685 ack_seq++; 686 msk->ack_seq = ack_seq; 687 } 688 newsk = new_mptcp_sock; 689 mptcp_copy_inaddrs(newsk, ssk); 690 list_add(&subflow->node, &msk->conn_list); 691 692 /* will be fully established at mptcp_stream_accept() 693 * completion. 694 */ 695 inet_sk_state_store(new_mptcp_sock, TCP_SYN_RECV); 696 bh_unlock_sock(new_mptcp_sock); 697 local_bh_enable(); 698 release_sock(sk); 699 700 /* the subflow can already receive packet, avoid racing with 701 * the receive path and process the pending ones 702 */ 703 lock_sock(ssk); 704 subflow->rel_write_seq = 1; 705 subflow->tcp_sock = ssk; 706 subflow->conn = new_mptcp_sock; 707 if (unlikely(!skb_queue_empty(&ssk->sk_receive_queue))) 708 mptcp_subflow_data_available(ssk); 709 release_sock(ssk); 710 } 711 712 return newsk; 713 } 714 715 static void mptcp_destroy(struct sock *sk) 716 { 717 struct mptcp_sock *msk = mptcp_sk(sk); 718 719 if (msk->cached_ext) 720 __skb_ext_put(msk->cached_ext); 721 } 722 723 static int mptcp_setsockopt(struct sock *sk, int level, int optname, 724 char __user *optval, unsigned int optlen) 725 { 726 struct mptcp_sock *msk = mptcp_sk(sk); 727 int ret = -EOPNOTSUPP; 728 struct socket *ssock; 729 struct sock *ssk; 730 731 pr_debug("msk=%p", msk); 732 733 /* @@ the meaning of setsockopt() when the socket is connected and 734 * there are multiple subflows is not defined. 735 */ 736 lock_sock(sk); 737 ssock = __mptcp_socket_create(msk, MPTCP_SAME_STATE); 738 if (IS_ERR(ssock)) { 739 release_sock(sk); 740 return ret; 741 } 742 743 ssk = ssock->sk; 744 sock_hold(ssk); 745 release_sock(sk); 746 747 ret = tcp_setsockopt(ssk, level, optname, optval, optlen); 748 sock_put(ssk); 749 750 return ret; 751 } 752 753 static int mptcp_getsockopt(struct sock *sk, int level, int optname, 754 char __user *optval, int __user *option) 755 { 756 struct mptcp_sock *msk = mptcp_sk(sk); 757 int ret = -EOPNOTSUPP; 758 struct socket *ssock; 759 struct sock *ssk; 760 761 pr_debug("msk=%p", msk); 762 763 /* @@ the meaning of getsockopt() when the socket is connected and 764 * there are multiple subflows is not defined. 765 */ 766 lock_sock(sk); 767 ssock = __mptcp_socket_create(msk, MPTCP_SAME_STATE); 768 if (IS_ERR(ssock)) { 769 release_sock(sk); 770 return ret; 771 } 772 773 ssk = ssock->sk; 774 sock_hold(ssk); 775 release_sock(sk); 776 777 ret = tcp_getsockopt(ssk, level, optname, optval, option); 778 sock_put(ssk); 779 780 return ret; 781 } 782 783 static int mptcp_get_port(struct sock *sk, unsigned short snum) 784 { 785 struct mptcp_sock *msk = mptcp_sk(sk); 786 struct socket *ssock; 787 788 ssock = __mptcp_nmpc_socket(msk); 789 pr_debug("msk=%p, subflow=%p", msk, ssock); 790 if (WARN_ON_ONCE(!ssock)) 791 return -EINVAL; 792 793 return inet_csk_get_port(ssock->sk, snum); 794 } 795 796 void mptcp_finish_connect(struct sock *ssk) 797 { 798 struct mptcp_subflow_context *subflow; 799 struct mptcp_sock *msk; 800 struct sock *sk; 801 u64 ack_seq; 802 803 subflow = mptcp_subflow_ctx(ssk); 804 805 if (!subflow->mp_capable) 806 return; 807 808 sk = subflow->conn; 809 msk = mptcp_sk(sk); 810 811 pr_debug("msk=%p, token=%u", sk, subflow->token); 812 813 mptcp_crypto_key_sha(subflow->remote_key, NULL, &ack_seq); 814 ack_seq++; 815 subflow->map_seq = ack_seq; 816 subflow->map_subflow_seq = 1; 817 subflow->rel_write_seq = 1; 818 819 /* the socket is not connected yet, no msk/subflow ops can access/race 820 * accessing the field below 821 */ 822 WRITE_ONCE(msk->remote_key, subflow->remote_key); 823 WRITE_ONCE(msk->local_key, subflow->local_key); 824 WRITE_ONCE(msk->token, subflow->token); 825 WRITE_ONCE(msk->write_seq, subflow->idsn + 1); 826 WRITE_ONCE(msk->ack_seq, ack_seq); 827 WRITE_ONCE(msk->can_ack, 1); 828 } 829 830 static void mptcp_sock_graft(struct sock *sk, struct socket *parent) 831 { 832 write_lock_bh(&sk->sk_callback_lock); 833 rcu_assign_pointer(sk->sk_wq, &parent->wq); 834 sk_set_socket(sk, parent); 835 sk->sk_uid = SOCK_INODE(parent)->i_uid; 836 write_unlock_bh(&sk->sk_callback_lock); 837 } 838 839 static bool mptcp_memory_free(const struct sock *sk, int wake) 840 { 841 struct mptcp_sock *msk = mptcp_sk(sk); 842 843 return wake ? test_bit(MPTCP_SEND_SPACE, &msk->flags) : true; 844 } 845 846 static struct proto mptcp_prot = { 847 .name = "MPTCP", 848 .owner = THIS_MODULE, 849 .init = mptcp_init_sock, 850 .close = mptcp_close, 851 .accept = mptcp_accept, 852 .setsockopt = mptcp_setsockopt, 853 .getsockopt = mptcp_getsockopt, 854 .shutdown = tcp_shutdown, 855 .destroy = mptcp_destroy, 856 .sendmsg = mptcp_sendmsg, 857 .recvmsg = mptcp_recvmsg, 858 .hash = inet_hash, 859 .unhash = inet_unhash, 860 .get_port = mptcp_get_port, 861 .stream_memory_free = mptcp_memory_free, 862 .obj_size = sizeof(struct mptcp_sock), 863 .no_autobind = true, 864 }; 865 866 static int mptcp_bind(struct socket *sock, struct sockaddr *uaddr, int addr_len) 867 { 868 struct mptcp_sock *msk = mptcp_sk(sock->sk); 869 struct socket *ssock; 870 int err; 871 872 lock_sock(sock->sk); 873 ssock = __mptcp_socket_create(msk, MPTCP_SAME_STATE); 874 if (IS_ERR(ssock)) { 875 err = PTR_ERR(ssock); 876 goto unlock; 877 } 878 879 err = ssock->ops->bind(ssock, uaddr, addr_len); 880 if (!err) 881 mptcp_copy_inaddrs(sock->sk, ssock->sk); 882 883 unlock: 884 release_sock(sock->sk); 885 return err; 886 } 887 888 static int mptcp_stream_connect(struct socket *sock, struct sockaddr *uaddr, 889 int addr_len, int flags) 890 { 891 struct mptcp_sock *msk = mptcp_sk(sock->sk); 892 struct socket *ssock; 893 int err; 894 895 lock_sock(sock->sk); 896 ssock = __mptcp_socket_create(msk, TCP_SYN_SENT); 897 if (IS_ERR(ssock)) { 898 err = PTR_ERR(ssock); 899 goto unlock; 900 } 901 902 #ifdef CONFIG_TCP_MD5SIG 903 /* no MPTCP if MD5SIG is enabled on this socket or we may run out of 904 * TCP option space. 905 */ 906 if (rcu_access_pointer(tcp_sk(ssock->sk)->md5sig_info)) 907 mptcp_subflow_ctx(ssock->sk)->request_mptcp = 0; 908 #endif 909 910 err = ssock->ops->connect(ssock, uaddr, addr_len, flags); 911 inet_sk_state_store(sock->sk, inet_sk_state_load(ssock->sk)); 912 mptcp_copy_inaddrs(sock->sk, ssock->sk); 913 914 unlock: 915 release_sock(sock->sk); 916 return err; 917 } 918 919 static int mptcp_v4_getname(struct socket *sock, struct sockaddr *uaddr, 920 int peer) 921 { 922 if (sock->sk->sk_prot == &tcp_prot) { 923 /* we are being invoked from __sys_accept4, after 924 * mptcp_accept() has just accepted a non-mp-capable 925 * flow: sk is a tcp_sk, not an mptcp one. 926 * 927 * Hand the socket over to tcp so all further socket ops 928 * bypass mptcp. 929 */ 930 sock->ops = &inet_stream_ops; 931 } 932 933 return inet_getname(sock, uaddr, peer); 934 } 935 936 #if IS_ENABLED(CONFIG_MPTCP_IPV6) 937 static int mptcp_v6_getname(struct socket *sock, struct sockaddr *uaddr, 938 int peer) 939 { 940 if (sock->sk->sk_prot == &tcpv6_prot) { 941 /* we are being invoked from __sys_accept4 after 942 * mptcp_accept() has accepted a non-mp-capable 943 * subflow: sk is a tcp_sk, not mptcp. 944 * 945 * Hand the socket over to tcp so all further 946 * socket ops bypass mptcp. 947 */ 948 sock->ops = &inet6_stream_ops; 949 } 950 951 return inet6_getname(sock, uaddr, peer); 952 } 953 #endif 954 955 static int mptcp_listen(struct socket *sock, int backlog) 956 { 957 struct mptcp_sock *msk = mptcp_sk(sock->sk); 958 struct socket *ssock; 959 int err; 960 961 pr_debug("msk=%p", msk); 962 963 lock_sock(sock->sk); 964 ssock = __mptcp_socket_create(msk, TCP_LISTEN); 965 if (IS_ERR(ssock)) { 966 err = PTR_ERR(ssock); 967 goto unlock; 968 } 969 970 err = ssock->ops->listen(ssock, backlog); 971 inet_sk_state_store(sock->sk, inet_sk_state_load(ssock->sk)); 972 if (!err) 973 mptcp_copy_inaddrs(sock->sk, ssock->sk); 974 975 unlock: 976 release_sock(sock->sk); 977 return err; 978 } 979 980 static bool is_tcp_proto(const struct proto *p) 981 { 982 #if IS_ENABLED(CONFIG_MPTCP_IPV6) 983 return p == &tcp_prot || p == &tcpv6_prot; 984 #else 985 return p == &tcp_prot; 986 #endif 987 } 988 989 static int mptcp_stream_accept(struct socket *sock, struct socket *newsock, 990 int flags, bool kern) 991 { 992 struct mptcp_sock *msk = mptcp_sk(sock->sk); 993 struct socket *ssock; 994 int err; 995 996 pr_debug("msk=%p", msk); 997 998 lock_sock(sock->sk); 999 if (sock->sk->sk_state != TCP_LISTEN) 1000 goto unlock_fail; 1001 1002 ssock = __mptcp_nmpc_socket(msk); 1003 if (!ssock) 1004 goto unlock_fail; 1005 1006 sock_hold(ssock->sk); 1007 release_sock(sock->sk); 1008 1009 err = ssock->ops->accept(sock, newsock, flags, kern); 1010 if (err == 0 && !is_tcp_proto(newsock->sk->sk_prot)) { 1011 struct mptcp_sock *msk = mptcp_sk(newsock->sk); 1012 struct mptcp_subflow_context *subflow; 1013 1014 /* set ssk->sk_socket of accept()ed flows to mptcp socket. 1015 * This is needed so NOSPACE flag can be set from tcp stack. 1016 */ 1017 list_for_each_entry(subflow, &msk->conn_list, node) { 1018 struct sock *ssk = mptcp_subflow_tcp_sock(subflow); 1019 1020 if (!ssk->sk_socket) 1021 mptcp_sock_graft(ssk, newsock); 1022 } 1023 1024 inet_sk_state_store(newsock->sk, TCP_ESTABLISHED); 1025 } 1026 1027 sock_put(ssock->sk); 1028 return err; 1029 1030 unlock_fail: 1031 release_sock(sock->sk); 1032 return -EINVAL; 1033 } 1034 1035 static __poll_t mptcp_poll(struct file *file, struct socket *sock, 1036 struct poll_table_struct *wait) 1037 { 1038 struct sock *sk = sock->sk; 1039 struct mptcp_sock *msk; 1040 struct socket *ssock; 1041 __poll_t mask = 0; 1042 1043 msk = mptcp_sk(sk); 1044 lock_sock(sk); 1045 ssock = __mptcp_nmpc_socket(msk); 1046 if (ssock) { 1047 mask = ssock->ops->poll(file, ssock, wait); 1048 release_sock(sk); 1049 return mask; 1050 } 1051 1052 release_sock(sk); 1053 sock_poll_wait(file, sock, wait); 1054 lock_sock(sk); 1055 ssock = __mptcp_tcp_fallback(msk); 1056 if (unlikely(ssock)) 1057 return ssock->ops->poll(file, ssock, NULL); 1058 1059 if (test_bit(MPTCP_DATA_READY, &msk->flags)) 1060 mask = EPOLLIN | EPOLLRDNORM; 1061 if (sk_stream_is_writeable(sk) && 1062 test_bit(MPTCP_SEND_SPACE, &msk->flags)) 1063 mask |= EPOLLOUT | EPOLLWRNORM; 1064 if (sk->sk_shutdown & RCV_SHUTDOWN) 1065 mask |= EPOLLIN | EPOLLRDNORM | EPOLLRDHUP; 1066 1067 release_sock(sk); 1068 1069 return mask; 1070 } 1071 1072 static int mptcp_shutdown(struct socket *sock, int how) 1073 { 1074 struct mptcp_sock *msk = mptcp_sk(sock->sk); 1075 struct mptcp_subflow_context *subflow; 1076 int ret = 0; 1077 1078 pr_debug("sk=%p, how=%d", msk, how); 1079 1080 lock_sock(sock->sk); 1081 1082 if (how == SHUT_WR || how == SHUT_RDWR) 1083 inet_sk_state_store(sock->sk, TCP_FIN_WAIT1); 1084 1085 how++; 1086 1087 if ((how & ~SHUTDOWN_MASK) || !how) { 1088 ret = -EINVAL; 1089 goto out_unlock; 1090 } 1091 1092 if (sock->state == SS_CONNECTING) { 1093 if ((1 << sock->sk->sk_state) & 1094 (TCPF_SYN_SENT | TCPF_SYN_RECV | TCPF_CLOSE)) 1095 sock->state = SS_DISCONNECTING; 1096 else 1097 sock->state = SS_CONNECTED; 1098 } 1099 1100 mptcp_for_each_subflow(msk, subflow) { 1101 struct sock *tcp_sk = mptcp_subflow_tcp_sock(subflow); 1102 1103 mptcp_subflow_shutdown(tcp_sk, how); 1104 } 1105 1106 out_unlock: 1107 release_sock(sock->sk); 1108 1109 return ret; 1110 } 1111 1112 static const struct proto_ops mptcp_stream_ops = { 1113 .family = PF_INET, 1114 .owner = THIS_MODULE, 1115 .release = inet_release, 1116 .bind = mptcp_bind, 1117 .connect = mptcp_stream_connect, 1118 .socketpair = sock_no_socketpair, 1119 .accept = mptcp_stream_accept, 1120 .getname = mptcp_v4_getname, 1121 .poll = mptcp_poll, 1122 .ioctl = inet_ioctl, 1123 .gettstamp = sock_gettstamp, 1124 .listen = mptcp_listen, 1125 .shutdown = mptcp_shutdown, 1126 .setsockopt = sock_common_setsockopt, 1127 .getsockopt = sock_common_getsockopt, 1128 .sendmsg = inet_sendmsg, 1129 .recvmsg = inet_recvmsg, 1130 .mmap = sock_no_mmap, 1131 .sendpage = inet_sendpage, 1132 #ifdef CONFIG_COMPAT 1133 .compat_setsockopt = compat_sock_common_setsockopt, 1134 .compat_getsockopt = compat_sock_common_getsockopt, 1135 #endif 1136 }; 1137 1138 static struct inet_protosw mptcp_protosw = { 1139 .type = SOCK_STREAM, 1140 .protocol = IPPROTO_MPTCP, 1141 .prot = &mptcp_prot, 1142 .ops = &mptcp_stream_ops, 1143 .flags = INET_PROTOSW_ICSK, 1144 }; 1145 1146 void mptcp_proto_init(void) 1147 { 1148 mptcp_prot.h.hashinfo = tcp_prot.h.hashinfo; 1149 1150 mptcp_subflow_init(); 1151 1152 if (proto_register(&mptcp_prot, 1) != 0) 1153 panic("Failed to register MPTCP proto.\n"); 1154 1155 inet_register_protosw(&mptcp_protosw); 1156 } 1157 1158 #if IS_ENABLED(CONFIG_MPTCP_IPV6) 1159 static const struct proto_ops mptcp_v6_stream_ops = { 1160 .family = PF_INET6, 1161 .owner = THIS_MODULE, 1162 .release = inet6_release, 1163 .bind = mptcp_bind, 1164 .connect = mptcp_stream_connect, 1165 .socketpair = sock_no_socketpair, 1166 .accept = mptcp_stream_accept, 1167 .getname = mptcp_v6_getname, 1168 .poll = mptcp_poll, 1169 .ioctl = inet6_ioctl, 1170 .gettstamp = sock_gettstamp, 1171 .listen = mptcp_listen, 1172 .shutdown = mptcp_shutdown, 1173 .setsockopt = sock_common_setsockopt, 1174 .getsockopt = sock_common_getsockopt, 1175 .sendmsg = inet6_sendmsg, 1176 .recvmsg = inet6_recvmsg, 1177 .mmap = sock_no_mmap, 1178 .sendpage = inet_sendpage, 1179 #ifdef CONFIG_COMPAT 1180 .compat_setsockopt = compat_sock_common_setsockopt, 1181 .compat_getsockopt = compat_sock_common_getsockopt, 1182 #endif 1183 }; 1184 1185 static struct proto mptcp_v6_prot; 1186 1187 static void mptcp_v6_destroy(struct sock *sk) 1188 { 1189 mptcp_destroy(sk); 1190 inet6_destroy_sock(sk); 1191 } 1192 1193 static struct inet_protosw mptcp_v6_protosw = { 1194 .type = SOCK_STREAM, 1195 .protocol = IPPROTO_MPTCP, 1196 .prot = &mptcp_v6_prot, 1197 .ops = &mptcp_v6_stream_ops, 1198 .flags = INET_PROTOSW_ICSK, 1199 }; 1200 1201 int mptcp_proto_v6_init(void) 1202 { 1203 int err; 1204 1205 mptcp_v6_prot = mptcp_prot; 1206 strcpy(mptcp_v6_prot.name, "MPTCPv6"); 1207 mptcp_v6_prot.slab = NULL; 1208 mptcp_v6_prot.destroy = mptcp_v6_destroy; 1209 mptcp_v6_prot.obj_size = sizeof(struct mptcp_sock) + 1210 sizeof(struct ipv6_pinfo); 1211 1212 err = proto_register(&mptcp_v6_prot, 1); 1213 if (err) 1214 return err; 1215 1216 err = inet6_register_protosw(&mptcp_v6_protosw); 1217 if (err) 1218 proto_unregister(&mptcp_v6_prot); 1219 1220 return err; 1221 } 1222 #endif 1223