1 // SPDX-License-Identifier: GPL-2.0 2 /* Multipath TCP 3 * 4 * Copyright (c) 2017 - 2019, Intel Corporation. 5 */ 6 7 #define pr_fmt(fmt) "MPTCP: " fmt 8 9 #include <linux/kernel.h> 10 #include <linux/module.h> 11 #include <linux/netdevice.h> 12 #include <net/sock.h> 13 #include <net/inet_common.h> 14 #include <net/inet_hashtables.h> 15 #include <net/protocol.h> 16 #include <net/tcp.h> 17 #include <net/mptcp.h> 18 #include "protocol.h" 19 20 #define MPTCP_SAME_STATE TCP_MAX_STATES 21 22 /* If msk has an initial subflow socket, and the MP_CAPABLE handshake has not 23 * completed yet or has failed, return the subflow socket. 24 * Otherwise return NULL. 25 */ 26 static struct socket *__mptcp_nmpc_socket(const struct mptcp_sock *msk) 27 { 28 if (!msk->subflow || mptcp_subflow_ctx(msk->subflow->sk)->fourth_ack) 29 return NULL; 30 31 return msk->subflow; 32 } 33 34 /* if msk has a single subflow, and the mp_capable handshake is failed, 35 * return it. 36 * Otherwise returns NULL 37 */ 38 static struct socket *__mptcp_tcp_fallback(const struct mptcp_sock *msk) 39 { 40 struct socket *ssock = __mptcp_nmpc_socket(msk); 41 42 sock_owned_by_me((const struct sock *)msk); 43 44 if (!ssock || sk_is_mptcp(ssock->sk)) 45 return NULL; 46 47 return ssock; 48 } 49 50 static bool __mptcp_can_create_subflow(const struct mptcp_sock *msk) 51 { 52 return ((struct sock *)msk)->sk_state == TCP_CLOSE; 53 } 54 55 static struct socket *__mptcp_socket_create(struct mptcp_sock *msk, int state) 56 { 57 struct mptcp_subflow_context *subflow; 58 struct sock *sk = (struct sock *)msk; 59 struct socket *ssock; 60 int err; 61 62 ssock = __mptcp_nmpc_socket(msk); 63 if (ssock) 64 goto set_state; 65 66 if (!__mptcp_can_create_subflow(msk)) 67 return ERR_PTR(-EINVAL); 68 69 err = mptcp_subflow_create_socket(sk, &ssock); 70 if (err) 71 return ERR_PTR(err); 72 73 msk->subflow = ssock; 74 subflow = mptcp_subflow_ctx(ssock->sk); 75 list_add(&subflow->node, &msk->conn_list); 76 subflow->request_mptcp = 1; 77 78 set_state: 79 if (state != MPTCP_SAME_STATE) 80 inet_sk_state_store(sk, state); 81 return ssock; 82 } 83 84 static struct sock *mptcp_subflow_get(const struct mptcp_sock *msk) 85 { 86 struct mptcp_subflow_context *subflow; 87 88 sock_owned_by_me((const struct sock *)msk); 89 90 mptcp_for_each_subflow(msk, subflow) { 91 return mptcp_subflow_tcp_sock(subflow); 92 } 93 94 return NULL; 95 } 96 97 static int mptcp_sendmsg(struct sock *sk, struct msghdr *msg, size_t len) 98 { 99 struct mptcp_sock *msk = mptcp_sk(sk); 100 struct socket *ssock; 101 struct sock *ssk; 102 int ret; 103 104 if (msg->msg_flags & ~(MSG_MORE | MSG_DONTWAIT | MSG_NOSIGNAL)) 105 return -EOPNOTSUPP; 106 107 lock_sock(sk); 108 ssock = __mptcp_tcp_fallback(msk); 109 if (ssock) { 110 pr_debug("fallback passthrough"); 111 ret = sock_sendmsg(ssock, msg); 112 release_sock(sk); 113 return ret; 114 } 115 116 ssk = mptcp_subflow_get(msk); 117 if (!ssk) { 118 release_sock(sk); 119 return -ENOTCONN; 120 } 121 122 ret = sock_sendmsg(ssk->sk_socket, msg); 123 124 release_sock(sk); 125 return ret; 126 } 127 128 static int mptcp_recvmsg(struct sock *sk, struct msghdr *msg, size_t len, 129 int nonblock, int flags, int *addr_len) 130 { 131 struct mptcp_sock *msk = mptcp_sk(sk); 132 struct socket *ssock; 133 struct sock *ssk; 134 int copied = 0; 135 136 if (msg->msg_flags & ~(MSG_WAITALL | MSG_DONTWAIT)) 137 return -EOPNOTSUPP; 138 139 lock_sock(sk); 140 ssock = __mptcp_tcp_fallback(msk); 141 if (ssock) { 142 pr_debug("fallback-read subflow=%p", 143 mptcp_subflow_ctx(ssock->sk)); 144 copied = sock_recvmsg(ssock, msg, flags); 145 release_sock(sk); 146 return copied; 147 } 148 149 ssk = mptcp_subflow_get(msk); 150 if (!ssk) { 151 release_sock(sk); 152 return -ENOTCONN; 153 } 154 155 copied = sock_recvmsg(ssk->sk_socket, msg, flags); 156 157 release_sock(sk); 158 159 return copied; 160 } 161 162 /* subflow sockets can be either outgoing (connect) or incoming 163 * (accept). 164 * 165 * Outgoing subflows use in-kernel sockets. 166 * Incoming subflows do not have their own 'struct socket' allocated, 167 * so we need to use tcp_close() after detaching them from the mptcp 168 * parent socket. 169 */ 170 static void __mptcp_close_ssk(struct sock *sk, struct sock *ssk, 171 struct mptcp_subflow_context *subflow, 172 long timeout) 173 { 174 struct socket *sock = READ_ONCE(ssk->sk_socket); 175 176 list_del(&subflow->node); 177 178 if (sock && sock != sk->sk_socket) { 179 /* outgoing subflow */ 180 sock_release(sock); 181 } else { 182 /* incoming subflow */ 183 tcp_close(ssk, timeout); 184 } 185 } 186 187 static int mptcp_init_sock(struct sock *sk) 188 { 189 struct mptcp_sock *msk = mptcp_sk(sk); 190 191 INIT_LIST_HEAD(&msk->conn_list); 192 193 return 0; 194 } 195 196 static void mptcp_close(struct sock *sk, long timeout) 197 { 198 struct mptcp_subflow_context *subflow, *tmp; 199 struct mptcp_sock *msk = mptcp_sk(sk); 200 201 inet_sk_state_store(sk, TCP_CLOSE); 202 203 lock_sock(sk); 204 205 list_for_each_entry_safe(subflow, tmp, &msk->conn_list, node) { 206 struct sock *ssk = mptcp_subflow_tcp_sock(subflow); 207 208 __mptcp_close_ssk(sk, ssk, subflow, timeout); 209 } 210 211 release_sock(sk); 212 sk_common_release(sk); 213 } 214 215 static int mptcp_get_port(struct sock *sk, unsigned short snum) 216 { 217 struct mptcp_sock *msk = mptcp_sk(sk); 218 struct socket *ssock; 219 220 ssock = __mptcp_nmpc_socket(msk); 221 pr_debug("msk=%p, subflow=%p", msk, ssock); 222 if (WARN_ON_ONCE(!ssock)) 223 return -EINVAL; 224 225 return inet_csk_get_port(ssock->sk, snum); 226 } 227 228 void mptcp_finish_connect(struct sock *ssk) 229 { 230 struct mptcp_subflow_context *subflow; 231 struct mptcp_sock *msk; 232 struct sock *sk; 233 234 subflow = mptcp_subflow_ctx(ssk); 235 236 if (!subflow->mp_capable) 237 return; 238 239 sk = subflow->conn; 240 msk = mptcp_sk(sk); 241 242 /* the socket is not connected yet, no msk/subflow ops can access/race 243 * accessing the field below 244 */ 245 WRITE_ONCE(msk->remote_key, subflow->remote_key); 246 WRITE_ONCE(msk->local_key, subflow->local_key); 247 } 248 249 static struct proto mptcp_prot = { 250 .name = "MPTCP", 251 .owner = THIS_MODULE, 252 .init = mptcp_init_sock, 253 .close = mptcp_close, 254 .accept = inet_csk_accept, 255 .shutdown = tcp_shutdown, 256 .sendmsg = mptcp_sendmsg, 257 .recvmsg = mptcp_recvmsg, 258 .hash = inet_hash, 259 .unhash = inet_unhash, 260 .get_port = mptcp_get_port, 261 .obj_size = sizeof(struct mptcp_sock), 262 .no_autobind = true, 263 }; 264 265 static int mptcp_bind(struct socket *sock, struct sockaddr *uaddr, int addr_len) 266 { 267 struct mptcp_sock *msk = mptcp_sk(sock->sk); 268 struct socket *ssock; 269 int err = -ENOTSUPP; 270 271 if (uaddr->sa_family != AF_INET) // @@ allow only IPv4 for now 272 return err; 273 274 lock_sock(sock->sk); 275 ssock = __mptcp_socket_create(msk, MPTCP_SAME_STATE); 276 if (IS_ERR(ssock)) { 277 err = PTR_ERR(ssock); 278 goto unlock; 279 } 280 281 err = ssock->ops->bind(ssock, uaddr, addr_len); 282 283 unlock: 284 release_sock(sock->sk); 285 return err; 286 } 287 288 static int mptcp_stream_connect(struct socket *sock, struct sockaddr *uaddr, 289 int addr_len, int flags) 290 { 291 struct mptcp_sock *msk = mptcp_sk(sock->sk); 292 struct socket *ssock; 293 int err; 294 295 lock_sock(sock->sk); 296 ssock = __mptcp_socket_create(msk, TCP_SYN_SENT); 297 if (IS_ERR(ssock)) { 298 err = PTR_ERR(ssock); 299 goto unlock; 300 } 301 302 err = ssock->ops->connect(ssock, uaddr, addr_len, flags); 303 inet_sk_state_store(sock->sk, inet_sk_state_load(ssock->sk)); 304 305 unlock: 306 release_sock(sock->sk); 307 return err; 308 } 309 310 static __poll_t mptcp_poll(struct file *file, struct socket *sock, 311 struct poll_table_struct *wait) 312 { 313 __poll_t mask = 0; 314 315 return mask; 316 } 317 318 static struct proto_ops mptcp_stream_ops; 319 320 static struct inet_protosw mptcp_protosw = { 321 .type = SOCK_STREAM, 322 .protocol = IPPROTO_MPTCP, 323 .prot = &mptcp_prot, 324 .ops = &mptcp_stream_ops, 325 .flags = INET_PROTOSW_ICSK, 326 }; 327 328 void __init mptcp_init(void) 329 { 330 mptcp_prot.h.hashinfo = tcp_prot.h.hashinfo; 331 mptcp_stream_ops = inet_stream_ops; 332 mptcp_stream_ops.bind = mptcp_bind; 333 mptcp_stream_ops.connect = mptcp_stream_connect; 334 mptcp_stream_ops.poll = mptcp_poll; 335 336 mptcp_subflow_init(); 337 338 if (proto_register(&mptcp_prot, 1) != 0) 339 panic("Failed to register MPTCP proto.\n"); 340 341 inet_register_protosw(&mptcp_protosw); 342 } 343 344 #if IS_ENABLED(CONFIG_MPTCP_IPV6) 345 static struct proto_ops mptcp_v6_stream_ops; 346 static struct proto mptcp_v6_prot; 347 348 static struct inet_protosw mptcp_v6_protosw = { 349 .type = SOCK_STREAM, 350 .protocol = IPPROTO_MPTCP, 351 .prot = &mptcp_v6_prot, 352 .ops = &mptcp_v6_stream_ops, 353 .flags = INET_PROTOSW_ICSK, 354 }; 355 356 int mptcpv6_init(void) 357 { 358 int err; 359 360 mptcp_v6_prot = mptcp_prot; 361 strcpy(mptcp_v6_prot.name, "MPTCPv6"); 362 mptcp_v6_prot.slab = NULL; 363 mptcp_v6_prot.obj_size = sizeof(struct mptcp_sock) + 364 sizeof(struct ipv6_pinfo); 365 366 err = proto_register(&mptcp_v6_prot, 1); 367 if (err) 368 return err; 369 370 mptcp_v6_stream_ops = inet6_stream_ops; 371 mptcp_v6_stream_ops.bind = mptcp_bind; 372 mptcp_v6_stream_ops.connect = mptcp_stream_connect; 373 mptcp_v6_stream_ops.poll = mptcp_poll; 374 375 err = inet6_register_protosw(&mptcp_v6_protosw); 376 if (err) 377 proto_unregister(&mptcp_v6_prot); 378 379 return err; 380 } 381 #endif 382