xref: /openbmc/linux/net/mptcp/pm_netlink.c (revision b4e18b29)
1 // SPDX-License-Identifier: GPL-2.0
2 /* Multipath TCP
3  *
4  * Copyright (c) 2020, Red Hat, Inc.
5  */
6 
7 #define pr_fmt(fmt) "MPTCP: " fmt
8 
9 #include <linux/inet.h>
10 #include <linux/kernel.h>
11 #include <net/tcp.h>
12 #include <net/netns/generic.h>
13 #include <net/mptcp.h>
14 #include <net/genetlink.h>
15 #include <uapi/linux/mptcp.h>
16 
17 #include "protocol.h"
18 #include "mib.h"
19 
20 /* forward declaration */
21 static struct genl_family mptcp_genl_family;
22 
23 static int pm_nl_pernet_id;
24 
25 struct mptcp_pm_addr_entry {
26 	struct list_head	list;
27 	struct mptcp_addr_info	addr;
28 	struct rcu_head		rcu;
29 	struct socket		*lsk;
30 };
31 
32 struct mptcp_pm_add_entry {
33 	struct list_head	list;
34 	struct mptcp_addr_info	addr;
35 	struct timer_list	add_timer;
36 	struct mptcp_sock	*sock;
37 	u8			retrans_times;
38 };
39 
40 #define MAX_ADDR_ID		255
41 #define BITMAP_SZ DIV_ROUND_UP(MAX_ADDR_ID + 1, BITS_PER_LONG)
42 
43 struct pm_nl_pernet {
44 	/* protects pernet updates */
45 	spinlock_t		lock;
46 	struct list_head	local_addr_list;
47 	unsigned int		addrs;
48 	unsigned int		add_addr_signal_max;
49 	unsigned int		add_addr_accept_max;
50 	unsigned int		local_addr_max;
51 	unsigned int		subflows_max;
52 	unsigned int		next_id;
53 	unsigned long		id_bitmap[BITMAP_SZ];
54 };
55 
56 #define MPTCP_PM_ADDR_MAX	8
57 #define ADD_ADDR_RETRANS_MAX	3
58 
59 static bool addresses_equal(const struct mptcp_addr_info *a,
60 			    struct mptcp_addr_info *b, bool use_port)
61 {
62 	bool addr_equals = false;
63 
64 	if (a->family == b->family) {
65 		if (a->family == AF_INET)
66 			addr_equals = a->addr.s_addr == b->addr.s_addr;
67 #if IS_ENABLED(CONFIG_MPTCP_IPV6)
68 		else
69 			addr_equals = !ipv6_addr_cmp(&a->addr6, &b->addr6);
70 	} else if (a->family == AF_INET) {
71 		if (ipv6_addr_v4mapped(&b->addr6))
72 			addr_equals = a->addr.s_addr == b->addr6.s6_addr32[3];
73 	} else if (b->family == AF_INET) {
74 		if (ipv6_addr_v4mapped(&a->addr6))
75 			addr_equals = a->addr6.s6_addr32[3] == b->addr.s_addr;
76 #endif
77 	}
78 
79 	if (!addr_equals)
80 		return false;
81 	if (!use_port)
82 		return true;
83 
84 	return a->port == b->port;
85 }
86 
87 static bool address_zero(const struct mptcp_addr_info *addr)
88 {
89 	struct mptcp_addr_info zero;
90 
91 	memset(&zero, 0, sizeof(zero));
92 	zero.family = addr->family;
93 
94 	return addresses_equal(addr, &zero, true);
95 }
96 
97 static void local_address(const struct sock_common *skc,
98 			  struct mptcp_addr_info *addr)
99 {
100 	addr->family = skc->skc_family;
101 	addr->port = htons(skc->skc_num);
102 	if (addr->family == AF_INET)
103 		addr->addr.s_addr = skc->skc_rcv_saddr;
104 #if IS_ENABLED(CONFIG_MPTCP_IPV6)
105 	else if (addr->family == AF_INET6)
106 		addr->addr6 = skc->skc_v6_rcv_saddr;
107 #endif
108 }
109 
110 static void remote_address(const struct sock_common *skc,
111 			   struct mptcp_addr_info *addr)
112 {
113 	addr->family = skc->skc_family;
114 	addr->port = skc->skc_dport;
115 	if (addr->family == AF_INET)
116 		addr->addr.s_addr = skc->skc_daddr;
117 #if IS_ENABLED(CONFIG_MPTCP_IPV6)
118 	else if (addr->family == AF_INET6)
119 		addr->addr6 = skc->skc_v6_daddr;
120 #endif
121 }
122 
123 static bool lookup_subflow_by_saddr(const struct list_head *list,
124 				    struct mptcp_addr_info *saddr)
125 {
126 	struct mptcp_subflow_context *subflow;
127 	struct mptcp_addr_info cur;
128 	struct sock_common *skc;
129 
130 	list_for_each_entry(subflow, list, node) {
131 		skc = (struct sock_common *)mptcp_subflow_tcp_sock(subflow);
132 
133 		local_address(skc, &cur);
134 		if (addresses_equal(&cur, saddr, saddr->port))
135 			return true;
136 	}
137 
138 	return false;
139 }
140 
141 static struct mptcp_pm_addr_entry *
142 select_local_address(const struct pm_nl_pernet *pernet,
143 		     struct mptcp_sock *msk)
144 {
145 	struct mptcp_pm_addr_entry *entry, *ret = NULL;
146 	struct sock *sk = (struct sock *)msk;
147 
148 	msk_owned_by_me(msk);
149 
150 	rcu_read_lock();
151 	__mptcp_flush_join_list(msk);
152 	list_for_each_entry_rcu(entry, &pernet->local_addr_list, list) {
153 		if (!(entry->addr.flags & MPTCP_PM_ADDR_FLAG_SUBFLOW))
154 			continue;
155 
156 		if (entry->addr.family != sk->sk_family) {
157 #if IS_ENABLED(CONFIG_MPTCP_IPV6)
158 			if ((entry->addr.family == AF_INET &&
159 			     !ipv6_addr_v4mapped(&sk->sk_v6_daddr)) ||
160 			    (sk->sk_family == AF_INET &&
161 			     !ipv6_addr_v4mapped(&entry->addr.addr6)))
162 #endif
163 				continue;
164 		}
165 
166 		/* avoid any address already in use by subflows and
167 		 * pending join
168 		 */
169 		if (!lookup_subflow_by_saddr(&msk->conn_list, &entry->addr)) {
170 			ret = entry;
171 			break;
172 		}
173 	}
174 	rcu_read_unlock();
175 	return ret;
176 }
177 
178 static struct mptcp_pm_addr_entry *
179 select_signal_address(struct pm_nl_pernet *pernet, unsigned int pos)
180 {
181 	struct mptcp_pm_addr_entry *entry, *ret = NULL;
182 	int i = 0;
183 
184 	rcu_read_lock();
185 	/* do not keep any additional per socket state, just signal
186 	 * the address list in order.
187 	 * Note: removal from the local address list during the msk life-cycle
188 	 * can lead to additional addresses not being announced.
189 	 */
190 	list_for_each_entry_rcu(entry, &pernet->local_addr_list, list) {
191 		if (!(entry->addr.flags & MPTCP_PM_ADDR_FLAG_SIGNAL))
192 			continue;
193 		if (i++ == pos) {
194 			ret = entry;
195 			break;
196 		}
197 	}
198 	rcu_read_unlock();
199 	return ret;
200 }
201 
202 unsigned int mptcp_pm_get_add_addr_signal_max(struct mptcp_sock *msk)
203 {
204 	struct pm_nl_pernet *pernet;
205 
206 	pernet = net_generic(sock_net((struct sock *)msk), pm_nl_pernet_id);
207 	return READ_ONCE(pernet->add_addr_signal_max);
208 }
209 EXPORT_SYMBOL_GPL(mptcp_pm_get_add_addr_signal_max);
210 
211 unsigned int mptcp_pm_get_add_addr_accept_max(struct mptcp_sock *msk)
212 {
213 	struct pm_nl_pernet *pernet;
214 
215 	pernet = net_generic(sock_net((struct sock *)msk), pm_nl_pernet_id);
216 	return READ_ONCE(pernet->add_addr_accept_max);
217 }
218 EXPORT_SYMBOL_GPL(mptcp_pm_get_add_addr_accept_max);
219 
220 unsigned int mptcp_pm_get_subflows_max(struct mptcp_sock *msk)
221 {
222 	struct pm_nl_pernet *pernet;
223 
224 	pernet = net_generic(sock_net((struct sock *)msk), pm_nl_pernet_id);
225 	return READ_ONCE(pernet->subflows_max);
226 }
227 EXPORT_SYMBOL_GPL(mptcp_pm_get_subflows_max);
228 
229 static unsigned int mptcp_pm_get_local_addr_max(struct mptcp_sock *msk)
230 {
231 	struct pm_nl_pernet *pernet;
232 
233 	pernet = net_generic(sock_net((struct sock *)msk), pm_nl_pernet_id);
234 	return READ_ONCE(pernet->local_addr_max);
235 }
236 
237 static void check_work_pending(struct mptcp_sock *msk)
238 {
239 	if (msk->pm.add_addr_signaled == mptcp_pm_get_add_addr_signal_max(msk) &&
240 	    (msk->pm.local_addr_used == mptcp_pm_get_local_addr_max(msk) ||
241 	     msk->pm.subflows == mptcp_pm_get_subflows_max(msk)))
242 		WRITE_ONCE(msk->pm.work_pending, false);
243 }
244 
245 static struct mptcp_pm_add_entry *
246 lookup_anno_list_by_saddr(struct mptcp_sock *msk,
247 			  struct mptcp_addr_info *addr)
248 {
249 	struct mptcp_pm_add_entry *entry;
250 
251 	lockdep_assert_held(&msk->pm.lock);
252 
253 	list_for_each_entry(entry, &msk->pm.anno_list, list) {
254 		if (addresses_equal(&entry->addr, addr, true))
255 			return entry;
256 	}
257 
258 	return NULL;
259 }
260 
261 bool mptcp_pm_sport_in_anno_list(struct mptcp_sock *msk, const struct sock *sk)
262 {
263 	struct mptcp_pm_add_entry *entry;
264 	struct mptcp_addr_info saddr;
265 	bool ret = false;
266 
267 	local_address((struct sock_common *)sk, &saddr);
268 
269 	spin_lock_bh(&msk->pm.lock);
270 	list_for_each_entry(entry, &msk->pm.anno_list, list) {
271 		if (addresses_equal(&entry->addr, &saddr, true)) {
272 			ret = true;
273 			goto out;
274 		}
275 	}
276 
277 out:
278 	spin_unlock_bh(&msk->pm.lock);
279 	return ret;
280 }
281 
282 static void mptcp_pm_add_timer(struct timer_list *timer)
283 {
284 	struct mptcp_pm_add_entry *entry = from_timer(entry, timer, add_timer);
285 	struct mptcp_sock *msk = entry->sock;
286 	struct sock *sk = (struct sock *)msk;
287 
288 	pr_debug("msk=%p", msk);
289 
290 	if (!msk)
291 		return;
292 
293 	if (inet_sk_state_load(sk) == TCP_CLOSE)
294 		return;
295 
296 	if (!entry->addr.id)
297 		return;
298 
299 	if (mptcp_pm_should_add_signal(msk)) {
300 		sk_reset_timer(sk, timer, jiffies + TCP_RTO_MAX / 8);
301 		goto out;
302 	}
303 
304 	spin_lock_bh(&msk->pm.lock);
305 
306 	if (!mptcp_pm_should_add_signal(msk)) {
307 		pr_debug("retransmit ADD_ADDR id=%d", entry->addr.id);
308 		mptcp_pm_announce_addr(msk, &entry->addr, false, entry->addr.port);
309 		mptcp_pm_add_addr_send_ack(msk);
310 		entry->retrans_times++;
311 	}
312 
313 	if (entry->retrans_times < ADD_ADDR_RETRANS_MAX)
314 		sk_reset_timer(sk, timer,
315 			       jiffies + mptcp_get_add_addr_timeout(sock_net(sk)));
316 
317 	spin_unlock_bh(&msk->pm.lock);
318 
319 out:
320 	__sock_put(sk);
321 }
322 
323 struct mptcp_pm_add_entry *
324 mptcp_pm_del_add_timer(struct mptcp_sock *msk,
325 		       struct mptcp_addr_info *addr)
326 {
327 	struct mptcp_pm_add_entry *entry;
328 	struct sock *sk = (struct sock *)msk;
329 
330 	spin_lock_bh(&msk->pm.lock);
331 	entry = lookup_anno_list_by_saddr(msk, addr);
332 	if (entry)
333 		entry->retrans_times = ADD_ADDR_RETRANS_MAX;
334 	spin_unlock_bh(&msk->pm.lock);
335 
336 	if (entry)
337 		sk_stop_timer_sync(sk, &entry->add_timer);
338 
339 	return entry;
340 }
341 
342 static bool mptcp_pm_alloc_anno_list(struct mptcp_sock *msk,
343 				     struct mptcp_pm_addr_entry *entry)
344 {
345 	struct mptcp_pm_add_entry *add_entry = NULL;
346 	struct sock *sk = (struct sock *)msk;
347 	struct net *net = sock_net(sk);
348 
349 	lockdep_assert_held(&msk->pm.lock);
350 
351 	if (lookup_anno_list_by_saddr(msk, &entry->addr))
352 		return false;
353 
354 	add_entry = kmalloc(sizeof(*add_entry), GFP_ATOMIC);
355 	if (!add_entry)
356 		return false;
357 
358 	list_add(&add_entry->list, &msk->pm.anno_list);
359 
360 	add_entry->addr = entry->addr;
361 	add_entry->sock = msk;
362 	add_entry->retrans_times = 0;
363 
364 	timer_setup(&add_entry->add_timer, mptcp_pm_add_timer, 0);
365 	sk_reset_timer(sk, &add_entry->add_timer,
366 		       jiffies + mptcp_get_add_addr_timeout(net));
367 
368 	return true;
369 }
370 
371 void mptcp_pm_free_anno_list(struct mptcp_sock *msk)
372 {
373 	struct mptcp_pm_add_entry *entry, *tmp;
374 	struct sock *sk = (struct sock *)msk;
375 	LIST_HEAD(free_list);
376 
377 	pr_debug("msk=%p", msk);
378 
379 	spin_lock_bh(&msk->pm.lock);
380 	list_splice_init(&msk->pm.anno_list, &free_list);
381 	spin_unlock_bh(&msk->pm.lock);
382 
383 	list_for_each_entry_safe(entry, tmp, &free_list, list) {
384 		sk_stop_timer_sync(sk, &entry->add_timer);
385 		kfree(entry);
386 	}
387 }
388 
389 static void mptcp_pm_create_subflow_or_signal_addr(struct mptcp_sock *msk)
390 {
391 	struct sock *sk = (struct sock *)msk;
392 	struct mptcp_pm_addr_entry *local;
393 	unsigned int add_addr_signal_max;
394 	unsigned int local_addr_max;
395 	struct pm_nl_pernet *pernet;
396 	unsigned int subflows_max;
397 
398 	pernet = net_generic(sock_net(sk), pm_nl_pernet_id);
399 
400 	add_addr_signal_max = mptcp_pm_get_add_addr_signal_max(msk);
401 	local_addr_max = mptcp_pm_get_local_addr_max(msk);
402 	subflows_max = mptcp_pm_get_subflows_max(msk);
403 
404 	pr_debug("local %d:%d signal %d:%d subflows %d:%d\n",
405 		 msk->pm.local_addr_used, local_addr_max,
406 		 msk->pm.add_addr_signaled, add_addr_signal_max,
407 		 msk->pm.subflows, subflows_max);
408 
409 	/* check first for announce */
410 	if (msk->pm.add_addr_signaled < add_addr_signal_max) {
411 		local = select_signal_address(pernet,
412 					      msk->pm.add_addr_signaled);
413 
414 		if (local) {
415 			if (mptcp_pm_alloc_anno_list(msk, local)) {
416 				msk->pm.add_addr_signaled++;
417 				mptcp_pm_announce_addr(msk, &local->addr, false, local->addr.port);
418 				mptcp_pm_nl_add_addr_send_ack(msk);
419 			}
420 		} else {
421 			/* pick failed, avoid fourther attempts later */
422 			msk->pm.local_addr_used = add_addr_signal_max;
423 		}
424 
425 		check_work_pending(msk);
426 	}
427 
428 	/* check if should create a new subflow */
429 	if (msk->pm.local_addr_used < local_addr_max &&
430 	    msk->pm.subflows < subflows_max) {
431 		local = select_local_address(pernet, msk);
432 		if (local) {
433 			struct mptcp_addr_info remote = { 0 };
434 
435 			msk->pm.local_addr_used++;
436 			msk->pm.subflows++;
437 			check_work_pending(msk);
438 			remote_address((struct sock_common *)sk, &remote);
439 			spin_unlock_bh(&msk->pm.lock);
440 			__mptcp_subflow_connect(sk, &local->addr, &remote);
441 			spin_lock_bh(&msk->pm.lock);
442 			return;
443 		}
444 
445 		/* lookup failed, avoid fourther attempts later */
446 		msk->pm.local_addr_used = local_addr_max;
447 		check_work_pending(msk);
448 	}
449 }
450 
451 void mptcp_pm_nl_fully_established(struct mptcp_sock *msk)
452 {
453 	mptcp_pm_create_subflow_or_signal_addr(msk);
454 }
455 
456 void mptcp_pm_nl_subflow_established(struct mptcp_sock *msk)
457 {
458 	mptcp_pm_create_subflow_or_signal_addr(msk);
459 }
460 
461 void mptcp_pm_nl_add_addr_received(struct mptcp_sock *msk)
462 {
463 	struct sock *sk = (struct sock *)msk;
464 	unsigned int add_addr_accept_max;
465 	struct mptcp_addr_info remote;
466 	struct mptcp_addr_info local;
467 	unsigned int subflows_max;
468 	bool use_port = false;
469 
470 	add_addr_accept_max = mptcp_pm_get_add_addr_accept_max(msk);
471 	subflows_max = mptcp_pm_get_subflows_max(msk);
472 
473 	pr_debug("accepted %d:%d remote family %d",
474 		 msk->pm.add_addr_accepted, add_addr_accept_max,
475 		 msk->pm.remote.family);
476 	msk->pm.add_addr_accepted++;
477 	msk->pm.subflows++;
478 	if (msk->pm.add_addr_accepted >= add_addr_accept_max ||
479 	    msk->pm.subflows >= subflows_max)
480 		WRITE_ONCE(msk->pm.accept_addr, false);
481 
482 	/* connect to the specified remote address, using whatever
483 	 * local address the routing configuration will pick.
484 	 */
485 	remote = msk->pm.remote;
486 	if (!remote.port)
487 		remote.port = sk->sk_dport;
488 	else
489 		use_port = true;
490 	memset(&local, 0, sizeof(local));
491 	local.family = remote.family;
492 
493 	spin_unlock_bh(&msk->pm.lock);
494 	__mptcp_subflow_connect(sk, &local, &remote);
495 	spin_lock_bh(&msk->pm.lock);
496 
497 	mptcp_pm_announce_addr(msk, &remote, true, use_port);
498 	mptcp_pm_nl_add_addr_send_ack(msk);
499 }
500 
501 void mptcp_pm_nl_add_addr_send_ack(struct mptcp_sock *msk)
502 {
503 	struct mptcp_subflow_context *subflow;
504 
505 	msk_owned_by_me(msk);
506 	lockdep_assert_held(&msk->pm.lock);
507 
508 	if (!mptcp_pm_should_add_signal(msk))
509 		return;
510 
511 	__mptcp_flush_join_list(msk);
512 	subflow = list_first_entry_or_null(&msk->conn_list, typeof(*subflow), node);
513 	if (subflow) {
514 		struct sock *ssk = mptcp_subflow_tcp_sock(subflow);
515 		u8 add_addr;
516 
517 		spin_unlock_bh(&msk->pm.lock);
518 		pr_debug("send ack for add_addr%s%s",
519 			 mptcp_pm_should_add_signal_ipv6(msk) ? " [ipv6]" : "",
520 			 mptcp_pm_should_add_signal_port(msk) ? " [port]" : "");
521 
522 		lock_sock(ssk);
523 		tcp_send_ack(ssk);
524 		release_sock(ssk);
525 		spin_lock_bh(&msk->pm.lock);
526 
527 		add_addr = READ_ONCE(msk->pm.addr_signal);
528 		if (mptcp_pm_should_add_signal_ipv6(msk))
529 			add_addr &= ~BIT(MPTCP_ADD_ADDR_IPV6);
530 		if (mptcp_pm_should_add_signal_port(msk))
531 			add_addr &= ~BIT(MPTCP_ADD_ADDR_PORT);
532 		WRITE_ONCE(msk->pm.addr_signal, add_addr);
533 	}
534 }
535 
536 int mptcp_pm_nl_mp_prio_send_ack(struct mptcp_sock *msk,
537 				 struct mptcp_addr_info *addr,
538 				 u8 bkup)
539 {
540 	struct mptcp_subflow_context *subflow;
541 
542 	pr_debug("bkup=%d", bkup);
543 
544 	mptcp_for_each_subflow(msk, subflow) {
545 		struct sock *ssk = mptcp_subflow_tcp_sock(subflow);
546 		struct sock *sk = (struct sock *)msk;
547 		struct mptcp_addr_info local;
548 
549 		local_address((struct sock_common *)ssk, &local);
550 		if (!addresses_equal(&local, addr, addr->port))
551 			continue;
552 
553 		subflow->backup = bkup;
554 		subflow->send_mp_prio = 1;
555 		subflow->request_bkup = bkup;
556 		__MPTCP_INC_STATS(sock_net(sk), MPTCP_MIB_MPPRIOTX);
557 
558 		spin_unlock_bh(&msk->pm.lock);
559 		pr_debug("send ack for mp_prio");
560 		lock_sock(ssk);
561 		tcp_send_ack(ssk);
562 		release_sock(ssk);
563 		spin_lock_bh(&msk->pm.lock);
564 
565 		return 0;
566 	}
567 
568 	return -EINVAL;
569 }
570 
571 void mptcp_pm_nl_rm_addr_received(struct mptcp_sock *msk)
572 {
573 	struct mptcp_subflow_context *subflow, *tmp;
574 	struct sock *sk = (struct sock *)msk;
575 
576 	pr_debug("address rm_id %d", msk->pm.rm_id);
577 
578 	msk_owned_by_me(msk);
579 
580 	if (!msk->pm.rm_id)
581 		return;
582 
583 	if (list_empty(&msk->conn_list))
584 		return;
585 
586 	list_for_each_entry_safe(subflow, tmp, &msk->conn_list, node) {
587 		struct sock *ssk = mptcp_subflow_tcp_sock(subflow);
588 		int how = RCV_SHUTDOWN | SEND_SHUTDOWN;
589 
590 		if (msk->pm.rm_id != subflow->remote_id)
591 			continue;
592 
593 		spin_unlock_bh(&msk->pm.lock);
594 		mptcp_subflow_shutdown(sk, ssk, how);
595 		__mptcp_close_ssk(sk, ssk, subflow);
596 		spin_lock_bh(&msk->pm.lock);
597 
598 		msk->pm.add_addr_accepted--;
599 		msk->pm.subflows--;
600 		WRITE_ONCE(msk->pm.accept_addr, true);
601 
602 		__MPTCP_INC_STATS(sock_net(sk), MPTCP_MIB_RMADDR);
603 
604 		break;
605 	}
606 }
607 
608 void mptcp_pm_nl_rm_subflow_received(struct mptcp_sock *msk, u8 rm_id)
609 {
610 	struct mptcp_subflow_context *subflow, *tmp;
611 	struct sock *sk = (struct sock *)msk;
612 
613 	pr_debug("subflow rm_id %d", rm_id);
614 
615 	msk_owned_by_me(msk);
616 
617 	if (!rm_id)
618 		return;
619 
620 	if (list_empty(&msk->conn_list))
621 		return;
622 
623 	list_for_each_entry_safe(subflow, tmp, &msk->conn_list, node) {
624 		struct sock *ssk = mptcp_subflow_tcp_sock(subflow);
625 		int how = RCV_SHUTDOWN | SEND_SHUTDOWN;
626 
627 		if (rm_id != subflow->local_id)
628 			continue;
629 
630 		spin_unlock_bh(&msk->pm.lock);
631 		mptcp_subflow_shutdown(sk, ssk, how);
632 		__mptcp_close_ssk(sk, ssk, subflow);
633 		spin_lock_bh(&msk->pm.lock);
634 
635 		msk->pm.local_addr_used--;
636 		msk->pm.subflows--;
637 
638 		__MPTCP_INC_STATS(sock_net(sk), MPTCP_MIB_RMSUBFLOW);
639 
640 		break;
641 	}
642 }
643 
644 static bool address_use_port(struct mptcp_pm_addr_entry *entry)
645 {
646 	return (entry->addr.flags &
647 		(MPTCP_PM_ADDR_FLAG_SIGNAL | MPTCP_PM_ADDR_FLAG_SUBFLOW)) ==
648 		MPTCP_PM_ADDR_FLAG_SIGNAL;
649 }
650 
651 static int mptcp_pm_nl_append_new_local_addr(struct pm_nl_pernet *pernet,
652 					     struct mptcp_pm_addr_entry *entry)
653 {
654 	struct mptcp_pm_addr_entry *cur;
655 	unsigned int addr_max;
656 	int ret = -EINVAL;
657 
658 	spin_lock_bh(&pernet->lock);
659 	/* to keep the code simple, don't do IDR-like allocation for address ID,
660 	 * just bail when we exceed limits
661 	 */
662 	if (pernet->next_id == MAX_ADDR_ID)
663 		pernet->next_id = 1;
664 	if (pernet->addrs >= MPTCP_PM_ADDR_MAX)
665 		goto out;
666 	if (test_bit(entry->addr.id, pernet->id_bitmap))
667 		goto out;
668 
669 	/* do not insert duplicate address, differentiate on port only
670 	 * singled addresses
671 	 */
672 	list_for_each_entry(cur, &pernet->local_addr_list, list) {
673 		if (addresses_equal(&cur->addr, &entry->addr,
674 				    address_use_port(entry) &&
675 				    address_use_port(cur)))
676 			goto out;
677 	}
678 
679 	if (!entry->addr.id) {
680 find_next:
681 		entry->addr.id = find_next_zero_bit(pernet->id_bitmap,
682 						    MAX_ADDR_ID + 1,
683 						    pernet->next_id);
684 		if ((!entry->addr.id || entry->addr.id > MAX_ADDR_ID) &&
685 		    pernet->next_id != 1) {
686 			pernet->next_id = 1;
687 			goto find_next;
688 		}
689 	}
690 
691 	if (!entry->addr.id || entry->addr.id > MAX_ADDR_ID)
692 		goto out;
693 
694 	__set_bit(entry->addr.id, pernet->id_bitmap);
695 	if (entry->addr.id > pernet->next_id)
696 		pernet->next_id = entry->addr.id;
697 
698 	if (entry->addr.flags & MPTCP_PM_ADDR_FLAG_SIGNAL) {
699 		addr_max = pernet->add_addr_signal_max;
700 		WRITE_ONCE(pernet->add_addr_signal_max, addr_max + 1);
701 	}
702 	if (entry->addr.flags & MPTCP_PM_ADDR_FLAG_SUBFLOW) {
703 		addr_max = pernet->local_addr_max;
704 		WRITE_ONCE(pernet->local_addr_max, addr_max + 1);
705 	}
706 
707 	pernet->addrs++;
708 	list_add_tail_rcu(&entry->list, &pernet->local_addr_list);
709 	ret = entry->addr.id;
710 
711 out:
712 	spin_unlock_bh(&pernet->lock);
713 	return ret;
714 }
715 
716 static int mptcp_pm_nl_create_listen_socket(struct sock *sk,
717 					    struct mptcp_pm_addr_entry *entry)
718 {
719 	struct sockaddr_storage addr;
720 	struct mptcp_sock *msk;
721 	struct socket *ssock;
722 	int backlog = 1024;
723 	int err;
724 
725 	err = sock_create_kern(sock_net(sk), entry->addr.family,
726 			       SOCK_STREAM, IPPROTO_MPTCP, &entry->lsk);
727 	if (err)
728 		return err;
729 
730 	msk = mptcp_sk(entry->lsk->sk);
731 	if (!msk) {
732 		err = -EINVAL;
733 		goto out;
734 	}
735 
736 	ssock = __mptcp_nmpc_socket(msk);
737 	if (!ssock) {
738 		err = -EINVAL;
739 		goto out;
740 	}
741 
742 	mptcp_info2sockaddr(&entry->addr, &addr, entry->addr.family);
743 	err = kernel_bind(ssock, (struct sockaddr *)&addr,
744 			  sizeof(struct sockaddr_in));
745 	if (err) {
746 		pr_warn("kernel_bind error, err=%d", err);
747 		goto out;
748 	}
749 
750 	err = kernel_listen(ssock, backlog);
751 	if (err) {
752 		pr_warn("kernel_listen error, err=%d", err);
753 		goto out;
754 	}
755 
756 	return 0;
757 
758 out:
759 	sock_release(entry->lsk);
760 	return err;
761 }
762 
763 int mptcp_pm_nl_get_local_id(struct mptcp_sock *msk, struct sock_common *skc)
764 {
765 	struct mptcp_pm_addr_entry *entry;
766 	struct mptcp_addr_info skc_local;
767 	struct mptcp_addr_info msk_local;
768 	struct pm_nl_pernet *pernet;
769 	int ret = -1;
770 
771 	if (WARN_ON_ONCE(!msk))
772 		return -1;
773 
774 	/* The 0 ID mapping is defined by the first subflow, copied into the msk
775 	 * addr
776 	 */
777 	local_address((struct sock_common *)msk, &msk_local);
778 	local_address((struct sock_common *)skc, &skc_local);
779 	if (addresses_equal(&msk_local, &skc_local, false))
780 		return 0;
781 
782 	if (address_zero(&skc_local))
783 		return 0;
784 
785 	pernet = net_generic(sock_net((struct sock *)msk), pm_nl_pernet_id);
786 
787 	rcu_read_lock();
788 	list_for_each_entry_rcu(entry, &pernet->local_addr_list, list) {
789 		if (addresses_equal(&entry->addr, &skc_local, entry->addr.port)) {
790 			ret = entry->addr.id;
791 			break;
792 		}
793 	}
794 	rcu_read_unlock();
795 	if (ret >= 0)
796 		return ret;
797 
798 	/* address not found, add to local list */
799 	entry = kmalloc(sizeof(*entry), GFP_ATOMIC);
800 	if (!entry)
801 		return -ENOMEM;
802 
803 	entry->addr = skc_local;
804 	entry->addr.ifindex = 0;
805 	entry->addr.flags = 0;
806 	entry->addr.id = 0;
807 	entry->addr.port = 0;
808 	entry->lsk = NULL;
809 	ret = mptcp_pm_nl_append_new_local_addr(pernet, entry);
810 	if (ret < 0)
811 		kfree(entry);
812 
813 	return ret;
814 }
815 
816 void mptcp_pm_nl_data_init(struct mptcp_sock *msk)
817 {
818 	struct mptcp_pm_data *pm = &msk->pm;
819 	bool subflows;
820 
821 	subflows = !!mptcp_pm_get_subflows_max(msk);
822 	WRITE_ONCE(pm->work_pending, (!!mptcp_pm_get_local_addr_max(msk) && subflows) ||
823 		   !!mptcp_pm_get_add_addr_signal_max(msk));
824 	WRITE_ONCE(pm->accept_addr, !!mptcp_pm_get_add_addr_accept_max(msk) && subflows);
825 	WRITE_ONCE(pm->accept_subflow, subflows);
826 }
827 
828 #define MPTCP_PM_CMD_GRP_OFFSET	0
829 
830 static const struct genl_multicast_group mptcp_pm_mcgrps[] = {
831 	[MPTCP_PM_CMD_GRP_OFFSET]	= { .name = MPTCP_PM_CMD_GRP_NAME, },
832 };
833 
834 static const struct nla_policy
835 mptcp_pm_addr_policy[MPTCP_PM_ADDR_ATTR_MAX + 1] = {
836 	[MPTCP_PM_ADDR_ATTR_FAMILY]	= { .type	= NLA_U16,	},
837 	[MPTCP_PM_ADDR_ATTR_ID]		= { .type	= NLA_U8,	},
838 	[MPTCP_PM_ADDR_ATTR_ADDR4]	= { .type	= NLA_U32,	},
839 	[MPTCP_PM_ADDR_ATTR_ADDR6]	=
840 		NLA_POLICY_EXACT_LEN(sizeof(struct in6_addr)),
841 	[MPTCP_PM_ADDR_ATTR_PORT]	= { .type	= NLA_U16	},
842 	[MPTCP_PM_ADDR_ATTR_FLAGS]	= { .type	= NLA_U32	},
843 	[MPTCP_PM_ADDR_ATTR_IF_IDX]     = { .type	= NLA_S32	},
844 };
845 
846 static const struct nla_policy mptcp_pm_policy[MPTCP_PM_ATTR_MAX + 1] = {
847 	[MPTCP_PM_ATTR_ADDR]		=
848 					NLA_POLICY_NESTED(mptcp_pm_addr_policy),
849 	[MPTCP_PM_ATTR_RCV_ADD_ADDRS]	= { .type	= NLA_U32,	},
850 	[MPTCP_PM_ATTR_SUBFLOWS]	= { .type	= NLA_U32,	},
851 };
852 
853 static int mptcp_pm_family_to_addr(int family)
854 {
855 #if IS_ENABLED(CONFIG_MPTCP_IPV6)
856 	if (family == AF_INET6)
857 		return MPTCP_PM_ADDR_ATTR_ADDR6;
858 #endif
859 	return MPTCP_PM_ADDR_ATTR_ADDR4;
860 }
861 
862 static int mptcp_pm_parse_addr(struct nlattr *attr, struct genl_info *info,
863 			       bool require_family,
864 			       struct mptcp_pm_addr_entry *entry)
865 {
866 	struct nlattr *tb[MPTCP_PM_ADDR_ATTR_MAX + 1];
867 	int err, addr_addr;
868 
869 	if (!attr) {
870 		GENL_SET_ERR_MSG(info, "missing address info");
871 		return -EINVAL;
872 	}
873 
874 	/* no validation needed - was already done via nested policy */
875 	err = nla_parse_nested_deprecated(tb, MPTCP_PM_ADDR_ATTR_MAX, attr,
876 					  mptcp_pm_addr_policy, info->extack);
877 	if (err)
878 		return err;
879 
880 	memset(entry, 0, sizeof(*entry));
881 	if (!tb[MPTCP_PM_ADDR_ATTR_FAMILY]) {
882 		if (!require_family)
883 			goto skip_family;
884 
885 		NL_SET_ERR_MSG_ATTR(info->extack, attr,
886 				    "missing family");
887 		return -EINVAL;
888 	}
889 
890 	entry->addr.family = nla_get_u16(tb[MPTCP_PM_ADDR_ATTR_FAMILY]);
891 	if (entry->addr.family != AF_INET
892 #if IS_ENABLED(CONFIG_MPTCP_IPV6)
893 	    && entry->addr.family != AF_INET6
894 #endif
895 	    ) {
896 		NL_SET_ERR_MSG_ATTR(info->extack, attr,
897 				    "unknown address family");
898 		return -EINVAL;
899 	}
900 	addr_addr = mptcp_pm_family_to_addr(entry->addr.family);
901 	if (!tb[addr_addr]) {
902 		NL_SET_ERR_MSG_ATTR(info->extack, attr,
903 				    "missing address data");
904 		return -EINVAL;
905 	}
906 
907 #if IS_ENABLED(CONFIG_MPTCP_IPV6)
908 	if (entry->addr.family == AF_INET6)
909 		entry->addr.addr6 = nla_get_in6_addr(tb[addr_addr]);
910 	else
911 #endif
912 		entry->addr.addr.s_addr = nla_get_in_addr(tb[addr_addr]);
913 
914 skip_family:
915 	if (tb[MPTCP_PM_ADDR_ATTR_IF_IDX]) {
916 		u32 val = nla_get_s32(tb[MPTCP_PM_ADDR_ATTR_IF_IDX]);
917 
918 		entry->addr.ifindex = val;
919 	}
920 
921 	if (tb[MPTCP_PM_ADDR_ATTR_ID])
922 		entry->addr.id = nla_get_u8(tb[MPTCP_PM_ADDR_ATTR_ID]);
923 
924 	if (tb[MPTCP_PM_ADDR_ATTR_FLAGS])
925 		entry->addr.flags = nla_get_u32(tb[MPTCP_PM_ADDR_ATTR_FLAGS]);
926 
927 	if (tb[MPTCP_PM_ADDR_ATTR_PORT])
928 		entry->addr.port = htons(nla_get_u16(tb[MPTCP_PM_ADDR_ATTR_PORT]));
929 
930 	return 0;
931 }
932 
933 static struct pm_nl_pernet *genl_info_pm_nl(struct genl_info *info)
934 {
935 	return net_generic(genl_info_net(info), pm_nl_pernet_id);
936 }
937 
938 static int mptcp_nl_add_subflow_or_signal_addr(struct net *net)
939 {
940 	struct mptcp_sock *msk;
941 	long s_slot = 0, s_num = 0;
942 
943 	while ((msk = mptcp_token_iter_next(net, &s_slot, &s_num)) != NULL) {
944 		struct sock *sk = (struct sock *)msk;
945 
946 		if (!READ_ONCE(msk->fully_established))
947 			goto next;
948 
949 		lock_sock(sk);
950 		spin_lock_bh(&msk->pm.lock);
951 		mptcp_pm_create_subflow_or_signal_addr(msk);
952 		spin_unlock_bh(&msk->pm.lock);
953 		release_sock(sk);
954 
955 next:
956 		sock_put(sk);
957 		cond_resched();
958 	}
959 
960 	return 0;
961 }
962 
963 static int mptcp_nl_cmd_add_addr(struct sk_buff *skb, struct genl_info *info)
964 {
965 	struct nlattr *attr = info->attrs[MPTCP_PM_ATTR_ADDR];
966 	struct pm_nl_pernet *pernet = genl_info_pm_nl(info);
967 	struct mptcp_pm_addr_entry addr, *entry;
968 	int ret;
969 
970 	ret = mptcp_pm_parse_addr(attr, info, true, &addr);
971 	if (ret < 0)
972 		return ret;
973 
974 	entry = kmalloc(sizeof(*entry), GFP_KERNEL);
975 	if (!entry) {
976 		GENL_SET_ERR_MSG(info, "can't allocate addr");
977 		return -ENOMEM;
978 	}
979 
980 	*entry = addr;
981 	if (entry->addr.port) {
982 		ret = mptcp_pm_nl_create_listen_socket(skb->sk, entry);
983 		if (ret) {
984 			GENL_SET_ERR_MSG(info, "create listen socket error");
985 			kfree(entry);
986 			return ret;
987 		}
988 	}
989 	ret = mptcp_pm_nl_append_new_local_addr(pernet, entry);
990 	if (ret < 0) {
991 		GENL_SET_ERR_MSG(info, "too many addresses or duplicate one");
992 		if (entry->lsk)
993 			sock_release(entry->lsk);
994 		kfree(entry);
995 		return ret;
996 	}
997 
998 	mptcp_nl_add_subflow_or_signal_addr(sock_net(skb->sk));
999 
1000 	return 0;
1001 }
1002 
1003 static struct mptcp_pm_addr_entry *
1004 __lookup_addr_by_id(struct pm_nl_pernet *pernet, unsigned int id)
1005 {
1006 	struct mptcp_pm_addr_entry *entry;
1007 
1008 	list_for_each_entry(entry, &pernet->local_addr_list, list) {
1009 		if (entry->addr.id == id)
1010 			return entry;
1011 	}
1012 	return NULL;
1013 }
1014 
1015 static bool remove_anno_list_by_saddr(struct mptcp_sock *msk,
1016 				      struct mptcp_addr_info *addr)
1017 {
1018 	struct mptcp_pm_add_entry *entry;
1019 
1020 	entry = mptcp_pm_del_add_timer(msk, addr);
1021 	if (entry) {
1022 		list_del(&entry->list);
1023 		kfree(entry);
1024 		return true;
1025 	}
1026 
1027 	return false;
1028 }
1029 
1030 static bool mptcp_pm_remove_anno_addr(struct mptcp_sock *msk,
1031 				      struct mptcp_addr_info *addr,
1032 				      bool force)
1033 {
1034 	bool ret;
1035 
1036 	ret = remove_anno_list_by_saddr(msk, addr);
1037 	if (ret || force) {
1038 		spin_lock_bh(&msk->pm.lock);
1039 		mptcp_pm_remove_addr(msk, addr->id);
1040 		spin_unlock_bh(&msk->pm.lock);
1041 	}
1042 	return ret;
1043 }
1044 
1045 static int mptcp_nl_remove_subflow_and_signal_addr(struct net *net,
1046 						   struct mptcp_addr_info *addr)
1047 {
1048 	struct mptcp_sock *msk;
1049 	long s_slot = 0, s_num = 0;
1050 
1051 	pr_debug("remove_id=%d", addr->id);
1052 
1053 	while ((msk = mptcp_token_iter_next(net, &s_slot, &s_num)) != NULL) {
1054 		struct sock *sk = (struct sock *)msk;
1055 		bool remove_subflow;
1056 
1057 		if (list_empty(&msk->conn_list)) {
1058 			mptcp_pm_remove_anno_addr(msk, addr, false);
1059 			goto next;
1060 		}
1061 
1062 		lock_sock(sk);
1063 		remove_subflow = lookup_subflow_by_saddr(&msk->conn_list, addr);
1064 		mptcp_pm_remove_anno_addr(msk, addr, remove_subflow);
1065 		if (remove_subflow)
1066 			mptcp_pm_remove_subflow(msk, addr->id);
1067 		release_sock(sk);
1068 
1069 next:
1070 		sock_put(sk);
1071 		cond_resched();
1072 	}
1073 
1074 	return 0;
1075 }
1076 
1077 struct addr_entry_release_work {
1078 	struct rcu_work	rwork;
1079 	struct mptcp_pm_addr_entry *entry;
1080 };
1081 
1082 static void mptcp_pm_release_addr_entry(struct work_struct *work)
1083 {
1084 	struct addr_entry_release_work *w;
1085 	struct mptcp_pm_addr_entry *entry;
1086 
1087 	w = container_of(to_rcu_work(work), struct addr_entry_release_work, rwork);
1088 	entry = w->entry;
1089 	if (entry) {
1090 		if (entry->lsk)
1091 			sock_release(entry->lsk);
1092 		kfree(entry);
1093 	}
1094 	kfree(w);
1095 }
1096 
1097 static void mptcp_pm_free_addr_entry(struct mptcp_pm_addr_entry *entry)
1098 {
1099 	struct addr_entry_release_work *w;
1100 
1101 	w = kmalloc(sizeof(*w), GFP_ATOMIC);
1102 	if (w) {
1103 		INIT_RCU_WORK(&w->rwork, mptcp_pm_release_addr_entry);
1104 		w->entry = entry;
1105 		queue_rcu_work(system_wq, &w->rwork);
1106 	}
1107 }
1108 
1109 static int mptcp_nl_cmd_del_addr(struct sk_buff *skb, struct genl_info *info)
1110 {
1111 	struct nlattr *attr = info->attrs[MPTCP_PM_ATTR_ADDR];
1112 	struct pm_nl_pernet *pernet = genl_info_pm_nl(info);
1113 	struct mptcp_pm_addr_entry addr, *entry;
1114 	unsigned int addr_max;
1115 	int ret;
1116 
1117 	ret = mptcp_pm_parse_addr(attr, info, false, &addr);
1118 	if (ret < 0)
1119 		return ret;
1120 
1121 	spin_lock_bh(&pernet->lock);
1122 	entry = __lookup_addr_by_id(pernet, addr.addr.id);
1123 	if (!entry) {
1124 		GENL_SET_ERR_MSG(info, "address not found");
1125 		spin_unlock_bh(&pernet->lock);
1126 		return -EINVAL;
1127 	}
1128 	if (entry->addr.flags & MPTCP_PM_ADDR_FLAG_SIGNAL) {
1129 		addr_max = pernet->add_addr_signal_max;
1130 		WRITE_ONCE(pernet->add_addr_signal_max, addr_max - 1);
1131 	}
1132 	if (entry->addr.flags & MPTCP_PM_ADDR_FLAG_SUBFLOW) {
1133 		addr_max = pernet->local_addr_max;
1134 		WRITE_ONCE(pernet->local_addr_max, addr_max - 1);
1135 	}
1136 
1137 	pernet->addrs--;
1138 	list_del_rcu(&entry->list);
1139 	__clear_bit(entry->addr.id, pernet->id_bitmap);
1140 	spin_unlock_bh(&pernet->lock);
1141 
1142 	mptcp_nl_remove_subflow_and_signal_addr(sock_net(skb->sk), &entry->addr);
1143 	mptcp_pm_free_addr_entry(entry);
1144 
1145 	return ret;
1146 }
1147 
1148 static void __flush_addrs(struct net *net, struct list_head *list)
1149 {
1150 	while (!list_empty(list)) {
1151 		struct mptcp_pm_addr_entry *cur;
1152 
1153 		cur = list_entry(list->next,
1154 				 struct mptcp_pm_addr_entry, list);
1155 		mptcp_nl_remove_subflow_and_signal_addr(net, &cur->addr);
1156 		list_del_rcu(&cur->list);
1157 		mptcp_pm_free_addr_entry(cur);
1158 	}
1159 }
1160 
1161 static void __reset_counters(struct pm_nl_pernet *pernet)
1162 {
1163 	WRITE_ONCE(pernet->add_addr_signal_max, 0);
1164 	WRITE_ONCE(pernet->add_addr_accept_max, 0);
1165 	WRITE_ONCE(pernet->local_addr_max, 0);
1166 	pernet->addrs = 0;
1167 }
1168 
1169 static int mptcp_nl_cmd_flush_addrs(struct sk_buff *skb, struct genl_info *info)
1170 {
1171 	struct pm_nl_pernet *pernet = genl_info_pm_nl(info);
1172 	LIST_HEAD(free_list);
1173 
1174 	spin_lock_bh(&pernet->lock);
1175 	list_splice_init(&pernet->local_addr_list, &free_list);
1176 	__reset_counters(pernet);
1177 	pernet->next_id = 1;
1178 	bitmap_zero(pernet->id_bitmap, MAX_ADDR_ID + 1);
1179 	spin_unlock_bh(&pernet->lock);
1180 	__flush_addrs(sock_net(skb->sk), &free_list);
1181 	return 0;
1182 }
1183 
1184 static int mptcp_nl_fill_addr(struct sk_buff *skb,
1185 			      struct mptcp_pm_addr_entry *entry)
1186 {
1187 	struct mptcp_addr_info *addr = &entry->addr;
1188 	struct nlattr *attr;
1189 
1190 	attr = nla_nest_start(skb, MPTCP_PM_ATTR_ADDR);
1191 	if (!attr)
1192 		return -EMSGSIZE;
1193 
1194 	if (nla_put_u16(skb, MPTCP_PM_ADDR_ATTR_FAMILY, addr->family))
1195 		goto nla_put_failure;
1196 	if (nla_put_u16(skb, MPTCP_PM_ADDR_ATTR_PORT, ntohs(addr->port)))
1197 		goto nla_put_failure;
1198 	if (nla_put_u8(skb, MPTCP_PM_ADDR_ATTR_ID, addr->id))
1199 		goto nla_put_failure;
1200 	if (nla_put_u32(skb, MPTCP_PM_ADDR_ATTR_FLAGS, entry->addr.flags))
1201 		goto nla_put_failure;
1202 	if (entry->addr.ifindex &&
1203 	    nla_put_s32(skb, MPTCP_PM_ADDR_ATTR_IF_IDX, entry->addr.ifindex))
1204 		goto nla_put_failure;
1205 
1206 	if (addr->family == AF_INET &&
1207 	    nla_put_in_addr(skb, MPTCP_PM_ADDR_ATTR_ADDR4,
1208 			    addr->addr.s_addr))
1209 		goto nla_put_failure;
1210 #if IS_ENABLED(CONFIG_MPTCP_IPV6)
1211 	else if (addr->family == AF_INET6 &&
1212 		 nla_put_in6_addr(skb, MPTCP_PM_ADDR_ATTR_ADDR6, &addr->addr6))
1213 		goto nla_put_failure;
1214 #endif
1215 	nla_nest_end(skb, attr);
1216 	return 0;
1217 
1218 nla_put_failure:
1219 	nla_nest_cancel(skb, attr);
1220 	return -EMSGSIZE;
1221 }
1222 
1223 static int mptcp_nl_cmd_get_addr(struct sk_buff *skb, struct genl_info *info)
1224 {
1225 	struct nlattr *attr = info->attrs[MPTCP_PM_ATTR_ADDR];
1226 	struct pm_nl_pernet *pernet = genl_info_pm_nl(info);
1227 	struct mptcp_pm_addr_entry addr, *entry;
1228 	struct sk_buff *msg;
1229 	void *reply;
1230 	int ret;
1231 
1232 	ret = mptcp_pm_parse_addr(attr, info, false, &addr);
1233 	if (ret < 0)
1234 		return ret;
1235 
1236 	msg = nlmsg_new(NLMSG_DEFAULT_SIZE, GFP_KERNEL);
1237 	if (!msg)
1238 		return -ENOMEM;
1239 
1240 	reply = genlmsg_put_reply(msg, info, &mptcp_genl_family, 0,
1241 				  info->genlhdr->cmd);
1242 	if (!reply) {
1243 		GENL_SET_ERR_MSG(info, "not enough space in Netlink message");
1244 		ret = -EMSGSIZE;
1245 		goto fail;
1246 	}
1247 
1248 	spin_lock_bh(&pernet->lock);
1249 	entry = __lookup_addr_by_id(pernet, addr.addr.id);
1250 	if (!entry) {
1251 		GENL_SET_ERR_MSG(info, "address not found");
1252 		ret = -EINVAL;
1253 		goto unlock_fail;
1254 	}
1255 
1256 	ret = mptcp_nl_fill_addr(msg, entry);
1257 	if (ret)
1258 		goto unlock_fail;
1259 
1260 	genlmsg_end(msg, reply);
1261 	ret = genlmsg_reply(msg, info);
1262 	spin_unlock_bh(&pernet->lock);
1263 	return ret;
1264 
1265 unlock_fail:
1266 	spin_unlock_bh(&pernet->lock);
1267 
1268 fail:
1269 	nlmsg_free(msg);
1270 	return ret;
1271 }
1272 
1273 static int mptcp_nl_cmd_dump_addrs(struct sk_buff *msg,
1274 				   struct netlink_callback *cb)
1275 {
1276 	struct net *net = sock_net(msg->sk);
1277 	struct mptcp_pm_addr_entry *entry;
1278 	struct pm_nl_pernet *pernet;
1279 	int id = cb->args[0];
1280 	void *hdr;
1281 	int i;
1282 
1283 	pernet = net_generic(net, pm_nl_pernet_id);
1284 
1285 	spin_lock_bh(&pernet->lock);
1286 	for (i = id; i < MAX_ADDR_ID + 1; i++) {
1287 		if (test_bit(i, pernet->id_bitmap)) {
1288 			entry = __lookup_addr_by_id(pernet, i);
1289 			if (!entry)
1290 				break;
1291 
1292 			if (entry->addr.id <= id)
1293 				continue;
1294 
1295 			hdr = genlmsg_put(msg, NETLINK_CB(cb->skb).portid,
1296 					  cb->nlh->nlmsg_seq, &mptcp_genl_family,
1297 					  NLM_F_MULTI, MPTCP_PM_CMD_GET_ADDR);
1298 			if (!hdr)
1299 				break;
1300 
1301 			if (mptcp_nl_fill_addr(msg, entry) < 0) {
1302 				genlmsg_cancel(msg, hdr);
1303 				break;
1304 			}
1305 
1306 			id = entry->addr.id;
1307 			genlmsg_end(msg, hdr);
1308 		}
1309 	}
1310 	spin_unlock_bh(&pernet->lock);
1311 
1312 	cb->args[0] = id;
1313 	return msg->len;
1314 }
1315 
1316 static int parse_limit(struct genl_info *info, int id, unsigned int *limit)
1317 {
1318 	struct nlattr *attr = info->attrs[id];
1319 
1320 	if (!attr)
1321 		return 0;
1322 
1323 	*limit = nla_get_u32(attr);
1324 	if (*limit > MPTCP_PM_ADDR_MAX) {
1325 		GENL_SET_ERR_MSG(info, "limit greater than maximum");
1326 		return -EINVAL;
1327 	}
1328 	return 0;
1329 }
1330 
1331 static int
1332 mptcp_nl_cmd_set_limits(struct sk_buff *skb, struct genl_info *info)
1333 {
1334 	struct pm_nl_pernet *pernet = genl_info_pm_nl(info);
1335 	unsigned int rcv_addrs, subflows;
1336 	int ret;
1337 
1338 	spin_lock_bh(&pernet->lock);
1339 	rcv_addrs = pernet->add_addr_accept_max;
1340 	ret = parse_limit(info, MPTCP_PM_ATTR_RCV_ADD_ADDRS, &rcv_addrs);
1341 	if (ret)
1342 		goto unlock;
1343 
1344 	subflows = pernet->subflows_max;
1345 	ret = parse_limit(info, MPTCP_PM_ATTR_SUBFLOWS, &subflows);
1346 	if (ret)
1347 		goto unlock;
1348 
1349 	WRITE_ONCE(pernet->add_addr_accept_max, rcv_addrs);
1350 	WRITE_ONCE(pernet->subflows_max, subflows);
1351 
1352 unlock:
1353 	spin_unlock_bh(&pernet->lock);
1354 	return ret;
1355 }
1356 
1357 static int
1358 mptcp_nl_cmd_get_limits(struct sk_buff *skb, struct genl_info *info)
1359 {
1360 	struct pm_nl_pernet *pernet = genl_info_pm_nl(info);
1361 	struct sk_buff *msg;
1362 	void *reply;
1363 
1364 	msg = nlmsg_new(NLMSG_DEFAULT_SIZE, GFP_KERNEL);
1365 	if (!msg)
1366 		return -ENOMEM;
1367 
1368 	reply = genlmsg_put_reply(msg, info, &mptcp_genl_family, 0,
1369 				  MPTCP_PM_CMD_GET_LIMITS);
1370 	if (!reply)
1371 		goto fail;
1372 
1373 	if (nla_put_u32(msg, MPTCP_PM_ATTR_RCV_ADD_ADDRS,
1374 			READ_ONCE(pernet->add_addr_accept_max)))
1375 		goto fail;
1376 
1377 	if (nla_put_u32(msg, MPTCP_PM_ATTR_SUBFLOWS,
1378 			READ_ONCE(pernet->subflows_max)))
1379 		goto fail;
1380 
1381 	genlmsg_end(msg, reply);
1382 	return genlmsg_reply(msg, info);
1383 
1384 fail:
1385 	GENL_SET_ERR_MSG(info, "not enough space in Netlink message");
1386 	nlmsg_free(msg);
1387 	return -EMSGSIZE;
1388 }
1389 
1390 static int mptcp_nl_addr_backup(struct net *net,
1391 				struct mptcp_addr_info *addr,
1392 				u8 bkup)
1393 {
1394 	long s_slot = 0, s_num = 0;
1395 	struct mptcp_sock *msk;
1396 	int ret = -EINVAL;
1397 
1398 	while ((msk = mptcp_token_iter_next(net, &s_slot, &s_num)) != NULL) {
1399 		struct sock *sk = (struct sock *)msk;
1400 
1401 		if (list_empty(&msk->conn_list))
1402 			goto next;
1403 
1404 		lock_sock(sk);
1405 		spin_lock_bh(&msk->pm.lock);
1406 		ret = mptcp_pm_nl_mp_prio_send_ack(msk, addr, bkup);
1407 		spin_unlock_bh(&msk->pm.lock);
1408 		release_sock(sk);
1409 
1410 next:
1411 		sock_put(sk);
1412 		cond_resched();
1413 	}
1414 
1415 	return ret;
1416 }
1417 
1418 static int mptcp_nl_cmd_set_flags(struct sk_buff *skb, struct genl_info *info)
1419 {
1420 	struct nlattr *attr = info->attrs[MPTCP_PM_ATTR_ADDR];
1421 	struct pm_nl_pernet *pernet = genl_info_pm_nl(info);
1422 	struct mptcp_pm_addr_entry addr, *entry;
1423 	struct net *net = sock_net(skb->sk);
1424 	u8 bkup = 0;
1425 	int ret;
1426 
1427 	ret = mptcp_pm_parse_addr(attr, info, true, &addr);
1428 	if (ret < 0)
1429 		return ret;
1430 
1431 	if (addr.addr.flags & MPTCP_PM_ADDR_FLAG_BACKUP)
1432 		bkup = 1;
1433 
1434 	list_for_each_entry(entry, &pernet->local_addr_list, list) {
1435 		if (addresses_equal(&entry->addr, &addr.addr, true)) {
1436 			ret = mptcp_nl_addr_backup(net, &entry->addr, bkup);
1437 			if (ret)
1438 				return ret;
1439 
1440 			if (bkup)
1441 				entry->addr.flags |= MPTCP_PM_ADDR_FLAG_BACKUP;
1442 			else
1443 				entry->addr.flags &= ~MPTCP_PM_ADDR_FLAG_BACKUP;
1444 		}
1445 	}
1446 
1447 	return 0;
1448 }
1449 
1450 static const struct genl_small_ops mptcp_pm_ops[] = {
1451 	{
1452 		.cmd    = MPTCP_PM_CMD_ADD_ADDR,
1453 		.doit   = mptcp_nl_cmd_add_addr,
1454 		.flags  = GENL_ADMIN_PERM,
1455 	},
1456 	{
1457 		.cmd    = MPTCP_PM_CMD_DEL_ADDR,
1458 		.doit   = mptcp_nl_cmd_del_addr,
1459 		.flags  = GENL_ADMIN_PERM,
1460 	},
1461 	{
1462 		.cmd    = MPTCP_PM_CMD_FLUSH_ADDRS,
1463 		.doit   = mptcp_nl_cmd_flush_addrs,
1464 		.flags  = GENL_ADMIN_PERM,
1465 	},
1466 	{
1467 		.cmd    = MPTCP_PM_CMD_GET_ADDR,
1468 		.doit   = mptcp_nl_cmd_get_addr,
1469 		.dumpit   = mptcp_nl_cmd_dump_addrs,
1470 	},
1471 	{
1472 		.cmd    = MPTCP_PM_CMD_SET_LIMITS,
1473 		.doit   = mptcp_nl_cmd_set_limits,
1474 		.flags  = GENL_ADMIN_PERM,
1475 	},
1476 	{
1477 		.cmd    = MPTCP_PM_CMD_GET_LIMITS,
1478 		.doit   = mptcp_nl_cmd_get_limits,
1479 	},
1480 	{
1481 		.cmd    = MPTCP_PM_CMD_SET_FLAGS,
1482 		.doit   = mptcp_nl_cmd_set_flags,
1483 		.flags  = GENL_ADMIN_PERM,
1484 	},
1485 };
1486 
1487 static struct genl_family mptcp_genl_family __ro_after_init = {
1488 	.name		= MPTCP_PM_NAME,
1489 	.version	= MPTCP_PM_VER,
1490 	.maxattr	= MPTCP_PM_ATTR_MAX,
1491 	.policy		= mptcp_pm_policy,
1492 	.netnsok	= true,
1493 	.module		= THIS_MODULE,
1494 	.small_ops	= mptcp_pm_ops,
1495 	.n_small_ops	= ARRAY_SIZE(mptcp_pm_ops),
1496 	.mcgrps		= mptcp_pm_mcgrps,
1497 	.n_mcgrps	= ARRAY_SIZE(mptcp_pm_mcgrps),
1498 };
1499 
1500 static int __net_init pm_nl_init_net(struct net *net)
1501 {
1502 	struct pm_nl_pernet *pernet = net_generic(net, pm_nl_pernet_id);
1503 
1504 	INIT_LIST_HEAD_RCU(&pernet->local_addr_list);
1505 	__reset_counters(pernet);
1506 	pernet->next_id = 1;
1507 	bitmap_zero(pernet->id_bitmap, MAX_ADDR_ID + 1);
1508 	spin_lock_init(&pernet->lock);
1509 	return 0;
1510 }
1511 
1512 static void __net_exit pm_nl_exit_net(struct list_head *net_list)
1513 {
1514 	struct net *net;
1515 
1516 	list_for_each_entry(net, net_list, exit_list) {
1517 		struct pm_nl_pernet *pernet = net_generic(net, pm_nl_pernet_id);
1518 
1519 		/* net is removed from namespace list, can't race with
1520 		 * other modifiers
1521 		 */
1522 		__flush_addrs(net, &pernet->local_addr_list);
1523 	}
1524 }
1525 
1526 static struct pernet_operations mptcp_pm_pernet_ops = {
1527 	.init = pm_nl_init_net,
1528 	.exit_batch = pm_nl_exit_net,
1529 	.id = &pm_nl_pernet_id,
1530 	.size = sizeof(struct pm_nl_pernet),
1531 };
1532 
1533 void __init mptcp_pm_nl_init(void)
1534 {
1535 	if (register_pernet_subsys(&mptcp_pm_pernet_ops) < 0)
1536 		panic("Failed to register MPTCP PM pernet subsystem.\n");
1537 
1538 	if (genl_register_family(&mptcp_genl_family))
1539 		panic("Failed to register MPTCP PM netlink family\n");
1540 }
1541