1 // SPDX-License-Identifier: GPL-2.0 2 /* Multipath TCP 3 * 4 * Copyright (c) 2020, Red Hat, Inc. 5 */ 6 7 #define pr_fmt(fmt) "MPTCP: " fmt 8 9 #include <linux/inet.h> 10 #include <linux/kernel.h> 11 #include <net/tcp.h> 12 #include <net/netns/generic.h> 13 #include <net/mptcp.h> 14 #include <net/genetlink.h> 15 #include <uapi/linux/mptcp.h> 16 17 #include "protocol.h" 18 #include "mib.h" 19 20 /* forward declaration */ 21 static struct genl_family mptcp_genl_family; 22 23 static int pm_nl_pernet_id; 24 25 struct mptcp_pm_addr_entry { 26 struct list_head list; 27 struct mptcp_addr_info addr; 28 struct rcu_head rcu; 29 struct socket *lsk; 30 }; 31 32 struct mptcp_pm_add_entry { 33 struct list_head list; 34 struct mptcp_addr_info addr; 35 struct timer_list add_timer; 36 struct mptcp_sock *sock; 37 u8 retrans_times; 38 }; 39 40 #define MAX_ADDR_ID 255 41 #define BITMAP_SZ DIV_ROUND_UP(MAX_ADDR_ID + 1, BITS_PER_LONG) 42 43 struct pm_nl_pernet { 44 /* protects pernet updates */ 45 spinlock_t lock; 46 struct list_head local_addr_list; 47 unsigned int addrs; 48 unsigned int add_addr_signal_max; 49 unsigned int add_addr_accept_max; 50 unsigned int local_addr_max; 51 unsigned int subflows_max; 52 unsigned int next_id; 53 unsigned long id_bitmap[BITMAP_SZ]; 54 }; 55 56 #define MPTCP_PM_ADDR_MAX 8 57 #define ADD_ADDR_RETRANS_MAX 3 58 59 static bool addresses_equal(const struct mptcp_addr_info *a, 60 struct mptcp_addr_info *b, bool use_port) 61 { 62 bool addr_equals = false; 63 64 if (a->family == b->family) { 65 if (a->family == AF_INET) 66 addr_equals = a->addr.s_addr == b->addr.s_addr; 67 #if IS_ENABLED(CONFIG_MPTCP_IPV6) 68 else 69 addr_equals = !ipv6_addr_cmp(&a->addr6, &b->addr6); 70 } else if (a->family == AF_INET) { 71 if (ipv6_addr_v4mapped(&b->addr6)) 72 addr_equals = a->addr.s_addr == b->addr6.s6_addr32[3]; 73 } else if (b->family == AF_INET) { 74 if (ipv6_addr_v4mapped(&a->addr6)) 75 addr_equals = a->addr6.s6_addr32[3] == b->addr.s_addr; 76 #endif 77 } 78 79 if (!addr_equals) 80 return false; 81 if (!use_port) 82 return true; 83 84 return a->port == b->port; 85 } 86 87 static bool address_zero(const struct mptcp_addr_info *addr) 88 { 89 struct mptcp_addr_info zero; 90 91 memset(&zero, 0, sizeof(zero)); 92 zero.family = addr->family; 93 94 return addresses_equal(addr, &zero, true); 95 } 96 97 static void local_address(const struct sock_common *skc, 98 struct mptcp_addr_info *addr) 99 { 100 addr->family = skc->skc_family; 101 addr->port = htons(skc->skc_num); 102 if (addr->family == AF_INET) 103 addr->addr.s_addr = skc->skc_rcv_saddr; 104 #if IS_ENABLED(CONFIG_MPTCP_IPV6) 105 else if (addr->family == AF_INET6) 106 addr->addr6 = skc->skc_v6_rcv_saddr; 107 #endif 108 } 109 110 static void remote_address(const struct sock_common *skc, 111 struct mptcp_addr_info *addr) 112 { 113 addr->family = skc->skc_family; 114 addr->port = skc->skc_dport; 115 if (addr->family == AF_INET) 116 addr->addr.s_addr = skc->skc_daddr; 117 #if IS_ENABLED(CONFIG_MPTCP_IPV6) 118 else if (addr->family == AF_INET6) 119 addr->addr6 = skc->skc_v6_daddr; 120 #endif 121 } 122 123 static bool lookup_subflow_by_saddr(const struct list_head *list, 124 struct mptcp_addr_info *saddr) 125 { 126 struct mptcp_subflow_context *subflow; 127 struct mptcp_addr_info cur; 128 struct sock_common *skc; 129 130 list_for_each_entry(subflow, list, node) { 131 skc = (struct sock_common *)mptcp_subflow_tcp_sock(subflow); 132 133 local_address(skc, &cur); 134 if (addresses_equal(&cur, saddr, saddr->port)) 135 return true; 136 } 137 138 return false; 139 } 140 141 static struct mptcp_pm_addr_entry * 142 select_local_address(const struct pm_nl_pernet *pernet, 143 struct mptcp_sock *msk) 144 { 145 struct mptcp_pm_addr_entry *entry, *ret = NULL; 146 struct sock *sk = (struct sock *)msk; 147 148 msk_owned_by_me(msk); 149 150 rcu_read_lock(); 151 __mptcp_flush_join_list(msk); 152 list_for_each_entry_rcu(entry, &pernet->local_addr_list, list) { 153 if (!(entry->addr.flags & MPTCP_PM_ADDR_FLAG_SUBFLOW)) 154 continue; 155 156 if (entry->addr.family != sk->sk_family) { 157 #if IS_ENABLED(CONFIG_MPTCP_IPV6) 158 if ((entry->addr.family == AF_INET && 159 !ipv6_addr_v4mapped(&sk->sk_v6_daddr)) || 160 (sk->sk_family == AF_INET && 161 !ipv6_addr_v4mapped(&entry->addr.addr6))) 162 #endif 163 continue; 164 } 165 166 /* avoid any address already in use by subflows and 167 * pending join 168 */ 169 if (!lookup_subflow_by_saddr(&msk->conn_list, &entry->addr)) { 170 ret = entry; 171 break; 172 } 173 } 174 rcu_read_unlock(); 175 return ret; 176 } 177 178 static struct mptcp_pm_addr_entry * 179 select_signal_address(struct pm_nl_pernet *pernet, unsigned int pos) 180 { 181 struct mptcp_pm_addr_entry *entry, *ret = NULL; 182 int i = 0; 183 184 rcu_read_lock(); 185 /* do not keep any additional per socket state, just signal 186 * the address list in order. 187 * Note: removal from the local address list during the msk life-cycle 188 * can lead to additional addresses not being announced. 189 */ 190 list_for_each_entry_rcu(entry, &pernet->local_addr_list, list) { 191 if (!(entry->addr.flags & MPTCP_PM_ADDR_FLAG_SIGNAL)) 192 continue; 193 if (i++ == pos) { 194 ret = entry; 195 break; 196 } 197 } 198 rcu_read_unlock(); 199 return ret; 200 } 201 202 unsigned int mptcp_pm_get_add_addr_signal_max(struct mptcp_sock *msk) 203 { 204 struct pm_nl_pernet *pernet; 205 206 pernet = net_generic(sock_net((struct sock *)msk), pm_nl_pernet_id); 207 return READ_ONCE(pernet->add_addr_signal_max); 208 } 209 EXPORT_SYMBOL_GPL(mptcp_pm_get_add_addr_signal_max); 210 211 unsigned int mptcp_pm_get_add_addr_accept_max(struct mptcp_sock *msk) 212 { 213 struct pm_nl_pernet *pernet; 214 215 pernet = net_generic(sock_net((struct sock *)msk), pm_nl_pernet_id); 216 return READ_ONCE(pernet->add_addr_accept_max); 217 } 218 EXPORT_SYMBOL_GPL(mptcp_pm_get_add_addr_accept_max); 219 220 unsigned int mptcp_pm_get_subflows_max(struct mptcp_sock *msk) 221 { 222 struct pm_nl_pernet *pernet; 223 224 pernet = net_generic(sock_net((struct sock *)msk), pm_nl_pernet_id); 225 return READ_ONCE(pernet->subflows_max); 226 } 227 EXPORT_SYMBOL_GPL(mptcp_pm_get_subflows_max); 228 229 static unsigned int mptcp_pm_get_local_addr_max(struct mptcp_sock *msk) 230 { 231 struct pm_nl_pernet *pernet; 232 233 pernet = net_generic(sock_net((struct sock *)msk), pm_nl_pernet_id); 234 return READ_ONCE(pernet->local_addr_max); 235 } 236 237 static void check_work_pending(struct mptcp_sock *msk) 238 { 239 if (msk->pm.add_addr_signaled == mptcp_pm_get_add_addr_signal_max(msk) && 240 (msk->pm.local_addr_used == mptcp_pm_get_local_addr_max(msk) || 241 msk->pm.subflows == mptcp_pm_get_subflows_max(msk))) 242 WRITE_ONCE(msk->pm.work_pending, false); 243 } 244 245 static struct mptcp_pm_add_entry * 246 lookup_anno_list_by_saddr(struct mptcp_sock *msk, 247 struct mptcp_addr_info *addr) 248 { 249 struct mptcp_pm_add_entry *entry; 250 251 lockdep_assert_held(&msk->pm.lock); 252 253 list_for_each_entry(entry, &msk->pm.anno_list, list) { 254 if (addresses_equal(&entry->addr, addr, true)) 255 return entry; 256 } 257 258 return NULL; 259 } 260 261 bool mptcp_pm_sport_in_anno_list(struct mptcp_sock *msk, const struct sock *sk) 262 { 263 struct mptcp_pm_add_entry *entry; 264 struct mptcp_addr_info saddr; 265 bool ret = false; 266 267 local_address((struct sock_common *)sk, &saddr); 268 269 spin_lock_bh(&msk->pm.lock); 270 list_for_each_entry(entry, &msk->pm.anno_list, list) { 271 if (addresses_equal(&entry->addr, &saddr, true)) { 272 ret = true; 273 goto out; 274 } 275 } 276 277 out: 278 spin_unlock_bh(&msk->pm.lock); 279 return ret; 280 } 281 282 static void mptcp_pm_add_timer(struct timer_list *timer) 283 { 284 struct mptcp_pm_add_entry *entry = from_timer(entry, timer, add_timer); 285 struct mptcp_sock *msk = entry->sock; 286 struct sock *sk = (struct sock *)msk; 287 288 pr_debug("msk=%p", msk); 289 290 if (!msk) 291 return; 292 293 if (inet_sk_state_load(sk) == TCP_CLOSE) 294 return; 295 296 if (!entry->addr.id) 297 return; 298 299 if (mptcp_pm_should_add_signal(msk)) { 300 sk_reset_timer(sk, timer, jiffies + TCP_RTO_MAX / 8); 301 goto out; 302 } 303 304 spin_lock_bh(&msk->pm.lock); 305 306 if (!mptcp_pm_should_add_signal(msk)) { 307 pr_debug("retransmit ADD_ADDR id=%d", entry->addr.id); 308 mptcp_pm_announce_addr(msk, &entry->addr, false, entry->addr.port); 309 mptcp_pm_add_addr_send_ack(msk); 310 entry->retrans_times++; 311 } 312 313 if (entry->retrans_times < ADD_ADDR_RETRANS_MAX) 314 sk_reset_timer(sk, timer, 315 jiffies + mptcp_get_add_addr_timeout(sock_net(sk))); 316 317 spin_unlock_bh(&msk->pm.lock); 318 319 out: 320 __sock_put(sk); 321 } 322 323 struct mptcp_pm_add_entry * 324 mptcp_pm_del_add_timer(struct mptcp_sock *msk, 325 struct mptcp_addr_info *addr) 326 { 327 struct mptcp_pm_add_entry *entry; 328 struct sock *sk = (struct sock *)msk; 329 330 spin_lock_bh(&msk->pm.lock); 331 entry = lookup_anno_list_by_saddr(msk, addr); 332 if (entry) 333 entry->retrans_times = ADD_ADDR_RETRANS_MAX; 334 spin_unlock_bh(&msk->pm.lock); 335 336 if (entry) 337 sk_stop_timer_sync(sk, &entry->add_timer); 338 339 return entry; 340 } 341 342 static bool mptcp_pm_alloc_anno_list(struct mptcp_sock *msk, 343 struct mptcp_pm_addr_entry *entry) 344 { 345 struct mptcp_pm_add_entry *add_entry = NULL; 346 struct sock *sk = (struct sock *)msk; 347 struct net *net = sock_net(sk); 348 349 lockdep_assert_held(&msk->pm.lock); 350 351 if (lookup_anno_list_by_saddr(msk, &entry->addr)) 352 return false; 353 354 add_entry = kmalloc(sizeof(*add_entry), GFP_ATOMIC); 355 if (!add_entry) 356 return false; 357 358 list_add(&add_entry->list, &msk->pm.anno_list); 359 360 add_entry->addr = entry->addr; 361 add_entry->sock = msk; 362 add_entry->retrans_times = 0; 363 364 timer_setup(&add_entry->add_timer, mptcp_pm_add_timer, 0); 365 sk_reset_timer(sk, &add_entry->add_timer, 366 jiffies + mptcp_get_add_addr_timeout(net)); 367 368 return true; 369 } 370 371 void mptcp_pm_free_anno_list(struct mptcp_sock *msk) 372 { 373 struct mptcp_pm_add_entry *entry, *tmp; 374 struct sock *sk = (struct sock *)msk; 375 LIST_HEAD(free_list); 376 377 pr_debug("msk=%p", msk); 378 379 spin_lock_bh(&msk->pm.lock); 380 list_splice_init(&msk->pm.anno_list, &free_list); 381 spin_unlock_bh(&msk->pm.lock); 382 383 list_for_each_entry_safe(entry, tmp, &free_list, list) { 384 sk_stop_timer_sync(sk, &entry->add_timer); 385 kfree(entry); 386 } 387 } 388 389 static void mptcp_pm_create_subflow_or_signal_addr(struct mptcp_sock *msk) 390 { 391 struct sock *sk = (struct sock *)msk; 392 struct mptcp_pm_addr_entry *local; 393 unsigned int add_addr_signal_max; 394 unsigned int local_addr_max; 395 struct pm_nl_pernet *pernet; 396 unsigned int subflows_max; 397 398 pernet = net_generic(sock_net(sk), pm_nl_pernet_id); 399 400 add_addr_signal_max = mptcp_pm_get_add_addr_signal_max(msk); 401 local_addr_max = mptcp_pm_get_local_addr_max(msk); 402 subflows_max = mptcp_pm_get_subflows_max(msk); 403 404 pr_debug("local %d:%d signal %d:%d subflows %d:%d\n", 405 msk->pm.local_addr_used, local_addr_max, 406 msk->pm.add_addr_signaled, add_addr_signal_max, 407 msk->pm.subflows, subflows_max); 408 409 /* check first for announce */ 410 if (msk->pm.add_addr_signaled < add_addr_signal_max) { 411 local = select_signal_address(pernet, 412 msk->pm.add_addr_signaled); 413 414 if (local) { 415 if (mptcp_pm_alloc_anno_list(msk, local)) { 416 msk->pm.add_addr_signaled++; 417 mptcp_pm_announce_addr(msk, &local->addr, false, local->addr.port); 418 mptcp_pm_nl_add_addr_send_ack(msk); 419 } 420 } else { 421 /* pick failed, avoid fourther attempts later */ 422 msk->pm.local_addr_used = add_addr_signal_max; 423 } 424 425 check_work_pending(msk); 426 } 427 428 /* check if should create a new subflow */ 429 if (msk->pm.local_addr_used < local_addr_max && 430 msk->pm.subflows < subflows_max) { 431 local = select_local_address(pernet, msk); 432 if (local) { 433 struct mptcp_addr_info remote = { 0 }; 434 435 msk->pm.local_addr_used++; 436 msk->pm.subflows++; 437 check_work_pending(msk); 438 remote_address((struct sock_common *)sk, &remote); 439 spin_unlock_bh(&msk->pm.lock); 440 __mptcp_subflow_connect(sk, &local->addr, &remote); 441 spin_lock_bh(&msk->pm.lock); 442 return; 443 } 444 445 /* lookup failed, avoid fourther attempts later */ 446 msk->pm.local_addr_used = local_addr_max; 447 check_work_pending(msk); 448 } 449 } 450 451 void mptcp_pm_nl_fully_established(struct mptcp_sock *msk) 452 { 453 mptcp_pm_create_subflow_or_signal_addr(msk); 454 } 455 456 void mptcp_pm_nl_subflow_established(struct mptcp_sock *msk) 457 { 458 mptcp_pm_create_subflow_or_signal_addr(msk); 459 } 460 461 void mptcp_pm_nl_add_addr_received(struct mptcp_sock *msk) 462 { 463 struct sock *sk = (struct sock *)msk; 464 unsigned int add_addr_accept_max; 465 struct mptcp_addr_info remote; 466 struct mptcp_addr_info local; 467 unsigned int subflows_max; 468 bool use_port = false; 469 470 add_addr_accept_max = mptcp_pm_get_add_addr_accept_max(msk); 471 subflows_max = mptcp_pm_get_subflows_max(msk); 472 473 pr_debug("accepted %d:%d remote family %d", 474 msk->pm.add_addr_accepted, add_addr_accept_max, 475 msk->pm.remote.family); 476 msk->pm.add_addr_accepted++; 477 msk->pm.subflows++; 478 if (msk->pm.add_addr_accepted >= add_addr_accept_max || 479 msk->pm.subflows >= subflows_max) 480 WRITE_ONCE(msk->pm.accept_addr, false); 481 482 /* connect to the specified remote address, using whatever 483 * local address the routing configuration will pick. 484 */ 485 remote = msk->pm.remote; 486 if (!remote.port) 487 remote.port = sk->sk_dport; 488 else 489 use_port = true; 490 memset(&local, 0, sizeof(local)); 491 local.family = remote.family; 492 493 spin_unlock_bh(&msk->pm.lock); 494 __mptcp_subflow_connect(sk, &local, &remote); 495 spin_lock_bh(&msk->pm.lock); 496 497 mptcp_pm_announce_addr(msk, &remote, true, use_port); 498 mptcp_pm_nl_add_addr_send_ack(msk); 499 } 500 501 void mptcp_pm_nl_add_addr_send_ack(struct mptcp_sock *msk) 502 { 503 struct mptcp_subflow_context *subflow; 504 505 msk_owned_by_me(msk); 506 lockdep_assert_held(&msk->pm.lock); 507 508 if (!mptcp_pm_should_add_signal(msk)) 509 return; 510 511 __mptcp_flush_join_list(msk); 512 subflow = list_first_entry_or_null(&msk->conn_list, typeof(*subflow), node); 513 if (subflow) { 514 struct sock *ssk = mptcp_subflow_tcp_sock(subflow); 515 u8 add_addr; 516 517 spin_unlock_bh(&msk->pm.lock); 518 pr_debug("send ack for add_addr%s%s", 519 mptcp_pm_should_add_signal_ipv6(msk) ? " [ipv6]" : "", 520 mptcp_pm_should_add_signal_port(msk) ? " [port]" : ""); 521 522 lock_sock(ssk); 523 tcp_send_ack(ssk); 524 release_sock(ssk); 525 spin_lock_bh(&msk->pm.lock); 526 527 add_addr = READ_ONCE(msk->pm.addr_signal); 528 if (mptcp_pm_should_add_signal_ipv6(msk)) 529 add_addr &= ~BIT(MPTCP_ADD_ADDR_IPV6); 530 if (mptcp_pm_should_add_signal_port(msk)) 531 add_addr &= ~BIT(MPTCP_ADD_ADDR_PORT); 532 WRITE_ONCE(msk->pm.addr_signal, add_addr); 533 } 534 } 535 536 int mptcp_pm_nl_mp_prio_send_ack(struct mptcp_sock *msk, 537 struct mptcp_addr_info *addr, 538 u8 bkup) 539 { 540 struct mptcp_subflow_context *subflow; 541 542 pr_debug("bkup=%d", bkup); 543 544 mptcp_for_each_subflow(msk, subflow) { 545 struct sock *ssk = mptcp_subflow_tcp_sock(subflow); 546 struct sock *sk = (struct sock *)msk; 547 struct mptcp_addr_info local; 548 549 local_address((struct sock_common *)ssk, &local); 550 if (!addresses_equal(&local, addr, addr->port)) 551 continue; 552 553 subflow->backup = bkup; 554 subflow->send_mp_prio = 1; 555 subflow->request_bkup = bkup; 556 __MPTCP_INC_STATS(sock_net(sk), MPTCP_MIB_MPPRIOTX); 557 558 spin_unlock_bh(&msk->pm.lock); 559 pr_debug("send ack for mp_prio"); 560 lock_sock(ssk); 561 tcp_send_ack(ssk); 562 release_sock(ssk); 563 spin_lock_bh(&msk->pm.lock); 564 565 return 0; 566 } 567 568 return -EINVAL; 569 } 570 571 void mptcp_pm_nl_rm_addr_received(struct mptcp_sock *msk) 572 { 573 struct mptcp_subflow_context *subflow, *tmp; 574 struct sock *sk = (struct sock *)msk; 575 576 pr_debug("address rm_id %d", msk->pm.rm_id); 577 578 msk_owned_by_me(msk); 579 580 if (!msk->pm.rm_id) 581 return; 582 583 if (list_empty(&msk->conn_list)) 584 return; 585 586 list_for_each_entry_safe(subflow, tmp, &msk->conn_list, node) { 587 struct sock *ssk = mptcp_subflow_tcp_sock(subflow); 588 int how = RCV_SHUTDOWN | SEND_SHUTDOWN; 589 590 if (msk->pm.rm_id != subflow->remote_id) 591 continue; 592 593 spin_unlock_bh(&msk->pm.lock); 594 mptcp_subflow_shutdown(sk, ssk, how); 595 __mptcp_close_ssk(sk, ssk, subflow); 596 spin_lock_bh(&msk->pm.lock); 597 598 msk->pm.add_addr_accepted--; 599 msk->pm.subflows--; 600 WRITE_ONCE(msk->pm.accept_addr, true); 601 602 __MPTCP_INC_STATS(sock_net(sk), MPTCP_MIB_RMADDR); 603 604 break; 605 } 606 } 607 608 void mptcp_pm_nl_rm_subflow_received(struct mptcp_sock *msk, u8 rm_id) 609 { 610 struct mptcp_subflow_context *subflow, *tmp; 611 struct sock *sk = (struct sock *)msk; 612 613 pr_debug("subflow rm_id %d", rm_id); 614 615 msk_owned_by_me(msk); 616 617 if (!rm_id) 618 return; 619 620 if (list_empty(&msk->conn_list)) 621 return; 622 623 list_for_each_entry_safe(subflow, tmp, &msk->conn_list, node) { 624 struct sock *ssk = mptcp_subflow_tcp_sock(subflow); 625 int how = RCV_SHUTDOWN | SEND_SHUTDOWN; 626 627 if (rm_id != subflow->local_id) 628 continue; 629 630 spin_unlock_bh(&msk->pm.lock); 631 mptcp_subflow_shutdown(sk, ssk, how); 632 __mptcp_close_ssk(sk, ssk, subflow); 633 spin_lock_bh(&msk->pm.lock); 634 635 msk->pm.local_addr_used--; 636 msk->pm.subflows--; 637 638 __MPTCP_INC_STATS(sock_net(sk), MPTCP_MIB_RMSUBFLOW); 639 640 break; 641 } 642 } 643 644 static bool address_use_port(struct mptcp_pm_addr_entry *entry) 645 { 646 return (entry->addr.flags & 647 (MPTCP_PM_ADDR_FLAG_SIGNAL | MPTCP_PM_ADDR_FLAG_SUBFLOW)) == 648 MPTCP_PM_ADDR_FLAG_SIGNAL; 649 } 650 651 static int mptcp_pm_nl_append_new_local_addr(struct pm_nl_pernet *pernet, 652 struct mptcp_pm_addr_entry *entry) 653 { 654 struct mptcp_pm_addr_entry *cur; 655 unsigned int addr_max; 656 int ret = -EINVAL; 657 658 spin_lock_bh(&pernet->lock); 659 /* to keep the code simple, don't do IDR-like allocation for address ID, 660 * just bail when we exceed limits 661 */ 662 if (pernet->next_id == MAX_ADDR_ID) 663 pernet->next_id = 1; 664 if (pernet->addrs >= MPTCP_PM_ADDR_MAX) 665 goto out; 666 if (test_bit(entry->addr.id, pernet->id_bitmap)) 667 goto out; 668 669 /* do not insert duplicate address, differentiate on port only 670 * singled addresses 671 */ 672 list_for_each_entry(cur, &pernet->local_addr_list, list) { 673 if (addresses_equal(&cur->addr, &entry->addr, 674 address_use_port(entry) && 675 address_use_port(cur))) 676 goto out; 677 } 678 679 if (!entry->addr.id) { 680 find_next: 681 entry->addr.id = find_next_zero_bit(pernet->id_bitmap, 682 MAX_ADDR_ID + 1, 683 pernet->next_id); 684 if ((!entry->addr.id || entry->addr.id > MAX_ADDR_ID) && 685 pernet->next_id != 1) { 686 pernet->next_id = 1; 687 goto find_next; 688 } 689 } 690 691 if (!entry->addr.id || entry->addr.id > MAX_ADDR_ID) 692 goto out; 693 694 __set_bit(entry->addr.id, pernet->id_bitmap); 695 if (entry->addr.id > pernet->next_id) 696 pernet->next_id = entry->addr.id; 697 698 if (entry->addr.flags & MPTCP_PM_ADDR_FLAG_SIGNAL) { 699 addr_max = pernet->add_addr_signal_max; 700 WRITE_ONCE(pernet->add_addr_signal_max, addr_max + 1); 701 } 702 if (entry->addr.flags & MPTCP_PM_ADDR_FLAG_SUBFLOW) { 703 addr_max = pernet->local_addr_max; 704 WRITE_ONCE(pernet->local_addr_max, addr_max + 1); 705 } 706 707 pernet->addrs++; 708 list_add_tail_rcu(&entry->list, &pernet->local_addr_list); 709 ret = entry->addr.id; 710 711 out: 712 spin_unlock_bh(&pernet->lock); 713 return ret; 714 } 715 716 static int mptcp_pm_nl_create_listen_socket(struct sock *sk, 717 struct mptcp_pm_addr_entry *entry) 718 { 719 struct sockaddr_storage addr; 720 struct mptcp_sock *msk; 721 struct socket *ssock; 722 int backlog = 1024; 723 int err; 724 725 err = sock_create_kern(sock_net(sk), entry->addr.family, 726 SOCK_STREAM, IPPROTO_MPTCP, &entry->lsk); 727 if (err) 728 return err; 729 730 msk = mptcp_sk(entry->lsk->sk); 731 if (!msk) { 732 err = -EINVAL; 733 goto out; 734 } 735 736 ssock = __mptcp_nmpc_socket(msk); 737 if (!ssock) { 738 err = -EINVAL; 739 goto out; 740 } 741 742 mptcp_info2sockaddr(&entry->addr, &addr, entry->addr.family); 743 err = kernel_bind(ssock, (struct sockaddr *)&addr, 744 sizeof(struct sockaddr_in)); 745 if (err) { 746 pr_warn("kernel_bind error, err=%d", err); 747 goto out; 748 } 749 750 err = kernel_listen(ssock, backlog); 751 if (err) { 752 pr_warn("kernel_listen error, err=%d", err); 753 goto out; 754 } 755 756 return 0; 757 758 out: 759 sock_release(entry->lsk); 760 return err; 761 } 762 763 int mptcp_pm_nl_get_local_id(struct mptcp_sock *msk, struct sock_common *skc) 764 { 765 struct mptcp_pm_addr_entry *entry; 766 struct mptcp_addr_info skc_local; 767 struct mptcp_addr_info msk_local; 768 struct pm_nl_pernet *pernet; 769 int ret = -1; 770 771 if (WARN_ON_ONCE(!msk)) 772 return -1; 773 774 /* The 0 ID mapping is defined by the first subflow, copied into the msk 775 * addr 776 */ 777 local_address((struct sock_common *)msk, &msk_local); 778 local_address((struct sock_common *)skc, &skc_local); 779 if (addresses_equal(&msk_local, &skc_local, false)) 780 return 0; 781 782 if (address_zero(&skc_local)) 783 return 0; 784 785 pernet = net_generic(sock_net((struct sock *)msk), pm_nl_pernet_id); 786 787 rcu_read_lock(); 788 list_for_each_entry_rcu(entry, &pernet->local_addr_list, list) { 789 if (addresses_equal(&entry->addr, &skc_local, entry->addr.port)) { 790 ret = entry->addr.id; 791 break; 792 } 793 } 794 rcu_read_unlock(); 795 if (ret >= 0) 796 return ret; 797 798 /* address not found, add to local list */ 799 entry = kmalloc(sizeof(*entry), GFP_ATOMIC); 800 if (!entry) 801 return -ENOMEM; 802 803 entry->addr = skc_local; 804 entry->addr.ifindex = 0; 805 entry->addr.flags = 0; 806 entry->addr.id = 0; 807 entry->addr.port = 0; 808 entry->lsk = NULL; 809 ret = mptcp_pm_nl_append_new_local_addr(pernet, entry); 810 if (ret < 0) 811 kfree(entry); 812 813 return ret; 814 } 815 816 void mptcp_pm_nl_data_init(struct mptcp_sock *msk) 817 { 818 struct mptcp_pm_data *pm = &msk->pm; 819 bool subflows; 820 821 subflows = !!mptcp_pm_get_subflows_max(msk); 822 WRITE_ONCE(pm->work_pending, (!!mptcp_pm_get_local_addr_max(msk) && subflows) || 823 !!mptcp_pm_get_add_addr_signal_max(msk)); 824 WRITE_ONCE(pm->accept_addr, !!mptcp_pm_get_add_addr_accept_max(msk) && subflows); 825 WRITE_ONCE(pm->accept_subflow, subflows); 826 } 827 828 #define MPTCP_PM_CMD_GRP_OFFSET 0 829 830 static const struct genl_multicast_group mptcp_pm_mcgrps[] = { 831 [MPTCP_PM_CMD_GRP_OFFSET] = { .name = MPTCP_PM_CMD_GRP_NAME, }, 832 }; 833 834 static const struct nla_policy 835 mptcp_pm_addr_policy[MPTCP_PM_ADDR_ATTR_MAX + 1] = { 836 [MPTCP_PM_ADDR_ATTR_FAMILY] = { .type = NLA_U16, }, 837 [MPTCP_PM_ADDR_ATTR_ID] = { .type = NLA_U8, }, 838 [MPTCP_PM_ADDR_ATTR_ADDR4] = { .type = NLA_U32, }, 839 [MPTCP_PM_ADDR_ATTR_ADDR6] = 840 NLA_POLICY_EXACT_LEN(sizeof(struct in6_addr)), 841 [MPTCP_PM_ADDR_ATTR_PORT] = { .type = NLA_U16 }, 842 [MPTCP_PM_ADDR_ATTR_FLAGS] = { .type = NLA_U32 }, 843 [MPTCP_PM_ADDR_ATTR_IF_IDX] = { .type = NLA_S32 }, 844 }; 845 846 static const struct nla_policy mptcp_pm_policy[MPTCP_PM_ATTR_MAX + 1] = { 847 [MPTCP_PM_ATTR_ADDR] = 848 NLA_POLICY_NESTED(mptcp_pm_addr_policy), 849 [MPTCP_PM_ATTR_RCV_ADD_ADDRS] = { .type = NLA_U32, }, 850 [MPTCP_PM_ATTR_SUBFLOWS] = { .type = NLA_U32, }, 851 }; 852 853 static int mptcp_pm_family_to_addr(int family) 854 { 855 #if IS_ENABLED(CONFIG_MPTCP_IPV6) 856 if (family == AF_INET6) 857 return MPTCP_PM_ADDR_ATTR_ADDR6; 858 #endif 859 return MPTCP_PM_ADDR_ATTR_ADDR4; 860 } 861 862 static int mptcp_pm_parse_addr(struct nlattr *attr, struct genl_info *info, 863 bool require_family, 864 struct mptcp_pm_addr_entry *entry) 865 { 866 struct nlattr *tb[MPTCP_PM_ADDR_ATTR_MAX + 1]; 867 int err, addr_addr; 868 869 if (!attr) { 870 GENL_SET_ERR_MSG(info, "missing address info"); 871 return -EINVAL; 872 } 873 874 /* no validation needed - was already done via nested policy */ 875 err = nla_parse_nested_deprecated(tb, MPTCP_PM_ADDR_ATTR_MAX, attr, 876 mptcp_pm_addr_policy, info->extack); 877 if (err) 878 return err; 879 880 memset(entry, 0, sizeof(*entry)); 881 if (!tb[MPTCP_PM_ADDR_ATTR_FAMILY]) { 882 if (!require_family) 883 goto skip_family; 884 885 NL_SET_ERR_MSG_ATTR(info->extack, attr, 886 "missing family"); 887 return -EINVAL; 888 } 889 890 entry->addr.family = nla_get_u16(tb[MPTCP_PM_ADDR_ATTR_FAMILY]); 891 if (entry->addr.family != AF_INET 892 #if IS_ENABLED(CONFIG_MPTCP_IPV6) 893 && entry->addr.family != AF_INET6 894 #endif 895 ) { 896 NL_SET_ERR_MSG_ATTR(info->extack, attr, 897 "unknown address family"); 898 return -EINVAL; 899 } 900 addr_addr = mptcp_pm_family_to_addr(entry->addr.family); 901 if (!tb[addr_addr]) { 902 NL_SET_ERR_MSG_ATTR(info->extack, attr, 903 "missing address data"); 904 return -EINVAL; 905 } 906 907 #if IS_ENABLED(CONFIG_MPTCP_IPV6) 908 if (entry->addr.family == AF_INET6) 909 entry->addr.addr6 = nla_get_in6_addr(tb[addr_addr]); 910 else 911 #endif 912 entry->addr.addr.s_addr = nla_get_in_addr(tb[addr_addr]); 913 914 skip_family: 915 if (tb[MPTCP_PM_ADDR_ATTR_IF_IDX]) { 916 u32 val = nla_get_s32(tb[MPTCP_PM_ADDR_ATTR_IF_IDX]); 917 918 entry->addr.ifindex = val; 919 } 920 921 if (tb[MPTCP_PM_ADDR_ATTR_ID]) 922 entry->addr.id = nla_get_u8(tb[MPTCP_PM_ADDR_ATTR_ID]); 923 924 if (tb[MPTCP_PM_ADDR_ATTR_FLAGS]) 925 entry->addr.flags = nla_get_u32(tb[MPTCP_PM_ADDR_ATTR_FLAGS]); 926 927 if (tb[MPTCP_PM_ADDR_ATTR_PORT]) 928 entry->addr.port = htons(nla_get_u16(tb[MPTCP_PM_ADDR_ATTR_PORT])); 929 930 return 0; 931 } 932 933 static struct pm_nl_pernet *genl_info_pm_nl(struct genl_info *info) 934 { 935 return net_generic(genl_info_net(info), pm_nl_pernet_id); 936 } 937 938 static int mptcp_nl_add_subflow_or_signal_addr(struct net *net) 939 { 940 struct mptcp_sock *msk; 941 long s_slot = 0, s_num = 0; 942 943 while ((msk = mptcp_token_iter_next(net, &s_slot, &s_num)) != NULL) { 944 struct sock *sk = (struct sock *)msk; 945 946 if (!READ_ONCE(msk->fully_established)) 947 goto next; 948 949 lock_sock(sk); 950 spin_lock_bh(&msk->pm.lock); 951 mptcp_pm_create_subflow_or_signal_addr(msk); 952 spin_unlock_bh(&msk->pm.lock); 953 release_sock(sk); 954 955 next: 956 sock_put(sk); 957 cond_resched(); 958 } 959 960 return 0; 961 } 962 963 static int mptcp_nl_cmd_add_addr(struct sk_buff *skb, struct genl_info *info) 964 { 965 struct nlattr *attr = info->attrs[MPTCP_PM_ATTR_ADDR]; 966 struct pm_nl_pernet *pernet = genl_info_pm_nl(info); 967 struct mptcp_pm_addr_entry addr, *entry; 968 int ret; 969 970 ret = mptcp_pm_parse_addr(attr, info, true, &addr); 971 if (ret < 0) 972 return ret; 973 974 entry = kmalloc(sizeof(*entry), GFP_KERNEL); 975 if (!entry) { 976 GENL_SET_ERR_MSG(info, "can't allocate addr"); 977 return -ENOMEM; 978 } 979 980 *entry = addr; 981 if (entry->addr.port) { 982 ret = mptcp_pm_nl_create_listen_socket(skb->sk, entry); 983 if (ret) { 984 GENL_SET_ERR_MSG(info, "create listen socket error"); 985 kfree(entry); 986 return ret; 987 } 988 } 989 ret = mptcp_pm_nl_append_new_local_addr(pernet, entry); 990 if (ret < 0) { 991 GENL_SET_ERR_MSG(info, "too many addresses or duplicate one"); 992 if (entry->lsk) 993 sock_release(entry->lsk); 994 kfree(entry); 995 return ret; 996 } 997 998 mptcp_nl_add_subflow_or_signal_addr(sock_net(skb->sk)); 999 1000 return 0; 1001 } 1002 1003 static struct mptcp_pm_addr_entry * 1004 __lookup_addr_by_id(struct pm_nl_pernet *pernet, unsigned int id) 1005 { 1006 struct mptcp_pm_addr_entry *entry; 1007 1008 list_for_each_entry(entry, &pernet->local_addr_list, list) { 1009 if (entry->addr.id == id) 1010 return entry; 1011 } 1012 return NULL; 1013 } 1014 1015 static bool remove_anno_list_by_saddr(struct mptcp_sock *msk, 1016 struct mptcp_addr_info *addr) 1017 { 1018 struct mptcp_pm_add_entry *entry; 1019 1020 entry = mptcp_pm_del_add_timer(msk, addr); 1021 if (entry) { 1022 list_del(&entry->list); 1023 kfree(entry); 1024 return true; 1025 } 1026 1027 return false; 1028 } 1029 1030 static bool mptcp_pm_remove_anno_addr(struct mptcp_sock *msk, 1031 struct mptcp_addr_info *addr, 1032 bool force) 1033 { 1034 bool ret; 1035 1036 ret = remove_anno_list_by_saddr(msk, addr); 1037 if (ret || force) { 1038 spin_lock_bh(&msk->pm.lock); 1039 mptcp_pm_remove_addr(msk, addr->id); 1040 spin_unlock_bh(&msk->pm.lock); 1041 } 1042 return ret; 1043 } 1044 1045 static int mptcp_nl_remove_subflow_and_signal_addr(struct net *net, 1046 struct mptcp_addr_info *addr) 1047 { 1048 struct mptcp_sock *msk; 1049 long s_slot = 0, s_num = 0; 1050 1051 pr_debug("remove_id=%d", addr->id); 1052 1053 while ((msk = mptcp_token_iter_next(net, &s_slot, &s_num)) != NULL) { 1054 struct sock *sk = (struct sock *)msk; 1055 bool remove_subflow; 1056 1057 if (list_empty(&msk->conn_list)) { 1058 mptcp_pm_remove_anno_addr(msk, addr, false); 1059 goto next; 1060 } 1061 1062 lock_sock(sk); 1063 remove_subflow = lookup_subflow_by_saddr(&msk->conn_list, addr); 1064 mptcp_pm_remove_anno_addr(msk, addr, remove_subflow); 1065 if (remove_subflow) 1066 mptcp_pm_remove_subflow(msk, addr->id); 1067 release_sock(sk); 1068 1069 next: 1070 sock_put(sk); 1071 cond_resched(); 1072 } 1073 1074 return 0; 1075 } 1076 1077 struct addr_entry_release_work { 1078 struct rcu_work rwork; 1079 struct mptcp_pm_addr_entry *entry; 1080 }; 1081 1082 static void mptcp_pm_release_addr_entry(struct work_struct *work) 1083 { 1084 struct addr_entry_release_work *w; 1085 struct mptcp_pm_addr_entry *entry; 1086 1087 w = container_of(to_rcu_work(work), struct addr_entry_release_work, rwork); 1088 entry = w->entry; 1089 if (entry) { 1090 if (entry->lsk) 1091 sock_release(entry->lsk); 1092 kfree(entry); 1093 } 1094 kfree(w); 1095 } 1096 1097 static void mptcp_pm_free_addr_entry(struct mptcp_pm_addr_entry *entry) 1098 { 1099 struct addr_entry_release_work *w; 1100 1101 w = kmalloc(sizeof(*w), GFP_ATOMIC); 1102 if (w) { 1103 INIT_RCU_WORK(&w->rwork, mptcp_pm_release_addr_entry); 1104 w->entry = entry; 1105 queue_rcu_work(system_wq, &w->rwork); 1106 } 1107 } 1108 1109 static int mptcp_nl_cmd_del_addr(struct sk_buff *skb, struct genl_info *info) 1110 { 1111 struct nlattr *attr = info->attrs[MPTCP_PM_ATTR_ADDR]; 1112 struct pm_nl_pernet *pernet = genl_info_pm_nl(info); 1113 struct mptcp_pm_addr_entry addr, *entry; 1114 unsigned int addr_max; 1115 int ret; 1116 1117 ret = mptcp_pm_parse_addr(attr, info, false, &addr); 1118 if (ret < 0) 1119 return ret; 1120 1121 spin_lock_bh(&pernet->lock); 1122 entry = __lookup_addr_by_id(pernet, addr.addr.id); 1123 if (!entry) { 1124 GENL_SET_ERR_MSG(info, "address not found"); 1125 spin_unlock_bh(&pernet->lock); 1126 return -EINVAL; 1127 } 1128 if (entry->addr.flags & MPTCP_PM_ADDR_FLAG_SIGNAL) { 1129 addr_max = pernet->add_addr_signal_max; 1130 WRITE_ONCE(pernet->add_addr_signal_max, addr_max - 1); 1131 } 1132 if (entry->addr.flags & MPTCP_PM_ADDR_FLAG_SUBFLOW) { 1133 addr_max = pernet->local_addr_max; 1134 WRITE_ONCE(pernet->local_addr_max, addr_max - 1); 1135 } 1136 1137 pernet->addrs--; 1138 list_del_rcu(&entry->list); 1139 __clear_bit(entry->addr.id, pernet->id_bitmap); 1140 spin_unlock_bh(&pernet->lock); 1141 1142 mptcp_nl_remove_subflow_and_signal_addr(sock_net(skb->sk), &entry->addr); 1143 mptcp_pm_free_addr_entry(entry); 1144 1145 return ret; 1146 } 1147 1148 static void __flush_addrs(struct net *net, struct list_head *list) 1149 { 1150 while (!list_empty(list)) { 1151 struct mptcp_pm_addr_entry *cur; 1152 1153 cur = list_entry(list->next, 1154 struct mptcp_pm_addr_entry, list); 1155 mptcp_nl_remove_subflow_and_signal_addr(net, &cur->addr); 1156 list_del_rcu(&cur->list); 1157 mptcp_pm_free_addr_entry(cur); 1158 } 1159 } 1160 1161 static void __reset_counters(struct pm_nl_pernet *pernet) 1162 { 1163 WRITE_ONCE(pernet->add_addr_signal_max, 0); 1164 WRITE_ONCE(pernet->add_addr_accept_max, 0); 1165 WRITE_ONCE(pernet->local_addr_max, 0); 1166 pernet->addrs = 0; 1167 } 1168 1169 static int mptcp_nl_cmd_flush_addrs(struct sk_buff *skb, struct genl_info *info) 1170 { 1171 struct pm_nl_pernet *pernet = genl_info_pm_nl(info); 1172 LIST_HEAD(free_list); 1173 1174 spin_lock_bh(&pernet->lock); 1175 list_splice_init(&pernet->local_addr_list, &free_list); 1176 __reset_counters(pernet); 1177 pernet->next_id = 1; 1178 bitmap_zero(pernet->id_bitmap, MAX_ADDR_ID + 1); 1179 spin_unlock_bh(&pernet->lock); 1180 __flush_addrs(sock_net(skb->sk), &free_list); 1181 return 0; 1182 } 1183 1184 static int mptcp_nl_fill_addr(struct sk_buff *skb, 1185 struct mptcp_pm_addr_entry *entry) 1186 { 1187 struct mptcp_addr_info *addr = &entry->addr; 1188 struct nlattr *attr; 1189 1190 attr = nla_nest_start(skb, MPTCP_PM_ATTR_ADDR); 1191 if (!attr) 1192 return -EMSGSIZE; 1193 1194 if (nla_put_u16(skb, MPTCP_PM_ADDR_ATTR_FAMILY, addr->family)) 1195 goto nla_put_failure; 1196 if (nla_put_u16(skb, MPTCP_PM_ADDR_ATTR_PORT, ntohs(addr->port))) 1197 goto nla_put_failure; 1198 if (nla_put_u8(skb, MPTCP_PM_ADDR_ATTR_ID, addr->id)) 1199 goto nla_put_failure; 1200 if (nla_put_u32(skb, MPTCP_PM_ADDR_ATTR_FLAGS, entry->addr.flags)) 1201 goto nla_put_failure; 1202 if (entry->addr.ifindex && 1203 nla_put_s32(skb, MPTCP_PM_ADDR_ATTR_IF_IDX, entry->addr.ifindex)) 1204 goto nla_put_failure; 1205 1206 if (addr->family == AF_INET && 1207 nla_put_in_addr(skb, MPTCP_PM_ADDR_ATTR_ADDR4, 1208 addr->addr.s_addr)) 1209 goto nla_put_failure; 1210 #if IS_ENABLED(CONFIG_MPTCP_IPV6) 1211 else if (addr->family == AF_INET6 && 1212 nla_put_in6_addr(skb, MPTCP_PM_ADDR_ATTR_ADDR6, &addr->addr6)) 1213 goto nla_put_failure; 1214 #endif 1215 nla_nest_end(skb, attr); 1216 return 0; 1217 1218 nla_put_failure: 1219 nla_nest_cancel(skb, attr); 1220 return -EMSGSIZE; 1221 } 1222 1223 static int mptcp_nl_cmd_get_addr(struct sk_buff *skb, struct genl_info *info) 1224 { 1225 struct nlattr *attr = info->attrs[MPTCP_PM_ATTR_ADDR]; 1226 struct pm_nl_pernet *pernet = genl_info_pm_nl(info); 1227 struct mptcp_pm_addr_entry addr, *entry; 1228 struct sk_buff *msg; 1229 void *reply; 1230 int ret; 1231 1232 ret = mptcp_pm_parse_addr(attr, info, false, &addr); 1233 if (ret < 0) 1234 return ret; 1235 1236 msg = nlmsg_new(NLMSG_DEFAULT_SIZE, GFP_KERNEL); 1237 if (!msg) 1238 return -ENOMEM; 1239 1240 reply = genlmsg_put_reply(msg, info, &mptcp_genl_family, 0, 1241 info->genlhdr->cmd); 1242 if (!reply) { 1243 GENL_SET_ERR_MSG(info, "not enough space in Netlink message"); 1244 ret = -EMSGSIZE; 1245 goto fail; 1246 } 1247 1248 spin_lock_bh(&pernet->lock); 1249 entry = __lookup_addr_by_id(pernet, addr.addr.id); 1250 if (!entry) { 1251 GENL_SET_ERR_MSG(info, "address not found"); 1252 ret = -EINVAL; 1253 goto unlock_fail; 1254 } 1255 1256 ret = mptcp_nl_fill_addr(msg, entry); 1257 if (ret) 1258 goto unlock_fail; 1259 1260 genlmsg_end(msg, reply); 1261 ret = genlmsg_reply(msg, info); 1262 spin_unlock_bh(&pernet->lock); 1263 return ret; 1264 1265 unlock_fail: 1266 spin_unlock_bh(&pernet->lock); 1267 1268 fail: 1269 nlmsg_free(msg); 1270 return ret; 1271 } 1272 1273 static int mptcp_nl_cmd_dump_addrs(struct sk_buff *msg, 1274 struct netlink_callback *cb) 1275 { 1276 struct net *net = sock_net(msg->sk); 1277 struct mptcp_pm_addr_entry *entry; 1278 struct pm_nl_pernet *pernet; 1279 int id = cb->args[0]; 1280 void *hdr; 1281 int i; 1282 1283 pernet = net_generic(net, pm_nl_pernet_id); 1284 1285 spin_lock_bh(&pernet->lock); 1286 for (i = id; i < MAX_ADDR_ID + 1; i++) { 1287 if (test_bit(i, pernet->id_bitmap)) { 1288 entry = __lookup_addr_by_id(pernet, i); 1289 if (!entry) 1290 break; 1291 1292 if (entry->addr.id <= id) 1293 continue; 1294 1295 hdr = genlmsg_put(msg, NETLINK_CB(cb->skb).portid, 1296 cb->nlh->nlmsg_seq, &mptcp_genl_family, 1297 NLM_F_MULTI, MPTCP_PM_CMD_GET_ADDR); 1298 if (!hdr) 1299 break; 1300 1301 if (mptcp_nl_fill_addr(msg, entry) < 0) { 1302 genlmsg_cancel(msg, hdr); 1303 break; 1304 } 1305 1306 id = entry->addr.id; 1307 genlmsg_end(msg, hdr); 1308 } 1309 } 1310 spin_unlock_bh(&pernet->lock); 1311 1312 cb->args[0] = id; 1313 return msg->len; 1314 } 1315 1316 static int parse_limit(struct genl_info *info, int id, unsigned int *limit) 1317 { 1318 struct nlattr *attr = info->attrs[id]; 1319 1320 if (!attr) 1321 return 0; 1322 1323 *limit = nla_get_u32(attr); 1324 if (*limit > MPTCP_PM_ADDR_MAX) { 1325 GENL_SET_ERR_MSG(info, "limit greater than maximum"); 1326 return -EINVAL; 1327 } 1328 return 0; 1329 } 1330 1331 static int 1332 mptcp_nl_cmd_set_limits(struct sk_buff *skb, struct genl_info *info) 1333 { 1334 struct pm_nl_pernet *pernet = genl_info_pm_nl(info); 1335 unsigned int rcv_addrs, subflows; 1336 int ret; 1337 1338 spin_lock_bh(&pernet->lock); 1339 rcv_addrs = pernet->add_addr_accept_max; 1340 ret = parse_limit(info, MPTCP_PM_ATTR_RCV_ADD_ADDRS, &rcv_addrs); 1341 if (ret) 1342 goto unlock; 1343 1344 subflows = pernet->subflows_max; 1345 ret = parse_limit(info, MPTCP_PM_ATTR_SUBFLOWS, &subflows); 1346 if (ret) 1347 goto unlock; 1348 1349 WRITE_ONCE(pernet->add_addr_accept_max, rcv_addrs); 1350 WRITE_ONCE(pernet->subflows_max, subflows); 1351 1352 unlock: 1353 spin_unlock_bh(&pernet->lock); 1354 return ret; 1355 } 1356 1357 static int 1358 mptcp_nl_cmd_get_limits(struct sk_buff *skb, struct genl_info *info) 1359 { 1360 struct pm_nl_pernet *pernet = genl_info_pm_nl(info); 1361 struct sk_buff *msg; 1362 void *reply; 1363 1364 msg = nlmsg_new(NLMSG_DEFAULT_SIZE, GFP_KERNEL); 1365 if (!msg) 1366 return -ENOMEM; 1367 1368 reply = genlmsg_put_reply(msg, info, &mptcp_genl_family, 0, 1369 MPTCP_PM_CMD_GET_LIMITS); 1370 if (!reply) 1371 goto fail; 1372 1373 if (nla_put_u32(msg, MPTCP_PM_ATTR_RCV_ADD_ADDRS, 1374 READ_ONCE(pernet->add_addr_accept_max))) 1375 goto fail; 1376 1377 if (nla_put_u32(msg, MPTCP_PM_ATTR_SUBFLOWS, 1378 READ_ONCE(pernet->subflows_max))) 1379 goto fail; 1380 1381 genlmsg_end(msg, reply); 1382 return genlmsg_reply(msg, info); 1383 1384 fail: 1385 GENL_SET_ERR_MSG(info, "not enough space in Netlink message"); 1386 nlmsg_free(msg); 1387 return -EMSGSIZE; 1388 } 1389 1390 static int mptcp_nl_addr_backup(struct net *net, 1391 struct mptcp_addr_info *addr, 1392 u8 bkup) 1393 { 1394 long s_slot = 0, s_num = 0; 1395 struct mptcp_sock *msk; 1396 int ret = -EINVAL; 1397 1398 while ((msk = mptcp_token_iter_next(net, &s_slot, &s_num)) != NULL) { 1399 struct sock *sk = (struct sock *)msk; 1400 1401 if (list_empty(&msk->conn_list)) 1402 goto next; 1403 1404 lock_sock(sk); 1405 spin_lock_bh(&msk->pm.lock); 1406 ret = mptcp_pm_nl_mp_prio_send_ack(msk, addr, bkup); 1407 spin_unlock_bh(&msk->pm.lock); 1408 release_sock(sk); 1409 1410 next: 1411 sock_put(sk); 1412 cond_resched(); 1413 } 1414 1415 return ret; 1416 } 1417 1418 static int mptcp_nl_cmd_set_flags(struct sk_buff *skb, struct genl_info *info) 1419 { 1420 struct nlattr *attr = info->attrs[MPTCP_PM_ATTR_ADDR]; 1421 struct pm_nl_pernet *pernet = genl_info_pm_nl(info); 1422 struct mptcp_pm_addr_entry addr, *entry; 1423 struct net *net = sock_net(skb->sk); 1424 u8 bkup = 0; 1425 int ret; 1426 1427 ret = mptcp_pm_parse_addr(attr, info, true, &addr); 1428 if (ret < 0) 1429 return ret; 1430 1431 if (addr.addr.flags & MPTCP_PM_ADDR_FLAG_BACKUP) 1432 bkup = 1; 1433 1434 list_for_each_entry(entry, &pernet->local_addr_list, list) { 1435 if (addresses_equal(&entry->addr, &addr.addr, true)) { 1436 ret = mptcp_nl_addr_backup(net, &entry->addr, bkup); 1437 if (ret) 1438 return ret; 1439 1440 if (bkup) 1441 entry->addr.flags |= MPTCP_PM_ADDR_FLAG_BACKUP; 1442 else 1443 entry->addr.flags &= ~MPTCP_PM_ADDR_FLAG_BACKUP; 1444 } 1445 } 1446 1447 return 0; 1448 } 1449 1450 static const struct genl_small_ops mptcp_pm_ops[] = { 1451 { 1452 .cmd = MPTCP_PM_CMD_ADD_ADDR, 1453 .doit = mptcp_nl_cmd_add_addr, 1454 .flags = GENL_ADMIN_PERM, 1455 }, 1456 { 1457 .cmd = MPTCP_PM_CMD_DEL_ADDR, 1458 .doit = mptcp_nl_cmd_del_addr, 1459 .flags = GENL_ADMIN_PERM, 1460 }, 1461 { 1462 .cmd = MPTCP_PM_CMD_FLUSH_ADDRS, 1463 .doit = mptcp_nl_cmd_flush_addrs, 1464 .flags = GENL_ADMIN_PERM, 1465 }, 1466 { 1467 .cmd = MPTCP_PM_CMD_GET_ADDR, 1468 .doit = mptcp_nl_cmd_get_addr, 1469 .dumpit = mptcp_nl_cmd_dump_addrs, 1470 }, 1471 { 1472 .cmd = MPTCP_PM_CMD_SET_LIMITS, 1473 .doit = mptcp_nl_cmd_set_limits, 1474 .flags = GENL_ADMIN_PERM, 1475 }, 1476 { 1477 .cmd = MPTCP_PM_CMD_GET_LIMITS, 1478 .doit = mptcp_nl_cmd_get_limits, 1479 }, 1480 { 1481 .cmd = MPTCP_PM_CMD_SET_FLAGS, 1482 .doit = mptcp_nl_cmd_set_flags, 1483 .flags = GENL_ADMIN_PERM, 1484 }, 1485 }; 1486 1487 static struct genl_family mptcp_genl_family __ro_after_init = { 1488 .name = MPTCP_PM_NAME, 1489 .version = MPTCP_PM_VER, 1490 .maxattr = MPTCP_PM_ATTR_MAX, 1491 .policy = mptcp_pm_policy, 1492 .netnsok = true, 1493 .module = THIS_MODULE, 1494 .small_ops = mptcp_pm_ops, 1495 .n_small_ops = ARRAY_SIZE(mptcp_pm_ops), 1496 .mcgrps = mptcp_pm_mcgrps, 1497 .n_mcgrps = ARRAY_SIZE(mptcp_pm_mcgrps), 1498 }; 1499 1500 static int __net_init pm_nl_init_net(struct net *net) 1501 { 1502 struct pm_nl_pernet *pernet = net_generic(net, pm_nl_pernet_id); 1503 1504 INIT_LIST_HEAD_RCU(&pernet->local_addr_list); 1505 __reset_counters(pernet); 1506 pernet->next_id = 1; 1507 bitmap_zero(pernet->id_bitmap, MAX_ADDR_ID + 1); 1508 spin_lock_init(&pernet->lock); 1509 return 0; 1510 } 1511 1512 static void __net_exit pm_nl_exit_net(struct list_head *net_list) 1513 { 1514 struct net *net; 1515 1516 list_for_each_entry(net, net_list, exit_list) { 1517 struct pm_nl_pernet *pernet = net_generic(net, pm_nl_pernet_id); 1518 1519 /* net is removed from namespace list, can't race with 1520 * other modifiers 1521 */ 1522 __flush_addrs(net, &pernet->local_addr_list); 1523 } 1524 } 1525 1526 static struct pernet_operations mptcp_pm_pernet_ops = { 1527 .init = pm_nl_init_net, 1528 .exit_batch = pm_nl_exit_net, 1529 .id = &pm_nl_pernet_id, 1530 .size = sizeof(struct pm_nl_pernet), 1531 }; 1532 1533 void __init mptcp_pm_nl_init(void) 1534 { 1535 if (register_pernet_subsys(&mptcp_pm_pernet_ops) < 0) 1536 panic("Failed to register MPTCP PM pernet subsystem.\n"); 1537 1538 if (genl_register_family(&mptcp_genl_family)) 1539 panic("Failed to register MPTCP PM netlink family\n"); 1540 } 1541