1 // SPDX-License-Identifier: GPL-2.0 2 /* Multipath TCP 3 * 4 * Copyright (c) 2017 - 2019, Intel Corporation. 5 */ 6 7 #define pr_fmt(fmt) "MPTCP: " fmt 8 9 #include <linux/kernel.h> 10 #include <linux/module.h> 11 #include <linux/netdevice.h> 12 #include <net/sock.h> 13 #include <net/inet_common.h> 14 #include <net/inet_hashtables.h> 15 #include <net/protocol.h> 16 #include <net/tcp.h> 17 #if IS_ENABLED(CONFIG_MPTCP_IPV6) 18 #include <net/transp_v6.h> 19 #endif 20 #include <net/mptcp.h> 21 #include "protocol.h" 22 23 #define MPTCP_SAME_STATE TCP_MAX_STATES 24 25 /* If msk has an initial subflow socket, and the MP_CAPABLE handshake has not 26 * completed yet or has failed, return the subflow socket. 27 * Otherwise return NULL. 28 */ 29 static struct socket *__mptcp_nmpc_socket(const struct mptcp_sock *msk) 30 { 31 if (!msk->subflow || mptcp_subflow_ctx(msk->subflow->sk)->fourth_ack) 32 return NULL; 33 34 return msk->subflow; 35 } 36 37 /* if msk has a single subflow, and the mp_capable handshake is failed, 38 * return it. 39 * Otherwise returns NULL 40 */ 41 static struct socket *__mptcp_tcp_fallback(const struct mptcp_sock *msk) 42 { 43 struct socket *ssock = __mptcp_nmpc_socket(msk); 44 45 sock_owned_by_me((const struct sock *)msk); 46 47 if (!ssock || sk_is_mptcp(ssock->sk)) 48 return NULL; 49 50 return ssock; 51 } 52 53 static bool __mptcp_can_create_subflow(const struct mptcp_sock *msk) 54 { 55 return ((struct sock *)msk)->sk_state == TCP_CLOSE; 56 } 57 58 static struct socket *__mptcp_socket_create(struct mptcp_sock *msk, int state) 59 { 60 struct mptcp_subflow_context *subflow; 61 struct sock *sk = (struct sock *)msk; 62 struct socket *ssock; 63 int err; 64 65 ssock = __mptcp_nmpc_socket(msk); 66 if (ssock) 67 goto set_state; 68 69 if (!__mptcp_can_create_subflow(msk)) 70 return ERR_PTR(-EINVAL); 71 72 err = mptcp_subflow_create_socket(sk, &ssock); 73 if (err) 74 return ERR_PTR(err); 75 76 msk->subflow = ssock; 77 subflow = mptcp_subflow_ctx(ssock->sk); 78 list_add(&subflow->node, &msk->conn_list); 79 subflow->request_mptcp = 1; 80 81 set_state: 82 if (state != MPTCP_SAME_STATE) 83 inet_sk_state_store(sk, state); 84 return ssock; 85 } 86 87 static struct sock *mptcp_subflow_get(const struct mptcp_sock *msk) 88 { 89 struct mptcp_subflow_context *subflow; 90 91 sock_owned_by_me((const struct sock *)msk); 92 93 mptcp_for_each_subflow(msk, subflow) { 94 return mptcp_subflow_tcp_sock(subflow); 95 } 96 97 return NULL; 98 } 99 100 static int mptcp_sendmsg(struct sock *sk, struct msghdr *msg, size_t len) 101 { 102 struct mptcp_sock *msk = mptcp_sk(sk); 103 struct socket *ssock; 104 struct sock *ssk; 105 int ret; 106 107 if (msg->msg_flags & ~(MSG_MORE | MSG_DONTWAIT | MSG_NOSIGNAL)) 108 return -EOPNOTSUPP; 109 110 lock_sock(sk); 111 ssock = __mptcp_tcp_fallback(msk); 112 if (ssock) { 113 pr_debug("fallback passthrough"); 114 ret = sock_sendmsg(ssock, msg); 115 release_sock(sk); 116 return ret; 117 } 118 119 ssk = mptcp_subflow_get(msk); 120 if (!ssk) { 121 release_sock(sk); 122 return -ENOTCONN; 123 } 124 125 ret = sock_sendmsg(ssk->sk_socket, msg); 126 127 release_sock(sk); 128 return ret; 129 } 130 131 static int mptcp_recvmsg(struct sock *sk, struct msghdr *msg, size_t len, 132 int nonblock, int flags, int *addr_len) 133 { 134 struct mptcp_sock *msk = mptcp_sk(sk); 135 struct socket *ssock; 136 struct sock *ssk; 137 int copied = 0; 138 139 if (msg->msg_flags & ~(MSG_WAITALL | MSG_DONTWAIT)) 140 return -EOPNOTSUPP; 141 142 lock_sock(sk); 143 ssock = __mptcp_tcp_fallback(msk); 144 if (ssock) { 145 pr_debug("fallback-read subflow=%p", 146 mptcp_subflow_ctx(ssock->sk)); 147 copied = sock_recvmsg(ssock, msg, flags); 148 release_sock(sk); 149 return copied; 150 } 151 152 ssk = mptcp_subflow_get(msk); 153 if (!ssk) { 154 release_sock(sk); 155 return -ENOTCONN; 156 } 157 158 copied = sock_recvmsg(ssk->sk_socket, msg, flags); 159 160 release_sock(sk); 161 162 return copied; 163 } 164 165 /* subflow sockets can be either outgoing (connect) or incoming 166 * (accept). 167 * 168 * Outgoing subflows use in-kernel sockets. 169 * Incoming subflows do not have their own 'struct socket' allocated, 170 * so we need to use tcp_close() after detaching them from the mptcp 171 * parent socket. 172 */ 173 static void __mptcp_close_ssk(struct sock *sk, struct sock *ssk, 174 struct mptcp_subflow_context *subflow, 175 long timeout) 176 { 177 struct socket *sock = READ_ONCE(ssk->sk_socket); 178 179 list_del(&subflow->node); 180 181 if (sock && sock != sk->sk_socket) { 182 /* outgoing subflow */ 183 sock_release(sock); 184 } else { 185 /* incoming subflow */ 186 tcp_close(ssk, timeout); 187 } 188 } 189 190 static int mptcp_init_sock(struct sock *sk) 191 { 192 struct mptcp_sock *msk = mptcp_sk(sk); 193 194 INIT_LIST_HEAD(&msk->conn_list); 195 196 return 0; 197 } 198 199 static void mptcp_close(struct sock *sk, long timeout) 200 { 201 struct mptcp_subflow_context *subflow, *tmp; 202 struct mptcp_sock *msk = mptcp_sk(sk); 203 204 inet_sk_state_store(sk, TCP_CLOSE); 205 206 lock_sock(sk); 207 208 list_for_each_entry_safe(subflow, tmp, &msk->conn_list, node) { 209 struct sock *ssk = mptcp_subflow_tcp_sock(subflow); 210 211 __mptcp_close_ssk(sk, ssk, subflow, timeout); 212 } 213 214 release_sock(sk); 215 sk_common_release(sk); 216 } 217 218 static void mptcp_copy_inaddrs(struct sock *msk, const struct sock *ssk) 219 { 220 #if IS_ENABLED(CONFIG_MPTCP_IPV6) 221 const struct ipv6_pinfo *ssk6 = inet6_sk(ssk); 222 struct ipv6_pinfo *msk6 = inet6_sk(msk); 223 224 msk->sk_v6_daddr = ssk->sk_v6_daddr; 225 msk->sk_v6_rcv_saddr = ssk->sk_v6_rcv_saddr; 226 227 if (msk6 && ssk6) { 228 msk6->saddr = ssk6->saddr; 229 msk6->flow_label = ssk6->flow_label; 230 } 231 #endif 232 233 inet_sk(msk)->inet_num = inet_sk(ssk)->inet_num; 234 inet_sk(msk)->inet_dport = inet_sk(ssk)->inet_dport; 235 inet_sk(msk)->inet_sport = inet_sk(ssk)->inet_sport; 236 inet_sk(msk)->inet_daddr = inet_sk(ssk)->inet_daddr; 237 inet_sk(msk)->inet_saddr = inet_sk(ssk)->inet_saddr; 238 inet_sk(msk)->inet_rcv_saddr = inet_sk(ssk)->inet_rcv_saddr; 239 } 240 241 static struct sock *mptcp_accept(struct sock *sk, int flags, int *err, 242 bool kern) 243 { 244 struct mptcp_sock *msk = mptcp_sk(sk); 245 struct socket *listener; 246 struct sock *newsk; 247 248 listener = __mptcp_nmpc_socket(msk); 249 if (WARN_ON_ONCE(!listener)) { 250 *err = -EINVAL; 251 return NULL; 252 } 253 254 pr_debug("msk=%p, listener=%p", msk, mptcp_subflow_ctx(listener->sk)); 255 newsk = inet_csk_accept(listener->sk, flags, err, kern); 256 if (!newsk) 257 return NULL; 258 259 pr_debug("msk=%p, subflow is mptcp=%d", msk, sk_is_mptcp(newsk)); 260 261 if (sk_is_mptcp(newsk)) { 262 struct mptcp_subflow_context *subflow; 263 struct sock *new_mptcp_sock; 264 struct sock *ssk = newsk; 265 266 subflow = mptcp_subflow_ctx(newsk); 267 lock_sock(sk); 268 269 local_bh_disable(); 270 new_mptcp_sock = sk_clone_lock(sk, GFP_ATOMIC); 271 if (!new_mptcp_sock) { 272 *err = -ENOBUFS; 273 local_bh_enable(); 274 release_sock(sk); 275 tcp_close(newsk, 0); 276 return NULL; 277 } 278 279 mptcp_init_sock(new_mptcp_sock); 280 281 msk = mptcp_sk(new_mptcp_sock); 282 msk->remote_key = subflow->remote_key; 283 msk->local_key = subflow->local_key; 284 msk->subflow = NULL; 285 286 newsk = new_mptcp_sock; 287 mptcp_copy_inaddrs(newsk, ssk); 288 list_add(&subflow->node, &msk->conn_list); 289 290 /* will be fully established at mptcp_stream_accept() 291 * completion. 292 */ 293 inet_sk_state_store(new_mptcp_sock, TCP_SYN_RECV); 294 bh_unlock_sock(new_mptcp_sock); 295 local_bh_enable(); 296 release_sock(sk); 297 } 298 299 return newsk; 300 } 301 302 static int mptcp_get_port(struct sock *sk, unsigned short snum) 303 { 304 struct mptcp_sock *msk = mptcp_sk(sk); 305 struct socket *ssock; 306 307 ssock = __mptcp_nmpc_socket(msk); 308 pr_debug("msk=%p, subflow=%p", msk, ssock); 309 if (WARN_ON_ONCE(!ssock)) 310 return -EINVAL; 311 312 return inet_csk_get_port(ssock->sk, snum); 313 } 314 315 void mptcp_finish_connect(struct sock *ssk) 316 { 317 struct mptcp_subflow_context *subflow; 318 struct mptcp_sock *msk; 319 struct sock *sk; 320 321 subflow = mptcp_subflow_ctx(ssk); 322 323 if (!subflow->mp_capable) 324 return; 325 326 sk = subflow->conn; 327 msk = mptcp_sk(sk); 328 329 /* the socket is not connected yet, no msk/subflow ops can access/race 330 * accessing the field below 331 */ 332 WRITE_ONCE(msk->remote_key, subflow->remote_key); 333 WRITE_ONCE(msk->local_key, subflow->local_key); 334 } 335 336 static void mptcp_sock_graft(struct sock *sk, struct socket *parent) 337 { 338 write_lock_bh(&sk->sk_callback_lock); 339 rcu_assign_pointer(sk->sk_wq, &parent->wq); 340 sk_set_socket(sk, parent); 341 sk->sk_uid = SOCK_INODE(parent)->i_uid; 342 write_unlock_bh(&sk->sk_callback_lock); 343 } 344 345 static struct proto mptcp_prot = { 346 .name = "MPTCP", 347 .owner = THIS_MODULE, 348 .init = mptcp_init_sock, 349 .close = mptcp_close, 350 .accept = mptcp_accept, 351 .shutdown = tcp_shutdown, 352 .sendmsg = mptcp_sendmsg, 353 .recvmsg = mptcp_recvmsg, 354 .hash = inet_hash, 355 .unhash = inet_unhash, 356 .get_port = mptcp_get_port, 357 .obj_size = sizeof(struct mptcp_sock), 358 .no_autobind = true, 359 }; 360 361 static int mptcp_bind(struct socket *sock, struct sockaddr *uaddr, int addr_len) 362 { 363 struct mptcp_sock *msk = mptcp_sk(sock->sk); 364 struct socket *ssock; 365 int err; 366 367 lock_sock(sock->sk); 368 ssock = __mptcp_socket_create(msk, MPTCP_SAME_STATE); 369 if (IS_ERR(ssock)) { 370 err = PTR_ERR(ssock); 371 goto unlock; 372 } 373 374 err = ssock->ops->bind(ssock, uaddr, addr_len); 375 if (!err) 376 mptcp_copy_inaddrs(sock->sk, ssock->sk); 377 378 unlock: 379 release_sock(sock->sk); 380 return err; 381 } 382 383 static int mptcp_stream_connect(struct socket *sock, struct sockaddr *uaddr, 384 int addr_len, int flags) 385 { 386 struct mptcp_sock *msk = mptcp_sk(sock->sk); 387 struct socket *ssock; 388 int err; 389 390 lock_sock(sock->sk); 391 ssock = __mptcp_socket_create(msk, TCP_SYN_SENT); 392 if (IS_ERR(ssock)) { 393 err = PTR_ERR(ssock); 394 goto unlock; 395 } 396 397 #ifdef CONFIG_TCP_MD5SIG 398 /* no MPTCP if MD5SIG is enabled on this socket or we may run out of 399 * TCP option space. 400 */ 401 if (rcu_access_pointer(tcp_sk(ssock->sk)->md5sig_info)) 402 mptcp_subflow_ctx(ssock->sk)->request_mptcp = 0; 403 #endif 404 405 err = ssock->ops->connect(ssock, uaddr, addr_len, flags); 406 inet_sk_state_store(sock->sk, inet_sk_state_load(ssock->sk)); 407 mptcp_copy_inaddrs(sock->sk, ssock->sk); 408 409 unlock: 410 release_sock(sock->sk); 411 return err; 412 } 413 414 static int mptcp_v4_getname(struct socket *sock, struct sockaddr *uaddr, 415 int peer) 416 { 417 if (sock->sk->sk_prot == &tcp_prot) { 418 /* we are being invoked from __sys_accept4, after 419 * mptcp_accept() has just accepted a non-mp-capable 420 * flow: sk is a tcp_sk, not an mptcp one. 421 * 422 * Hand the socket over to tcp so all further socket ops 423 * bypass mptcp. 424 */ 425 sock->ops = &inet_stream_ops; 426 } 427 428 return inet_getname(sock, uaddr, peer); 429 } 430 431 #if IS_ENABLED(CONFIG_MPTCP_IPV6) 432 static int mptcp_v6_getname(struct socket *sock, struct sockaddr *uaddr, 433 int peer) 434 { 435 if (sock->sk->sk_prot == &tcpv6_prot) { 436 /* we are being invoked from __sys_accept4 after 437 * mptcp_accept() has accepted a non-mp-capable 438 * subflow: sk is a tcp_sk, not mptcp. 439 * 440 * Hand the socket over to tcp so all further 441 * socket ops bypass mptcp. 442 */ 443 sock->ops = &inet6_stream_ops; 444 } 445 446 return inet6_getname(sock, uaddr, peer); 447 } 448 #endif 449 450 static int mptcp_listen(struct socket *sock, int backlog) 451 { 452 struct mptcp_sock *msk = mptcp_sk(sock->sk); 453 struct socket *ssock; 454 int err; 455 456 pr_debug("msk=%p", msk); 457 458 lock_sock(sock->sk); 459 ssock = __mptcp_socket_create(msk, TCP_LISTEN); 460 if (IS_ERR(ssock)) { 461 err = PTR_ERR(ssock); 462 goto unlock; 463 } 464 465 err = ssock->ops->listen(ssock, backlog); 466 inet_sk_state_store(sock->sk, inet_sk_state_load(ssock->sk)); 467 if (!err) 468 mptcp_copy_inaddrs(sock->sk, ssock->sk); 469 470 unlock: 471 release_sock(sock->sk); 472 return err; 473 } 474 475 static bool is_tcp_proto(const struct proto *p) 476 { 477 #if IS_ENABLED(CONFIG_MPTCP_IPV6) 478 return p == &tcp_prot || p == &tcpv6_prot; 479 #else 480 return p == &tcp_prot; 481 #endif 482 } 483 484 static int mptcp_stream_accept(struct socket *sock, struct socket *newsock, 485 int flags, bool kern) 486 { 487 struct mptcp_sock *msk = mptcp_sk(sock->sk); 488 struct socket *ssock; 489 int err; 490 491 pr_debug("msk=%p", msk); 492 493 lock_sock(sock->sk); 494 if (sock->sk->sk_state != TCP_LISTEN) 495 goto unlock_fail; 496 497 ssock = __mptcp_nmpc_socket(msk); 498 if (!ssock) 499 goto unlock_fail; 500 501 sock_hold(ssock->sk); 502 release_sock(sock->sk); 503 504 err = ssock->ops->accept(sock, newsock, flags, kern); 505 if (err == 0 && !is_tcp_proto(newsock->sk->sk_prot)) { 506 struct mptcp_sock *msk = mptcp_sk(newsock->sk); 507 struct mptcp_subflow_context *subflow; 508 509 /* set ssk->sk_socket of accept()ed flows to mptcp socket. 510 * This is needed so NOSPACE flag can be set from tcp stack. 511 */ 512 list_for_each_entry(subflow, &msk->conn_list, node) { 513 struct sock *ssk = mptcp_subflow_tcp_sock(subflow); 514 515 if (!ssk->sk_socket) 516 mptcp_sock_graft(ssk, newsock); 517 } 518 519 inet_sk_state_store(newsock->sk, TCP_ESTABLISHED); 520 } 521 522 sock_put(ssock->sk); 523 return err; 524 525 unlock_fail: 526 release_sock(sock->sk); 527 return -EINVAL; 528 } 529 530 static __poll_t mptcp_poll(struct file *file, struct socket *sock, 531 struct poll_table_struct *wait) 532 { 533 __poll_t mask = 0; 534 535 return mask; 536 } 537 538 static struct proto_ops mptcp_stream_ops; 539 540 static struct inet_protosw mptcp_protosw = { 541 .type = SOCK_STREAM, 542 .protocol = IPPROTO_MPTCP, 543 .prot = &mptcp_prot, 544 .ops = &mptcp_stream_ops, 545 .flags = INET_PROTOSW_ICSK, 546 }; 547 548 void __init mptcp_init(void) 549 { 550 mptcp_prot.h.hashinfo = tcp_prot.h.hashinfo; 551 mptcp_stream_ops = inet_stream_ops; 552 mptcp_stream_ops.bind = mptcp_bind; 553 mptcp_stream_ops.connect = mptcp_stream_connect; 554 mptcp_stream_ops.poll = mptcp_poll; 555 mptcp_stream_ops.accept = mptcp_stream_accept; 556 mptcp_stream_ops.getname = mptcp_v4_getname; 557 mptcp_stream_ops.listen = mptcp_listen; 558 559 mptcp_subflow_init(); 560 561 if (proto_register(&mptcp_prot, 1) != 0) 562 panic("Failed to register MPTCP proto.\n"); 563 564 inet_register_protosw(&mptcp_protosw); 565 } 566 567 #if IS_ENABLED(CONFIG_MPTCP_IPV6) 568 static struct proto_ops mptcp_v6_stream_ops; 569 static struct proto mptcp_v6_prot; 570 571 static struct inet_protosw mptcp_v6_protosw = { 572 .type = SOCK_STREAM, 573 .protocol = IPPROTO_MPTCP, 574 .prot = &mptcp_v6_prot, 575 .ops = &mptcp_v6_stream_ops, 576 .flags = INET_PROTOSW_ICSK, 577 }; 578 579 int mptcpv6_init(void) 580 { 581 int err; 582 583 mptcp_v6_prot = mptcp_prot; 584 strcpy(mptcp_v6_prot.name, "MPTCPv6"); 585 mptcp_v6_prot.slab = NULL; 586 mptcp_v6_prot.obj_size = sizeof(struct mptcp_sock) + 587 sizeof(struct ipv6_pinfo); 588 589 err = proto_register(&mptcp_v6_prot, 1); 590 if (err) 591 return err; 592 593 mptcp_v6_stream_ops = inet6_stream_ops; 594 mptcp_v6_stream_ops.bind = mptcp_bind; 595 mptcp_v6_stream_ops.connect = mptcp_stream_connect; 596 mptcp_v6_stream_ops.poll = mptcp_poll; 597 mptcp_v6_stream_ops.accept = mptcp_stream_accept; 598 mptcp_v6_stream_ops.getname = mptcp_v6_getname; 599 mptcp_v6_stream_ops.listen = mptcp_listen; 600 601 err = inet6_register_protosw(&mptcp_v6_protosw); 602 if (err) 603 proto_unregister(&mptcp_v6_prot); 604 605 return err; 606 } 607 #endif 608