1 // SPDX-License-Identifier: GPL-2.0 2 /* Multipath TCP 3 * 4 * Copyright (c) 2020, Red Hat, Inc. 5 */ 6 7 #define pr_fmt(fmt) "MPTCP: " fmt 8 9 #include <linux/inet.h> 10 #include <linux/kernel.h> 11 #include <net/tcp.h> 12 #include <net/netns/generic.h> 13 #include <net/mptcp.h> 14 #include <net/genetlink.h> 15 #include <uapi/linux/mptcp.h> 16 17 #include "protocol.h" 18 #include "mib.h" 19 20 /* forward declaration */ 21 static struct genl_family mptcp_genl_family; 22 23 static int pm_nl_pernet_id; 24 25 struct mptcp_pm_addr_entry { 26 struct list_head list; 27 struct mptcp_addr_info addr; 28 struct rcu_head rcu; 29 struct socket *lsk; 30 }; 31 32 struct mptcp_pm_add_entry { 33 struct list_head list; 34 struct mptcp_addr_info addr; 35 struct timer_list add_timer; 36 struct mptcp_sock *sock; 37 u8 retrans_times; 38 }; 39 40 #define MAX_ADDR_ID 255 41 #define BITMAP_SZ DIV_ROUND_UP(MAX_ADDR_ID + 1, BITS_PER_LONG) 42 43 struct pm_nl_pernet { 44 /* protects pernet updates */ 45 spinlock_t lock; 46 struct list_head local_addr_list; 47 unsigned int addrs; 48 unsigned int add_addr_signal_max; 49 unsigned int add_addr_accept_max; 50 unsigned int local_addr_max; 51 unsigned int subflows_max; 52 unsigned int next_id; 53 unsigned long id_bitmap[BITMAP_SZ]; 54 }; 55 56 #define MPTCP_PM_ADDR_MAX 8 57 #define ADD_ADDR_RETRANS_MAX 3 58 59 static void mptcp_pm_nl_add_addr_send_ack(struct mptcp_sock *msk); 60 61 static bool addresses_equal(const struct mptcp_addr_info *a, 62 struct mptcp_addr_info *b, bool use_port) 63 { 64 bool addr_equals = false; 65 66 if (a->family == b->family) { 67 if (a->family == AF_INET) 68 addr_equals = a->addr.s_addr == b->addr.s_addr; 69 #if IS_ENABLED(CONFIG_MPTCP_IPV6) 70 else 71 addr_equals = !ipv6_addr_cmp(&a->addr6, &b->addr6); 72 } else if (a->family == AF_INET) { 73 if (ipv6_addr_v4mapped(&b->addr6)) 74 addr_equals = a->addr.s_addr == b->addr6.s6_addr32[3]; 75 } else if (b->family == AF_INET) { 76 if (ipv6_addr_v4mapped(&a->addr6)) 77 addr_equals = a->addr6.s6_addr32[3] == b->addr.s_addr; 78 #endif 79 } 80 81 if (!addr_equals) 82 return false; 83 if (!use_port) 84 return true; 85 86 return a->port == b->port; 87 } 88 89 static bool address_zero(const struct mptcp_addr_info *addr) 90 { 91 struct mptcp_addr_info zero; 92 93 memset(&zero, 0, sizeof(zero)); 94 zero.family = addr->family; 95 96 return addresses_equal(addr, &zero, true); 97 } 98 99 static void local_address(const struct sock_common *skc, 100 struct mptcp_addr_info *addr) 101 { 102 addr->family = skc->skc_family; 103 addr->port = htons(skc->skc_num); 104 if (addr->family == AF_INET) 105 addr->addr.s_addr = skc->skc_rcv_saddr; 106 #if IS_ENABLED(CONFIG_MPTCP_IPV6) 107 else if (addr->family == AF_INET6) 108 addr->addr6 = skc->skc_v6_rcv_saddr; 109 #endif 110 } 111 112 static void remote_address(const struct sock_common *skc, 113 struct mptcp_addr_info *addr) 114 { 115 addr->family = skc->skc_family; 116 addr->port = skc->skc_dport; 117 if (addr->family == AF_INET) 118 addr->addr.s_addr = skc->skc_daddr; 119 #if IS_ENABLED(CONFIG_MPTCP_IPV6) 120 else if (addr->family == AF_INET6) 121 addr->addr6 = skc->skc_v6_daddr; 122 #endif 123 } 124 125 static bool lookup_subflow_by_saddr(const struct list_head *list, 126 struct mptcp_addr_info *saddr) 127 { 128 struct mptcp_subflow_context *subflow; 129 struct mptcp_addr_info cur; 130 struct sock_common *skc; 131 132 list_for_each_entry(subflow, list, node) { 133 skc = (struct sock_common *)mptcp_subflow_tcp_sock(subflow); 134 135 local_address(skc, &cur); 136 if (addresses_equal(&cur, saddr, saddr->port)) 137 return true; 138 } 139 140 return false; 141 } 142 143 static struct mptcp_pm_addr_entry * 144 select_local_address(const struct pm_nl_pernet *pernet, 145 struct mptcp_sock *msk) 146 { 147 struct mptcp_pm_addr_entry *entry, *ret = NULL; 148 struct sock *sk = (struct sock *)msk; 149 150 msk_owned_by_me(msk); 151 152 rcu_read_lock(); 153 __mptcp_flush_join_list(msk); 154 list_for_each_entry_rcu(entry, &pernet->local_addr_list, list) { 155 if (!(entry->addr.flags & MPTCP_PM_ADDR_FLAG_SUBFLOW)) 156 continue; 157 158 if (entry->addr.family != sk->sk_family) { 159 #if IS_ENABLED(CONFIG_MPTCP_IPV6) 160 if ((entry->addr.family == AF_INET && 161 !ipv6_addr_v4mapped(&sk->sk_v6_daddr)) || 162 (sk->sk_family == AF_INET && 163 !ipv6_addr_v4mapped(&entry->addr.addr6))) 164 #endif 165 continue; 166 } 167 168 /* avoid any address already in use by subflows and 169 * pending join 170 */ 171 if (!lookup_subflow_by_saddr(&msk->conn_list, &entry->addr)) { 172 ret = entry; 173 break; 174 } 175 } 176 rcu_read_unlock(); 177 return ret; 178 } 179 180 static struct mptcp_pm_addr_entry * 181 select_signal_address(struct pm_nl_pernet *pernet, unsigned int pos) 182 { 183 struct mptcp_pm_addr_entry *entry, *ret = NULL; 184 int i = 0; 185 186 rcu_read_lock(); 187 /* do not keep any additional per socket state, just signal 188 * the address list in order. 189 * Note: removal from the local address list during the msk life-cycle 190 * can lead to additional addresses not being announced. 191 */ 192 list_for_each_entry_rcu(entry, &pernet->local_addr_list, list) { 193 if (!(entry->addr.flags & MPTCP_PM_ADDR_FLAG_SIGNAL)) 194 continue; 195 if (i++ == pos) { 196 ret = entry; 197 break; 198 } 199 } 200 rcu_read_unlock(); 201 return ret; 202 } 203 204 unsigned int mptcp_pm_get_add_addr_signal_max(struct mptcp_sock *msk) 205 { 206 struct pm_nl_pernet *pernet; 207 208 pernet = net_generic(sock_net((struct sock *)msk), pm_nl_pernet_id); 209 return READ_ONCE(pernet->add_addr_signal_max); 210 } 211 EXPORT_SYMBOL_GPL(mptcp_pm_get_add_addr_signal_max); 212 213 unsigned int mptcp_pm_get_add_addr_accept_max(struct mptcp_sock *msk) 214 { 215 struct pm_nl_pernet *pernet; 216 217 pernet = net_generic(sock_net((struct sock *)msk), pm_nl_pernet_id); 218 return READ_ONCE(pernet->add_addr_accept_max); 219 } 220 EXPORT_SYMBOL_GPL(mptcp_pm_get_add_addr_accept_max); 221 222 unsigned int mptcp_pm_get_subflows_max(struct mptcp_sock *msk) 223 { 224 struct pm_nl_pernet *pernet; 225 226 pernet = net_generic(sock_net((struct sock *)msk), pm_nl_pernet_id); 227 return READ_ONCE(pernet->subflows_max); 228 } 229 EXPORT_SYMBOL_GPL(mptcp_pm_get_subflows_max); 230 231 unsigned int mptcp_pm_get_local_addr_max(struct mptcp_sock *msk) 232 { 233 struct pm_nl_pernet *pernet; 234 235 pernet = net_generic(sock_net((struct sock *)msk), pm_nl_pernet_id); 236 return READ_ONCE(pernet->local_addr_max); 237 } 238 EXPORT_SYMBOL_GPL(mptcp_pm_get_local_addr_max); 239 240 static void check_work_pending(struct mptcp_sock *msk) 241 { 242 if (msk->pm.add_addr_signaled == mptcp_pm_get_add_addr_signal_max(msk) && 243 (msk->pm.local_addr_used == mptcp_pm_get_local_addr_max(msk) || 244 msk->pm.subflows == mptcp_pm_get_subflows_max(msk))) 245 WRITE_ONCE(msk->pm.work_pending, false); 246 } 247 248 static struct mptcp_pm_add_entry * 249 lookup_anno_list_by_saddr(struct mptcp_sock *msk, 250 struct mptcp_addr_info *addr) 251 { 252 struct mptcp_pm_add_entry *entry; 253 254 lockdep_assert_held(&msk->pm.lock); 255 256 list_for_each_entry(entry, &msk->pm.anno_list, list) { 257 if (addresses_equal(&entry->addr, addr, true)) 258 return entry; 259 } 260 261 return NULL; 262 } 263 264 bool mptcp_pm_sport_in_anno_list(struct mptcp_sock *msk, const struct sock *sk) 265 { 266 struct mptcp_pm_add_entry *entry; 267 struct mptcp_addr_info saddr; 268 bool ret = false; 269 270 local_address((struct sock_common *)sk, &saddr); 271 272 spin_lock_bh(&msk->pm.lock); 273 list_for_each_entry(entry, &msk->pm.anno_list, list) { 274 if (addresses_equal(&entry->addr, &saddr, true)) { 275 ret = true; 276 goto out; 277 } 278 } 279 280 out: 281 spin_unlock_bh(&msk->pm.lock); 282 return ret; 283 } 284 285 static void mptcp_pm_add_timer(struct timer_list *timer) 286 { 287 struct mptcp_pm_add_entry *entry = from_timer(entry, timer, add_timer); 288 struct mptcp_sock *msk = entry->sock; 289 struct sock *sk = (struct sock *)msk; 290 291 pr_debug("msk=%p", msk); 292 293 if (!msk) 294 return; 295 296 if (inet_sk_state_load(sk) == TCP_CLOSE) 297 return; 298 299 if (!entry->addr.id) 300 return; 301 302 if (mptcp_pm_should_add_signal(msk)) { 303 sk_reset_timer(sk, timer, jiffies + TCP_RTO_MAX / 8); 304 goto out; 305 } 306 307 spin_lock_bh(&msk->pm.lock); 308 309 if (!mptcp_pm_should_add_signal(msk)) { 310 pr_debug("retransmit ADD_ADDR id=%d", entry->addr.id); 311 mptcp_pm_announce_addr(msk, &entry->addr, false, entry->addr.port); 312 mptcp_pm_add_addr_send_ack(msk); 313 entry->retrans_times++; 314 } 315 316 if (entry->retrans_times < ADD_ADDR_RETRANS_MAX) 317 sk_reset_timer(sk, timer, 318 jiffies + mptcp_get_add_addr_timeout(sock_net(sk))); 319 320 spin_unlock_bh(&msk->pm.lock); 321 322 out: 323 __sock_put(sk); 324 } 325 326 struct mptcp_pm_add_entry * 327 mptcp_pm_del_add_timer(struct mptcp_sock *msk, 328 struct mptcp_addr_info *addr) 329 { 330 struct mptcp_pm_add_entry *entry; 331 struct sock *sk = (struct sock *)msk; 332 333 spin_lock_bh(&msk->pm.lock); 334 entry = lookup_anno_list_by_saddr(msk, addr); 335 if (entry) 336 entry->retrans_times = ADD_ADDR_RETRANS_MAX; 337 spin_unlock_bh(&msk->pm.lock); 338 339 if (entry) 340 sk_stop_timer_sync(sk, &entry->add_timer); 341 342 return entry; 343 } 344 345 static bool mptcp_pm_alloc_anno_list(struct mptcp_sock *msk, 346 struct mptcp_pm_addr_entry *entry) 347 { 348 struct mptcp_pm_add_entry *add_entry = NULL; 349 struct sock *sk = (struct sock *)msk; 350 struct net *net = sock_net(sk); 351 352 lockdep_assert_held(&msk->pm.lock); 353 354 if (lookup_anno_list_by_saddr(msk, &entry->addr)) 355 return false; 356 357 add_entry = kmalloc(sizeof(*add_entry), GFP_ATOMIC); 358 if (!add_entry) 359 return false; 360 361 list_add(&add_entry->list, &msk->pm.anno_list); 362 363 add_entry->addr = entry->addr; 364 add_entry->sock = msk; 365 add_entry->retrans_times = 0; 366 367 timer_setup(&add_entry->add_timer, mptcp_pm_add_timer, 0); 368 sk_reset_timer(sk, &add_entry->add_timer, 369 jiffies + mptcp_get_add_addr_timeout(net)); 370 371 return true; 372 } 373 374 void mptcp_pm_free_anno_list(struct mptcp_sock *msk) 375 { 376 struct mptcp_pm_add_entry *entry, *tmp; 377 struct sock *sk = (struct sock *)msk; 378 LIST_HEAD(free_list); 379 380 pr_debug("msk=%p", msk); 381 382 spin_lock_bh(&msk->pm.lock); 383 list_splice_init(&msk->pm.anno_list, &free_list); 384 spin_unlock_bh(&msk->pm.lock); 385 386 list_for_each_entry_safe(entry, tmp, &free_list, list) { 387 sk_stop_timer_sync(sk, &entry->add_timer); 388 kfree(entry); 389 } 390 } 391 392 static void mptcp_pm_create_subflow_or_signal_addr(struct mptcp_sock *msk) 393 { 394 struct sock *sk = (struct sock *)msk; 395 struct mptcp_pm_addr_entry *local; 396 unsigned int add_addr_signal_max; 397 unsigned int local_addr_max; 398 struct pm_nl_pernet *pernet; 399 unsigned int subflows_max; 400 401 pernet = net_generic(sock_net(sk), pm_nl_pernet_id); 402 403 add_addr_signal_max = mptcp_pm_get_add_addr_signal_max(msk); 404 local_addr_max = mptcp_pm_get_local_addr_max(msk); 405 subflows_max = mptcp_pm_get_subflows_max(msk); 406 407 pr_debug("local %d:%d signal %d:%d subflows %d:%d\n", 408 msk->pm.local_addr_used, local_addr_max, 409 msk->pm.add_addr_signaled, add_addr_signal_max, 410 msk->pm.subflows, subflows_max); 411 412 /* check first for announce */ 413 if (msk->pm.add_addr_signaled < add_addr_signal_max) { 414 local = select_signal_address(pernet, 415 msk->pm.add_addr_signaled); 416 417 if (local) { 418 if (mptcp_pm_alloc_anno_list(msk, local)) { 419 msk->pm.add_addr_signaled++; 420 mptcp_pm_announce_addr(msk, &local->addr, false, local->addr.port); 421 mptcp_pm_nl_add_addr_send_ack(msk); 422 } 423 } else { 424 /* pick failed, avoid fourther attempts later */ 425 msk->pm.local_addr_used = add_addr_signal_max; 426 } 427 428 check_work_pending(msk); 429 } 430 431 /* check if should create a new subflow */ 432 if (msk->pm.local_addr_used < local_addr_max && 433 msk->pm.subflows < subflows_max) { 434 local = select_local_address(pernet, msk); 435 if (local) { 436 struct mptcp_addr_info remote = { 0 }; 437 438 msk->pm.local_addr_used++; 439 msk->pm.subflows++; 440 check_work_pending(msk); 441 remote_address((struct sock_common *)sk, &remote); 442 spin_unlock_bh(&msk->pm.lock); 443 __mptcp_subflow_connect(sk, &local->addr, &remote); 444 spin_lock_bh(&msk->pm.lock); 445 return; 446 } 447 448 /* lookup failed, avoid fourther attempts later */ 449 msk->pm.local_addr_used = local_addr_max; 450 check_work_pending(msk); 451 } 452 } 453 454 static void mptcp_pm_nl_fully_established(struct mptcp_sock *msk) 455 { 456 mptcp_pm_create_subflow_or_signal_addr(msk); 457 } 458 459 static void mptcp_pm_nl_subflow_established(struct mptcp_sock *msk) 460 { 461 mptcp_pm_create_subflow_or_signal_addr(msk); 462 } 463 464 static void mptcp_pm_nl_add_addr_received(struct mptcp_sock *msk) 465 { 466 struct sock *sk = (struct sock *)msk; 467 unsigned int add_addr_accept_max; 468 struct mptcp_addr_info remote; 469 struct mptcp_addr_info local; 470 unsigned int subflows_max; 471 bool use_port = false; 472 473 add_addr_accept_max = mptcp_pm_get_add_addr_accept_max(msk); 474 subflows_max = mptcp_pm_get_subflows_max(msk); 475 476 pr_debug("accepted %d:%d remote family %d", 477 msk->pm.add_addr_accepted, add_addr_accept_max, 478 msk->pm.remote.family); 479 msk->pm.add_addr_accepted++; 480 msk->pm.subflows++; 481 if (msk->pm.add_addr_accepted >= add_addr_accept_max || 482 msk->pm.subflows >= subflows_max) 483 WRITE_ONCE(msk->pm.accept_addr, false); 484 485 /* connect to the specified remote address, using whatever 486 * local address the routing configuration will pick. 487 */ 488 remote = msk->pm.remote; 489 if (!remote.port) 490 remote.port = sk->sk_dport; 491 else 492 use_port = true; 493 memset(&local, 0, sizeof(local)); 494 local.family = remote.family; 495 496 spin_unlock_bh(&msk->pm.lock); 497 __mptcp_subflow_connect(sk, &local, &remote); 498 spin_lock_bh(&msk->pm.lock); 499 500 mptcp_pm_announce_addr(msk, &remote, true, use_port); 501 mptcp_pm_nl_add_addr_send_ack(msk); 502 } 503 504 static void mptcp_pm_nl_add_addr_send_ack(struct mptcp_sock *msk) 505 { 506 struct mptcp_subflow_context *subflow; 507 508 msk_owned_by_me(msk); 509 lockdep_assert_held(&msk->pm.lock); 510 511 if (!mptcp_pm_should_add_signal(msk)) 512 return; 513 514 __mptcp_flush_join_list(msk); 515 subflow = list_first_entry_or_null(&msk->conn_list, typeof(*subflow), node); 516 if (subflow) { 517 struct sock *ssk = mptcp_subflow_tcp_sock(subflow); 518 u8 add_addr; 519 520 spin_unlock_bh(&msk->pm.lock); 521 pr_debug("send ack for add_addr%s%s", 522 mptcp_pm_should_add_signal_ipv6(msk) ? " [ipv6]" : "", 523 mptcp_pm_should_add_signal_port(msk) ? " [port]" : ""); 524 525 lock_sock(ssk); 526 tcp_send_ack(ssk); 527 release_sock(ssk); 528 spin_lock_bh(&msk->pm.lock); 529 530 add_addr = READ_ONCE(msk->pm.addr_signal); 531 if (mptcp_pm_should_add_signal_ipv6(msk)) 532 add_addr &= ~BIT(MPTCP_ADD_ADDR_IPV6); 533 if (mptcp_pm_should_add_signal_port(msk)) 534 add_addr &= ~BIT(MPTCP_ADD_ADDR_PORT); 535 WRITE_ONCE(msk->pm.addr_signal, add_addr); 536 } 537 } 538 539 int mptcp_pm_nl_mp_prio_send_ack(struct mptcp_sock *msk, 540 struct mptcp_addr_info *addr, 541 u8 bkup) 542 { 543 struct mptcp_subflow_context *subflow; 544 545 pr_debug("bkup=%d", bkup); 546 547 mptcp_for_each_subflow(msk, subflow) { 548 struct sock *ssk = mptcp_subflow_tcp_sock(subflow); 549 struct sock *sk = (struct sock *)msk; 550 struct mptcp_addr_info local; 551 552 local_address((struct sock_common *)ssk, &local); 553 if (!addresses_equal(&local, addr, addr->port)) 554 continue; 555 556 subflow->backup = bkup; 557 subflow->send_mp_prio = 1; 558 subflow->request_bkup = bkup; 559 __MPTCP_INC_STATS(sock_net(sk), MPTCP_MIB_MPPRIOTX); 560 561 spin_unlock_bh(&msk->pm.lock); 562 pr_debug("send ack for mp_prio"); 563 lock_sock(ssk); 564 tcp_send_ack(ssk); 565 release_sock(ssk); 566 spin_lock_bh(&msk->pm.lock); 567 568 return 0; 569 } 570 571 return -EINVAL; 572 } 573 574 static void mptcp_pm_nl_rm_addr_received(struct mptcp_sock *msk) 575 { 576 struct mptcp_subflow_context *subflow, *tmp; 577 struct sock *sk = (struct sock *)msk; 578 u8 i; 579 580 pr_debug("address rm_list_nr %d", msk->pm.rm_list_rx.nr); 581 582 msk_owned_by_me(msk); 583 584 if (!msk->pm.rm_list_rx.nr) 585 return; 586 587 if (list_empty(&msk->conn_list)) 588 return; 589 590 for (i = 0; i < msk->pm.rm_list_rx.nr; i++) { 591 list_for_each_entry_safe(subflow, tmp, &msk->conn_list, node) { 592 struct sock *ssk = mptcp_subflow_tcp_sock(subflow); 593 int how = RCV_SHUTDOWN | SEND_SHUTDOWN; 594 595 if (msk->pm.rm_list_rx.ids[i] != subflow->remote_id) 596 continue; 597 598 pr_debug(" -> address rm_list_ids[%d]=%u", i, msk->pm.rm_list_rx.ids[i]); 599 spin_unlock_bh(&msk->pm.lock); 600 mptcp_subflow_shutdown(sk, ssk, how); 601 mptcp_close_ssk(sk, ssk, subflow); 602 spin_lock_bh(&msk->pm.lock); 603 604 msk->pm.add_addr_accepted--; 605 msk->pm.subflows--; 606 WRITE_ONCE(msk->pm.accept_addr, true); 607 608 __MPTCP_INC_STATS(sock_net(sk), MPTCP_MIB_RMADDR); 609 610 break; 611 } 612 } 613 } 614 615 void mptcp_pm_nl_work(struct mptcp_sock *msk) 616 { 617 struct mptcp_pm_data *pm = &msk->pm; 618 619 msk_owned_by_me(msk); 620 621 spin_lock_bh(&msk->pm.lock); 622 623 pr_debug("msk=%p status=%x", msk, pm->status); 624 if (pm->status & BIT(MPTCP_PM_ADD_ADDR_RECEIVED)) { 625 pm->status &= ~BIT(MPTCP_PM_ADD_ADDR_RECEIVED); 626 mptcp_pm_nl_add_addr_received(msk); 627 } 628 if (pm->status & BIT(MPTCP_PM_ADD_ADDR_SEND_ACK)) { 629 pm->status &= ~BIT(MPTCP_PM_ADD_ADDR_SEND_ACK); 630 mptcp_pm_nl_add_addr_send_ack(msk); 631 } 632 if (pm->status & BIT(MPTCP_PM_RM_ADDR_RECEIVED)) { 633 pm->status &= ~BIT(MPTCP_PM_RM_ADDR_RECEIVED); 634 mptcp_pm_nl_rm_addr_received(msk); 635 } 636 if (pm->status & BIT(MPTCP_PM_ESTABLISHED)) { 637 pm->status &= ~BIT(MPTCP_PM_ESTABLISHED); 638 mptcp_pm_nl_fully_established(msk); 639 } 640 if (pm->status & BIT(MPTCP_PM_SUBFLOW_ESTABLISHED)) { 641 pm->status &= ~BIT(MPTCP_PM_SUBFLOW_ESTABLISHED); 642 mptcp_pm_nl_subflow_established(msk); 643 } 644 645 spin_unlock_bh(&msk->pm.lock); 646 } 647 648 void mptcp_pm_nl_rm_subflow_received(struct mptcp_sock *msk, 649 const struct mptcp_rm_list *rm_list) 650 { 651 struct mptcp_subflow_context *subflow, *tmp; 652 struct sock *sk = (struct sock *)msk; 653 u8 i; 654 655 pr_debug("subflow rm_list_nr %d", rm_list->nr); 656 657 msk_owned_by_me(msk); 658 659 if (!rm_list->nr) 660 return; 661 662 if (list_empty(&msk->conn_list)) 663 return; 664 665 for (i = 0; i < rm_list->nr; i++) { 666 list_for_each_entry_safe(subflow, tmp, &msk->conn_list, node) { 667 struct sock *ssk = mptcp_subflow_tcp_sock(subflow); 668 int how = RCV_SHUTDOWN | SEND_SHUTDOWN; 669 670 if (rm_list->ids[i] != subflow->local_id) 671 continue; 672 673 pr_debug(" -> subflow rm_list_ids[%d]=%u", i, rm_list->ids[i]); 674 spin_unlock_bh(&msk->pm.lock); 675 mptcp_subflow_shutdown(sk, ssk, how); 676 mptcp_close_ssk(sk, ssk, subflow); 677 spin_lock_bh(&msk->pm.lock); 678 679 msk->pm.local_addr_used--; 680 msk->pm.subflows--; 681 682 __MPTCP_INC_STATS(sock_net(sk), MPTCP_MIB_RMSUBFLOW); 683 684 break; 685 } 686 } 687 } 688 689 static bool address_use_port(struct mptcp_pm_addr_entry *entry) 690 { 691 return (entry->addr.flags & 692 (MPTCP_PM_ADDR_FLAG_SIGNAL | MPTCP_PM_ADDR_FLAG_SUBFLOW)) == 693 MPTCP_PM_ADDR_FLAG_SIGNAL; 694 } 695 696 static int mptcp_pm_nl_append_new_local_addr(struct pm_nl_pernet *pernet, 697 struct mptcp_pm_addr_entry *entry) 698 { 699 struct mptcp_pm_addr_entry *cur; 700 unsigned int addr_max; 701 int ret = -EINVAL; 702 703 spin_lock_bh(&pernet->lock); 704 /* to keep the code simple, don't do IDR-like allocation for address ID, 705 * just bail when we exceed limits 706 */ 707 if (pernet->next_id == MAX_ADDR_ID) 708 pernet->next_id = 1; 709 if (pernet->addrs >= MPTCP_PM_ADDR_MAX) 710 goto out; 711 if (test_bit(entry->addr.id, pernet->id_bitmap)) 712 goto out; 713 714 /* do not insert duplicate address, differentiate on port only 715 * singled addresses 716 */ 717 list_for_each_entry(cur, &pernet->local_addr_list, list) { 718 if (addresses_equal(&cur->addr, &entry->addr, 719 address_use_port(entry) && 720 address_use_port(cur))) 721 goto out; 722 } 723 724 if (!entry->addr.id) { 725 find_next: 726 entry->addr.id = find_next_zero_bit(pernet->id_bitmap, 727 MAX_ADDR_ID + 1, 728 pernet->next_id); 729 if ((!entry->addr.id || entry->addr.id > MAX_ADDR_ID) && 730 pernet->next_id != 1) { 731 pernet->next_id = 1; 732 goto find_next; 733 } 734 } 735 736 if (!entry->addr.id || entry->addr.id > MAX_ADDR_ID) 737 goto out; 738 739 __set_bit(entry->addr.id, pernet->id_bitmap); 740 if (entry->addr.id > pernet->next_id) 741 pernet->next_id = entry->addr.id; 742 743 if (entry->addr.flags & MPTCP_PM_ADDR_FLAG_SIGNAL) { 744 addr_max = pernet->add_addr_signal_max; 745 WRITE_ONCE(pernet->add_addr_signal_max, addr_max + 1); 746 } 747 if (entry->addr.flags & MPTCP_PM_ADDR_FLAG_SUBFLOW) { 748 addr_max = pernet->local_addr_max; 749 WRITE_ONCE(pernet->local_addr_max, addr_max + 1); 750 } 751 752 pernet->addrs++; 753 list_add_tail_rcu(&entry->list, &pernet->local_addr_list); 754 ret = entry->addr.id; 755 756 out: 757 spin_unlock_bh(&pernet->lock); 758 return ret; 759 } 760 761 static int mptcp_pm_nl_create_listen_socket(struct sock *sk, 762 struct mptcp_pm_addr_entry *entry) 763 { 764 struct sockaddr_storage addr; 765 struct mptcp_sock *msk; 766 struct socket *ssock; 767 int backlog = 1024; 768 int err; 769 770 err = sock_create_kern(sock_net(sk), entry->addr.family, 771 SOCK_STREAM, IPPROTO_MPTCP, &entry->lsk); 772 if (err) 773 return err; 774 775 msk = mptcp_sk(entry->lsk->sk); 776 if (!msk) { 777 err = -EINVAL; 778 goto out; 779 } 780 781 ssock = __mptcp_nmpc_socket(msk); 782 if (!ssock) { 783 err = -EINVAL; 784 goto out; 785 } 786 787 mptcp_info2sockaddr(&entry->addr, &addr, entry->addr.family); 788 err = kernel_bind(ssock, (struct sockaddr *)&addr, 789 sizeof(struct sockaddr_in)); 790 if (err) { 791 pr_warn("kernel_bind error, err=%d", err); 792 goto out; 793 } 794 795 err = kernel_listen(ssock, backlog); 796 if (err) { 797 pr_warn("kernel_listen error, err=%d", err); 798 goto out; 799 } 800 801 return 0; 802 803 out: 804 sock_release(entry->lsk); 805 return err; 806 } 807 808 int mptcp_pm_nl_get_local_id(struct mptcp_sock *msk, struct sock_common *skc) 809 { 810 struct mptcp_pm_addr_entry *entry; 811 struct mptcp_addr_info skc_local; 812 struct mptcp_addr_info msk_local; 813 struct pm_nl_pernet *pernet; 814 int ret = -1; 815 816 if (WARN_ON_ONCE(!msk)) 817 return -1; 818 819 /* The 0 ID mapping is defined by the first subflow, copied into the msk 820 * addr 821 */ 822 local_address((struct sock_common *)msk, &msk_local); 823 local_address((struct sock_common *)skc, &skc_local); 824 if (addresses_equal(&msk_local, &skc_local, false)) 825 return 0; 826 827 if (address_zero(&skc_local)) 828 return 0; 829 830 pernet = net_generic(sock_net((struct sock *)msk), pm_nl_pernet_id); 831 832 rcu_read_lock(); 833 list_for_each_entry_rcu(entry, &pernet->local_addr_list, list) { 834 if (addresses_equal(&entry->addr, &skc_local, entry->addr.port)) { 835 ret = entry->addr.id; 836 break; 837 } 838 } 839 rcu_read_unlock(); 840 if (ret >= 0) 841 return ret; 842 843 /* address not found, add to local list */ 844 entry = kmalloc(sizeof(*entry), GFP_ATOMIC); 845 if (!entry) 846 return -ENOMEM; 847 848 entry->addr = skc_local; 849 entry->addr.ifindex = 0; 850 entry->addr.flags = 0; 851 entry->addr.id = 0; 852 entry->addr.port = 0; 853 entry->lsk = NULL; 854 ret = mptcp_pm_nl_append_new_local_addr(pernet, entry); 855 if (ret < 0) 856 kfree(entry); 857 858 return ret; 859 } 860 861 void mptcp_pm_nl_data_init(struct mptcp_sock *msk) 862 { 863 struct mptcp_pm_data *pm = &msk->pm; 864 bool subflows; 865 866 subflows = !!mptcp_pm_get_subflows_max(msk); 867 WRITE_ONCE(pm->work_pending, (!!mptcp_pm_get_local_addr_max(msk) && subflows) || 868 !!mptcp_pm_get_add_addr_signal_max(msk)); 869 WRITE_ONCE(pm->accept_addr, !!mptcp_pm_get_add_addr_accept_max(msk) && subflows); 870 WRITE_ONCE(pm->accept_subflow, subflows); 871 } 872 873 #define MPTCP_PM_CMD_GRP_OFFSET 0 874 #define MPTCP_PM_EV_GRP_OFFSET 1 875 876 static const struct genl_multicast_group mptcp_pm_mcgrps[] = { 877 [MPTCP_PM_CMD_GRP_OFFSET] = { .name = MPTCP_PM_CMD_GRP_NAME, }, 878 [MPTCP_PM_EV_GRP_OFFSET] = { .name = MPTCP_PM_EV_GRP_NAME, 879 .flags = GENL_UNS_ADMIN_PERM, 880 }, 881 }; 882 883 static const struct nla_policy 884 mptcp_pm_addr_policy[MPTCP_PM_ADDR_ATTR_MAX + 1] = { 885 [MPTCP_PM_ADDR_ATTR_FAMILY] = { .type = NLA_U16, }, 886 [MPTCP_PM_ADDR_ATTR_ID] = { .type = NLA_U8, }, 887 [MPTCP_PM_ADDR_ATTR_ADDR4] = { .type = NLA_U32, }, 888 [MPTCP_PM_ADDR_ATTR_ADDR6] = 889 NLA_POLICY_EXACT_LEN(sizeof(struct in6_addr)), 890 [MPTCP_PM_ADDR_ATTR_PORT] = { .type = NLA_U16 }, 891 [MPTCP_PM_ADDR_ATTR_FLAGS] = { .type = NLA_U32 }, 892 [MPTCP_PM_ADDR_ATTR_IF_IDX] = { .type = NLA_S32 }, 893 }; 894 895 static const struct nla_policy mptcp_pm_policy[MPTCP_PM_ATTR_MAX + 1] = { 896 [MPTCP_PM_ATTR_ADDR] = 897 NLA_POLICY_NESTED(mptcp_pm_addr_policy), 898 [MPTCP_PM_ATTR_RCV_ADD_ADDRS] = { .type = NLA_U32, }, 899 [MPTCP_PM_ATTR_SUBFLOWS] = { .type = NLA_U32, }, 900 }; 901 902 static int mptcp_pm_family_to_addr(int family) 903 { 904 #if IS_ENABLED(CONFIG_MPTCP_IPV6) 905 if (family == AF_INET6) 906 return MPTCP_PM_ADDR_ATTR_ADDR6; 907 #endif 908 return MPTCP_PM_ADDR_ATTR_ADDR4; 909 } 910 911 static int mptcp_pm_parse_addr(struct nlattr *attr, struct genl_info *info, 912 bool require_family, 913 struct mptcp_pm_addr_entry *entry) 914 { 915 struct nlattr *tb[MPTCP_PM_ADDR_ATTR_MAX + 1]; 916 int err, addr_addr; 917 918 if (!attr) { 919 GENL_SET_ERR_MSG(info, "missing address info"); 920 return -EINVAL; 921 } 922 923 /* no validation needed - was already done via nested policy */ 924 err = nla_parse_nested_deprecated(tb, MPTCP_PM_ADDR_ATTR_MAX, attr, 925 mptcp_pm_addr_policy, info->extack); 926 if (err) 927 return err; 928 929 memset(entry, 0, sizeof(*entry)); 930 if (!tb[MPTCP_PM_ADDR_ATTR_FAMILY]) { 931 if (!require_family) 932 goto skip_family; 933 934 NL_SET_ERR_MSG_ATTR(info->extack, attr, 935 "missing family"); 936 return -EINVAL; 937 } 938 939 entry->addr.family = nla_get_u16(tb[MPTCP_PM_ADDR_ATTR_FAMILY]); 940 if (entry->addr.family != AF_INET 941 #if IS_ENABLED(CONFIG_MPTCP_IPV6) 942 && entry->addr.family != AF_INET6 943 #endif 944 ) { 945 NL_SET_ERR_MSG_ATTR(info->extack, attr, 946 "unknown address family"); 947 return -EINVAL; 948 } 949 addr_addr = mptcp_pm_family_to_addr(entry->addr.family); 950 if (!tb[addr_addr]) { 951 NL_SET_ERR_MSG_ATTR(info->extack, attr, 952 "missing address data"); 953 return -EINVAL; 954 } 955 956 #if IS_ENABLED(CONFIG_MPTCP_IPV6) 957 if (entry->addr.family == AF_INET6) 958 entry->addr.addr6 = nla_get_in6_addr(tb[addr_addr]); 959 else 960 #endif 961 entry->addr.addr.s_addr = nla_get_in_addr(tb[addr_addr]); 962 963 skip_family: 964 if (tb[MPTCP_PM_ADDR_ATTR_IF_IDX]) { 965 u32 val = nla_get_s32(tb[MPTCP_PM_ADDR_ATTR_IF_IDX]); 966 967 entry->addr.ifindex = val; 968 } 969 970 if (tb[MPTCP_PM_ADDR_ATTR_ID]) 971 entry->addr.id = nla_get_u8(tb[MPTCP_PM_ADDR_ATTR_ID]); 972 973 if (tb[MPTCP_PM_ADDR_ATTR_FLAGS]) 974 entry->addr.flags = nla_get_u32(tb[MPTCP_PM_ADDR_ATTR_FLAGS]); 975 976 if (tb[MPTCP_PM_ADDR_ATTR_PORT]) 977 entry->addr.port = htons(nla_get_u16(tb[MPTCP_PM_ADDR_ATTR_PORT])); 978 979 return 0; 980 } 981 982 static struct pm_nl_pernet *genl_info_pm_nl(struct genl_info *info) 983 { 984 return net_generic(genl_info_net(info), pm_nl_pernet_id); 985 } 986 987 static int mptcp_nl_add_subflow_or_signal_addr(struct net *net) 988 { 989 struct mptcp_sock *msk; 990 long s_slot = 0, s_num = 0; 991 992 while ((msk = mptcp_token_iter_next(net, &s_slot, &s_num)) != NULL) { 993 struct sock *sk = (struct sock *)msk; 994 995 if (!READ_ONCE(msk->fully_established)) 996 goto next; 997 998 lock_sock(sk); 999 spin_lock_bh(&msk->pm.lock); 1000 mptcp_pm_create_subflow_or_signal_addr(msk); 1001 spin_unlock_bh(&msk->pm.lock); 1002 release_sock(sk); 1003 1004 next: 1005 sock_put(sk); 1006 cond_resched(); 1007 } 1008 1009 return 0; 1010 } 1011 1012 static int mptcp_nl_cmd_add_addr(struct sk_buff *skb, struct genl_info *info) 1013 { 1014 struct nlattr *attr = info->attrs[MPTCP_PM_ATTR_ADDR]; 1015 struct pm_nl_pernet *pernet = genl_info_pm_nl(info); 1016 struct mptcp_pm_addr_entry addr, *entry; 1017 int ret; 1018 1019 ret = mptcp_pm_parse_addr(attr, info, true, &addr); 1020 if (ret < 0) 1021 return ret; 1022 1023 entry = kmalloc(sizeof(*entry), GFP_KERNEL); 1024 if (!entry) { 1025 GENL_SET_ERR_MSG(info, "can't allocate addr"); 1026 return -ENOMEM; 1027 } 1028 1029 *entry = addr; 1030 if (entry->addr.port) { 1031 ret = mptcp_pm_nl_create_listen_socket(skb->sk, entry); 1032 if (ret) { 1033 GENL_SET_ERR_MSG(info, "create listen socket error"); 1034 kfree(entry); 1035 return ret; 1036 } 1037 } 1038 ret = mptcp_pm_nl_append_new_local_addr(pernet, entry); 1039 if (ret < 0) { 1040 GENL_SET_ERR_MSG(info, "too many addresses or duplicate one"); 1041 if (entry->lsk) 1042 sock_release(entry->lsk); 1043 kfree(entry); 1044 return ret; 1045 } 1046 1047 mptcp_nl_add_subflow_or_signal_addr(sock_net(skb->sk)); 1048 1049 return 0; 1050 } 1051 1052 static struct mptcp_pm_addr_entry * 1053 __lookup_addr_by_id(struct pm_nl_pernet *pernet, unsigned int id) 1054 { 1055 struct mptcp_pm_addr_entry *entry; 1056 1057 list_for_each_entry(entry, &pernet->local_addr_list, list) { 1058 if (entry->addr.id == id) 1059 return entry; 1060 } 1061 return NULL; 1062 } 1063 1064 static bool remove_anno_list_by_saddr(struct mptcp_sock *msk, 1065 struct mptcp_addr_info *addr) 1066 { 1067 struct mptcp_pm_add_entry *entry; 1068 1069 entry = mptcp_pm_del_add_timer(msk, addr); 1070 if (entry) { 1071 list_del(&entry->list); 1072 kfree(entry); 1073 return true; 1074 } 1075 1076 return false; 1077 } 1078 1079 static bool mptcp_pm_remove_anno_addr(struct mptcp_sock *msk, 1080 struct mptcp_addr_info *addr, 1081 bool force) 1082 { 1083 struct mptcp_rm_list list = { .nr = 0 }; 1084 bool ret; 1085 1086 list.ids[list.nr++] = addr->id; 1087 1088 ret = remove_anno_list_by_saddr(msk, addr); 1089 if (ret || force) { 1090 spin_lock_bh(&msk->pm.lock); 1091 mptcp_pm_remove_addr(msk, &list); 1092 spin_unlock_bh(&msk->pm.lock); 1093 } 1094 return ret; 1095 } 1096 1097 static int mptcp_nl_remove_subflow_and_signal_addr(struct net *net, 1098 struct mptcp_addr_info *addr) 1099 { 1100 struct mptcp_sock *msk; 1101 long s_slot = 0, s_num = 0; 1102 struct mptcp_rm_list list = { .nr = 0 }; 1103 1104 pr_debug("remove_id=%d", addr->id); 1105 1106 list.ids[list.nr++] = addr->id; 1107 1108 while ((msk = mptcp_token_iter_next(net, &s_slot, &s_num)) != NULL) { 1109 struct sock *sk = (struct sock *)msk; 1110 bool remove_subflow; 1111 1112 if (list_empty(&msk->conn_list)) { 1113 mptcp_pm_remove_anno_addr(msk, addr, false); 1114 goto next; 1115 } 1116 1117 lock_sock(sk); 1118 remove_subflow = lookup_subflow_by_saddr(&msk->conn_list, addr); 1119 mptcp_pm_remove_anno_addr(msk, addr, remove_subflow); 1120 if (remove_subflow) 1121 mptcp_pm_remove_subflow(msk, &list); 1122 release_sock(sk); 1123 1124 next: 1125 sock_put(sk); 1126 cond_resched(); 1127 } 1128 1129 return 0; 1130 } 1131 1132 struct addr_entry_release_work { 1133 struct rcu_work rwork; 1134 struct mptcp_pm_addr_entry *entry; 1135 }; 1136 1137 static void mptcp_pm_release_addr_entry(struct work_struct *work) 1138 { 1139 struct addr_entry_release_work *w; 1140 struct mptcp_pm_addr_entry *entry; 1141 1142 w = container_of(to_rcu_work(work), struct addr_entry_release_work, rwork); 1143 entry = w->entry; 1144 if (entry) { 1145 if (entry->lsk) 1146 sock_release(entry->lsk); 1147 kfree(entry); 1148 } 1149 kfree(w); 1150 } 1151 1152 static void mptcp_pm_free_addr_entry(struct mptcp_pm_addr_entry *entry) 1153 { 1154 struct addr_entry_release_work *w; 1155 1156 w = kmalloc(sizeof(*w), GFP_ATOMIC); 1157 if (w) { 1158 INIT_RCU_WORK(&w->rwork, mptcp_pm_release_addr_entry); 1159 w->entry = entry; 1160 queue_rcu_work(system_wq, &w->rwork); 1161 } 1162 } 1163 1164 static int mptcp_nl_cmd_del_addr(struct sk_buff *skb, struct genl_info *info) 1165 { 1166 struct nlattr *attr = info->attrs[MPTCP_PM_ATTR_ADDR]; 1167 struct pm_nl_pernet *pernet = genl_info_pm_nl(info); 1168 struct mptcp_pm_addr_entry addr, *entry; 1169 unsigned int addr_max; 1170 int ret; 1171 1172 ret = mptcp_pm_parse_addr(attr, info, false, &addr); 1173 if (ret < 0) 1174 return ret; 1175 1176 spin_lock_bh(&pernet->lock); 1177 entry = __lookup_addr_by_id(pernet, addr.addr.id); 1178 if (!entry) { 1179 GENL_SET_ERR_MSG(info, "address not found"); 1180 spin_unlock_bh(&pernet->lock); 1181 return -EINVAL; 1182 } 1183 if (entry->addr.flags & MPTCP_PM_ADDR_FLAG_SIGNAL) { 1184 addr_max = pernet->add_addr_signal_max; 1185 WRITE_ONCE(pernet->add_addr_signal_max, addr_max - 1); 1186 } 1187 if (entry->addr.flags & MPTCP_PM_ADDR_FLAG_SUBFLOW) { 1188 addr_max = pernet->local_addr_max; 1189 WRITE_ONCE(pernet->local_addr_max, addr_max - 1); 1190 } 1191 1192 pernet->addrs--; 1193 list_del_rcu(&entry->list); 1194 __clear_bit(entry->addr.id, pernet->id_bitmap); 1195 spin_unlock_bh(&pernet->lock); 1196 1197 mptcp_nl_remove_subflow_and_signal_addr(sock_net(skb->sk), &entry->addr); 1198 mptcp_pm_free_addr_entry(entry); 1199 1200 return ret; 1201 } 1202 1203 static void mptcp_pm_remove_addrs_and_subflows(struct mptcp_sock *msk, 1204 struct list_head *rm_list) 1205 { 1206 struct mptcp_rm_list alist = { .nr = 0 }, slist = { .nr = 0 }; 1207 struct mptcp_pm_addr_entry *entry; 1208 1209 list_for_each_entry(entry, rm_list, list) { 1210 if (lookup_subflow_by_saddr(&msk->conn_list, &entry->addr) && 1211 alist.nr < MPTCP_RM_IDS_MAX && 1212 slist.nr < MPTCP_RM_IDS_MAX) { 1213 alist.ids[alist.nr++] = entry->addr.id; 1214 slist.ids[slist.nr++] = entry->addr.id; 1215 } else if (remove_anno_list_by_saddr(msk, &entry->addr) && 1216 alist.nr < MPTCP_RM_IDS_MAX) { 1217 alist.ids[alist.nr++] = entry->addr.id; 1218 } 1219 } 1220 1221 if (alist.nr) { 1222 spin_lock_bh(&msk->pm.lock); 1223 mptcp_pm_remove_addr(msk, &alist); 1224 spin_unlock_bh(&msk->pm.lock); 1225 } 1226 if (slist.nr) 1227 mptcp_pm_remove_subflow(msk, &slist); 1228 } 1229 1230 static void mptcp_nl_remove_addrs_list(struct net *net, 1231 struct list_head *rm_list) 1232 { 1233 long s_slot = 0, s_num = 0; 1234 struct mptcp_sock *msk; 1235 1236 if (list_empty(rm_list)) 1237 return; 1238 1239 while ((msk = mptcp_token_iter_next(net, &s_slot, &s_num)) != NULL) { 1240 struct sock *sk = (struct sock *)msk; 1241 1242 lock_sock(sk); 1243 mptcp_pm_remove_addrs_and_subflows(msk, rm_list); 1244 release_sock(sk); 1245 1246 sock_put(sk); 1247 cond_resched(); 1248 } 1249 } 1250 1251 static void __flush_addrs(struct list_head *list) 1252 { 1253 while (!list_empty(list)) { 1254 struct mptcp_pm_addr_entry *cur; 1255 1256 cur = list_entry(list->next, 1257 struct mptcp_pm_addr_entry, list); 1258 list_del_rcu(&cur->list); 1259 mptcp_pm_free_addr_entry(cur); 1260 } 1261 } 1262 1263 static void __reset_counters(struct pm_nl_pernet *pernet) 1264 { 1265 WRITE_ONCE(pernet->add_addr_signal_max, 0); 1266 WRITE_ONCE(pernet->add_addr_accept_max, 0); 1267 WRITE_ONCE(pernet->local_addr_max, 0); 1268 pernet->addrs = 0; 1269 } 1270 1271 static int mptcp_nl_cmd_flush_addrs(struct sk_buff *skb, struct genl_info *info) 1272 { 1273 struct pm_nl_pernet *pernet = genl_info_pm_nl(info); 1274 LIST_HEAD(free_list); 1275 1276 spin_lock_bh(&pernet->lock); 1277 list_splice_init(&pernet->local_addr_list, &free_list); 1278 __reset_counters(pernet); 1279 pernet->next_id = 1; 1280 bitmap_zero(pernet->id_bitmap, MAX_ADDR_ID + 1); 1281 spin_unlock_bh(&pernet->lock); 1282 mptcp_nl_remove_addrs_list(sock_net(skb->sk), &free_list); 1283 __flush_addrs(&free_list); 1284 return 0; 1285 } 1286 1287 static int mptcp_nl_fill_addr(struct sk_buff *skb, 1288 struct mptcp_pm_addr_entry *entry) 1289 { 1290 struct mptcp_addr_info *addr = &entry->addr; 1291 struct nlattr *attr; 1292 1293 attr = nla_nest_start(skb, MPTCP_PM_ATTR_ADDR); 1294 if (!attr) 1295 return -EMSGSIZE; 1296 1297 if (nla_put_u16(skb, MPTCP_PM_ADDR_ATTR_FAMILY, addr->family)) 1298 goto nla_put_failure; 1299 if (nla_put_u16(skb, MPTCP_PM_ADDR_ATTR_PORT, ntohs(addr->port))) 1300 goto nla_put_failure; 1301 if (nla_put_u8(skb, MPTCP_PM_ADDR_ATTR_ID, addr->id)) 1302 goto nla_put_failure; 1303 if (nla_put_u32(skb, MPTCP_PM_ADDR_ATTR_FLAGS, entry->addr.flags)) 1304 goto nla_put_failure; 1305 if (entry->addr.ifindex && 1306 nla_put_s32(skb, MPTCP_PM_ADDR_ATTR_IF_IDX, entry->addr.ifindex)) 1307 goto nla_put_failure; 1308 1309 if (addr->family == AF_INET && 1310 nla_put_in_addr(skb, MPTCP_PM_ADDR_ATTR_ADDR4, 1311 addr->addr.s_addr)) 1312 goto nla_put_failure; 1313 #if IS_ENABLED(CONFIG_MPTCP_IPV6) 1314 else if (addr->family == AF_INET6 && 1315 nla_put_in6_addr(skb, MPTCP_PM_ADDR_ATTR_ADDR6, &addr->addr6)) 1316 goto nla_put_failure; 1317 #endif 1318 nla_nest_end(skb, attr); 1319 return 0; 1320 1321 nla_put_failure: 1322 nla_nest_cancel(skb, attr); 1323 return -EMSGSIZE; 1324 } 1325 1326 static int mptcp_nl_cmd_get_addr(struct sk_buff *skb, struct genl_info *info) 1327 { 1328 struct nlattr *attr = info->attrs[MPTCP_PM_ATTR_ADDR]; 1329 struct pm_nl_pernet *pernet = genl_info_pm_nl(info); 1330 struct mptcp_pm_addr_entry addr, *entry; 1331 struct sk_buff *msg; 1332 void *reply; 1333 int ret; 1334 1335 ret = mptcp_pm_parse_addr(attr, info, false, &addr); 1336 if (ret < 0) 1337 return ret; 1338 1339 msg = nlmsg_new(NLMSG_DEFAULT_SIZE, GFP_KERNEL); 1340 if (!msg) 1341 return -ENOMEM; 1342 1343 reply = genlmsg_put_reply(msg, info, &mptcp_genl_family, 0, 1344 info->genlhdr->cmd); 1345 if (!reply) { 1346 GENL_SET_ERR_MSG(info, "not enough space in Netlink message"); 1347 ret = -EMSGSIZE; 1348 goto fail; 1349 } 1350 1351 spin_lock_bh(&pernet->lock); 1352 entry = __lookup_addr_by_id(pernet, addr.addr.id); 1353 if (!entry) { 1354 GENL_SET_ERR_MSG(info, "address not found"); 1355 ret = -EINVAL; 1356 goto unlock_fail; 1357 } 1358 1359 ret = mptcp_nl_fill_addr(msg, entry); 1360 if (ret) 1361 goto unlock_fail; 1362 1363 genlmsg_end(msg, reply); 1364 ret = genlmsg_reply(msg, info); 1365 spin_unlock_bh(&pernet->lock); 1366 return ret; 1367 1368 unlock_fail: 1369 spin_unlock_bh(&pernet->lock); 1370 1371 fail: 1372 nlmsg_free(msg); 1373 return ret; 1374 } 1375 1376 static int mptcp_nl_cmd_dump_addrs(struct sk_buff *msg, 1377 struct netlink_callback *cb) 1378 { 1379 struct net *net = sock_net(msg->sk); 1380 struct mptcp_pm_addr_entry *entry; 1381 struct pm_nl_pernet *pernet; 1382 int id = cb->args[0]; 1383 void *hdr; 1384 int i; 1385 1386 pernet = net_generic(net, pm_nl_pernet_id); 1387 1388 spin_lock_bh(&pernet->lock); 1389 for (i = id; i < MAX_ADDR_ID + 1; i++) { 1390 if (test_bit(i, pernet->id_bitmap)) { 1391 entry = __lookup_addr_by_id(pernet, i); 1392 if (!entry) 1393 break; 1394 1395 if (entry->addr.id <= id) 1396 continue; 1397 1398 hdr = genlmsg_put(msg, NETLINK_CB(cb->skb).portid, 1399 cb->nlh->nlmsg_seq, &mptcp_genl_family, 1400 NLM_F_MULTI, MPTCP_PM_CMD_GET_ADDR); 1401 if (!hdr) 1402 break; 1403 1404 if (mptcp_nl_fill_addr(msg, entry) < 0) { 1405 genlmsg_cancel(msg, hdr); 1406 break; 1407 } 1408 1409 id = entry->addr.id; 1410 genlmsg_end(msg, hdr); 1411 } 1412 } 1413 spin_unlock_bh(&pernet->lock); 1414 1415 cb->args[0] = id; 1416 return msg->len; 1417 } 1418 1419 static int parse_limit(struct genl_info *info, int id, unsigned int *limit) 1420 { 1421 struct nlattr *attr = info->attrs[id]; 1422 1423 if (!attr) 1424 return 0; 1425 1426 *limit = nla_get_u32(attr); 1427 if (*limit > MPTCP_PM_ADDR_MAX) { 1428 GENL_SET_ERR_MSG(info, "limit greater than maximum"); 1429 return -EINVAL; 1430 } 1431 return 0; 1432 } 1433 1434 static int 1435 mptcp_nl_cmd_set_limits(struct sk_buff *skb, struct genl_info *info) 1436 { 1437 struct pm_nl_pernet *pernet = genl_info_pm_nl(info); 1438 unsigned int rcv_addrs, subflows; 1439 int ret; 1440 1441 spin_lock_bh(&pernet->lock); 1442 rcv_addrs = pernet->add_addr_accept_max; 1443 ret = parse_limit(info, MPTCP_PM_ATTR_RCV_ADD_ADDRS, &rcv_addrs); 1444 if (ret) 1445 goto unlock; 1446 1447 subflows = pernet->subflows_max; 1448 ret = parse_limit(info, MPTCP_PM_ATTR_SUBFLOWS, &subflows); 1449 if (ret) 1450 goto unlock; 1451 1452 WRITE_ONCE(pernet->add_addr_accept_max, rcv_addrs); 1453 WRITE_ONCE(pernet->subflows_max, subflows); 1454 1455 unlock: 1456 spin_unlock_bh(&pernet->lock); 1457 return ret; 1458 } 1459 1460 static int 1461 mptcp_nl_cmd_get_limits(struct sk_buff *skb, struct genl_info *info) 1462 { 1463 struct pm_nl_pernet *pernet = genl_info_pm_nl(info); 1464 struct sk_buff *msg; 1465 void *reply; 1466 1467 msg = nlmsg_new(NLMSG_DEFAULT_SIZE, GFP_KERNEL); 1468 if (!msg) 1469 return -ENOMEM; 1470 1471 reply = genlmsg_put_reply(msg, info, &mptcp_genl_family, 0, 1472 MPTCP_PM_CMD_GET_LIMITS); 1473 if (!reply) 1474 goto fail; 1475 1476 if (nla_put_u32(msg, MPTCP_PM_ATTR_RCV_ADD_ADDRS, 1477 READ_ONCE(pernet->add_addr_accept_max))) 1478 goto fail; 1479 1480 if (nla_put_u32(msg, MPTCP_PM_ATTR_SUBFLOWS, 1481 READ_ONCE(pernet->subflows_max))) 1482 goto fail; 1483 1484 genlmsg_end(msg, reply); 1485 return genlmsg_reply(msg, info); 1486 1487 fail: 1488 GENL_SET_ERR_MSG(info, "not enough space in Netlink message"); 1489 nlmsg_free(msg); 1490 return -EMSGSIZE; 1491 } 1492 1493 static int mptcp_nl_addr_backup(struct net *net, 1494 struct mptcp_addr_info *addr, 1495 u8 bkup) 1496 { 1497 long s_slot = 0, s_num = 0; 1498 struct mptcp_sock *msk; 1499 int ret = -EINVAL; 1500 1501 while ((msk = mptcp_token_iter_next(net, &s_slot, &s_num)) != NULL) { 1502 struct sock *sk = (struct sock *)msk; 1503 1504 if (list_empty(&msk->conn_list)) 1505 goto next; 1506 1507 lock_sock(sk); 1508 spin_lock_bh(&msk->pm.lock); 1509 ret = mptcp_pm_nl_mp_prio_send_ack(msk, addr, bkup); 1510 spin_unlock_bh(&msk->pm.lock); 1511 release_sock(sk); 1512 1513 next: 1514 sock_put(sk); 1515 cond_resched(); 1516 } 1517 1518 return ret; 1519 } 1520 1521 static int mptcp_nl_cmd_set_flags(struct sk_buff *skb, struct genl_info *info) 1522 { 1523 struct nlattr *attr = info->attrs[MPTCP_PM_ATTR_ADDR]; 1524 struct pm_nl_pernet *pernet = genl_info_pm_nl(info); 1525 struct mptcp_pm_addr_entry addr, *entry; 1526 struct net *net = sock_net(skb->sk); 1527 u8 bkup = 0; 1528 int ret; 1529 1530 ret = mptcp_pm_parse_addr(attr, info, true, &addr); 1531 if (ret < 0) 1532 return ret; 1533 1534 if (addr.addr.flags & MPTCP_PM_ADDR_FLAG_BACKUP) 1535 bkup = 1; 1536 1537 list_for_each_entry(entry, &pernet->local_addr_list, list) { 1538 if (addresses_equal(&entry->addr, &addr.addr, true)) { 1539 ret = mptcp_nl_addr_backup(net, &entry->addr, bkup); 1540 if (ret) 1541 return ret; 1542 1543 if (bkup) 1544 entry->addr.flags |= MPTCP_PM_ADDR_FLAG_BACKUP; 1545 else 1546 entry->addr.flags &= ~MPTCP_PM_ADDR_FLAG_BACKUP; 1547 } 1548 } 1549 1550 return 0; 1551 } 1552 1553 static void mptcp_nl_mcast_send(struct net *net, struct sk_buff *nlskb, gfp_t gfp) 1554 { 1555 genlmsg_multicast_netns(&mptcp_genl_family, net, 1556 nlskb, 0, MPTCP_PM_EV_GRP_OFFSET, gfp); 1557 } 1558 1559 static int mptcp_event_add_subflow(struct sk_buff *skb, const struct sock *ssk) 1560 { 1561 const struct inet_sock *issk = inet_sk(ssk); 1562 const struct mptcp_subflow_context *sf; 1563 1564 if (nla_put_u16(skb, MPTCP_ATTR_FAMILY, ssk->sk_family)) 1565 return -EMSGSIZE; 1566 1567 switch (ssk->sk_family) { 1568 case AF_INET: 1569 if (nla_put_in_addr(skb, MPTCP_ATTR_SADDR4, issk->inet_saddr)) 1570 return -EMSGSIZE; 1571 if (nla_put_in_addr(skb, MPTCP_ATTR_DADDR4, issk->inet_daddr)) 1572 return -EMSGSIZE; 1573 break; 1574 #if IS_ENABLED(CONFIG_MPTCP_IPV6) 1575 case AF_INET6: { 1576 const struct ipv6_pinfo *np = inet6_sk(ssk); 1577 1578 if (nla_put_in6_addr(skb, MPTCP_ATTR_SADDR6, &np->saddr)) 1579 return -EMSGSIZE; 1580 if (nla_put_in6_addr(skb, MPTCP_ATTR_DADDR6, &ssk->sk_v6_daddr)) 1581 return -EMSGSIZE; 1582 break; 1583 } 1584 #endif 1585 default: 1586 WARN_ON_ONCE(1); 1587 return -EMSGSIZE; 1588 } 1589 1590 if (nla_put_be16(skb, MPTCP_ATTR_SPORT, issk->inet_sport)) 1591 return -EMSGSIZE; 1592 if (nla_put_be16(skb, MPTCP_ATTR_DPORT, issk->inet_dport)) 1593 return -EMSGSIZE; 1594 1595 sf = mptcp_subflow_ctx(ssk); 1596 if (WARN_ON_ONCE(!sf)) 1597 return -EINVAL; 1598 1599 if (nla_put_u8(skb, MPTCP_ATTR_LOC_ID, sf->local_id)) 1600 return -EMSGSIZE; 1601 1602 if (nla_put_u8(skb, MPTCP_ATTR_REM_ID, sf->remote_id)) 1603 return -EMSGSIZE; 1604 1605 return 0; 1606 } 1607 1608 static int mptcp_event_put_token_and_ssk(struct sk_buff *skb, 1609 const struct mptcp_sock *msk, 1610 const struct sock *ssk) 1611 { 1612 const struct sock *sk = (const struct sock *)msk; 1613 const struct mptcp_subflow_context *sf; 1614 u8 sk_err; 1615 1616 if (nla_put_u32(skb, MPTCP_ATTR_TOKEN, msk->token)) 1617 return -EMSGSIZE; 1618 1619 if (mptcp_event_add_subflow(skb, ssk)) 1620 return -EMSGSIZE; 1621 1622 sf = mptcp_subflow_ctx(ssk); 1623 if (WARN_ON_ONCE(!sf)) 1624 return -EINVAL; 1625 1626 if (nla_put_u8(skb, MPTCP_ATTR_BACKUP, sf->backup)) 1627 return -EMSGSIZE; 1628 1629 if (ssk->sk_bound_dev_if && 1630 nla_put_s32(skb, MPTCP_ATTR_IF_IDX, ssk->sk_bound_dev_if)) 1631 return -EMSGSIZE; 1632 1633 sk_err = ssk->sk_err; 1634 if (sk_err && sk->sk_state == TCP_ESTABLISHED && 1635 nla_put_u8(skb, MPTCP_ATTR_ERROR, sk_err)) 1636 return -EMSGSIZE; 1637 1638 return 0; 1639 } 1640 1641 static int mptcp_event_sub_established(struct sk_buff *skb, 1642 const struct mptcp_sock *msk, 1643 const struct sock *ssk) 1644 { 1645 return mptcp_event_put_token_and_ssk(skb, msk, ssk); 1646 } 1647 1648 static int mptcp_event_sub_closed(struct sk_buff *skb, 1649 const struct mptcp_sock *msk, 1650 const struct sock *ssk) 1651 { 1652 if (mptcp_event_put_token_and_ssk(skb, msk, ssk)) 1653 return -EMSGSIZE; 1654 1655 return 0; 1656 } 1657 1658 static int mptcp_event_created(struct sk_buff *skb, 1659 const struct mptcp_sock *msk, 1660 const struct sock *ssk) 1661 { 1662 int err = nla_put_u32(skb, MPTCP_ATTR_TOKEN, msk->token); 1663 1664 if (err) 1665 return err; 1666 1667 return mptcp_event_add_subflow(skb, ssk); 1668 } 1669 1670 void mptcp_event_addr_removed(const struct mptcp_sock *msk, uint8_t id) 1671 { 1672 struct net *net = sock_net((const struct sock *)msk); 1673 struct nlmsghdr *nlh; 1674 struct sk_buff *skb; 1675 1676 if (!genl_has_listeners(&mptcp_genl_family, net, MPTCP_PM_EV_GRP_OFFSET)) 1677 return; 1678 1679 skb = nlmsg_new(NLMSG_DEFAULT_SIZE, GFP_ATOMIC); 1680 if (!skb) 1681 return; 1682 1683 nlh = genlmsg_put(skb, 0, 0, &mptcp_genl_family, 0, MPTCP_EVENT_REMOVED); 1684 if (!nlh) 1685 goto nla_put_failure; 1686 1687 if (nla_put_u32(skb, MPTCP_ATTR_TOKEN, msk->token)) 1688 goto nla_put_failure; 1689 1690 if (nla_put_u8(skb, MPTCP_ATTR_REM_ID, id)) 1691 goto nla_put_failure; 1692 1693 genlmsg_end(skb, nlh); 1694 mptcp_nl_mcast_send(net, skb, GFP_ATOMIC); 1695 return; 1696 1697 nla_put_failure: 1698 kfree_skb(skb); 1699 } 1700 1701 void mptcp_event_addr_announced(const struct mptcp_sock *msk, 1702 const struct mptcp_addr_info *info) 1703 { 1704 struct net *net = sock_net((const struct sock *)msk); 1705 struct nlmsghdr *nlh; 1706 struct sk_buff *skb; 1707 1708 if (!genl_has_listeners(&mptcp_genl_family, net, MPTCP_PM_EV_GRP_OFFSET)) 1709 return; 1710 1711 skb = nlmsg_new(NLMSG_DEFAULT_SIZE, GFP_ATOMIC); 1712 if (!skb) 1713 return; 1714 1715 nlh = genlmsg_put(skb, 0, 0, &mptcp_genl_family, 0, 1716 MPTCP_EVENT_ANNOUNCED); 1717 if (!nlh) 1718 goto nla_put_failure; 1719 1720 if (nla_put_u32(skb, MPTCP_ATTR_TOKEN, msk->token)) 1721 goto nla_put_failure; 1722 1723 if (nla_put_u8(skb, MPTCP_ATTR_REM_ID, info->id)) 1724 goto nla_put_failure; 1725 1726 if (nla_put_be16(skb, MPTCP_ATTR_DPORT, info->port)) 1727 goto nla_put_failure; 1728 1729 switch (info->family) { 1730 case AF_INET: 1731 if (nla_put_in_addr(skb, MPTCP_ATTR_DADDR4, info->addr.s_addr)) 1732 goto nla_put_failure; 1733 break; 1734 #if IS_ENABLED(CONFIG_MPTCP_IPV6) 1735 case AF_INET6: 1736 if (nla_put_in6_addr(skb, MPTCP_ATTR_DADDR6, &info->addr6)) 1737 goto nla_put_failure; 1738 break; 1739 #endif 1740 default: 1741 WARN_ON_ONCE(1); 1742 goto nla_put_failure; 1743 } 1744 1745 genlmsg_end(skb, nlh); 1746 mptcp_nl_mcast_send(net, skb, GFP_ATOMIC); 1747 return; 1748 1749 nla_put_failure: 1750 kfree_skb(skb); 1751 } 1752 1753 void mptcp_event(enum mptcp_event_type type, const struct mptcp_sock *msk, 1754 const struct sock *ssk, gfp_t gfp) 1755 { 1756 struct net *net = sock_net((const struct sock *)msk); 1757 struct nlmsghdr *nlh; 1758 struct sk_buff *skb; 1759 1760 if (!genl_has_listeners(&mptcp_genl_family, net, MPTCP_PM_EV_GRP_OFFSET)) 1761 return; 1762 1763 skb = nlmsg_new(NLMSG_DEFAULT_SIZE, gfp); 1764 if (!skb) 1765 return; 1766 1767 nlh = genlmsg_put(skb, 0, 0, &mptcp_genl_family, 0, type); 1768 if (!nlh) 1769 goto nla_put_failure; 1770 1771 switch (type) { 1772 case MPTCP_EVENT_UNSPEC: 1773 WARN_ON_ONCE(1); 1774 break; 1775 case MPTCP_EVENT_CREATED: 1776 case MPTCP_EVENT_ESTABLISHED: 1777 if (mptcp_event_created(skb, msk, ssk) < 0) 1778 goto nla_put_failure; 1779 break; 1780 case MPTCP_EVENT_CLOSED: 1781 if (nla_put_u32(skb, MPTCP_ATTR_TOKEN, msk->token) < 0) 1782 goto nla_put_failure; 1783 break; 1784 case MPTCP_EVENT_ANNOUNCED: 1785 case MPTCP_EVENT_REMOVED: 1786 /* call mptcp_event_addr_announced()/removed instead */ 1787 WARN_ON_ONCE(1); 1788 break; 1789 case MPTCP_EVENT_SUB_ESTABLISHED: 1790 case MPTCP_EVENT_SUB_PRIORITY: 1791 if (mptcp_event_sub_established(skb, msk, ssk) < 0) 1792 goto nla_put_failure; 1793 break; 1794 case MPTCP_EVENT_SUB_CLOSED: 1795 if (mptcp_event_sub_closed(skb, msk, ssk) < 0) 1796 goto nla_put_failure; 1797 break; 1798 } 1799 1800 genlmsg_end(skb, nlh); 1801 mptcp_nl_mcast_send(net, skb, gfp); 1802 return; 1803 1804 nla_put_failure: 1805 kfree_skb(skb); 1806 } 1807 1808 static const struct genl_small_ops mptcp_pm_ops[] = { 1809 { 1810 .cmd = MPTCP_PM_CMD_ADD_ADDR, 1811 .doit = mptcp_nl_cmd_add_addr, 1812 .flags = GENL_ADMIN_PERM, 1813 }, 1814 { 1815 .cmd = MPTCP_PM_CMD_DEL_ADDR, 1816 .doit = mptcp_nl_cmd_del_addr, 1817 .flags = GENL_ADMIN_PERM, 1818 }, 1819 { 1820 .cmd = MPTCP_PM_CMD_FLUSH_ADDRS, 1821 .doit = mptcp_nl_cmd_flush_addrs, 1822 .flags = GENL_ADMIN_PERM, 1823 }, 1824 { 1825 .cmd = MPTCP_PM_CMD_GET_ADDR, 1826 .doit = mptcp_nl_cmd_get_addr, 1827 .dumpit = mptcp_nl_cmd_dump_addrs, 1828 }, 1829 { 1830 .cmd = MPTCP_PM_CMD_SET_LIMITS, 1831 .doit = mptcp_nl_cmd_set_limits, 1832 .flags = GENL_ADMIN_PERM, 1833 }, 1834 { 1835 .cmd = MPTCP_PM_CMD_GET_LIMITS, 1836 .doit = mptcp_nl_cmd_get_limits, 1837 }, 1838 { 1839 .cmd = MPTCP_PM_CMD_SET_FLAGS, 1840 .doit = mptcp_nl_cmd_set_flags, 1841 .flags = GENL_ADMIN_PERM, 1842 }, 1843 }; 1844 1845 static struct genl_family mptcp_genl_family __ro_after_init = { 1846 .name = MPTCP_PM_NAME, 1847 .version = MPTCP_PM_VER, 1848 .maxattr = MPTCP_PM_ATTR_MAX, 1849 .policy = mptcp_pm_policy, 1850 .netnsok = true, 1851 .module = THIS_MODULE, 1852 .small_ops = mptcp_pm_ops, 1853 .n_small_ops = ARRAY_SIZE(mptcp_pm_ops), 1854 .mcgrps = mptcp_pm_mcgrps, 1855 .n_mcgrps = ARRAY_SIZE(mptcp_pm_mcgrps), 1856 }; 1857 1858 static int __net_init pm_nl_init_net(struct net *net) 1859 { 1860 struct pm_nl_pernet *pernet = net_generic(net, pm_nl_pernet_id); 1861 1862 INIT_LIST_HEAD_RCU(&pernet->local_addr_list); 1863 __reset_counters(pernet); 1864 pernet->next_id = 1; 1865 bitmap_zero(pernet->id_bitmap, MAX_ADDR_ID + 1); 1866 spin_lock_init(&pernet->lock); 1867 return 0; 1868 } 1869 1870 static void __net_exit pm_nl_exit_net(struct list_head *net_list) 1871 { 1872 struct net *net; 1873 1874 list_for_each_entry(net, net_list, exit_list) { 1875 struct pm_nl_pernet *pernet = net_generic(net, pm_nl_pernet_id); 1876 1877 /* net is removed from namespace list, can't race with 1878 * other modifiers 1879 */ 1880 __flush_addrs(&pernet->local_addr_list); 1881 } 1882 } 1883 1884 static struct pernet_operations mptcp_pm_pernet_ops = { 1885 .init = pm_nl_init_net, 1886 .exit_batch = pm_nl_exit_net, 1887 .id = &pm_nl_pernet_id, 1888 .size = sizeof(struct pm_nl_pernet), 1889 }; 1890 1891 void __init mptcp_pm_nl_init(void) 1892 { 1893 if (register_pernet_subsys(&mptcp_pm_pernet_ops) < 0) 1894 panic("Failed to register MPTCP PM pernet subsystem.\n"); 1895 1896 if (genl_register_family(&mptcp_genl_family)) 1897 panic("Failed to register MPTCP PM netlink family\n"); 1898 } 1899