xref: /openbmc/linux/net/mptcp/protocol.c (revision cf7da0d66cc1a2a19fc5930bb746ffbb2d4cd1be)
1 // SPDX-License-Identifier: GPL-2.0
2 /* Multipath TCP
3  *
4  * Copyright (c) 2017 - 2019, Intel Corporation.
5  */
6 
7 #define pr_fmt(fmt) "MPTCP: " fmt
8 
9 #include <linux/kernel.h>
10 #include <linux/module.h>
11 #include <linux/netdevice.h>
12 #include <net/sock.h>
13 #include <net/inet_common.h>
14 #include <net/inet_hashtables.h>
15 #include <net/protocol.h>
16 #include <net/tcp.h>
17 #if IS_ENABLED(CONFIG_MPTCP_IPV6)
18 #include <net/transp_v6.h>
19 #endif
20 #include <net/mptcp.h>
21 #include "protocol.h"
22 
23 #define MPTCP_SAME_STATE TCP_MAX_STATES
24 
25 /* If msk has an initial subflow socket, and the MP_CAPABLE handshake has not
26  * completed yet or has failed, return the subflow socket.
27  * Otherwise return NULL.
28  */
29 static struct socket *__mptcp_nmpc_socket(const struct mptcp_sock *msk)
30 {
31 	if (!msk->subflow || mptcp_subflow_ctx(msk->subflow->sk)->fourth_ack)
32 		return NULL;
33 
34 	return msk->subflow;
35 }
36 
37 /* if msk has a single subflow, and the mp_capable handshake is failed,
38  * return it.
39  * Otherwise returns NULL
40  */
41 static struct socket *__mptcp_tcp_fallback(const struct mptcp_sock *msk)
42 {
43 	struct socket *ssock = __mptcp_nmpc_socket(msk);
44 
45 	sock_owned_by_me((const struct sock *)msk);
46 
47 	if (!ssock || sk_is_mptcp(ssock->sk))
48 		return NULL;
49 
50 	return ssock;
51 }
52 
53 static bool __mptcp_can_create_subflow(const struct mptcp_sock *msk)
54 {
55 	return ((struct sock *)msk)->sk_state == TCP_CLOSE;
56 }
57 
58 static struct socket *__mptcp_socket_create(struct mptcp_sock *msk, int state)
59 {
60 	struct mptcp_subflow_context *subflow;
61 	struct sock *sk = (struct sock *)msk;
62 	struct socket *ssock;
63 	int err;
64 
65 	ssock = __mptcp_nmpc_socket(msk);
66 	if (ssock)
67 		goto set_state;
68 
69 	if (!__mptcp_can_create_subflow(msk))
70 		return ERR_PTR(-EINVAL);
71 
72 	err = mptcp_subflow_create_socket(sk, &ssock);
73 	if (err)
74 		return ERR_PTR(err);
75 
76 	msk->subflow = ssock;
77 	subflow = mptcp_subflow_ctx(ssock->sk);
78 	list_add(&subflow->node, &msk->conn_list);
79 	subflow->request_mptcp = 1;
80 
81 set_state:
82 	if (state != MPTCP_SAME_STATE)
83 		inet_sk_state_store(sk, state);
84 	return ssock;
85 }
86 
87 static struct sock *mptcp_subflow_get(const struct mptcp_sock *msk)
88 {
89 	struct mptcp_subflow_context *subflow;
90 
91 	sock_owned_by_me((const struct sock *)msk);
92 
93 	mptcp_for_each_subflow(msk, subflow) {
94 		return mptcp_subflow_tcp_sock(subflow);
95 	}
96 
97 	return NULL;
98 }
99 
100 static int mptcp_sendmsg(struct sock *sk, struct msghdr *msg, size_t len)
101 {
102 	struct mptcp_sock *msk = mptcp_sk(sk);
103 	struct socket *ssock;
104 	struct sock *ssk;
105 	int ret;
106 
107 	if (msg->msg_flags & ~(MSG_MORE | MSG_DONTWAIT | MSG_NOSIGNAL))
108 		return -EOPNOTSUPP;
109 
110 	lock_sock(sk);
111 	ssock = __mptcp_tcp_fallback(msk);
112 	if (ssock) {
113 		pr_debug("fallback passthrough");
114 		ret = sock_sendmsg(ssock, msg);
115 		release_sock(sk);
116 		return ret;
117 	}
118 
119 	ssk = mptcp_subflow_get(msk);
120 	if (!ssk) {
121 		release_sock(sk);
122 		return -ENOTCONN;
123 	}
124 
125 	ret = sock_sendmsg(ssk->sk_socket, msg);
126 
127 	release_sock(sk);
128 	return ret;
129 }
130 
131 static int mptcp_recvmsg(struct sock *sk, struct msghdr *msg, size_t len,
132 			 int nonblock, int flags, int *addr_len)
133 {
134 	struct mptcp_sock *msk = mptcp_sk(sk);
135 	struct socket *ssock;
136 	struct sock *ssk;
137 	int copied = 0;
138 
139 	if (msg->msg_flags & ~(MSG_WAITALL | MSG_DONTWAIT))
140 		return -EOPNOTSUPP;
141 
142 	lock_sock(sk);
143 	ssock = __mptcp_tcp_fallback(msk);
144 	if (ssock) {
145 		pr_debug("fallback-read subflow=%p",
146 			 mptcp_subflow_ctx(ssock->sk));
147 		copied = sock_recvmsg(ssock, msg, flags);
148 		release_sock(sk);
149 		return copied;
150 	}
151 
152 	ssk = mptcp_subflow_get(msk);
153 	if (!ssk) {
154 		release_sock(sk);
155 		return -ENOTCONN;
156 	}
157 
158 	copied = sock_recvmsg(ssk->sk_socket, msg, flags);
159 
160 	release_sock(sk);
161 
162 	return copied;
163 }
164 
165 /* subflow sockets can be either outgoing (connect) or incoming
166  * (accept).
167  *
168  * Outgoing subflows use in-kernel sockets.
169  * Incoming subflows do not have their own 'struct socket' allocated,
170  * so we need to use tcp_close() after detaching them from the mptcp
171  * parent socket.
172  */
173 static void __mptcp_close_ssk(struct sock *sk, struct sock *ssk,
174 			      struct mptcp_subflow_context *subflow,
175 			      long timeout)
176 {
177 	struct socket *sock = READ_ONCE(ssk->sk_socket);
178 
179 	list_del(&subflow->node);
180 
181 	if (sock && sock != sk->sk_socket) {
182 		/* outgoing subflow */
183 		sock_release(sock);
184 	} else {
185 		/* incoming subflow */
186 		tcp_close(ssk, timeout);
187 	}
188 }
189 
190 static int mptcp_init_sock(struct sock *sk)
191 {
192 	struct mptcp_sock *msk = mptcp_sk(sk);
193 
194 	INIT_LIST_HEAD(&msk->conn_list);
195 
196 	return 0;
197 }
198 
199 static void mptcp_close(struct sock *sk, long timeout)
200 {
201 	struct mptcp_subflow_context *subflow, *tmp;
202 	struct mptcp_sock *msk = mptcp_sk(sk);
203 
204 	inet_sk_state_store(sk, TCP_CLOSE);
205 
206 	lock_sock(sk);
207 
208 	list_for_each_entry_safe(subflow, tmp, &msk->conn_list, node) {
209 		struct sock *ssk = mptcp_subflow_tcp_sock(subflow);
210 
211 		__mptcp_close_ssk(sk, ssk, subflow, timeout);
212 	}
213 
214 	release_sock(sk);
215 	sk_common_release(sk);
216 }
217 
218 static void mptcp_copy_inaddrs(struct sock *msk, const struct sock *ssk)
219 {
220 #if IS_ENABLED(CONFIG_MPTCP_IPV6)
221 	const struct ipv6_pinfo *ssk6 = inet6_sk(ssk);
222 	struct ipv6_pinfo *msk6 = inet6_sk(msk);
223 
224 	msk->sk_v6_daddr = ssk->sk_v6_daddr;
225 	msk->sk_v6_rcv_saddr = ssk->sk_v6_rcv_saddr;
226 
227 	if (msk6 && ssk6) {
228 		msk6->saddr = ssk6->saddr;
229 		msk6->flow_label = ssk6->flow_label;
230 	}
231 #endif
232 
233 	inet_sk(msk)->inet_num = inet_sk(ssk)->inet_num;
234 	inet_sk(msk)->inet_dport = inet_sk(ssk)->inet_dport;
235 	inet_sk(msk)->inet_sport = inet_sk(ssk)->inet_sport;
236 	inet_sk(msk)->inet_daddr = inet_sk(ssk)->inet_daddr;
237 	inet_sk(msk)->inet_saddr = inet_sk(ssk)->inet_saddr;
238 	inet_sk(msk)->inet_rcv_saddr = inet_sk(ssk)->inet_rcv_saddr;
239 }
240 
241 static struct sock *mptcp_accept(struct sock *sk, int flags, int *err,
242 				 bool kern)
243 {
244 	struct mptcp_sock *msk = mptcp_sk(sk);
245 	struct socket *listener;
246 	struct sock *newsk;
247 
248 	listener = __mptcp_nmpc_socket(msk);
249 	if (WARN_ON_ONCE(!listener)) {
250 		*err = -EINVAL;
251 		return NULL;
252 	}
253 
254 	pr_debug("msk=%p, listener=%p", msk, mptcp_subflow_ctx(listener->sk));
255 	newsk = inet_csk_accept(listener->sk, flags, err, kern);
256 	if (!newsk)
257 		return NULL;
258 
259 	pr_debug("msk=%p, subflow is mptcp=%d", msk, sk_is_mptcp(newsk));
260 
261 	if (sk_is_mptcp(newsk)) {
262 		struct mptcp_subflow_context *subflow;
263 		struct sock *new_mptcp_sock;
264 		struct sock *ssk = newsk;
265 
266 		subflow = mptcp_subflow_ctx(newsk);
267 		lock_sock(sk);
268 
269 		local_bh_disable();
270 		new_mptcp_sock = sk_clone_lock(sk, GFP_ATOMIC);
271 		if (!new_mptcp_sock) {
272 			*err = -ENOBUFS;
273 			local_bh_enable();
274 			release_sock(sk);
275 			tcp_close(newsk, 0);
276 			return NULL;
277 		}
278 
279 		mptcp_init_sock(new_mptcp_sock);
280 
281 		msk = mptcp_sk(new_mptcp_sock);
282 		msk->remote_key = subflow->remote_key;
283 		msk->local_key = subflow->local_key;
284 		msk->subflow = NULL;
285 
286 		newsk = new_mptcp_sock;
287 		mptcp_copy_inaddrs(newsk, ssk);
288 		list_add(&subflow->node, &msk->conn_list);
289 
290 		/* will be fully established at mptcp_stream_accept()
291 		 * completion.
292 		 */
293 		inet_sk_state_store(new_mptcp_sock, TCP_SYN_RECV);
294 		bh_unlock_sock(new_mptcp_sock);
295 		local_bh_enable();
296 		release_sock(sk);
297 	}
298 
299 	return newsk;
300 }
301 
302 static int mptcp_get_port(struct sock *sk, unsigned short snum)
303 {
304 	struct mptcp_sock *msk = mptcp_sk(sk);
305 	struct socket *ssock;
306 
307 	ssock = __mptcp_nmpc_socket(msk);
308 	pr_debug("msk=%p, subflow=%p", msk, ssock);
309 	if (WARN_ON_ONCE(!ssock))
310 		return -EINVAL;
311 
312 	return inet_csk_get_port(ssock->sk, snum);
313 }
314 
315 void mptcp_finish_connect(struct sock *ssk)
316 {
317 	struct mptcp_subflow_context *subflow;
318 	struct mptcp_sock *msk;
319 	struct sock *sk;
320 
321 	subflow = mptcp_subflow_ctx(ssk);
322 
323 	if (!subflow->mp_capable)
324 		return;
325 
326 	sk = subflow->conn;
327 	msk = mptcp_sk(sk);
328 
329 	/* the socket is not connected yet, no msk/subflow ops can access/race
330 	 * accessing the field below
331 	 */
332 	WRITE_ONCE(msk->remote_key, subflow->remote_key);
333 	WRITE_ONCE(msk->local_key, subflow->local_key);
334 }
335 
336 static void mptcp_sock_graft(struct sock *sk, struct socket *parent)
337 {
338 	write_lock_bh(&sk->sk_callback_lock);
339 	rcu_assign_pointer(sk->sk_wq, &parent->wq);
340 	sk_set_socket(sk, parent);
341 	sk->sk_uid = SOCK_INODE(parent)->i_uid;
342 	write_unlock_bh(&sk->sk_callback_lock);
343 }
344 
345 static struct proto mptcp_prot = {
346 	.name		= "MPTCP",
347 	.owner		= THIS_MODULE,
348 	.init		= mptcp_init_sock,
349 	.close		= mptcp_close,
350 	.accept		= mptcp_accept,
351 	.shutdown	= tcp_shutdown,
352 	.sendmsg	= mptcp_sendmsg,
353 	.recvmsg	= mptcp_recvmsg,
354 	.hash		= inet_hash,
355 	.unhash		= inet_unhash,
356 	.get_port	= mptcp_get_port,
357 	.obj_size	= sizeof(struct mptcp_sock),
358 	.no_autobind	= true,
359 };
360 
361 static int mptcp_bind(struct socket *sock, struct sockaddr *uaddr, int addr_len)
362 {
363 	struct mptcp_sock *msk = mptcp_sk(sock->sk);
364 	struct socket *ssock;
365 	int err;
366 
367 	lock_sock(sock->sk);
368 	ssock = __mptcp_socket_create(msk, MPTCP_SAME_STATE);
369 	if (IS_ERR(ssock)) {
370 		err = PTR_ERR(ssock);
371 		goto unlock;
372 	}
373 
374 	err = ssock->ops->bind(ssock, uaddr, addr_len);
375 	if (!err)
376 		mptcp_copy_inaddrs(sock->sk, ssock->sk);
377 
378 unlock:
379 	release_sock(sock->sk);
380 	return err;
381 }
382 
383 static int mptcp_stream_connect(struct socket *sock, struct sockaddr *uaddr,
384 				int addr_len, int flags)
385 {
386 	struct mptcp_sock *msk = mptcp_sk(sock->sk);
387 	struct socket *ssock;
388 	int err;
389 
390 	lock_sock(sock->sk);
391 	ssock = __mptcp_socket_create(msk, TCP_SYN_SENT);
392 	if (IS_ERR(ssock)) {
393 		err = PTR_ERR(ssock);
394 		goto unlock;
395 	}
396 
397 #ifdef CONFIG_TCP_MD5SIG
398 	/* no MPTCP if MD5SIG is enabled on this socket or we may run out of
399 	 * TCP option space.
400 	 */
401 	if (rcu_access_pointer(tcp_sk(ssock->sk)->md5sig_info))
402 		mptcp_subflow_ctx(ssock->sk)->request_mptcp = 0;
403 #endif
404 
405 	err = ssock->ops->connect(ssock, uaddr, addr_len, flags);
406 	inet_sk_state_store(sock->sk, inet_sk_state_load(ssock->sk));
407 	mptcp_copy_inaddrs(sock->sk, ssock->sk);
408 
409 unlock:
410 	release_sock(sock->sk);
411 	return err;
412 }
413 
414 static int mptcp_v4_getname(struct socket *sock, struct sockaddr *uaddr,
415 			    int peer)
416 {
417 	if (sock->sk->sk_prot == &tcp_prot) {
418 		/* we are being invoked from __sys_accept4, after
419 		 * mptcp_accept() has just accepted a non-mp-capable
420 		 * flow: sk is a tcp_sk, not an mptcp one.
421 		 *
422 		 * Hand the socket over to tcp so all further socket ops
423 		 * bypass mptcp.
424 		 */
425 		sock->ops = &inet_stream_ops;
426 	}
427 
428 	return inet_getname(sock, uaddr, peer);
429 }
430 
431 #if IS_ENABLED(CONFIG_MPTCP_IPV6)
432 static int mptcp_v6_getname(struct socket *sock, struct sockaddr *uaddr,
433 			    int peer)
434 {
435 	if (sock->sk->sk_prot == &tcpv6_prot) {
436 		/* we are being invoked from __sys_accept4 after
437 		 * mptcp_accept() has accepted a non-mp-capable
438 		 * subflow: sk is a tcp_sk, not mptcp.
439 		 *
440 		 * Hand the socket over to tcp so all further
441 		 * socket ops bypass mptcp.
442 		 */
443 		sock->ops = &inet6_stream_ops;
444 	}
445 
446 	return inet6_getname(sock, uaddr, peer);
447 }
448 #endif
449 
450 static int mptcp_listen(struct socket *sock, int backlog)
451 {
452 	struct mptcp_sock *msk = mptcp_sk(sock->sk);
453 	struct socket *ssock;
454 	int err;
455 
456 	pr_debug("msk=%p", msk);
457 
458 	lock_sock(sock->sk);
459 	ssock = __mptcp_socket_create(msk, TCP_LISTEN);
460 	if (IS_ERR(ssock)) {
461 		err = PTR_ERR(ssock);
462 		goto unlock;
463 	}
464 
465 	err = ssock->ops->listen(ssock, backlog);
466 	inet_sk_state_store(sock->sk, inet_sk_state_load(ssock->sk));
467 	if (!err)
468 		mptcp_copy_inaddrs(sock->sk, ssock->sk);
469 
470 unlock:
471 	release_sock(sock->sk);
472 	return err;
473 }
474 
475 static bool is_tcp_proto(const struct proto *p)
476 {
477 #if IS_ENABLED(CONFIG_MPTCP_IPV6)
478 	return p == &tcp_prot || p == &tcpv6_prot;
479 #else
480 	return p == &tcp_prot;
481 #endif
482 }
483 
484 static int mptcp_stream_accept(struct socket *sock, struct socket *newsock,
485 			       int flags, bool kern)
486 {
487 	struct mptcp_sock *msk = mptcp_sk(sock->sk);
488 	struct socket *ssock;
489 	int err;
490 
491 	pr_debug("msk=%p", msk);
492 
493 	lock_sock(sock->sk);
494 	if (sock->sk->sk_state != TCP_LISTEN)
495 		goto unlock_fail;
496 
497 	ssock = __mptcp_nmpc_socket(msk);
498 	if (!ssock)
499 		goto unlock_fail;
500 
501 	sock_hold(ssock->sk);
502 	release_sock(sock->sk);
503 
504 	err = ssock->ops->accept(sock, newsock, flags, kern);
505 	if (err == 0 && !is_tcp_proto(newsock->sk->sk_prot)) {
506 		struct mptcp_sock *msk = mptcp_sk(newsock->sk);
507 		struct mptcp_subflow_context *subflow;
508 
509 		/* set ssk->sk_socket of accept()ed flows to mptcp socket.
510 		 * This is needed so NOSPACE flag can be set from tcp stack.
511 		 */
512 		list_for_each_entry(subflow, &msk->conn_list, node) {
513 			struct sock *ssk = mptcp_subflow_tcp_sock(subflow);
514 
515 			if (!ssk->sk_socket)
516 				mptcp_sock_graft(ssk, newsock);
517 		}
518 
519 		inet_sk_state_store(newsock->sk, TCP_ESTABLISHED);
520 	}
521 
522 	sock_put(ssock->sk);
523 	return err;
524 
525 unlock_fail:
526 	release_sock(sock->sk);
527 	return -EINVAL;
528 }
529 
530 static __poll_t mptcp_poll(struct file *file, struct socket *sock,
531 			   struct poll_table_struct *wait)
532 {
533 	__poll_t mask = 0;
534 
535 	return mask;
536 }
537 
538 static struct proto_ops mptcp_stream_ops;
539 
540 static struct inet_protosw mptcp_protosw = {
541 	.type		= SOCK_STREAM,
542 	.protocol	= IPPROTO_MPTCP,
543 	.prot		= &mptcp_prot,
544 	.ops		= &mptcp_stream_ops,
545 	.flags		= INET_PROTOSW_ICSK,
546 };
547 
548 void __init mptcp_init(void)
549 {
550 	mptcp_prot.h.hashinfo = tcp_prot.h.hashinfo;
551 	mptcp_stream_ops = inet_stream_ops;
552 	mptcp_stream_ops.bind = mptcp_bind;
553 	mptcp_stream_ops.connect = mptcp_stream_connect;
554 	mptcp_stream_ops.poll = mptcp_poll;
555 	mptcp_stream_ops.accept = mptcp_stream_accept;
556 	mptcp_stream_ops.getname = mptcp_v4_getname;
557 	mptcp_stream_ops.listen = mptcp_listen;
558 
559 	mptcp_subflow_init();
560 
561 	if (proto_register(&mptcp_prot, 1) != 0)
562 		panic("Failed to register MPTCP proto.\n");
563 
564 	inet_register_protosw(&mptcp_protosw);
565 }
566 
567 #if IS_ENABLED(CONFIG_MPTCP_IPV6)
568 static struct proto_ops mptcp_v6_stream_ops;
569 static struct proto mptcp_v6_prot;
570 
571 static struct inet_protosw mptcp_v6_protosw = {
572 	.type		= SOCK_STREAM,
573 	.protocol	= IPPROTO_MPTCP,
574 	.prot		= &mptcp_v6_prot,
575 	.ops		= &mptcp_v6_stream_ops,
576 	.flags		= INET_PROTOSW_ICSK,
577 };
578 
579 int mptcpv6_init(void)
580 {
581 	int err;
582 
583 	mptcp_v6_prot = mptcp_prot;
584 	strcpy(mptcp_v6_prot.name, "MPTCPv6");
585 	mptcp_v6_prot.slab = NULL;
586 	mptcp_v6_prot.obj_size = sizeof(struct mptcp_sock) +
587 				 sizeof(struct ipv6_pinfo);
588 
589 	err = proto_register(&mptcp_v6_prot, 1);
590 	if (err)
591 		return err;
592 
593 	mptcp_v6_stream_ops = inet6_stream_ops;
594 	mptcp_v6_stream_ops.bind = mptcp_bind;
595 	mptcp_v6_stream_ops.connect = mptcp_stream_connect;
596 	mptcp_v6_stream_ops.poll = mptcp_poll;
597 	mptcp_v6_stream_ops.accept = mptcp_stream_accept;
598 	mptcp_v6_stream_ops.getname = mptcp_v6_getname;
599 	mptcp_v6_stream_ops.listen = mptcp_listen;
600 
601 	err = inet6_register_protosw(&mptcp_v6_protosw);
602 	if (err)
603 		proto_unregister(&mptcp_v6_prot);
604 
605 	return err;
606 }
607 #endif
608