xref: /openbmc/linux/net/mptcp/protocol.c (revision cec37a6e41aae7bf3df9a3da783380a4d9325fd8)
1 // SPDX-License-Identifier: GPL-2.0
2 /* Multipath TCP
3  *
4  * Copyright (c) 2017 - 2019, Intel Corporation.
5  */
6 
7 #define pr_fmt(fmt) "MPTCP: " fmt
8 
9 #include <linux/kernel.h>
10 #include <linux/module.h>
11 #include <linux/netdevice.h>
12 #include <net/sock.h>
13 #include <net/inet_common.h>
14 #include <net/inet_hashtables.h>
15 #include <net/protocol.h>
16 #include <net/tcp.h>
17 #include <net/mptcp.h>
18 #include "protocol.h"
19 
20 #define MPTCP_SAME_STATE TCP_MAX_STATES
21 
22 /* If msk has an initial subflow socket, and the MP_CAPABLE handshake has not
23  * completed yet or has failed, return the subflow socket.
24  * Otherwise return NULL.
25  */
26 static struct socket *__mptcp_nmpc_socket(const struct mptcp_sock *msk)
27 {
28 	if (!msk->subflow || mptcp_subflow_ctx(msk->subflow->sk)->fourth_ack)
29 		return NULL;
30 
31 	return msk->subflow;
32 }
33 
34 /* if msk has a single subflow, and the mp_capable handshake is failed,
35  * return it.
36  * Otherwise returns NULL
37  */
38 static struct socket *__mptcp_tcp_fallback(const struct mptcp_sock *msk)
39 {
40 	struct socket *ssock = __mptcp_nmpc_socket(msk);
41 
42 	sock_owned_by_me((const struct sock *)msk);
43 
44 	if (!ssock || sk_is_mptcp(ssock->sk))
45 		return NULL;
46 
47 	return ssock;
48 }
49 
50 static bool __mptcp_can_create_subflow(const struct mptcp_sock *msk)
51 {
52 	return ((struct sock *)msk)->sk_state == TCP_CLOSE;
53 }
54 
55 static struct socket *__mptcp_socket_create(struct mptcp_sock *msk, int state)
56 {
57 	struct mptcp_subflow_context *subflow;
58 	struct sock *sk = (struct sock *)msk;
59 	struct socket *ssock;
60 	int err;
61 
62 	ssock = __mptcp_nmpc_socket(msk);
63 	if (ssock)
64 		goto set_state;
65 
66 	if (!__mptcp_can_create_subflow(msk))
67 		return ERR_PTR(-EINVAL);
68 
69 	err = mptcp_subflow_create_socket(sk, &ssock);
70 	if (err)
71 		return ERR_PTR(err);
72 
73 	msk->subflow = ssock;
74 	subflow = mptcp_subflow_ctx(ssock->sk);
75 	list_add(&subflow->node, &msk->conn_list);
76 	subflow->request_mptcp = 1;
77 
78 set_state:
79 	if (state != MPTCP_SAME_STATE)
80 		inet_sk_state_store(sk, state);
81 	return ssock;
82 }
83 
84 static struct sock *mptcp_subflow_get(const struct mptcp_sock *msk)
85 {
86 	struct mptcp_subflow_context *subflow;
87 
88 	sock_owned_by_me((const struct sock *)msk);
89 
90 	mptcp_for_each_subflow(msk, subflow) {
91 		return mptcp_subflow_tcp_sock(subflow);
92 	}
93 
94 	return NULL;
95 }
96 
97 static int mptcp_sendmsg(struct sock *sk, struct msghdr *msg, size_t len)
98 {
99 	struct mptcp_sock *msk = mptcp_sk(sk);
100 	struct socket *ssock;
101 	struct sock *ssk;
102 	int ret;
103 
104 	if (msg->msg_flags & ~(MSG_MORE | MSG_DONTWAIT | MSG_NOSIGNAL))
105 		return -EOPNOTSUPP;
106 
107 	lock_sock(sk);
108 	ssock = __mptcp_tcp_fallback(msk);
109 	if (ssock) {
110 		pr_debug("fallback passthrough");
111 		ret = sock_sendmsg(ssock, msg);
112 		release_sock(sk);
113 		return ret;
114 	}
115 
116 	ssk = mptcp_subflow_get(msk);
117 	if (!ssk) {
118 		release_sock(sk);
119 		return -ENOTCONN;
120 	}
121 
122 	ret = sock_sendmsg(ssk->sk_socket, msg);
123 
124 	release_sock(sk);
125 	return ret;
126 }
127 
128 static int mptcp_recvmsg(struct sock *sk, struct msghdr *msg, size_t len,
129 			 int nonblock, int flags, int *addr_len)
130 {
131 	struct mptcp_sock *msk = mptcp_sk(sk);
132 	struct socket *ssock;
133 	struct sock *ssk;
134 	int copied = 0;
135 
136 	if (msg->msg_flags & ~(MSG_WAITALL | MSG_DONTWAIT))
137 		return -EOPNOTSUPP;
138 
139 	lock_sock(sk);
140 	ssock = __mptcp_tcp_fallback(msk);
141 	if (ssock) {
142 		pr_debug("fallback-read subflow=%p",
143 			 mptcp_subflow_ctx(ssock->sk));
144 		copied = sock_recvmsg(ssock, msg, flags);
145 		release_sock(sk);
146 		return copied;
147 	}
148 
149 	ssk = mptcp_subflow_get(msk);
150 	if (!ssk) {
151 		release_sock(sk);
152 		return -ENOTCONN;
153 	}
154 
155 	copied = sock_recvmsg(ssk->sk_socket, msg, flags);
156 
157 	release_sock(sk);
158 
159 	return copied;
160 }
161 
162 /* subflow sockets can be either outgoing (connect) or incoming
163  * (accept).
164  *
165  * Outgoing subflows use in-kernel sockets.
166  * Incoming subflows do not have their own 'struct socket' allocated,
167  * so we need to use tcp_close() after detaching them from the mptcp
168  * parent socket.
169  */
170 static void __mptcp_close_ssk(struct sock *sk, struct sock *ssk,
171 			      struct mptcp_subflow_context *subflow,
172 			      long timeout)
173 {
174 	struct socket *sock = READ_ONCE(ssk->sk_socket);
175 
176 	list_del(&subflow->node);
177 
178 	if (sock && sock != sk->sk_socket) {
179 		/* outgoing subflow */
180 		sock_release(sock);
181 	} else {
182 		/* incoming subflow */
183 		tcp_close(ssk, timeout);
184 	}
185 }
186 
187 static int mptcp_init_sock(struct sock *sk)
188 {
189 	struct mptcp_sock *msk = mptcp_sk(sk);
190 
191 	INIT_LIST_HEAD(&msk->conn_list);
192 
193 	return 0;
194 }
195 
196 static void mptcp_close(struct sock *sk, long timeout)
197 {
198 	struct mptcp_subflow_context *subflow, *tmp;
199 	struct mptcp_sock *msk = mptcp_sk(sk);
200 
201 	inet_sk_state_store(sk, TCP_CLOSE);
202 
203 	lock_sock(sk);
204 
205 	list_for_each_entry_safe(subflow, tmp, &msk->conn_list, node) {
206 		struct sock *ssk = mptcp_subflow_tcp_sock(subflow);
207 
208 		__mptcp_close_ssk(sk, ssk, subflow, timeout);
209 	}
210 
211 	release_sock(sk);
212 	sk_common_release(sk);
213 }
214 
215 static int mptcp_get_port(struct sock *sk, unsigned short snum)
216 {
217 	struct mptcp_sock *msk = mptcp_sk(sk);
218 	struct socket *ssock;
219 
220 	ssock = __mptcp_nmpc_socket(msk);
221 	pr_debug("msk=%p, subflow=%p", msk, ssock);
222 	if (WARN_ON_ONCE(!ssock))
223 		return -EINVAL;
224 
225 	return inet_csk_get_port(ssock->sk, snum);
226 }
227 
228 void mptcp_finish_connect(struct sock *ssk)
229 {
230 	struct mptcp_subflow_context *subflow;
231 	struct mptcp_sock *msk;
232 	struct sock *sk;
233 
234 	subflow = mptcp_subflow_ctx(ssk);
235 
236 	if (!subflow->mp_capable)
237 		return;
238 
239 	sk = subflow->conn;
240 	msk = mptcp_sk(sk);
241 
242 	/* the socket is not connected yet, no msk/subflow ops can access/race
243 	 * accessing the field below
244 	 */
245 	WRITE_ONCE(msk->remote_key, subflow->remote_key);
246 	WRITE_ONCE(msk->local_key, subflow->local_key);
247 }
248 
249 static struct proto mptcp_prot = {
250 	.name		= "MPTCP",
251 	.owner		= THIS_MODULE,
252 	.init		= mptcp_init_sock,
253 	.close		= mptcp_close,
254 	.accept		= inet_csk_accept,
255 	.shutdown	= tcp_shutdown,
256 	.sendmsg	= mptcp_sendmsg,
257 	.recvmsg	= mptcp_recvmsg,
258 	.hash		= inet_hash,
259 	.unhash		= inet_unhash,
260 	.get_port	= mptcp_get_port,
261 	.obj_size	= sizeof(struct mptcp_sock),
262 	.no_autobind	= true,
263 };
264 
265 static int mptcp_bind(struct socket *sock, struct sockaddr *uaddr, int addr_len)
266 {
267 	struct mptcp_sock *msk = mptcp_sk(sock->sk);
268 	struct socket *ssock;
269 	int err = -ENOTSUPP;
270 
271 	if (uaddr->sa_family != AF_INET) // @@ allow only IPv4 for now
272 		return err;
273 
274 	lock_sock(sock->sk);
275 	ssock = __mptcp_socket_create(msk, MPTCP_SAME_STATE);
276 	if (IS_ERR(ssock)) {
277 		err = PTR_ERR(ssock);
278 		goto unlock;
279 	}
280 
281 	err = ssock->ops->bind(ssock, uaddr, addr_len);
282 
283 unlock:
284 	release_sock(sock->sk);
285 	return err;
286 }
287 
288 static int mptcp_stream_connect(struct socket *sock, struct sockaddr *uaddr,
289 				int addr_len, int flags)
290 {
291 	struct mptcp_sock *msk = mptcp_sk(sock->sk);
292 	struct socket *ssock;
293 	int err;
294 
295 	lock_sock(sock->sk);
296 	ssock = __mptcp_socket_create(msk, TCP_SYN_SENT);
297 	if (IS_ERR(ssock)) {
298 		err = PTR_ERR(ssock);
299 		goto unlock;
300 	}
301 
302 	err = ssock->ops->connect(ssock, uaddr, addr_len, flags);
303 	inet_sk_state_store(sock->sk, inet_sk_state_load(ssock->sk));
304 
305 unlock:
306 	release_sock(sock->sk);
307 	return err;
308 }
309 
310 static __poll_t mptcp_poll(struct file *file, struct socket *sock,
311 			   struct poll_table_struct *wait)
312 {
313 	__poll_t mask = 0;
314 
315 	return mask;
316 }
317 
318 static struct proto_ops mptcp_stream_ops;
319 
320 static struct inet_protosw mptcp_protosw = {
321 	.type		= SOCK_STREAM,
322 	.protocol	= IPPROTO_MPTCP,
323 	.prot		= &mptcp_prot,
324 	.ops		= &mptcp_stream_ops,
325 	.flags		= INET_PROTOSW_ICSK,
326 };
327 
328 void __init mptcp_init(void)
329 {
330 	mptcp_prot.h.hashinfo = tcp_prot.h.hashinfo;
331 	mptcp_stream_ops = inet_stream_ops;
332 	mptcp_stream_ops.bind = mptcp_bind;
333 	mptcp_stream_ops.connect = mptcp_stream_connect;
334 	mptcp_stream_ops.poll = mptcp_poll;
335 
336 	mptcp_subflow_init();
337 
338 	if (proto_register(&mptcp_prot, 1) != 0)
339 		panic("Failed to register MPTCP proto.\n");
340 
341 	inet_register_protosw(&mptcp_protosw);
342 }
343 
344 #if IS_ENABLED(CONFIG_MPTCP_IPV6)
345 static struct proto_ops mptcp_v6_stream_ops;
346 static struct proto mptcp_v6_prot;
347 
348 static struct inet_protosw mptcp_v6_protosw = {
349 	.type		= SOCK_STREAM,
350 	.protocol	= IPPROTO_MPTCP,
351 	.prot		= &mptcp_v6_prot,
352 	.ops		= &mptcp_v6_stream_ops,
353 	.flags		= INET_PROTOSW_ICSK,
354 };
355 
356 int mptcpv6_init(void)
357 {
358 	int err;
359 
360 	mptcp_v6_prot = mptcp_prot;
361 	strcpy(mptcp_v6_prot.name, "MPTCPv6");
362 	mptcp_v6_prot.slab = NULL;
363 	mptcp_v6_prot.obj_size = sizeof(struct mptcp_sock) +
364 				 sizeof(struct ipv6_pinfo);
365 
366 	err = proto_register(&mptcp_v6_prot, 1);
367 	if (err)
368 		return err;
369 
370 	mptcp_v6_stream_ops = inet6_stream_ops;
371 	mptcp_v6_stream_ops.bind = mptcp_bind;
372 	mptcp_v6_stream_ops.connect = mptcp_stream_connect;
373 	mptcp_v6_stream_ops.poll = mptcp_poll;
374 
375 	err = inet6_register_protosw(&mptcp_v6_protosw);
376 	if (err)
377 		proto_unregister(&mptcp_v6_prot);
378 
379 	return err;
380 }
381 #endif
382