1 // SPDX-License-Identifier: GPL-2.0 2 /* Multipath TCP 3 * 4 * Copyright (c) 2020, Red Hat, Inc. 5 */ 6 7 #define pr_fmt(fmt) "MPTCP: " fmt 8 9 #include <linux/inet.h> 10 #include <linux/kernel.h> 11 #include <net/tcp.h> 12 #include <net/netns/generic.h> 13 #include <net/mptcp.h> 14 #include <net/genetlink.h> 15 #include <uapi/linux/mptcp.h> 16 17 #include "protocol.h" 18 #include "mib.h" 19 20 /* forward declaration */ 21 static struct genl_family mptcp_genl_family; 22 23 static int pm_nl_pernet_id; 24 25 struct mptcp_pm_addr_entry { 26 struct list_head list; 27 struct mptcp_addr_info addr; 28 struct rcu_head rcu; 29 struct socket *lsk; 30 }; 31 32 struct mptcp_pm_add_entry { 33 struct list_head list; 34 struct mptcp_addr_info addr; 35 struct timer_list add_timer; 36 struct mptcp_sock *sock; 37 u8 retrans_times; 38 }; 39 40 #define MAX_ADDR_ID 255 41 #define BITMAP_SZ DIV_ROUND_UP(MAX_ADDR_ID + 1, BITS_PER_LONG) 42 43 struct pm_nl_pernet { 44 /* protects pernet updates */ 45 spinlock_t lock; 46 struct list_head local_addr_list; 47 unsigned int addrs; 48 unsigned int add_addr_signal_max; 49 unsigned int add_addr_accept_max; 50 unsigned int local_addr_max; 51 unsigned int subflows_max; 52 unsigned int next_id; 53 unsigned long id_bitmap[BITMAP_SZ]; 54 }; 55 56 #define MPTCP_PM_ADDR_MAX 8 57 #define ADD_ADDR_RETRANS_MAX 3 58 59 static bool addresses_equal(const struct mptcp_addr_info *a, 60 struct mptcp_addr_info *b, bool use_port) 61 { 62 bool addr_equals = false; 63 64 if (a->family == b->family) { 65 if (a->family == AF_INET) 66 addr_equals = a->addr.s_addr == b->addr.s_addr; 67 #if IS_ENABLED(CONFIG_MPTCP_IPV6) 68 else 69 addr_equals = !ipv6_addr_cmp(&a->addr6, &b->addr6); 70 } else if (a->family == AF_INET) { 71 if (ipv6_addr_v4mapped(&b->addr6)) 72 addr_equals = a->addr.s_addr == b->addr6.s6_addr32[3]; 73 } else if (b->family == AF_INET) { 74 if (ipv6_addr_v4mapped(&a->addr6)) 75 addr_equals = a->addr6.s6_addr32[3] == b->addr.s_addr; 76 #endif 77 } 78 79 if (!addr_equals) 80 return false; 81 if (!use_port) 82 return true; 83 84 return a->port == b->port; 85 } 86 87 static bool address_zero(const struct mptcp_addr_info *addr) 88 { 89 struct mptcp_addr_info zero; 90 91 memset(&zero, 0, sizeof(zero)); 92 zero.family = addr->family; 93 94 return addresses_equal(addr, &zero, true); 95 } 96 97 static void local_address(const struct sock_common *skc, 98 struct mptcp_addr_info *addr) 99 { 100 addr->family = skc->skc_family; 101 addr->port = htons(skc->skc_num); 102 if (addr->family == AF_INET) 103 addr->addr.s_addr = skc->skc_rcv_saddr; 104 #if IS_ENABLED(CONFIG_MPTCP_IPV6) 105 else if (addr->family == AF_INET6) 106 addr->addr6 = skc->skc_v6_rcv_saddr; 107 #endif 108 } 109 110 static void remote_address(const struct sock_common *skc, 111 struct mptcp_addr_info *addr) 112 { 113 addr->family = skc->skc_family; 114 addr->port = skc->skc_dport; 115 if (addr->family == AF_INET) 116 addr->addr.s_addr = skc->skc_daddr; 117 #if IS_ENABLED(CONFIG_MPTCP_IPV6) 118 else if (addr->family == AF_INET6) 119 addr->addr6 = skc->skc_v6_daddr; 120 #endif 121 } 122 123 static bool lookup_subflow_by_saddr(const struct list_head *list, 124 struct mptcp_addr_info *saddr) 125 { 126 struct mptcp_subflow_context *subflow; 127 struct mptcp_addr_info cur; 128 struct sock_common *skc; 129 130 list_for_each_entry(subflow, list, node) { 131 skc = (struct sock_common *)mptcp_subflow_tcp_sock(subflow); 132 133 local_address(skc, &cur); 134 if (addresses_equal(&cur, saddr, saddr->port)) 135 return true; 136 } 137 138 return false; 139 } 140 141 static bool lookup_subflow_by_daddr(const struct list_head *list, 142 struct mptcp_addr_info *daddr) 143 { 144 struct mptcp_subflow_context *subflow; 145 struct mptcp_addr_info cur; 146 struct sock_common *skc; 147 148 list_for_each_entry(subflow, list, node) { 149 skc = (struct sock_common *)mptcp_subflow_tcp_sock(subflow); 150 151 remote_address(skc, &cur); 152 if (addresses_equal(&cur, daddr, daddr->port)) 153 return true; 154 } 155 156 return false; 157 } 158 159 static struct mptcp_pm_addr_entry * 160 select_local_address(const struct pm_nl_pernet *pernet, 161 struct mptcp_sock *msk) 162 { 163 struct mptcp_pm_addr_entry *entry, *ret = NULL; 164 struct sock *sk = (struct sock *)msk; 165 166 msk_owned_by_me(msk); 167 168 rcu_read_lock(); 169 __mptcp_flush_join_list(msk); 170 list_for_each_entry_rcu(entry, &pernet->local_addr_list, list) { 171 if (!(entry->addr.flags & MPTCP_PM_ADDR_FLAG_SUBFLOW)) 172 continue; 173 174 if (entry->addr.family != sk->sk_family) { 175 #if IS_ENABLED(CONFIG_MPTCP_IPV6) 176 if ((entry->addr.family == AF_INET && 177 !ipv6_addr_v4mapped(&sk->sk_v6_daddr)) || 178 (sk->sk_family == AF_INET && 179 !ipv6_addr_v4mapped(&entry->addr.addr6))) 180 #endif 181 continue; 182 } 183 184 /* avoid any address already in use by subflows and 185 * pending join 186 */ 187 if (!lookup_subflow_by_saddr(&msk->conn_list, &entry->addr)) { 188 ret = entry; 189 break; 190 } 191 } 192 rcu_read_unlock(); 193 return ret; 194 } 195 196 static struct mptcp_pm_addr_entry * 197 select_signal_address(struct pm_nl_pernet *pernet, unsigned int pos) 198 { 199 struct mptcp_pm_addr_entry *entry, *ret = NULL; 200 int i = 0; 201 202 rcu_read_lock(); 203 /* do not keep any additional per socket state, just signal 204 * the address list in order. 205 * Note: removal from the local address list during the msk life-cycle 206 * can lead to additional addresses not being announced. 207 */ 208 list_for_each_entry_rcu(entry, &pernet->local_addr_list, list) { 209 if (!(entry->addr.flags & MPTCP_PM_ADDR_FLAG_SIGNAL)) 210 continue; 211 if (i++ == pos) { 212 ret = entry; 213 break; 214 } 215 } 216 rcu_read_unlock(); 217 return ret; 218 } 219 220 unsigned int mptcp_pm_get_add_addr_signal_max(struct mptcp_sock *msk) 221 { 222 struct pm_nl_pernet *pernet; 223 224 pernet = net_generic(sock_net((struct sock *)msk), pm_nl_pernet_id); 225 return READ_ONCE(pernet->add_addr_signal_max); 226 } 227 EXPORT_SYMBOL_GPL(mptcp_pm_get_add_addr_signal_max); 228 229 unsigned int mptcp_pm_get_add_addr_accept_max(struct mptcp_sock *msk) 230 { 231 struct pm_nl_pernet *pernet; 232 233 pernet = net_generic(sock_net((struct sock *)msk), pm_nl_pernet_id); 234 return READ_ONCE(pernet->add_addr_accept_max); 235 } 236 EXPORT_SYMBOL_GPL(mptcp_pm_get_add_addr_accept_max); 237 238 unsigned int mptcp_pm_get_subflows_max(struct mptcp_sock *msk) 239 { 240 struct pm_nl_pernet *pernet; 241 242 pernet = net_generic(sock_net((struct sock *)msk), pm_nl_pernet_id); 243 return READ_ONCE(pernet->subflows_max); 244 } 245 EXPORT_SYMBOL_GPL(mptcp_pm_get_subflows_max); 246 247 unsigned int mptcp_pm_get_local_addr_max(struct mptcp_sock *msk) 248 { 249 struct pm_nl_pernet *pernet; 250 251 pernet = net_generic(sock_net((struct sock *)msk), pm_nl_pernet_id); 252 return READ_ONCE(pernet->local_addr_max); 253 } 254 EXPORT_SYMBOL_GPL(mptcp_pm_get_local_addr_max); 255 256 static void check_work_pending(struct mptcp_sock *msk) 257 { 258 if (msk->pm.add_addr_signaled == mptcp_pm_get_add_addr_signal_max(msk) && 259 (msk->pm.local_addr_used == mptcp_pm_get_local_addr_max(msk) || 260 msk->pm.subflows == mptcp_pm_get_subflows_max(msk))) 261 WRITE_ONCE(msk->pm.work_pending, false); 262 } 263 264 struct mptcp_pm_add_entry * 265 mptcp_lookup_anno_list_by_saddr(struct mptcp_sock *msk, 266 struct mptcp_addr_info *addr) 267 { 268 struct mptcp_pm_add_entry *entry; 269 270 lockdep_assert_held(&msk->pm.lock); 271 272 list_for_each_entry(entry, &msk->pm.anno_list, list) { 273 if (addresses_equal(&entry->addr, addr, true)) 274 return entry; 275 } 276 277 return NULL; 278 } 279 280 bool mptcp_pm_sport_in_anno_list(struct mptcp_sock *msk, const struct sock *sk) 281 { 282 struct mptcp_pm_add_entry *entry; 283 struct mptcp_addr_info saddr; 284 bool ret = false; 285 286 local_address((struct sock_common *)sk, &saddr); 287 288 spin_lock_bh(&msk->pm.lock); 289 list_for_each_entry(entry, &msk->pm.anno_list, list) { 290 if (addresses_equal(&entry->addr, &saddr, true)) { 291 ret = true; 292 goto out; 293 } 294 } 295 296 out: 297 spin_unlock_bh(&msk->pm.lock); 298 return ret; 299 } 300 301 static void mptcp_pm_add_timer(struct timer_list *timer) 302 { 303 struct mptcp_pm_add_entry *entry = from_timer(entry, timer, add_timer); 304 struct mptcp_sock *msk = entry->sock; 305 struct sock *sk = (struct sock *)msk; 306 307 pr_debug("msk=%p", msk); 308 309 if (!msk) 310 return; 311 312 if (inet_sk_state_load(sk) == TCP_CLOSE) 313 return; 314 315 if (!entry->addr.id) 316 return; 317 318 if (mptcp_pm_should_add_signal(msk)) { 319 sk_reset_timer(sk, timer, jiffies + TCP_RTO_MAX / 8); 320 goto out; 321 } 322 323 spin_lock_bh(&msk->pm.lock); 324 325 if (!mptcp_pm_should_add_signal(msk)) { 326 pr_debug("retransmit ADD_ADDR id=%d", entry->addr.id); 327 mptcp_pm_announce_addr(msk, &entry->addr, false); 328 mptcp_pm_add_addr_send_ack(msk); 329 entry->retrans_times++; 330 } 331 332 if (entry->retrans_times < ADD_ADDR_RETRANS_MAX) 333 sk_reset_timer(sk, timer, 334 jiffies + mptcp_get_add_addr_timeout(sock_net(sk))); 335 336 spin_unlock_bh(&msk->pm.lock); 337 338 if (entry->retrans_times == ADD_ADDR_RETRANS_MAX) 339 mptcp_pm_subflow_established(msk); 340 341 out: 342 __sock_put(sk); 343 } 344 345 struct mptcp_pm_add_entry * 346 mptcp_pm_del_add_timer(struct mptcp_sock *msk, 347 struct mptcp_addr_info *addr) 348 { 349 struct mptcp_pm_add_entry *entry; 350 struct sock *sk = (struct sock *)msk; 351 352 spin_lock_bh(&msk->pm.lock); 353 entry = mptcp_lookup_anno_list_by_saddr(msk, addr); 354 if (entry) 355 entry->retrans_times = ADD_ADDR_RETRANS_MAX; 356 spin_unlock_bh(&msk->pm.lock); 357 358 if (entry) 359 sk_stop_timer_sync(sk, &entry->add_timer); 360 361 return entry; 362 } 363 364 static bool mptcp_pm_alloc_anno_list(struct mptcp_sock *msk, 365 struct mptcp_pm_addr_entry *entry) 366 { 367 struct mptcp_pm_add_entry *add_entry = NULL; 368 struct sock *sk = (struct sock *)msk; 369 struct net *net = sock_net(sk); 370 371 lockdep_assert_held(&msk->pm.lock); 372 373 if (mptcp_lookup_anno_list_by_saddr(msk, &entry->addr)) 374 return false; 375 376 add_entry = kmalloc(sizeof(*add_entry), GFP_ATOMIC); 377 if (!add_entry) 378 return false; 379 380 list_add(&add_entry->list, &msk->pm.anno_list); 381 382 add_entry->addr = entry->addr; 383 add_entry->sock = msk; 384 add_entry->retrans_times = 0; 385 386 timer_setup(&add_entry->add_timer, mptcp_pm_add_timer, 0); 387 sk_reset_timer(sk, &add_entry->add_timer, 388 jiffies + mptcp_get_add_addr_timeout(net)); 389 390 return true; 391 } 392 393 void mptcp_pm_free_anno_list(struct mptcp_sock *msk) 394 { 395 struct mptcp_pm_add_entry *entry, *tmp; 396 struct sock *sk = (struct sock *)msk; 397 LIST_HEAD(free_list); 398 399 pr_debug("msk=%p", msk); 400 401 spin_lock_bh(&msk->pm.lock); 402 list_splice_init(&msk->pm.anno_list, &free_list); 403 spin_unlock_bh(&msk->pm.lock); 404 405 list_for_each_entry_safe(entry, tmp, &free_list, list) { 406 sk_stop_timer_sync(sk, &entry->add_timer); 407 kfree(entry); 408 } 409 } 410 411 static void mptcp_pm_create_subflow_or_signal_addr(struct mptcp_sock *msk) 412 { 413 struct sock *sk = (struct sock *)msk; 414 struct mptcp_pm_addr_entry *local; 415 unsigned int add_addr_signal_max; 416 unsigned int local_addr_max; 417 struct pm_nl_pernet *pernet; 418 unsigned int subflows_max; 419 420 pernet = net_generic(sock_net(sk), pm_nl_pernet_id); 421 422 add_addr_signal_max = mptcp_pm_get_add_addr_signal_max(msk); 423 local_addr_max = mptcp_pm_get_local_addr_max(msk); 424 subflows_max = mptcp_pm_get_subflows_max(msk); 425 426 pr_debug("local %d:%d signal %d:%d subflows %d:%d\n", 427 msk->pm.local_addr_used, local_addr_max, 428 msk->pm.add_addr_signaled, add_addr_signal_max, 429 msk->pm.subflows, subflows_max); 430 431 /* check first for announce */ 432 if (msk->pm.add_addr_signaled < add_addr_signal_max) { 433 local = select_signal_address(pernet, 434 msk->pm.add_addr_signaled); 435 436 if (local) { 437 if (mptcp_pm_alloc_anno_list(msk, local)) { 438 msk->pm.add_addr_signaled++; 439 mptcp_pm_announce_addr(msk, &local->addr, false); 440 mptcp_pm_nl_addr_send_ack(msk); 441 } 442 } else { 443 /* pick failed, avoid fourther attempts later */ 444 msk->pm.local_addr_used = add_addr_signal_max; 445 } 446 447 check_work_pending(msk); 448 } 449 450 /* check if should create a new subflow */ 451 if (msk->pm.local_addr_used < local_addr_max && 452 msk->pm.subflows < subflows_max) { 453 local = select_local_address(pernet, msk); 454 if (local) { 455 struct mptcp_addr_info remote = { 0 }; 456 457 msk->pm.local_addr_used++; 458 msk->pm.subflows++; 459 check_work_pending(msk); 460 remote_address((struct sock_common *)sk, &remote); 461 spin_unlock_bh(&msk->pm.lock); 462 __mptcp_subflow_connect(sk, &local->addr, &remote); 463 spin_lock_bh(&msk->pm.lock); 464 return; 465 } 466 467 /* lookup failed, avoid fourther attempts later */ 468 msk->pm.local_addr_used = local_addr_max; 469 check_work_pending(msk); 470 } 471 } 472 473 static void mptcp_pm_nl_fully_established(struct mptcp_sock *msk) 474 { 475 mptcp_pm_create_subflow_or_signal_addr(msk); 476 } 477 478 static void mptcp_pm_nl_subflow_established(struct mptcp_sock *msk) 479 { 480 mptcp_pm_create_subflow_or_signal_addr(msk); 481 } 482 483 static void mptcp_pm_nl_add_addr_received(struct mptcp_sock *msk) 484 { 485 struct sock *sk = (struct sock *)msk; 486 unsigned int add_addr_accept_max; 487 struct mptcp_addr_info remote; 488 struct mptcp_addr_info local; 489 unsigned int subflows_max; 490 491 add_addr_accept_max = mptcp_pm_get_add_addr_accept_max(msk); 492 subflows_max = mptcp_pm_get_subflows_max(msk); 493 494 pr_debug("accepted %d:%d remote family %d", 495 msk->pm.add_addr_accepted, add_addr_accept_max, 496 msk->pm.remote.family); 497 498 if (lookup_subflow_by_daddr(&msk->conn_list, &msk->pm.remote)) 499 goto add_addr_echo; 500 501 msk->pm.add_addr_accepted++; 502 msk->pm.subflows++; 503 if (msk->pm.add_addr_accepted >= add_addr_accept_max || 504 msk->pm.subflows >= subflows_max) 505 WRITE_ONCE(msk->pm.accept_addr, false); 506 507 /* connect to the specified remote address, using whatever 508 * local address the routing configuration will pick. 509 */ 510 remote = msk->pm.remote; 511 if (!remote.port) 512 remote.port = sk->sk_dport; 513 memset(&local, 0, sizeof(local)); 514 local.family = remote.family; 515 516 spin_unlock_bh(&msk->pm.lock); 517 __mptcp_subflow_connect(sk, &local, &remote); 518 spin_lock_bh(&msk->pm.lock); 519 520 add_addr_echo: 521 mptcp_pm_announce_addr(msk, &msk->pm.remote, true); 522 mptcp_pm_nl_addr_send_ack(msk); 523 } 524 525 void mptcp_pm_nl_addr_send_ack(struct mptcp_sock *msk) 526 { 527 struct mptcp_subflow_context *subflow; 528 529 msk_owned_by_me(msk); 530 lockdep_assert_held(&msk->pm.lock); 531 532 if (!mptcp_pm_should_add_signal(msk) && 533 !mptcp_pm_should_rm_signal(msk)) 534 return; 535 536 __mptcp_flush_join_list(msk); 537 subflow = list_first_entry_or_null(&msk->conn_list, typeof(*subflow), node); 538 if (subflow) { 539 struct sock *ssk = mptcp_subflow_tcp_sock(subflow); 540 541 spin_unlock_bh(&msk->pm.lock); 542 pr_debug("send ack for %s%s%s", 543 mptcp_pm_should_add_signal(msk) ? "add_addr" : "rm_addr", 544 mptcp_pm_should_add_signal_ipv6(msk) ? " [ipv6]" : "", 545 mptcp_pm_should_add_signal_port(msk) ? " [port]" : ""); 546 547 lock_sock(ssk); 548 tcp_send_ack(ssk); 549 release_sock(ssk); 550 spin_lock_bh(&msk->pm.lock); 551 } 552 } 553 554 int mptcp_pm_nl_mp_prio_send_ack(struct mptcp_sock *msk, 555 struct mptcp_addr_info *addr, 556 u8 bkup) 557 { 558 struct mptcp_subflow_context *subflow; 559 560 pr_debug("bkup=%d", bkup); 561 562 mptcp_for_each_subflow(msk, subflow) { 563 struct sock *ssk = mptcp_subflow_tcp_sock(subflow); 564 struct sock *sk = (struct sock *)msk; 565 struct mptcp_addr_info local; 566 567 local_address((struct sock_common *)ssk, &local); 568 if (!addresses_equal(&local, addr, addr->port)) 569 continue; 570 571 subflow->backup = bkup; 572 subflow->send_mp_prio = 1; 573 subflow->request_bkup = bkup; 574 __MPTCP_INC_STATS(sock_net(sk), MPTCP_MIB_MPPRIOTX); 575 576 spin_unlock_bh(&msk->pm.lock); 577 pr_debug("send ack for mp_prio"); 578 lock_sock(ssk); 579 tcp_send_ack(ssk); 580 release_sock(ssk); 581 spin_lock_bh(&msk->pm.lock); 582 583 return 0; 584 } 585 586 return -EINVAL; 587 } 588 589 static void mptcp_pm_nl_rm_addr_received(struct mptcp_sock *msk) 590 { 591 struct mptcp_subflow_context *subflow, *tmp; 592 struct sock *sk = (struct sock *)msk; 593 u8 i; 594 595 pr_debug("address rm_list_nr %d", msk->pm.rm_list_rx.nr); 596 597 msk_owned_by_me(msk); 598 599 if (!msk->pm.rm_list_rx.nr) 600 return; 601 602 if (list_empty(&msk->conn_list)) 603 return; 604 605 for (i = 0; i < msk->pm.rm_list_rx.nr; i++) { 606 list_for_each_entry_safe(subflow, tmp, &msk->conn_list, node) { 607 struct sock *ssk = mptcp_subflow_tcp_sock(subflow); 608 int how = RCV_SHUTDOWN | SEND_SHUTDOWN; 609 610 if (msk->pm.rm_list_rx.ids[i] != subflow->remote_id) 611 continue; 612 613 pr_debug(" -> address rm_list_ids[%d]=%u", i, msk->pm.rm_list_rx.ids[i]); 614 spin_unlock_bh(&msk->pm.lock); 615 mptcp_subflow_shutdown(sk, ssk, how); 616 mptcp_close_ssk(sk, ssk, subflow); 617 spin_lock_bh(&msk->pm.lock); 618 619 msk->pm.add_addr_accepted--; 620 msk->pm.subflows--; 621 WRITE_ONCE(msk->pm.accept_addr, true); 622 623 __MPTCP_INC_STATS(sock_net(sk), MPTCP_MIB_RMADDR); 624 625 break; 626 } 627 } 628 } 629 630 void mptcp_pm_nl_work(struct mptcp_sock *msk) 631 { 632 struct mptcp_pm_data *pm = &msk->pm; 633 634 msk_owned_by_me(msk); 635 636 spin_lock_bh(&msk->pm.lock); 637 638 pr_debug("msk=%p status=%x", msk, pm->status); 639 if (pm->status & BIT(MPTCP_PM_ADD_ADDR_RECEIVED)) { 640 pm->status &= ~BIT(MPTCP_PM_ADD_ADDR_RECEIVED); 641 mptcp_pm_nl_add_addr_received(msk); 642 } 643 if (pm->status & BIT(MPTCP_PM_ADD_ADDR_SEND_ACK)) { 644 pm->status &= ~BIT(MPTCP_PM_ADD_ADDR_SEND_ACK); 645 mptcp_pm_nl_addr_send_ack(msk); 646 } 647 if (pm->status & BIT(MPTCP_PM_RM_ADDR_RECEIVED)) { 648 pm->status &= ~BIT(MPTCP_PM_RM_ADDR_RECEIVED); 649 mptcp_pm_nl_rm_addr_received(msk); 650 } 651 if (pm->status & BIT(MPTCP_PM_ESTABLISHED)) { 652 pm->status &= ~BIT(MPTCP_PM_ESTABLISHED); 653 mptcp_pm_nl_fully_established(msk); 654 } 655 if (pm->status & BIT(MPTCP_PM_SUBFLOW_ESTABLISHED)) { 656 pm->status &= ~BIT(MPTCP_PM_SUBFLOW_ESTABLISHED); 657 mptcp_pm_nl_subflow_established(msk); 658 } 659 660 spin_unlock_bh(&msk->pm.lock); 661 } 662 663 void mptcp_pm_nl_rm_subflow_received(struct mptcp_sock *msk, 664 const struct mptcp_rm_list *rm_list) 665 { 666 struct mptcp_subflow_context *subflow, *tmp; 667 struct sock *sk = (struct sock *)msk; 668 u8 i; 669 670 pr_debug("subflow rm_list_nr %d", rm_list->nr); 671 672 msk_owned_by_me(msk); 673 674 if (!rm_list->nr) 675 return; 676 677 if (list_empty(&msk->conn_list)) 678 return; 679 680 for (i = 0; i < rm_list->nr; i++) { 681 list_for_each_entry_safe(subflow, tmp, &msk->conn_list, node) { 682 struct sock *ssk = mptcp_subflow_tcp_sock(subflow); 683 int how = RCV_SHUTDOWN | SEND_SHUTDOWN; 684 685 if (rm_list->ids[i] != subflow->local_id) 686 continue; 687 688 pr_debug(" -> subflow rm_list_ids[%d]=%u", i, rm_list->ids[i]); 689 spin_unlock_bh(&msk->pm.lock); 690 mptcp_subflow_shutdown(sk, ssk, how); 691 mptcp_close_ssk(sk, ssk, subflow); 692 spin_lock_bh(&msk->pm.lock); 693 694 msk->pm.local_addr_used--; 695 msk->pm.subflows--; 696 697 __MPTCP_INC_STATS(sock_net(sk), MPTCP_MIB_RMSUBFLOW); 698 699 break; 700 } 701 } 702 } 703 704 static bool address_use_port(struct mptcp_pm_addr_entry *entry) 705 { 706 return (entry->addr.flags & 707 (MPTCP_PM_ADDR_FLAG_SIGNAL | MPTCP_PM_ADDR_FLAG_SUBFLOW)) == 708 MPTCP_PM_ADDR_FLAG_SIGNAL; 709 } 710 711 static int mptcp_pm_nl_append_new_local_addr(struct pm_nl_pernet *pernet, 712 struct mptcp_pm_addr_entry *entry) 713 { 714 struct mptcp_pm_addr_entry *cur; 715 unsigned int addr_max; 716 int ret = -EINVAL; 717 718 spin_lock_bh(&pernet->lock); 719 /* to keep the code simple, don't do IDR-like allocation for address ID, 720 * just bail when we exceed limits 721 */ 722 if (pernet->next_id == MAX_ADDR_ID) 723 pernet->next_id = 1; 724 if (pernet->addrs >= MPTCP_PM_ADDR_MAX) 725 goto out; 726 if (test_bit(entry->addr.id, pernet->id_bitmap)) 727 goto out; 728 729 /* do not insert duplicate address, differentiate on port only 730 * singled addresses 731 */ 732 list_for_each_entry(cur, &pernet->local_addr_list, list) { 733 if (addresses_equal(&cur->addr, &entry->addr, 734 address_use_port(entry) && 735 address_use_port(cur))) 736 goto out; 737 } 738 739 if (!entry->addr.id) { 740 find_next: 741 entry->addr.id = find_next_zero_bit(pernet->id_bitmap, 742 MAX_ADDR_ID + 1, 743 pernet->next_id); 744 if ((!entry->addr.id || entry->addr.id > MAX_ADDR_ID) && 745 pernet->next_id != 1) { 746 pernet->next_id = 1; 747 goto find_next; 748 } 749 } 750 751 if (!entry->addr.id || entry->addr.id > MAX_ADDR_ID) 752 goto out; 753 754 __set_bit(entry->addr.id, pernet->id_bitmap); 755 if (entry->addr.id > pernet->next_id) 756 pernet->next_id = entry->addr.id; 757 758 if (entry->addr.flags & MPTCP_PM_ADDR_FLAG_SIGNAL) { 759 addr_max = pernet->add_addr_signal_max; 760 WRITE_ONCE(pernet->add_addr_signal_max, addr_max + 1); 761 } 762 if (entry->addr.flags & MPTCP_PM_ADDR_FLAG_SUBFLOW) { 763 addr_max = pernet->local_addr_max; 764 WRITE_ONCE(pernet->local_addr_max, addr_max + 1); 765 } 766 767 pernet->addrs++; 768 list_add_tail_rcu(&entry->list, &pernet->local_addr_list); 769 ret = entry->addr.id; 770 771 out: 772 spin_unlock_bh(&pernet->lock); 773 return ret; 774 } 775 776 static int mptcp_pm_nl_create_listen_socket(struct sock *sk, 777 struct mptcp_pm_addr_entry *entry) 778 { 779 struct sockaddr_storage addr; 780 struct mptcp_sock *msk; 781 struct socket *ssock; 782 int backlog = 1024; 783 int err; 784 785 err = sock_create_kern(sock_net(sk), entry->addr.family, 786 SOCK_STREAM, IPPROTO_MPTCP, &entry->lsk); 787 if (err) 788 return err; 789 790 msk = mptcp_sk(entry->lsk->sk); 791 if (!msk) { 792 err = -EINVAL; 793 goto out; 794 } 795 796 ssock = __mptcp_nmpc_socket(msk); 797 if (!ssock) { 798 err = -EINVAL; 799 goto out; 800 } 801 802 mptcp_info2sockaddr(&entry->addr, &addr, entry->addr.family); 803 err = kernel_bind(ssock, (struct sockaddr *)&addr, 804 sizeof(struct sockaddr_in)); 805 if (err) { 806 pr_warn("kernel_bind error, err=%d", err); 807 goto out; 808 } 809 810 err = kernel_listen(ssock, backlog); 811 if (err) { 812 pr_warn("kernel_listen error, err=%d", err); 813 goto out; 814 } 815 816 return 0; 817 818 out: 819 sock_release(entry->lsk); 820 return err; 821 } 822 823 int mptcp_pm_nl_get_local_id(struct mptcp_sock *msk, struct sock_common *skc) 824 { 825 struct mptcp_pm_addr_entry *entry; 826 struct mptcp_addr_info skc_local; 827 struct mptcp_addr_info msk_local; 828 struct pm_nl_pernet *pernet; 829 int ret = -1; 830 831 if (WARN_ON_ONCE(!msk)) 832 return -1; 833 834 /* The 0 ID mapping is defined by the first subflow, copied into the msk 835 * addr 836 */ 837 local_address((struct sock_common *)msk, &msk_local); 838 local_address((struct sock_common *)skc, &skc_local); 839 if (addresses_equal(&msk_local, &skc_local, false)) 840 return 0; 841 842 if (address_zero(&skc_local)) 843 return 0; 844 845 pernet = net_generic(sock_net((struct sock *)msk), pm_nl_pernet_id); 846 847 rcu_read_lock(); 848 list_for_each_entry_rcu(entry, &pernet->local_addr_list, list) { 849 if (addresses_equal(&entry->addr, &skc_local, entry->addr.port)) { 850 ret = entry->addr.id; 851 break; 852 } 853 } 854 rcu_read_unlock(); 855 if (ret >= 0) 856 return ret; 857 858 /* address not found, add to local list */ 859 entry = kmalloc(sizeof(*entry), GFP_ATOMIC); 860 if (!entry) 861 return -ENOMEM; 862 863 entry->addr = skc_local; 864 entry->addr.ifindex = 0; 865 entry->addr.flags = 0; 866 entry->addr.id = 0; 867 entry->addr.port = 0; 868 entry->lsk = NULL; 869 ret = mptcp_pm_nl_append_new_local_addr(pernet, entry); 870 if (ret < 0) 871 kfree(entry); 872 873 return ret; 874 } 875 876 void mptcp_pm_nl_data_init(struct mptcp_sock *msk) 877 { 878 struct mptcp_pm_data *pm = &msk->pm; 879 bool subflows; 880 881 subflows = !!mptcp_pm_get_subflows_max(msk); 882 WRITE_ONCE(pm->work_pending, (!!mptcp_pm_get_local_addr_max(msk) && subflows) || 883 !!mptcp_pm_get_add_addr_signal_max(msk)); 884 WRITE_ONCE(pm->accept_addr, !!mptcp_pm_get_add_addr_accept_max(msk) && subflows); 885 WRITE_ONCE(pm->accept_subflow, subflows); 886 } 887 888 #define MPTCP_PM_CMD_GRP_OFFSET 0 889 #define MPTCP_PM_EV_GRP_OFFSET 1 890 891 static const struct genl_multicast_group mptcp_pm_mcgrps[] = { 892 [MPTCP_PM_CMD_GRP_OFFSET] = { .name = MPTCP_PM_CMD_GRP_NAME, }, 893 [MPTCP_PM_EV_GRP_OFFSET] = { .name = MPTCP_PM_EV_GRP_NAME, 894 .flags = GENL_UNS_ADMIN_PERM, 895 }, 896 }; 897 898 static const struct nla_policy 899 mptcp_pm_addr_policy[MPTCP_PM_ADDR_ATTR_MAX + 1] = { 900 [MPTCP_PM_ADDR_ATTR_FAMILY] = { .type = NLA_U16, }, 901 [MPTCP_PM_ADDR_ATTR_ID] = { .type = NLA_U8, }, 902 [MPTCP_PM_ADDR_ATTR_ADDR4] = { .type = NLA_U32, }, 903 [MPTCP_PM_ADDR_ATTR_ADDR6] = 904 NLA_POLICY_EXACT_LEN(sizeof(struct in6_addr)), 905 [MPTCP_PM_ADDR_ATTR_PORT] = { .type = NLA_U16 }, 906 [MPTCP_PM_ADDR_ATTR_FLAGS] = { .type = NLA_U32 }, 907 [MPTCP_PM_ADDR_ATTR_IF_IDX] = { .type = NLA_S32 }, 908 }; 909 910 static const struct nla_policy mptcp_pm_policy[MPTCP_PM_ATTR_MAX + 1] = { 911 [MPTCP_PM_ATTR_ADDR] = 912 NLA_POLICY_NESTED(mptcp_pm_addr_policy), 913 [MPTCP_PM_ATTR_RCV_ADD_ADDRS] = { .type = NLA_U32, }, 914 [MPTCP_PM_ATTR_SUBFLOWS] = { .type = NLA_U32, }, 915 }; 916 917 static int mptcp_pm_family_to_addr(int family) 918 { 919 #if IS_ENABLED(CONFIG_MPTCP_IPV6) 920 if (family == AF_INET6) 921 return MPTCP_PM_ADDR_ATTR_ADDR6; 922 #endif 923 return MPTCP_PM_ADDR_ATTR_ADDR4; 924 } 925 926 static int mptcp_pm_parse_addr(struct nlattr *attr, struct genl_info *info, 927 bool require_family, 928 struct mptcp_pm_addr_entry *entry) 929 { 930 struct nlattr *tb[MPTCP_PM_ADDR_ATTR_MAX + 1]; 931 int err, addr_addr; 932 933 if (!attr) { 934 GENL_SET_ERR_MSG(info, "missing address info"); 935 return -EINVAL; 936 } 937 938 /* no validation needed - was already done via nested policy */ 939 err = nla_parse_nested_deprecated(tb, MPTCP_PM_ADDR_ATTR_MAX, attr, 940 mptcp_pm_addr_policy, info->extack); 941 if (err) 942 return err; 943 944 memset(entry, 0, sizeof(*entry)); 945 if (!tb[MPTCP_PM_ADDR_ATTR_FAMILY]) { 946 if (!require_family) 947 goto skip_family; 948 949 NL_SET_ERR_MSG_ATTR(info->extack, attr, 950 "missing family"); 951 return -EINVAL; 952 } 953 954 entry->addr.family = nla_get_u16(tb[MPTCP_PM_ADDR_ATTR_FAMILY]); 955 if (entry->addr.family != AF_INET 956 #if IS_ENABLED(CONFIG_MPTCP_IPV6) 957 && entry->addr.family != AF_INET6 958 #endif 959 ) { 960 NL_SET_ERR_MSG_ATTR(info->extack, attr, 961 "unknown address family"); 962 return -EINVAL; 963 } 964 addr_addr = mptcp_pm_family_to_addr(entry->addr.family); 965 if (!tb[addr_addr]) { 966 NL_SET_ERR_MSG_ATTR(info->extack, attr, 967 "missing address data"); 968 return -EINVAL; 969 } 970 971 #if IS_ENABLED(CONFIG_MPTCP_IPV6) 972 if (entry->addr.family == AF_INET6) 973 entry->addr.addr6 = nla_get_in6_addr(tb[addr_addr]); 974 else 975 #endif 976 entry->addr.addr.s_addr = nla_get_in_addr(tb[addr_addr]); 977 978 skip_family: 979 if (tb[MPTCP_PM_ADDR_ATTR_IF_IDX]) { 980 u32 val = nla_get_s32(tb[MPTCP_PM_ADDR_ATTR_IF_IDX]); 981 982 entry->addr.ifindex = val; 983 } 984 985 if (tb[MPTCP_PM_ADDR_ATTR_ID]) 986 entry->addr.id = nla_get_u8(tb[MPTCP_PM_ADDR_ATTR_ID]); 987 988 if (tb[MPTCP_PM_ADDR_ATTR_FLAGS]) 989 entry->addr.flags = nla_get_u32(tb[MPTCP_PM_ADDR_ATTR_FLAGS]); 990 991 if (tb[MPTCP_PM_ADDR_ATTR_PORT]) 992 entry->addr.port = htons(nla_get_u16(tb[MPTCP_PM_ADDR_ATTR_PORT])); 993 994 return 0; 995 } 996 997 static struct pm_nl_pernet *genl_info_pm_nl(struct genl_info *info) 998 { 999 return net_generic(genl_info_net(info), pm_nl_pernet_id); 1000 } 1001 1002 static int mptcp_nl_add_subflow_or_signal_addr(struct net *net) 1003 { 1004 struct mptcp_sock *msk; 1005 long s_slot = 0, s_num = 0; 1006 1007 while ((msk = mptcp_token_iter_next(net, &s_slot, &s_num)) != NULL) { 1008 struct sock *sk = (struct sock *)msk; 1009 1010 if (!READ_ONCE(msk->fully_established)) 1011 goto next; 1012 1013 lock_sock(sk); 1014 spin_lock_bh(&msk->pm.lock); 1015 mptcp_pm_create_subflow_or_signal_addr(msk); 1016 spin_unlock_bh(&msk->pm.lock); 1017 release_sock(sk); 1018 1019 next: 1020 sock_put(sk); 1021 cond_resched(); 1022 } 1023 1024 return 0; 1025 } 1026 1027 static int mptcp_nl_cmd_add_addr(struct sk_buff *skb, struct genl_info *info) 1028 { 1029 struct nlattr *attr = info->attrs[MPTCP_PM_ATTR_ADDR]; 1030 struct pm_nl_pernet *pernet = genl_info_pm_nl(info); 1031 struct mptcp_pm_addr_entry addr, *entry; 1032 int ret; 1033 1034 ret = mptcp_pm_parse_addr(attr, info, true, &addr); 1035 if (ret < 0) 1036 return ret; 1037 1038 entry = kmalloc(sizeof(*entry), GFP_KERNEL); 1039 if (!entry) { 1040 GENL_SET_ERR_MSG(info, "can't allocate addr"); 1041 return -ENOMEM; 1042 } 1043 1044 *entry = addr; 1045 if (entry->addr.port) { 1046 ret = mptcp_pm_nl_create_listen_socket(skb->sk, entry); 1047 if (ret) { 1048 GENL_SET_ERR_MSG(info, "create listen socket error"); 1049 kfree(entry); 1050 return ret; 1051 } 1052 } 1053 ret = mptcp_pm_nl_append_new_local_addr(pernet, entry); 1054 if (ret < 0) { 1055 GENL_SET_ERR_MSG(info, "too many addresses or duplicate one"); 1056 if (entry->lsk) 1057 sock_release(entry->lsk); 1058 kfree(entry); 1059 return ret; 1060 } 1061 1062 mptcp_nl_add_subflow_or_signal_addr(sock_net(skb->sk)); 1063 1064 return 0; 1065 } 1066 1067 static struct mptcp_pm_addr_entry * 1068 __lookup_addr_by_id(struct pm_nl_pernet *pernet, unsigned int id) 1069 { 1070 struct mptcp_pm_addr_entry *entry; 1071 1072 list_for_each_entry(entry, &pernet->local_addr_list, list) { 1073 if (entry->addr.id == id) 1074 return entry; 1075 } 1076 return NULL; 1077 } 1078 1079 static bool remove_anno_list_by_saddr(struct mptcp_sock *msk, 1080 struct mptcp_addr_info *addr) 1081 { 1082 struct mptcp_pm_add_entry *entry; 1083 1084 entry = mptcp_pm_del_add_timer(msk, addr); 1085 if (entry) { 1086 list_del(&entry->list); 1087 kfree(entry); 1088 return true; 1089 } 1090 1091 return false; 1092 } 1093 1094 static bool mptcp_pm_remove_anno_addr(struct mptcp_sock *msk, 1095 struct mptcp_addr_info *addr, 1096 bool force) 1097 { 1098 struct mptcp_rm_list list = { .nr = 0 }; 1099 bool ret; 1100 1101 list.ids[list.nr++] = addr->id; 1102 1103 ret = remove_anno_list_by_saddr(msk, addr); 1104 if (ret || force) { 1105 spin_lock_bh(&msk->pm.lock); 1106 mptcp_pm_remove_addr(msk, &list); 1107 spin_unlock_bh(&msk->pm.lock); 1108 } 1109 return ret; 1110 } 1111 1112 static int mptcp_nl_remove_subflow_and_signal_addr(struct net *net, 1113 struct mptcp_addr_info *addr) 1114 { 1115 struct mptcp_sock *msk; 1116 long s_slot = 0, s_num = 0; 1117 struct mptcp_rm_list list = { .nr = 0 }; 1118 1119 pr_debug("remove_id=%d", addr->id); 1120 1121 list.ids[list.nr++] = addr->id; 1122 1123 while ((msk = mptcp_token_iter_next(net, &s_slot, &s_num)) != NULL) { 1124 struct sock *sk = (struct sock *)msk; 1125 bool remove_subflow; 1126 1127 if (list_empty(&msk->conn_list)) { 1128 mptcp_pm_remove_anno_addr(msk, addr, false); 1129 goto next; 1130 } 1131 1132 lock_sock(sk); 1133 remove_subflow = lookup_subflow_by_saddr(&msk->conn_list, addr); 1134 mptcp_pm_remove_anno_addr(msk, addr, remove_subflow); 1135 if (remove_subflow) 1136 mptcp_pm_remove_subflow(msk, &list); 1137 release_sock(sk); 1138 1139 next: 1140 sock_put(sk); 1141 cond_resched(); 1142 } 1143 1144 return 0; 1145 } 1146 1147 struct addr_entry_release_work { 1148 struct rcu_work rwork; 1149 struct mptcp_pm_addr_entry *entry; 1150 }; 1151 1152 static void mptcp_pm_release_addr_entry(struct work_struct *work) 1153 { 1154 struct addr_entry_release_work *w; 1155 struct mptcp_pm_addr_entry *entry; 1156 1157 w = container_of(to_rcu_work(work), struct addr_entry_release_work, rwork); 1158 entry = w->entry; 1159 if (entry) { 1160 if (entry->lsk) 1161 sock_release(entry->lsk); 1162 kfree(entry); 1163 } 1164 kfree(w); 1165 } 1166 1167 static void mptcp_pm_free_addr_entry(struct mptcp_pm_addr_entry *entry) 1168 { 1169 struct addr_entry_release_work *w; 1170 1171 w = kmalloc(sizeof(*w), GFP_ATOMIC); 1172 if (w) { 1173 INIT_RCU_WORK(&w->rwork, mptcp_pm_release_addr_entry); 1174 w->entry = entry; 1175 queue_rcu_work(system_wq, &w->rwork); 1176 } 1177 } 1178 1179 static int mptcp_nl_cmd_del_addr(struct sk_buff *skb, struct genl_info *info) 1180 { 1181 struct nlattr *attr = info->attrs[MPTCP_PM_ATTR_ADDR]; 1182 struct pm_nl_pernet *pernet = genl_info_pm_nl(info); 1183 struct mptcp_pm_addr_entry addr, *entry; 1184 unsigned int addr_max; 1185 int ret; 1186 1187 ret = mptcp_pm_parse_addr(attr, info, false, &addr); 1188 if (ret < 0) 1189 return ret; 1190 1191 spin_lock_bh(&pernet->lock); 1192 entry = __lookup_addr_by_id(pernet, addr.addr.id); 1193 if (!entry) { 1194 GENL_SET_ERR_MSG(info, "address not found"); 1195 spin_unlock_bh(&pernet->lock); 1196 return -EINVAL; 1197 } 1198 if (entry->addr.flags & MPTCP_PM_ADDR_FLAG_SIGNAL) { 1199 addr_max = pernet->add_addr_signal_max; 1200 WRITE_ONCE(pernet->add_addr_signal_max, addr_max - 1); 1201 } 1202 if (entry->addr.flags & MPTCP_PM_ADDR_FLAG_SUBFLOW) { 1203 addr_max = pernet->local_addr_max; 1204 WRITE_ONCE(pernet->local_addr_max, addr_max - 1); 1205 } 1206 1207 pernet->addrs--; 1208 list_del_rcu(&entry->list); 1209 __clear_bit(entry->addr.id, pernet->id_bitmap); 1210 spin_unlock_bh(&pernet->lock); 1211 1212 mptcp_nl_remove_subflow_and_signal_addr(sock_net(skb->sk), &entry->addr); 1213 mptcp_pm_free_addr_entry(entry); 1214 1215 return ret; 1216 } 1217 1218 static void mptcp_pm_remove_addrs_and_subflows(struct mptcp_sock *msk, 1219 struct list_head *rm_list) 1220 { 1221 struct mptcp_rm_list alist = { .nr = 0 }, slist = { .nr = 0 }; 1222 struct mptcp_pm_addr_entry *entry; 1223 1224 list_for_each_entry(entry, rm_list, list) { 1225 if (lookup_subflow_by_saddr(&msk->conn_list, &entry->addr) && 1226 alist.nr < MPTCP_RM_IDS_MAX && 1227 slist.nr < MPTCP_RM_IDS_MAX) { 1228 alist.ids[alist.nr++] = entry->addr.id; 1229 slist.ids[slist.nr++] = entry->addr.id; 1230 } else if (remove_anno_list_by_saddr(msk, &entry->addr) && 1231 alist.nr < MPTCP_RM_IDS_MAX) { 1232 alist.ids[alist.nr++] = entry->addr.id; 1233 } 1234 } 1235 1236 if (alist.nr) { 1237 spin_lock_bh(&msk->pm.lock); 1238 mptcp_pm_remove_addr(msk, &alist); 1239 spin_unlock_bh(&msk->pm.lock); 1240 } 1241 if (slist.nr) 1242 mptcp_pm_remove_subflow(msk, &slist); 1243 } 1244 1245 static void mptcp_nl_remove_addrs_list(struct net *net, 1246 struct list_head *rm_list) 1247 { 1248 long s_slot = 0, s_num = 0; 1249 struct mptcp_sock *msk; 1250 1251 if (list_empty(rm_list)) 1252 return; 1253 1254 while ((msk = mptcp_token_iter_next(net, &s_slot, &s_num)) != NULL) { 1255 struct sock *sk = (struct sock *)msk; 1256 1257 lock_sock(sk); 1258 mptcp_pm_remove_addrs_and_subflows(msk, rm_list); 1259 release_sock(sk); 1260 1261 sock_put(sk); 1262 cond_resched(); 1263 } 1264 } 1265 1266 static void __flush_addrs(struct list_head *list) 1267 { 1268 while (!list_empty(list)) { 1269 struct mptcp_pm_addr_entry *cur; 1270 1271 cur = list_entry(list->next, 1272 struct mptcp_pm_addr_entry, list); 1273 list_del_rcu(&cur->list); 1274 mptcp_pm_free_addr_entry(cur); 1275 } 1276 } 1277 1278 static void __reset_counters(struct pm_nl_pernet *pernet) 1279 { 1280 WRITE_ONCE(pernet->add_addr_signal_max, 0); 1281 WRITE_ONCE(pernet->add_addr_accept_max, 0); 1282 WRITE_ONCE(pernet->local_addr_max, 0); 1283 pernet->addrs = 0; 1284 } 1285 1286 static int mptcp_nl_cmd_flush_addrs(struct sk_buff *skb, struct genl_info *info) 1287 { 1288 struct pm_nl_pernet *pernet = genl_info_pm_nl(info); 1289 LIST_HEAD(free_list); 1290 1291 spin_lock_bh(&pernet->lock); 1292 list_splice_init(&pernet->local_addr_list, &free_list); 1293 __reset_counters(pernet); 1294 pernet->next_id = 1; 1295 bitmap_zero(pernet->id_bitmap, MAX_ADDR_ID + 1); 1296 spin_unlock_bh(&pernet->lock); 1297 mptcp_nl_remove_addrs_list(sock_net(skb->sk), &free_list); 1298 __flush_addrs(&free_list); 1299 return 0; 1300 } 1301 1302 static int mptcp_nl_fill_addr(struct sk_buff *skb, 1303 struct mptcp_pm_addr_entry *entry) 1304 { 1305 struct mptcp_addr_info *addr = &entry->addr; 1306 struct nlattr *attr; 1307 1308 attr = nla_nest_start(skb, MPTCP_PM_ATTR_ADDR); 1309 if (!attr) 1310 return -EMSGSIZE; 1311 1312 if (nla_put_u16(skb, MPTCP_PM_ADDR_ATTR_FAMILY, addr->family)) 1313 goto nla_put_failure; 1314 if (nla_put_u16(skb, MPTCP_PM_ADDR_ATTR_PORT, ntohs(addr->port))) 1315 goto nla_put_failure; 1316 if (nla_put_u8(skb, MPTCP_PM_ADDR_ATTR_ID, addr->id)) 1317 goto nla_put_failure; 1318 if (nla_put_u32(skb, MPTCP_PM_ADDR_ATTR_FLAGS, entry->addr.flags)) 1319 goto nla_put_failure; 1320 if (entry->addr.ifindex && 1321 nla_put_s32(skb, MPTCP_PM_ADDR_ATTR_IF_IDX, entry->addr.ifindex)) 1322 goto nla_put_failure; 1323 1324 if (addr->family == AF_INET && 1325 nla_put_in_addr(skb, MPTCP_PM_ADDR_ATTR_ADDR4, 1326 addr->addr.s_addr)) 1327 goto nla_put_failure; 1328 #if IS_ENABLED(CONFIG_MPTCP_IPV6) 1329 else if (addr->family == AF_INET6 && 1330 nla_put_in6_addr(skb, MPTCP_PM_ADDR_ATTR_ADDR6, &addr->addr6)) 1331 goto nla_put_failure; 1332 #endif 1333 nla_nest_end(skb, attr); 1334 return 0; 1335 1336 nla_put_failure: 1337 nla_nest_cancel(skb, attr); 1338 return -EMSGSIZE; 1339 } 1340 1341 static int mptcp_nl_cmd_get_addr(struct sk_buff *skb, struct genl_info *info) 1342 { 1343 struct nlattr *attr = info->attrs[MPTCP_PM_ATTR_ADDR]; 1344 struct pm_nl_pernet *pernet = genl_info_pm_nl(info); 1345 struct mptcp_pm_addr_entry addr, *entry; 1346 struct sk_buff *msg; 1347 void *reply; 1348 int ret; 1349 1350 ret = mptcp_pm_parse_addr(attr, info, false, &addr); 1351 if (ret < 0) 1352 return ret; 1353 1354 msg = nlmsg_new(NLMSG_DEFAULT_SIZE, GFP_KERNEL); 1355 if (!msg) 1356 return -ENOMEM; 1357 1358 reply = genlmsg_put_reply(msg, info, &mptcp_genl_family, 0, 1359 info->genlhdr->cmd); 1360 if (!reply) { 1361 GENL_SET_ERR_MSG(info, "not enough space in Netlink message"); 1362 ret = -EMSGSIZE; 1363 goto fail; 1364 } 1365 1366 spin_lock_bh(&pernet->lock); 1367 entry = __lookup_addr_by_id(pernet, addr.addr.id); 1368 if (!entry) { 1369 GENL_SET_ERR_MSG(info, "address not found"); 1370 ret = -EINVAL; 1371 goto unlock_fail; 1372 } 1373 1374 ret = mptcp_nl_fill_addr(msg, entry); 1375 if (ret) 1376 goto unlock_fail; 1377 1378 genlmsg_end(msg, reply); 1379 ret = genlmsg_reply(msg, info); 1380 spin_unlock_bh(&pernet->lock); 1381 return ret; 1382 1383 unlock_fail: 1384 spin_unlock_bh(&pernet->lock); 1385 1386 fail: 1387 nlmsg_free(msg); 1388 return ret; 1389 } 1390 1391 static int mptcp_nl_cmd_dump_addrs(struct sk_buff *msg, 1392 struct netlink_callback *cb) 1393 { 1394 struct net *net = sock_net(msg->sk); 1395 struct mptcp_pm_addr_entry *entry; 1396 struct pm_nl_pernet *pernet; 1397 int id = cb->args[0]; 1398 void *hdr; 1399 int i; 1400 1401 pernet = net_generic(net, pm_nl_pernet_id); 1402 1403 spin_lock_bh(&pernet->lock); 1404 for (i = id; i < MAX_ADDR_ID + 1; i++) { 1405 if (test_bit(i, pernet->id_bitmap)) { 1406 entry = __lookup_addr_by_id(pernet, i); 1407 if (!entry) 1408 break; 1409 1410 if (entry->addr.id <= id) 1411 continue; 1412 1413 hdr = genlmsg_put(msg, NETLINK_CB(cb->skb).portid, 1414 cb->nlh->nlmsg_seq, &mptcp_genl_family, 1415 NLM_F_MULTI, MPTCP_PM_CMD_GET_ADDR); 1416 if (!hdr) 1417 break; 1418 1419 if (mptcp_nl_fill_addr(msg, entry) < 0) { 1420 genlmsg_cancel(msg, hdr); 1421 break; 1422 } 1423 1424 id = entry->addr.id; 1425 genlmsg_end(msg, hdr); 1426 } 1427 } 1428 spin_unlock_bh(&pernet->lock); 1429 1430 cb->args[0] = id; 1431 return msg->len; 1432 } 1433 1434 static int parse_limit(struct genl_info *info, int id, unsigned int *limit) 1435 { 1436 struct nlattr *attr = info->attrs[id]; 1437 1438 if (!attr) 1439 return 0; 1440 1441 *limit = nla_get_u32(attr); 1442 if (*limit > MPTCP_PM_ADDR_MAX) { 1443 GENL_SET_ERR_MSG(info, "limit greater than maximum"); 1444 return -EINVAL; 1445 } 1446 return 0; 1447 } 1448 1449 static int 1450 mptcp_nl_cmd_set_limits(struct sk_buff *skb, struct genl_info *info) 1451 { 1452 struct pm_nl_pernet *pernet = genl_info_pm_nl(info); 1453 unsigned int rcv_addrs, subflows; 1454 int ret; 1455 1456 spin_lock_bh(&pernet->lock); 1457 rcv_addrs = pernet->add_addr_accept_max; 1458 ret = parse_limit(info, MPTCP_PM_ATTR_RCV_ADD_ADDRS, &rcv_addrs); 1459 if (ret) 1460 goto unlock; 1461 1462 subflows = pernet->subflows_max; 1463 ret = parse_limit(info, MPTCP_PM_ATTR_SUBFLOWS, &subflows); 1464 if (ret) 1465 goto unlock; 1466 1467 WRITE_ONCE(pernet->add_addr_accept_max, rcv_addrs); 1468 WRITE_ONCE(pernet->subflows_max, subflows); 1469 1470 unlock: 1471 spin_unlock_bh(&pernet->lock); 1472 return ret; 1473 } 1474 1475 static int 1476 mptcp_nl_cmd_get_limits(struct sk_buff *skb, struct genl_info *info) 1477 { 1478 struct pm_nl_pernet *pernet = genl_info_pm_nl(info); 1479 struct sk_buff *msg; 1480 void *reply; 1481 1482 msg = nlmsg_new(NLMSG_DEFAULT_SIZE, GFP_KERNEL); 1483 if (!msg) 1484 return -ENOMEM; 1485 1486 reply = genlmsg_put_reply(msg, info, &mptcp_genl_family, 0, 1487 MPTCP_PM_CMD_GET_LIMITS); 1488 if (!reply) 1489 goto fail; 1490 1491 if (nla_put_u32(msg, MPTCP_PM_ATTR_RCV_ADD_ADDRS, 1492 READ_ONCE(pernet->add_addr_accept_max))) 1493 goto fail; 1494 1495 if (nla_put_u32(msg, MPTCP_PM_ATTR_SUBFLOWS, 1496 READ_ONCE(pernet->subflows_max))) 1497 goto fail; 1498 1499 genlmsg_end(msg, reply); 1500 return genlmsg_reply(msg, info); 1501 1502 fail: 1503 GENL_SET_ERR_MSG(info, "not enough space in Netlink message"); 1504 nlmsg_free(msg); 1505 return -EMSGSIZE; 1506 } 1507 1508 static int mptcp_nl_addr_backup(struct net *net, 1509 struct mptcp_addr_info *addr, 1510 u8 bkup) 1511 { 1512 long s_slot = 0, s_num = 0; 1513 struct mptcp_sock *msk; 1514 int ret = -EINVAL; 1515 1516 while ((msk = mptcp_token_iter_next(net, &s_slot, &s_num)) != NULL) { 1517 struct sock *sk = (struct sock *)msk; 1518 1519 if (list_empty(&msk->conn_list)) 1520 goto next; 1521 1522 lock_sock(sk); 1523 spin_lock_bh(&msk->pm.lock); 1524 ret = mptcp_pm_nl_mp_prio_send_ack(msk, addr, bkup); 1525 spin_unlock_bh(&msk->pm.lock); 1526 release_sock(sk); 1527 1528 next: 1529 sock_put(sk); 1530 cond_resched(); 1531 } 1532 1533 return ret; 1534 } 1535 1536 static int mptcp_nl_cmd_set_flags(struct sk_buff *skb, struct genl_info *info) 1537 { 1538 struct nlattr *attr = info->attrs[MPTCP_PM_ATTR_ADDR]; 1539 struct pm_nl_pernet *pernet = genl_info_pm_nl(info); 1540 struct mptcp_pm_addr_entry addr, *entry; 1541 struct net *net = sock_net(skb->sk); 1542 u8 bkup = 0; 1543 int ret; 1544 1545 ret = mptcp_pm_parse_addr(attr, info, true, &addr); 1546 if (ret < 0) 1547 return ret; 1548 1549 if (addr.addr.flags & MPTCP_PM_ADDR_FLAG_BACKUP) 1550 bkup = 1; 1551 1552 list_for_each_entry(entry, &pernet->local_addr_list, list) { 1553 if (addresses_equal(&entry->addr, &addr.addr, true)) { 1554 ret = mptcp_nl_addr_backup(net, &entry->addr, bkup); 1555 if (ret) 1556 return ret; 1557 1558 if (bkup) 1559 entry->addr.flags |= MPTCP_PM_ADDR_FLAG_BACKUP; 1560 else 1561 entry->addr.flags &= ~MPTCP_PM_ADDR_FLAG_BACKUP; 1562 } 1563 } 1564 1565 return 0; 1566 } 1567 1568 static void mptcp_nl_mcast_send(struct net *net, struct sk_buff *nlskb, gfp_t gfp) 1569 { 1570 genlmsg_multicast_netns(&mptcp_genl_family, net, 1571 nlskb, 0, MPTCP_PM_EV_GRP_OFFSET, gfp); 1572 } 1573 1574 static int mptcp_event_add_subflow(struct sk_buff *skb, const struct sock *ssk) 1575 { 1576 const struct inet_sock *issk = inet_sk(ssk); 1577 const struct mptcp_subflow_context *sf; 1578 1579 if (nla_put_u16(skb, MPTCP_ATTR_FAMILY, ssk->sk_family)) 1580 return -EMSGSIZE; 1581 1582 switch (ssk->sk_family) { 1583 case AF_INET: 1584 if (nla_put_in_addr(skb, MPTCP_ATTR_SADDR4, issk->inet_saddr)) 1585 return -EMSGSIZE; 1586 if (nla_put_in_addr(skb, MPTCP_ATTR_DADDR4, issk->inet_daddr)) 1587 return -EMSGSIZE; 1588 break; 1589 #if IS_ENABLED(CONFIG_MPTCP_IPV6) 1590 case AF_INET6: { 1591 const struct ipv6_pinfo *np = inet6_sk(ssk); 1592 1593 if (nla_put_in6_addr(skb, MPTCP_ATTR_SADDR6, &np->saddr)) 1594 return -EMSGSIZE; 1595 if (nla_put_in6_addr(skb, MPTCP_ATTR_DADDR6, &ssk->sk_v6_daddr)) 1596 return -EMSGSIZE; 1597 break; 1598 } 1599 #endif 1600 default: 1601 WARN_ON_ONCE(1); 1602 return -EMSGSIZE; 1603 } 1604 1605 if (nla_put_be16(skb, MPTCP_ATTR_SPORT, issk->inet_sport)) 1606 return -EMSGSIZE; 1607 if (nla_put_be16(skb, MPTCP_ATTR_DPORT, issk->inet_dport)) 1608 return -EMSGSIZE; 1609 1610 sf = mptcp_subflow_ctx(ssk); 1611 if (WARN_ON_ONCE(!sf)) 1612 return -EINVAL; 1613 1614 if (nla_put_u8(skb, MPTCP_ATTR_LOC_ID, sf->local_id)) 1615 return -EMSGSIZE; 1616 1617 if (nla_put_u8(skb, MPTCP_ATTR_REM_ID, sf->remote_id)) 1618 return -EMSGSIZE; 1619 1620 return 0; 1621 } 1622 1623 static int mptcp_event_put_token_and_ssk(struct sk_buff *skb, 1624 const struct mptcp_sock *msk, 1625 const struct sock *ssk) 1626 { 1627 const struct sock *sk = (const struct sock *)msk; 1628 const struct mptcp_subflow_context *sf; 1629 u8 sk_err; 1630 1631 if (nla_put_u32(skb, MPTCP_ATTR_TOKEN, msk->token)) 1632 return -EMSGSIZE; 1633 1634 if (mptcp_event_add_subflow(skb, ssk)) 1635 return -EMSGSIZE; 1636 1637 sf = mptcp_subflow_ctx(ssk); 1638 if (WARN_ON_ONCE(!sf)) 1639 return -EINVAL; 1640 1641 if (nla_put_u8(skb, MPTCP_ATTR_BACKUP, sf->backup)) 1642 return -EMSGSIZE; 1643 1644 if (ssk->sk_bound_dev_if && 1645 nla_put_s32(skb, MPTCP_ATTR_IF_IDX, ssk->sk_bound_dev_if)) 1646 return -EMSGSIZE; 1647 1648 sk_err = ssk->sk_err; 1649 if (sk_err && sk->sk_state == TCP_ESTABLISHED && 1650 nla_put_u8(skb, MPTCP_ATTR_ERROR, sk_err)) 1651 return -EMSGSIZE; 1652 1653 return 0; 1654 } 1655 1656 static int mptcp_event_sub_established(struct sk_buff *skb, 1657 const struct mptcp_sock *msk, 1658 const struct sock *ssk) 1659 { 1660 return mptcp_event_put_token_and_ssk(skb, msk, ssk); 1661 } 1662 1663 static int mptcp_event_sub_closed(struct sk_buff *skb, 1664 const struct mptcp_sock *msk, 1665 const struct sock *ssk) 1666 { 1667 if (mptcp_event_put_token_and_ssk(skb, msk, ssk)) 1668 return -EMSGSIZE; 1669 1670 return 0; 1671 } 1672 1673 static int mptcp_event_created(struct sk_buff *skb, 1674 const struct mptcp_sock *msk, 1675 const struct sock *ssk) 1676 { 1677 int err = nla_put_u32(skb, MPTCP_ATTR_TOKEN, msk->token); 1678 1679 if (err) 1680 return err; 1681 1682 return mptcp_event_add_subflow(skb, ssk); 1683 } 1684 1685 void mptcp_event_addr_removed(const struct mptcp_sock *msk, uint8_t id) 1686 { 1687 struct net *net = sock_net((const struct sock *)msk); 1688 struct nlmsghdr *nlh; 1689 struct sk_buff *skb; 1690 1691 if (!genl_has_listeners(&mptcp_genl_family, net, MPTCP_PM_EV_GRP_OFFSET)) 1692 return; 1693 1694 skb = nlmsg_new(NLMSG_DEFAULT_SIZE, GFP_ATOMIC); 1695 if (!skb) 1696 return; 1697 1698 nlh = genlmsg_put(skb, 0, 0, &mptcp_genl_family, 0, MPTCP_EVENT_REMOVED); 1699 if (!nlh) 1700 goto nla_put_failure; 1701 1702 if (nla_put_u32(skb, MPTCP_ATTR_TOKEN, msk->token)) 1703 goto nla_put_failure; 1704 1705 if (nla_put_u8(skb, MPTCP_ATTR_REM_ID, id)) 1706 goto nla_put_failure; 1707 1708 genlmsg_end(skb, nlh); 1709 mptcp_nl_mcast_send(net, skb, GFP_ATOMIC); 1710 return; 1711 1712 nla_put_failure: 1713 kfree_skb(skb); 1714 } 1715 1716 void mptcp_event_addr_announced(const struct mptcp_sock *msk, 1717 const struct mptcp_addr_info *info) 1718 { 1719 struct net *net = sock_net((const struct sock *)msk); 1720 struct nlmsghdr *nlh; 1721 struct sk_buff *skb; 1722 1723 if (!genl_has_listeners(&mptcp_genl_family, net, MPTCP_PM_EV_GRP_OFFSET)) 1724 return; 1725 1726 skb = nlmsg_new(NLMSG_DEFAULT_SIZE, GFP_ATOMIC); 1727 if (!skb) 1728 return; 1729 1730 nlh = genlmsg_put(skb, 0, 0, &mptcp_genl_family, 0, 1731 MPTCP_EVENT_ANNOUNCED); 1732 if (!nlh) 1733 goto nla_put_failure; 1734 1735 if (nla_put_u32(skb, MPTCP_ATTR_TOKEN, msk->token)) 1736 goto nla_put_failure; 1737 1738 if (nla_put_u8(skb, MPTCP_ATTR_REM_ID, info->id)) 1739 goto nla_put_failure; 1740 1741 if (nla_put_be16(skb, MPTCP_ATTR_DPORT, info->port)) 1742 goto nla_put_failure; 1743 1744 switch (info->family) { 1745 case AF_INET: 1746 if (nla_put_in_addr(skb, MPTCP_ATTR_DADDR4, info->addr.s_addr)) 1747 goto nla_put_failure; 1748 break; 1749 #if IS_ENABLED(CONFIG_MPTCP_IPV6) 1750 case AF_INET6: 1751 if (nla_put_in6_addr(skb, MPTCP_ATTR_DADDR6, &info->addr6)) 1752 goto nla_put_failure; 1753 break; 1754 #endif 1755 default: 1756 WARN_ON_ONCE(1); 1757 goto nla_put_failure; 1758 } 1759 1760 genlmsg_end(skb, nlh); 1761 mptcp_nl_mcast_send(net, skb, GFP_ATOMIC); 1762 return; 1763 1764 nla_put_failure: 1765 kfree_skb(skb); 1766 } 1767 1768 void mptcp_event(enum mptcp_event_type type, const struct mptcp_sock *msk, 1769 const struct sock *ssk, gfp_t gfp) 1770 { 1771 struct net *net = sock_net((const struct sock *)msk); 1772 struct nlmsghdr *nlh; 1773 struct sk_buff *skb; 1774 1775 if (!genl_has_listeners(&mptcp_genl_family, net, MPTCP_PM_EV_GRP_OFFSET)) 1776 return; 1777 1778 skb = nlmsg_new(NLMSG_DEFAULT_SIZE, gfp); 1779 if (!skb) 1780 return; 1781 1782 nlh = genlmsg_put(skb, 0, 0, &mptcp_genl_family, 0, type); 1783 if (!nlh) 1784 goto nla_put_failure; 1785 1786 switch (type) { 1787 case MPTCP_EVENT_UNSPEC: 1788 WARN_ON_ONCE(1); 1789 break; 1790 case MPTCP_EVENT_CREATED: 1791 case MPTCP_EVENT_ESTABLISHED: 1792 if (mptcp_event_created(skb, msk, ssk) < 0) 1793 goto nla_put_failure; 1794 break; 1795 case MPTCP_EVENT_CLOSED: 1796 if (nla_put_u32(skb, MPTCP_ATTR_TOKEN, msk->token) < 0) 1797 goto nla_put_failure; 1798 break; 1799 case MPTCP_EVENT_ANNOUNCED: 1800 case MPTCP_EVENT_REMOVED: 1801 /* call mptcp_event_addr_announced()/removed instead */ 1802 WARN_ON_ONCE(1); 1803 break; 1804 case MPTCP_EVENT_SUB_ESTABLISHED: 1805 case MPTCP_EVENT_SUB_PRIORITY: 1806 if (mptcp_event_sub_established(skb, msk, ssk) < 0) 1807 goto nla_put_failure; 1808 break; 1809 case MPTCP_EVENT_SUB_CLOSED: 1810 if (mptcp_event_sub_closed(skb, msk, ssk) < 0) 1811 goto nla_put_failure; 1812 break; 1813 } 1814 1815 genlmsg_end(skb, nlh); 1816 mptcp_nl_mcast_send(net, skb, gfp); 1817 return; 1818 1819 nla_put_failure: 1820 kfree_skb(skb); 1821 } 1822 1823 static const struct genl_small_ops mptcp_pm_ops[] = { 1824 { 1825 .cmd = MPTCP_PM_CMD_ADD_ADDR, 1826 .doit = mptcp_nl_cmd_add_addr, 1827 .flags = GENL_ADMIN_PERM, 1828 }, 1829 { 1830 .cmd = MPTCP_PM_CMD_DEL_ADDR, 1831 .doit = mptcp_nl_cmd_del_addr, 1832 .flags = GENL_ADMIN_PERM, 1833 }, 1834 { 1835 .cmd = MPTCP_PM_CMD_FLUSH_ADDRS, 1836 .doit = mptcp_nl_cmd_flush_addrs, 1837 .flags = GENL_ADMIN_PERM, 1838 }, 1839 { 1840 .cmd = MPTCP_PM_CMD_GET_ADDR, 1841 .doit = mptcp_nl_cmd_get_addr, 1842 .dumpit = mptcp_nl_cmd_dump_addrs, 1843 }, 1844 { 1845 .cmd = MPTCP_PM_CMD_SET_LIMITS, 1846 .doit = mptcp_nl_cmd_set_limits, 1847 .flags = GENL_ADMIN_PERM, 1848 }, 1849 { 1850 .cmd = MPTCP_PM_CMD_GET_LIMITS, 1851 .doit = mptcp_nl_cmd_get_limits, 1852 }, 1853 { 1854 .cmd = MPTCP_PM_CMD_SET_FLAGS, 1855 .doit = mptcp_nl_cmd_set_flags, 1856 .flags = GENL_ADMIN_PERM, 1857 }, 1858 }; 1859 1860 static struct genl_family mptcp_genl_family __ro_after_init = { 1861 .name = MPTCP_PM_NAME, 1862 .version = MPTCP_PM_VER, 1863 .maxattr = MPTCP_PM_ATTR_MAX, 1864 .policy = mptcp_pm_policy, 1865 .netnsok = true, 1866 .module = THIS_MODULE, 1867 .small_ops = mptcp_pm_ops, 1868 .n_small_ops = ARRAY_SIZE(mptcp_pm_ops), 1869 .mcgrps = mptcp_pm_mcgrps, 1870 .n_mcgrps = ARRAY_SIZE(mptcp_pm_mcgrps), 1871 }; 1872 1873 static int __net_init pm_nl_init_net(struct net *net) 1874 { 1875 struct pm_nl_pernet *pernet = net_generic(net, pm_nl_pernet_id); 1876 1877 INIT_LIST_HEAD_RCU(&pernet->local_addr_list); 1878 __reset_counters(pernet); 1879 pernet->next_id = 1; 1880 bitmap_zero(pernet->id_bitmap, MAX_ADDR_ID + 1); 1881 spin_lock_init(&pernet->lock); 1882 return 0; 1883 } 1884 1885 static void __net_exit pm_nl_exit_net(struct list_head *net_list) 1886 { 1887 struct net *net; 1888 1889 list_for_each_entry(net, net_list, exit_list) { 1890 struct pm_nl_pernet *pernet = net_generic(net, pm_nl_pernet_id); 1891 1892 /* net is removed from namespace list, can't race with 1893 * other modifiers 1894 */ 1895 __flush_addrs(&pernet->local_addr_list); 1896 } 1897 } 1898 1899 static struct pernet_operations mptcp_pm_pernet_ops = { 1900 .init = pm_nl_init_net, 1901 .exit_batch = pm_nl_exit_net, 1902 .id = &pm_nl_pernet_id, 1903 .size = sizeof(struct pm_nl_pernet), 1904 }; 1905 1906 void __init mptcp_pm_nl_init(void) 1907 { 1908 if (register_pernet_subsys(&mptcp_pm_pernet_ops) < 0) 1909 panic("Failed to register MPTCP PM pernet subsystem.\n"); 1910 1911 if (genl_register_family(&mptcp_genl_family)) 1912 panic("Failed to register MPTCP PM netlink family\n"); 1913 } 1914