xref: /openbmc/linux/net/mptcp/pm_userspace.c (revision 48ca54e3)
1 // SPDX-License-Identifier: GPL-2.0
2 /* Multipath TCP
3  *
4  * Copyright (c) 2022, Intel Corporation.
5  */
6 
7 #include "protocol.h"
8 
9 void mptcp_free_local_addr_list(struct mptcp_sock *msk)
10 {
11 	struct mptcp_pm_addr_entry *entry, *tmp;
12 	struct sock *sk = (struct sock *)msk;
13 	LIST_HEAD(free_list);
14 
15 	if (!mptcp_pm_is_userspace(msk))
16 		return;
17 
18 	spin_lock_bh(&msk->pm.lock);
19 	list_splice_init(&msk->pm.userspace_pm_local_addr_list, &free_list);
20 	spin_unlock_bh(&msk->pm.lock);
21 
22 	list_for_each_entry_safe(entry, tmp, &free_list, list) {
23 		sock_kfree_s(sk, entry, sizeof(*entry));
24 	}
25 }
26 
27 int mptcp_userspace_pm_append_new_local_addr(struct mptcp_sock *msk,
28 					     struct mptcp_pm_addr_entry *entry)
29 {
30 	DECLARE_BITMAP(id_bitmap, MPTCP_PM_MAX_ADDR_ID + 1);
31 	struct mptcp_pm_addr_entry *match = NULL;
32 	struct sock *sk = (struct sock *)msk;
33 	struct mptcp_pm_addr_entry *e;
34 	bool addr_match = false;
35 	bool id_match = false;
36 	int ret = -EINVAL;
37 
38 	bitmap_zero(id_bitmap, MPTCP_PM_MAX_ADDR_ID + 1);
39 
40 	spin_lock_bh(&msk->pm.lock);
41 	list_for_each_entry(e, &msk->pm.userspace_pm_local_addr_list, list) {
42 		addr_match = mptcp_addresses_equal(&e->addr, &entry->addr, true);
43 		if (addr_match && entry->addr.id == 0)
44 			entry->addr.id = e->addr.id;
45 		id_match = (e->addr.id == entry->addr.id);
46 		if (addr_match && id_match) {
47 			match = e;
48 			break;
49 		} else if (addr_match || id_match) {
50 			break;
51 		}
52 		__set_bit(e->addr.id, id_bitmap);
53 	}
54 
55 	if (!match && !addr_match && !id_match) {
56 		/* Memory for the entry is allocated from the
57 		 * sock option buffer.
58 		 */
59 		e = sock_kmalloc(sk, sizeof(*e), GFP_ATOMIC);
60 		if (!e) {
61 			spin_unlock_bh(&msk->pm.lock);
62 			return -ENOMEM;
63 		}
64 
65 		*e = *entry;
66 		if (!e->addr.id)
67 			e->addr.id = find_next_zero_bit(id_bitmap,
68 							MPTCP_PM_MAX_ADDR_ID + 1,
69 							1);
70 		list_add_tail_rcu(&e->list, &msk->pm.userspace_pm_local_addr_list);
71 		ret = e->addr.id;
72 	} else if (match) {
73 		ret = entry->addr.id;
74 	}
75 
76 	spin_unlock_bh(&msk->pm.lock);
77 	return ret;
78 }
79 
80 int mptcp_userspace_pm_get_flags_and_ifindex_by_id(struct mptcp_sock *msk,
81 						   unsigned int id,
82 						   u8 *flags, int *ifindex)
83 {
84 	struct mptcp_pm_addr_entry *entry, *match = NULL;
85 
86 	*flags = 0;
87 	*ifindex = 0;
88 
89 	spin_lock_bh(&msk->pm.lock);
90 	list_for_each_entry(entry, &msk->pm.userspace_pm_local_addr_list, list) {
91 		if (id == entry->addr.id) {
92 			match = entry;
93 			break;
94 		}
95 	}
96 	spin_unlock_bh(&msk->pm.lock);
97 	if (match) {
98 		*flags = match->flags;
99 		*ifindex = match->ifindex;
100 	}
101 
102 	return 0;
103 }
104 
105 int mptcp_userspace_pm_get_local_id(struct mptcp_sock *msk,
106 				    struct mptcp_addr_info *skc)
107 {
108 	struct mptcp_pm_addr_entry new_entry;
109 	__be16 msk_sport =  ((struct inet_sock *)
110 			     inet_sk((struct sock *)msk))->inet_sport;
111 
112 	memset(&new_entry, 0, sizeof(struct mptcp_pm_addr_entry));
113 	new_entry.addr = *skc;
114 	new_entry.addr.id = 0;
115 	new_entry.flags = MPTCP_PM_ADDR_FLAG_IMPLICIT;
116 
117 	if (new_entry.addr.port == msk_sport)
118 		new_entry.addr.port = 0;
119 
120 	return mptcp_userspace_pm_append_new_local_addr(msk, &new_entry);
121 }
122 
123 int mptcp_nl_cmd_announce(struct sk_buff *skb, struct genl_info *info)
124 {
125 	struct nlattr *token = info->attrs[MPTCP_PM_ATTR_TOKEN];
126 	struct nlattr *addr = info->attrs[MPTCP_PM_ATTR_ADDR];
127 	struct mptcp_pm_addr_entry addr_val;
128 	struct mptcp_sock *msk;
129 	int err = -EINVAL;
130 	u32 token_val;
131 
132 	if (!addr || !token) {
133 		GENL_SET_ERR_MSG(info, "missing required inputs");
134 		return err;
135 	}
136 
137 	token_val = nla_get_u32(token);
138 
139 	msk = mptcp_token_get_sock(sock_net(skb->sk), token_val);
140 	if (!msk) {
141 		NL_SET_ERR_MSG_ATTR(info->extack, token, "invalid token");
142 		return err;
143 	}
144 
145 	if (!mptcp_pm_is_userspace(msk)) {
146 		GENL_SET_ERR_MSG(info, "invalid request; userspace PM not selected");
147 		goto announce_err;
148 	}
149 
150 	err = mptcp_pm_parse_entry(addr, info, true, &addr_val);
151 	if (err < 0) {
152 		GENL_SET_ERR_MSG(info, "error parsing local address");
153 		goto announce_err;
154 	}
155 
156 	if (addr_val.addr.id == 0 || !(addr_val.flags & MPTCP_PM_ADDR_FLAG_SIGNAL)) {
157 		GENL_SET_ERR_MSG(info, "invalid addr id or flags");
158 		goto announce_err;
159 	}
160 
161 	err = mptcp_userspace_pm_append_new_local_addr(msk, &addr_val);
162 	if (err < 0) {
163 		GENL_SET_ERR_MSG(info, "did not match address and id");
164 		goto announce_err;
165 	}
166 
167 	lock_sock((struct sock *)msk);
168 	spin_lock_bh(&msk->pm.lock);
169 
170 	if (mptcp_pm_alloc_anno_list(msk, &addr_val)) {
171 		mptcp_pm_announce_addr(msk, &addr_val.addr, false);
172 		mptcp_pm_nl_addr_send_ack(msk);
173 	}
174 
175 	spin_unlock_bh(&msk->pm.lock);
176 	release_sock((struct sock *)msk);
177 
178 	err = 0;
179  announce_err:
180 	sock_put((struct sock *)msk);
181 	return err;
182 }
183 
184 int mptcp_nl_cmd_remove(struct sk_buff *skb, struct genl_info *info)
185 {
186 	struct nlattr *token = info->attrs[MPTCP_PM_ATTR_TOKEN];
187 	struct nlattr *id = info->attrs[MPTCP_PM_ATTR_LOC_ID];
188 	struct mptcp_pm_addr_entry *match = NULL;
189 	struct mptcp_pm_addr_entry *entry;
190 	struct mptcp_sock *msk;
191 	LIST_HEAD(free_list);
192 	int err = -EINVAL;
193 	u32 token_val;
194 	u8 id_val;
195 
196 	if (!id || !token) {
197 		GENL_SET_ERR_MSG(info, "missing required inputs");
198 		return err;
199 	}
200 
201 	id_val = nla_get_u8(id);
202 	token_val = nla_get_u32(token);
203 
204 	msk = mptcp_token_get_sock(sock_net(skb->sk), token_val);
205 	if (!msk) {
206 		NL_SET_ERR_MSG_ATTR(info->extack, token, "invalid token");
207 		return err;
208 	}
209 
210 	if (!mptcp_pm_is_userspace(msk)) {
211 		GENL_SET_ERR_MSG(info, "invalid request; userspace PM not selected");
212 		goto remove_err;
213 	}
214 
215 	lock_sock((struct sock *)msk);
216 
217 	list_for_each_entry(entry, &msk->pm.userspace_pm_local_addr_list, list) {
218 		if (entry->addr.id == id_val) {
219 			match = entry;
220 			break;
221 		}
222 	}
223 
224 	if (!match) {
225 		GENL_SET_ERR_MSG(info, "address with specified id not found");
226 		release_sock((struct sock *)msk);
227 		goto remove_err;
228 	}
229 
230 	list_move(&match->list, &free_list);
231 
232 	mptcp_pm_remove_addrs_and_subflows(msk, &free_list);
233 
234 	release_sock((struct sock *)msk);
235 
236 	list_for_each_entry_safe(match, entry, &free_list, list) {
237 		sock_kfree_s((struct sock *)msk, match, sizeof(*match));
238 	}
239 
240 	err = 0;
241  remove_err:
242 	sock_put((struct sock *)msk);
243 	return err;
244 }
245 
246 int mptcp_nl_cmd_sf_create(struct sk_buff *skb, struct genl_info *info)
247 {
248 	struct nlattr *raddr = info->attrs[MPTCP_PM_ATTR_ADDR_REMOTE];
249 	struct nlattr *token = info->attrs[MPTCP_PM_ATTR_TOKEN];
250 	struct nlattr *laddr = info->attrs[MPTCP_PM_ATTR_ADDR];
251 	struct mptcp_addr_info addr_r;
252 	struct mptcp_addr_info addr_l;
253 	struct mptcp_sock *msk;
254 	int err = -EINVAL;
255 	struct sock *sk;
256 	u32 token_val;
257 
258 	if (!laddr || !raddr || !token) {
259 		GENL_SET_ERR_MSG(info, "missing required inputs");
260 		return err;
261 	}
262 
263 	token_val = nla_get_u32(token);
264 
265 	msk = mptcp_token_get_sock(genl_info_net(info), token_val);
266 	if (!msk) {
267 		NL_SET_ERR_MSG_ATTR(info->extack, token, "invalid token");
268 		return err;
269 	}
270 
271 	if (!mptcp_pm_is_userspace(msk)) {
272 		GENL_SET_ERR_MSG(info, "invalid request; userspace PM not selected");
273 		goto create_err;
274 	}
275 
276 	err = mptcp_pm_parse_addr(laddr, info, &addr_l);
277 	if (err < 0) {
278 		NL_SET_ERR_MSG_ATTR(info->extack, laddr, "error parsing local addr");
279 		goto create_err;
280 	}
281 
282 	if (addr_l.id == 0) {
283 		NL_SET_ERR_MSG_ATTR(info->extack, laddr, "missing local addr id");
284 		goto create_err;
285 	}
286 
287 	err = mptcp_pm_parse_addr(raddr, info, &addr_r);
288 	if (err < 0) {
289 		NL_SET_ERR_MSG_ATTR(info->extack, raddr, "error parsing remote addr");
290 		goto create_err;
291 	}
292 
293 	sk = &msk->sk.icsk_inet.sk;
294 	lock_sock(sk);
295 
296 	err = __mptcp_subflow_connect(sk, &addr_l, &addr_r);
297 
298 	release_sock(sk);
299 
300  create_err:
301 	sock_put((struct sock *)msk);
302 	return err;
303 }
304 
305 static struct sock *mptcp_nl_find_ssk(struct mptcp_sock *msk,
306 				      const struct mptcp_addr_info *local,
307 				      const struct mptcp_addr_info *remote)
308 {
309 	struct sock *sk = &msk->sk.icsk_inet.sk;
310 	struct mptcp_subflow_context *subflow;
311 	struct sock *found = NULL;
312 
313 	if (local->family != remote->family)
314 		return NULL;
315 
316 	lock_sock(sk);
317 
318 	mptcp_for_each_subflow(msk, subflow) {
319 		const struct inet_sock *issk;
320 		struct sock *ssk;
321 
322 		ssk = mptcp_subflow_tcp_sock(subflow);
323 
324 		if (local->family != ssk->sk_family)
325 			continue;
326 
327 		issk = inet_sk(ssk);
328 
329 		switch (ssk->sk_family) {
330 		case AF_INET:
331 			if (issk->inet_saddr != local->addr.s_addr ||
332 			    issk->inet_daddr != remote->addr.s_addr)
333 				continue;
334 			break;
335 #if IS_ENABLED(CONFIG_MPTCP_IPV6)
336 		case AF_INET6: {
337 			const struct ipv6_pinfo *pinfo = inet6_sk(ssk);
338 
339 			if (!ipv6_addr_equal(&local->addr6, &pinfo->saddr) ||
340 			    !ipv6_addr_equal(&remote->addr6, &ssk->sk_v6_daddr))
341 				continue;
342 			break;
343 		}
344 #endif
345 		default:
346 			continue;
347 		}
348 
349 		if (issk->inet_sport == local->port &&
350 		    issk->inet_dport == remote->port) {
351 			found = ssk;
352 			goto found;
353 		}
354 	}
355 
356 found:
357 	release_sock(sk);
358 
359 	return found;
360 }
361 
362 int mptcp_nl_cmd_sf_destroy(struct sk_buff *skb, struct genl_info *info)
363 {
364 	struct nlattr *raddr = info->attrs[MPTCP_PM_ATTR_ADDR_REMOTE];
365 	struct nlattr *token = info->attrs[MPTCP_PM_ATTR_TOKEN];
366 	struct nlattr *laddr = info->attrs[MPTCP_PM_ATTR_ADDR];
367 	struct mptcp_addr_info addr_l;
368 	struct mptcp_addr_info addr_r;
369 	struct mptcp_sock *msk;
370 	struct sock *sk, *ssk;
371 	int err = -EINVAL;
372 	u32 token_val;
373 
374 	if (!laddr || !raddr || !token) {
375 		GENL_SET_ERR_MSG(info, "missing required inputs");
376 		return err;
377 	}
378 
379 	token_val = nla_get_u32(token);
380 
381 	msk = mptcp_token_get_sock(genl_info_net(info), token_val);
382 	if (!msk) {
383 		NL_SET_ERR_MSG_ATTR(info->extack, token, "invalid token");
384 		return err;
385 	}
386 
387 	if (!mptcp_pm_is_userspace(msk)) {
388 		GENL_SET_ERR_MSG(info, "invalid request; userspace PM not selected");
389 		goto destroy_err;
390 	}
391 
392 	err = mptcp_pm_parse_addr(laddr, info, &addr_l);
393 	if (err < 0) {
394 		NL_SET_ERR_MSG_ATTR(info->extack, laddr, "error parsing local addr");
395 		goto destroy_err;
396 	}
397 
398 	err = mptcp_pm_parse_addr(raddr, info, &addr_r);
399 	if (err < 0) {
400 		NL_SET_ERR_MSG_ATTR(info->extack, raddr, "error parsing remote addr");
401 		goto destroy_err;
402 	}
403 
404 	if (addr_l.family != addr_r.family) {
405 		GENL_SET_ERR_MSG(info, "address families do not match");
406 		goto destroy_err;
407 	}
408 
409 	if (!addr_l.port || !addr_r.port) {
410 		GENL_SET_ERR_MSG(info, "missing local or remote port");
411 		goto destroy_err;
412 	}
413 
414 	sk = &msk->sk.icsk_inet.sk;
415 	ssk = mptcp_nl_find_ssk(msk, &addr_l, &addr_r);
416 	if (ssk) {
417 		struct mptcp_subflow_context *subflow = mptcp_subflow_ctx(ssk);
418 
419 		mptcp_subflow_shutdown(sk, ssk, RCV_SHUTDOWN | SEND_SHUTDOWN);
420 		mptcp_close_ssk(sk, ssk, subflow);
421 		err = 0;
422 	} else {
423 		err = -ESRCH;
424 	}
425 
426  destroy_err:
427 	sock_put((struct sock *)msk);
428 	return err;
429 }
430