1 // SPDX-License-Identifier: GPL-2.0-or-later 2 /* 3 * INET An implementation of the TCP/IP protocol suite for the LINUX 4 * operating system. INET is implemented using the BSD Socket 5 * interface as the means of communication with the user level. 6 * 7 * Generic INET transport hashtables 8 * 9 * Authors: Lotsa people, from code originally in tcp 10 */ 11 12 #include <linux/module.h> 13 #include <linux/random.h> 14 #include <linux/sched.h> 15 #include <linux/slab.h> 16 #include <linux/wait.h> 17 #include <linux/vmalloc.h> 18 #include <linux/memblock.h> 19 20 #include <net/addrconf.h> 21 #include <net/inet_connection_sock.h> 22 #include <net/inet_hashtables.h> 23 #if IS_ENABLED(CONFIG_IPV6) 24 #include <net/inet6_hashtables.h> 25 #endif 26 #include <net/secure_seq.h> 27 #include <net/ip.h> 28 #include <net/tcp.h> 29 #include <net/sock_reuseport.h> 30 31 static u32 inet_ehashfn(const struct net *net, const __be32 laddr, 32 const __u16 lport, const __be32 faddr, 33 const __be16 fport) 34 { 35 static u32 inet_ehash_secret __read_mostly; 36 37 net_get_random_once(&inet_ehash_secret, sizeof(inet_ehash_secret)); 38 39 return __inet_ehashfn(laddr, lport, faddr, fport, 40 inet_ehash_secret + net_hash_mix(net)); 41 } 42 43 /* This function handles inet_sock, but also timewait and request sockets 44 * for IPv4/IPv6. 45 */ 46 static u32 sk_ehashfn(const struct sock *sk) 47 { 48 #if IS_ENABLED(CONFIG_IPV6) 49 if (sk->sk_family == AF_INET6 && 50 !ipv6_addr_v4mapped(&sk->sk_v6_daddr)) 51 return inet6_ehashfn(sock_net(sk), 52 &sk->sk_v6_rcv_saddr, sk->sk_num, 53 &sk->sk_v6_daddr, sk->sk_dport); 54 #endif 55 return inet_ehashfn(sock_net(sk), 56 sk->sk_rcv_saddr, sk->sk_num, 57 sk->sk_daddr, sk->sk_dport); 58 } 59 60 /* 61 * Allocate and initialize a new local port bind bucket. 62 * The bindhash mutex for snum's hash chain must be held here. 63 */ 64 struct inet_bind_bucket *inet_bind_bucket_create(struct kmem_cache *cachep, 65 struct net *net, 66 struct inet_bind_hashbucket *head, 67 const unsigned short snum, 68 int l3mdev) 69 { 70 struct inet_bind_bucket *tb = kmem_cache_alloc(cachep, GFP_ATOMIC); 71 72 if (tb) { 73 write_pnet(&tb->ib_net, net); 74 tb->l3mdev = l3mdev; 75 tb->port = snum; 76 tb->fastreuse = 0; 77 tb->fastreuseport = 0; 78 INIT_HLIST_HEAD(&tb->owners); 79 hlist_add_head(&tb->node, &head->chain); 80 } 81 return tb; 82 } 83 84 /* 85 * Caller must hold hashbucket lock for this tb with local BH disabled 86 */ 87 void inet_bind_bucket_destroy(struct kmem_cache *cachep, struct inet_bind_bucket *tb) 88 { 89 if (hlist_empty(&tb->owners)) { 90 __hlist_del(&tb->node); 91 kmem_cache_free(cachep, tb); 92 } 93 } 94 95 bool inet_bind_bucket_match(const struct inet_bind_bucket *tb, const struct net *net, 96 unsigned short port, int l3mdev) 97 { 98 return net_eq(ib_net(tb), net) && tb->port == port && 99 tb->l3mdev == l3mdev; 100 } 101 102 static void inet_bind2_bucket_init(struct inet_bind2_bucket *tb, 103 struct net *net, 104 struct inet_bind_hashbucket *head, 105 unsigned short port, int l3mdev, 106 const struct sock *sk) 107 { 108 write_pnet(&tb->ib_net, net); 109 tb->l3mdev = l3mdev; 110 tb->port = port; 111 #if IS_ENABLED(CONFIG_IPV6) 112 tb->family = sk->sk_family; 113 if (sk->sk_family == AF_INET6) 114 tb->v6_rcv_saddr = sk->sk_v6_rcv_saddr; 115 else 116 #endif 117 tb->rcv_saddr = sk->sk_rcv_saddr; 118 INIT_HLIST_HEAD(&tb->owners); 119 INIT_HLIST_HEAD(&tb->deathrow); 120 hlist_add_head(&tb->node, &head->chain); 121 } 122 123 struct inet_bind2_bucket *inet_bind2_bucket_create(struct kmem_cache *cachep, 124 struct net *net, 125 struct inet_bind_hashbucket *head, 126 unsigned short port, 127 int l3mdev, 128 const struct sock *sk) 129 { 130 struct inet_bind2_bucket *tb = kmem_cache_alloc(cachep, GFP_ATOMIC); 131 132 if (tb) 133 inet_bind2_bucket_init(tb, net, head, port, l3mdev, sk); 134 135 return tb; 136 } 137 138 /* Caller must hold hashbucket lock for this tb with local BH disabled */ 139 void inet_bind2_bucket_destroy(struct kmem_cache *cachep, struct inet_bind2_bucket *tb) 140 { 141 if (hlist_empty(&tb->owners) && hlist_empty(&tb->deathrow)) { 142 __hlist_del(&tb->node); 143 kmem_cache_free(cachep, tb); 144 } 145 } 146 147 static bool inet_bind2_bucket_addr_match(const struct inet_bind2_bucket *tb2, 148 const struct sock *sk) 149 { 150 #if IS_ENABLED(CONFIG_IPV6) 151 if (sk->sk_family != tb2->family) 152 return false; 153 154 if (sk->sk_family == AF_INET6) 155 return ipv6_addr_equal(&tb2->v6_rcv_saddr, 156 &sk->sk_v6_rcv_saddr); 157 #endif 158 return tb2->rcv_saddr == sk->sk_rcv_saddr; 159 } 160 161 void inet_bind_hash(struct sock *sk, struct inet_bind_bucket *tb, 162 struct inet_bind2_bucket *tb2, unsigned short port) 163 { 164 inet_sk(sk)->inet_num = port; 165 sk_add_bind_node(sk, &tb->owners); 166 inet_csk(sk)->icsk_bind_hash = tb; 167 sk_add_bind2_node(sk, &tb2->owners); 168 inet_csk(sk)->icsk_bind2_hash = tb2; 169 } 170 171 /* 172 * Get rid of any references to a local port held by the given sock. 173 */ 174 static void __inet_put_port(struct sock *sk) 175 { 176 struct inet_hashinfo *hashinfo = tcp_or_dccp_get_hashinfo(sk); 177 struct inet_bind_hashbucket *head, *head2; 178 struct net *net = sock_net(sk); 179 struct inet_bind_bucket *tb; 180 int bhash; 181 182 bhash = inet_bhashfn(net, inet_sk(sk)->inet_num, hashinfo->bhash_size); 183 head = &hashinfo->bhash[bhash]; 184 head2 = inet_bhashfn_portaddr(hashinfo, sk, net, inet_sk(sk)->inet_num); 185 186 spin_lock(&head->lock); 187 tb = inet_csk(sk)->icsk_bind_hash; 188 __sk_del_bind_node(sk); 189 inet_csk(sk)->icsk_bind_hash = NULL; 190 inet_sk(sk)->inet_num = 0; 191 inet_bind_bucket_destroy(hashinfo->bind_bucket_cachep, tb); 192 193 spin_lock(&head2->lock); 194 if (inet_csk(sk)->icsk_bind2_hash) { 195 struct inet_bind2_bucket *tb2 = inet_csk(sk)->icsk_bind2_hash; 196 197 __sk_del_bind2_node(sk); 198 inet_csk(sk)->icsk_bind2_hash = NULL; 199 inet_bind2_bucket_destroy(hashinfo->bind2_bucket_cachep, tb2); 200 } 201 spin_unlock(&head2->lock); 202 203 spin_unlock(&head->lock); 204 } 205 206 void inet_put_port(struct sock *sk) 207 { 208 local_bh_disable(); 209 __inet_put_port(sk); 210 local_bh_enable(); 211 } 212 EXPORT_SYMBOL(inet_put_port); 213 214 int __inet_inherit_port(const struct sock *sk, struct sock *child) 215 { 216 struct inet_hashinfo *table = tcp_or_dccp_get_hashinfo(sk); 217 unsigned short port = inet_sk(child)->inet_num; 218 struct inet_bind_hashbucket *head, *head2; 219 bool created_inet_bind_bucket = false; 220 struct net *net = sock_net(sk); 221 bool update_fastreuse = false; 222 struct inet_bind2_bucket *tb2; 223 struct inet_bind_bucket *tb; 224 int bhash, l3mdev; 225 226 bhash = inet_bhashfn(net, port, table->bhash_size); 227 head = &table->bhash[bhash]; 228 head2 = inet_bhashfn_portaddr(table, child, net, port); 229 230 spin_lock(&head->lock); 231 spin_lock(&head2->lock); 232 tb = inet_csk(sk)->icsk_bind_hash; 233 tb2 = inet_csk(sk)->icsk_bind2_hash; 234 if (unlikely(!tb || !tb2)) { 235 spin_unlock(&head2->lock); 236 spin_unlock(&head->lock); 237 return -ENOENT; 238 } 239 if (tb->port != port) { 240 l3mdev = inet_sk_bound_l3mdev(sk); 241 242 /* NOTE: using tproxy and redirecting skbs to a proxy 243 * on a different listener port breaks the assumption 244 * that the listener socket's icsk_bind_hash is the same 245 * as that of the child socket. We have to look up or 246 * create a new bind bucket for the child here. */ 247 inet_bind_bucket_for_each(tb, &head->chain) { 248 if (inet_bind_bucket_match(tb, net, port, l3mdev)) 249 break; 250 } 251 if (!tb) { 252 tb = inet_bind_bucket_create(table->bind_bucket_cachep, 253 net, head, port, l3mdev); 254 if (!tb) { 255 spin_unlock(&head2->lock); 256 spin_unlock(&head->lock); 257 return -ENOMEM; 258 } 259 created_inet_bind_bucket = true; 260 } 261 update_fastreuse = true; 262 263 goto bhash2_find; 264 } else if (!inet_bind2_bucket_addr_match(tb2, child)) { 265 l3mdev = inet_sk_bound_l3mdev(sk); 266 267 bhash2_find: 268 tb2 = inet_bind2_bucket_find(head2, net, port, l3mdev, child); 269 if (!tb2) { 270 tb2 = inet_bind2_bucket_create(table->bind2_bucket_cachep, 271 net, head2, port, 272 l3mdev, child); 273 if (!tb2) 274 goto error; 275 } 276 } 277 if (update_fastreuse) 278 inet_csk_update_fastreuse(tb, child); 279 inet_bind_hash(child, tb, tb2, port); 280 spin_unlock(&head2->lock); 281 spin_unlock(&head->lock); 282 283 return 0; 284 285 error: 286 if (created_inet_bind_bucket) 287 inet_bind_bucket_destroy(table->bind_bucket_cachep, tb); 288 spin_unlock(&head2->lock); 289 spin_unlock(&head->lock); 290 return -ENOMEM; 291 } 292 EXPORT_SYMBOL_GPL(__inet_inherit_port); 293 294 static struct inet_listen_hashbucket * 295 inet_lhash2_bucket_sk(struct inet_hashinfo *h, struct sock *sk) 296 { 297 u32 hash; 298 299 #if IS_ENABLED(CONFIG_IPV6) 300 if (sk->sk_family == AF_INET6) 301 hash = ipv6_portaddr_hash(sock_net(sk), 302 &sk->sk_v6_rcv_saddr, 303 inet_sk(sk)->inet_num); 304 else 305 #endif 306 hash = ipv4_portaddr_hash(sock_net(sk), 307 inet_sk(sk)->inet_rcv_saddr, 308 inet_sk(sk)->inet_num); 309 return inet_lhash2_bucket(h, hash); 310 } 311 312 static inline int compute_score(struct sock *sk, struct net *net, 313 const unsigned short hnum, const __be32 daddr, 314 const int dif, const int sdif) 315 { 316 int score = -1; 317 318 if (net_eq(sock_net(sk), net) && sk->sk_num == hnum && 319 !ipv6_only_sock(sk)) { 320 if (sk->sk_rcv_saddr != daddr) 321 return -1; 322 323 if (!inet_sk_bound_dev_eq(net, sk->sk_bound_dev_if, dif, sdif)) 324 return -1; 325 score = sk->sk_bound_dev_if ? 2 : 1; 326 327 if (sk->sk_family == PF_INET) 328 score++; 329 if (READ_ONCE(sk->sk_incoming_cpu) == raw_smp_processor_id()) 330 score++; 331 } 332 return score; 333 } 334 335 static inline struct sock *lookup_reuseport(struct net *net, struct sock *sk, 336 struct sk_buff *skb, int doff, 337 __be32 saddr, __be16 sport, 338 __be32 daddr, unsigned short hnum) 339 { 340 struct sock *reuse_sk = NULL; 341 u32 phash; 342 343 if (sk->sk_reuseport) { 344 phash = inet_ehashfn(net, daddr, hnum, saddr, sport); 345 reuse_sk = reuseport_select_sock(sk, phash, skb, doff); 346 } 347 return reuse_sk; 348 } 349 350 /* 351 * Here are some nice properties to exploit here. The BSD API 352 * does not allow a listening sock to specify the remote port nor the 353 * remote address for the connection. So always assume those are both 354 * wildcarded during the search since they can never be otherwise. 355 */ 356 357 /* called with rcu_read_lock() : No refcount taken on the socket */ 358 static struct sock *inet_lhash2_lookup(struct net *net, 359 struct inet_listen_hashbucket *ilb2, 360 struct sk_buff *skb, int doff, 361 const __be32 saddr, __be16 sport, 362 const __be32 daddr, const unsigned short hnum, 363 const int dif, const int sdif) 364 { 365 struct sock *sk, *result = NULL; 366 struct hlist_nulls_node *node; 367 int score, hiscore = 0; 368 369 sk_nulls_for_each_rcu(sk, node, &ilb2->nulls_head) { 370 score = compute_score(sk, net, hnum, daddr, dif, sdif); 371 if (score > hiscore) { 372 result = lookup_reuseport(net, sk, skb, doff, 373 saddr, sport, daddr, hnum); 374 if (result) 375 return result; 376 377 result = sk; 378 hiscore = score; 379 } 380 } 381 382 return result; 383 } 384 385 static inline struct sock *inet_lookup_run_bpf(struct net *net, 386 struct inet_hashinfo *hashinfo, 387 struct sk_buff *skb, int doff, 388 __be32 saddr, __be16 sport, 389 __be32 daddr, u16 hnum, const int dif) 390 { 391 struct sock *sk, *reuse_sk; 392 bool no_reuseport; 393 394 if (hashinfo != net->ipv4.tcp_death_row.hashinfo) 395 return NULL; /* only TCP is supported */ 396 397 no_reuseport = bpf_sk_lookup_run_v4(net, IPPROTO_TCP, saddr, sport, 398 daddr, hnum, dif, &sk); 399 if (no_reuseport || IS_ERR_OR_NULL(sk)) 400 return sk; 401 402 reuse_sk = lookup_reuseport(net, sk, skb, doff, saddr, sport, daddr, hnum); 403 if (reuse_sk) 404 sk = reuse_sk; 405 return sk; 406 } 407 408 struct sock *__inet_lookup_listener(struct net *net, 409 struct inet_hashinfo *hashinfo, 410 struct sk_buff *skb, int doff, 411 const __be32 saddr, __be16 sport, 412 const __be32 daddr, const unsigned short hnum, 413 const int dif, const int sdif) 414 { 415 struct inet_listen_hashbucket *ilb2; 416 struct sock *result = NULL; 417 unsigned int hash2; 418 419 /* Lookup redirect from BPF */ 420 if (static_branch_unlikely(&bpf_sk_lookup_enabled)) { 421 result = inet_lookup_run_bpf(net, hashinfo, skb, doff, 422 saddr, sport, daddr, hnum, dif); 423 if (result) 424 goto done; 425 } 426 427 hash2 = ipv4_portaddr_hash(net, daddr, hnum); 428 ilb2 = inet_lhash2_bucket(hashinfo, hash2); 429 430 result = inet_lhash2_lookup(net, ilb2, skb, doff, 431 saddr, sport, daddr, hnum, 432 dif, sdif); 433 if (result) 434 goto done; 435 436 /* Lookup lhash2 with INADDR_ANY */ 437 hash2 = ipv4_portaddr_hash(net, htonl(INADDR_ANY), hnum); 438 ilb2 = inet_lhash2_bucket(hashinfo, hash2); 439 440 result = inet_lhash2_lookup(net, ilb2, skb, doff, 441 saddr, sport, htonl(INADDR_ANY), hnum, 442 dif, sdif); 443 done: 444 if (IS_ERR(result)) 445 return NULL; 446 return result; 447 } 448 EXPORT_SYMBOL_GPL(__inet_lookup_listener); 449 450 /* All sockets share common refcount, but have different destructors */ 451 void sock_gen_put(struct sock *sk) 452 { 453 if (!refcount_dec_and_test(&sk->sk_refcnt)) 454 return; 455 456 if (sk->sk_state == TCP_TIME_WAIT) 457 inet_twsk_free(inet_twsk(sk)); 458 else if (sk->sk_state == TCP_NEW_SYN_RECV) 459 reqsk_free(inet_reqsk(sk)); 460 else 461 sk_free(sk); 462 } 463 EXPORT_SYMBOL_GPL(sock_gen_put); 464 465 void sock_edemux(struct sk_buff *skb) 466 { 467 sock_gen_put(skb->sk); 468 } 469 EXPORT_SYMBOL(sock_edemux); 470 471 struct sock *__inet_lookup_established(struct net *net, 472 struct inet_hashinfo *hashinfo, 473 const __be32 saddr, const __be16 sport, 474 const __be32 daddr, const u16 hnum, 475 const int dif, const int sdif) 476 { 477 INET_ADDR_COOKIE(acookie, saddr, daddr); 478 const __portpair ports = INET_COMBINED_PORTS(sport, hnum); 479 struct sock *sk; 480 const struct hlist_nulls_node *node; 481 /* Optimize here for direct hit, only listening connections can 482 * have wildcards anyways. 483 */ 484 unsigned int hash = inet_ehashfn(net, daddr, hnum, saddr, sport); 485 unsigned int slot = hash & hashinfo->ehash_mask; 486 struct inet_ehash_bucket *head = &hashinfo->ehash[slot]; 487 488 begin: 489 sk_nulls_for_each_rcu(sk, node, &head->chain) { 490 if (sk->sk_hash != hash) 491 continue; 492 if (likely(inet_match(net, sk, acookie, ports, dif, sdif))) { 493 if (unlikely(!refcount_inc_not_zero(&sk->sk_refcnt))) 494 goto out; 495 if (unlikely(!inet_match(net, sk, acookie, 496 ports, dif, sdif))) { 497 sock_gen_put(sk); 498 goto begin; 499 } 500 goto found; 501 } 502 } 503 /* 504 * if the nulls value we got at the end of this lookup is 505 * not the expected one, we must restart lookup. 506 * We probably met an item that was moved to another chain. 507 */ 508 if (get_nulls_value(node) != slot) 509 goto begin; 510 out: 511 sk = NULL; 512 found: 513 return sk; 514 } 515 EXPORT_SYMBOL_GPL(__inet_lookup_established); 516 517 /* called with local bh disabled */ 518 static int __inet_check_established(struct inet_timewait_death_row *death_row, 519 struct sock *sk, __u16 lport, 520 struct inet_timewait_sock **twp) 521 { 522 struct inet_hashinfo *hinfo = death_row->hashinfo; 523 struct inet_sock *inet = inet_sk(sk); 524 __be32 daddr = inet->inet_rcv_saddr; 525 __be32 saddr = inet->inet_daddr; 526 int dif = sk->sk_bound_dev_if; 527 struct net *net = sock_net(sk); 528 int sdif = l3mdev_master_ifindex_by_index(net, dif); 529 INET_ADDR_COOKIE(acookie, saddr, daddr); 530 const __portpair ports = INET_COMBINED_PORTS(inet->inet_dport, lport); 531 unsigned int hash = inet_ehashfn(net, daddr, lport, 532 saddr, inet->inet_dport); 533 struct inet_ehash_bucket *head = inet_ehash_bucket(hinfo, hash); 534 spinlock_t *lock = inet_ehash_lockp(hinfo, hash); 535 struct sock *sk2; 536 const struct hlist_nulls_node *node; 537 struct inet_timewait_sock *tw = NULL; 538 539 spin_lock(lock); 540 541 sk_nulls_for_each(sk2, node, &head->chain) { 542 if (sk2->sk_hash != hash) 543 continue; 544 545 if (likely(inet_match(net, sk2, acookie, ports, dif, sdif))) { 546 if (sk2->sk_state == TCP_TIME_WAIT) { 547 tw = inet_twsk(sk2); 548 if (twsk_unique(sk, sk2, twp)) 549 break; 550 } 551 goto not_unique; 552 } 553 } 554 555 /* Must record num and sport now. Otherwise we will see 556 * in hash table socket with a funny identity. 557 */ 558 inet->inet_num = lport; 559 inet->inet_sport = htons(lport); 560 sk->sk_hash = hash; 561 WARN_ON(!sk_unhashed(sk)); 562 __sk_nulls_add_node_rcu(sk, &head->chain); 563 if (tw) { 564 sk_nulls_del_node_init_rcu((struct sock *)tw); 565 __NET_INC_STATS(net, LINUX_MIB_TIMEWAITRECYCLED); 566 } 567 spin_unlock(lock); 568 sock_prot_inuse_add(sock_net(sk), sk->sk_prot, 1); 569 570 if (twp) { 571 *twp = tw; 572 } else if (tw) { 573 /* Silly. Should hash-dance instead... */ 574 inet_twsk_deschedule_put(tw); 575 } 576 return 0; 577 578 not_unique: 579 spin_unlock(lock); 580 return -EADDRNOTAVAIL; 581 } 582 583 static u64 inet_sk_port_offset(const struct sock *sk) 584 { 585 const struct inet_sock *inet = inet_sk(sk); 586 587 return secure_ipv4_port_ephemeral(inet->inet_rcv_saddr, 588 inet->inet_daddr, 589 inet->inet_dport); 590 } 591 592 /* Searches for an exsiting socket in the ehash bucket list. 593 * Returns true if found, false otherwise. 594 */ 595 static bool inet_ehash_lookup_by_sk(struct sock *sk, 596 struct hlist_nulls_head *list) 597 { 598 const __portpair ports = INET_COMBINED_PORTS(sk->sk_dport, sk->sk_num); 599 const int sdif = sk->sk_bound_dev_if; 600 const int dif = sk->sk_bound_dev_if; 601 const struct hlist_nulls_node *node; 602 struct net *net = sock_net(sk); 603 struct sock *esk; 604 605 INET_ADDR_COOKIE(acookie, sk->sk_daddr, sk->sk_rcv_saddr); 606 607 sk_nulls_for_each_rcu(esk, node, list) { 608 if (esk->sk_hash != sk->sk_hash) 609 continue; 610 if (sk->sk_family == AF_INET) { 611 if (unlikely(inet_match(net, esk, acookie, 612 ports, dif, sdif))) { 613 return true; 614 } 615 } 616 #if IS_ENABLED(CONFIG_IPV6) 617 else if (sk->sk_family == AF_INET6) { 618 if (unlikely(inet6_match(net, esk, 619 &sk->sk_v6_daddr, 620 &sk->sk_v6_rcv_saddr, 621 ports, dif, sdif))) { 622 return true; 623 } 624 } 625 #endif 626 } 627 return false; 628 } 629 630 /* Insert a socket into ehash, and eventually remove another one 631 * (The another one can be a SYN_RECV or TIMEWAIT) 632 * If an existing socket already exists, socket sk is not inserted, 633 * and sets found_dup_sk parameter to true. 634 */ 635 bool inet_ehash_insert(struct sock *sk, struct sock *osk, bool *found_dup_sk) 636 { 637 struct inet_hashinfo *hashinfo = tcp_or_dccp_get_hashinfo(sk); 638 struct inet_ehash_bucket *head; 639 struct hlist_nulls_head *list; 640 spinlock_t *lock; 641 bool ret = true; 642 643 WARN_ON_ONCE(!sk_unhashed(sk)); 644 645 sk->sk_hash = sk_ehashfn(sk); 646 head = inet_ehash_bucket(hashinfo, sk->sk_hash); 647 list = &head->chain; 648 lock = inet_ehash_lockp(hashinfo, sk->sk_hash); 649 650 spin_lock(lock); 651 if (osk) { 652 WARN_ON_ONCE(sk->sk_hash != osk->sk_hash); 653 ret = sk_hashed(osk); 654 if (ret) { 655 /* Before deleting the node, we insert a new one to make 656 * sure that the look-up-sk process would not miss either 657 * of them and that at least one node would exist in ehash 658 * table all the time. Otherwise there's a tiny chance 659 * that lookup process could find nothing in ehash table. 660 */ 661 __sk_nulls_add_node_tail_rcu(sk, list); 662 sk_nulls_del_node_init_rcu(osk); 663 } 664 goto unlock; 665 } 666 if (found_dup_sk) { 667 *found_dup_sk = inet_ehash_lookup_by_sk(sk, list); 668 if (*found_dup_sk) 669 ret = false; 670 } 671 672 if (ret) 673 __sk_nulls_add_node_rcu(sk, list); 674 675 unlock: 676 spin_unlock(lock); 677 678 return ret; 679 } 680 681 bool inet_ehash_nolisten(struct sock *sk, struct sock *osk, bool *found_dup_sk) 682 { 683 bool ok = inet_ehash_insert(sk, osk, found_dup_sk); 684 685 if (ok) { 686 sock_prot_inuse_add(sock_net(sk), sk->sk_prot, 1); 687 } else { 688 this_cpu_inc(*sk->sk_prot->orphan_count); 689 inet_sk_set_state(sk, TCP_CLOSE); 690 sock_set_flag(sk, SOCK_DEAD); 691 inet_csk_destroy_sock(sk); 692 } 693 return ok; 694 } 695 EXPORT_SYMBOL_GPL(inet_ehash_nolisten); 696 697 static int inet_reuseport_add_sock(struct sock *sk, 698 struct inet_listen_hashbucket *ilb) 699 { 700 struct inet_bind_bucket *tb = inet_csk(sk)->icsk_bind_hash; 701 const struct hlist_nulls_node *node; 702 struct sock *sk2; 703 kuid_t uid = sock_i_uid(sk); 704 705 sk_nulls_for_each_rcu(sk2, node, &ilb->nulls_head) { 706 if (sk2 != sk && 707 sk2->sk_family == sk->sk_family && 708 ipv6_only_sock(sk2) == ipv6_only_sock(sk) && 709 sk2->sk_bound_dev_if == sk->sk_bound_dev_if && 710 inet_csk(sk2)->icsk_bind_hash == tb && 711 sk2->sk_reuseport && uid_eq(uid, sock_i_uid(sk2)) && 712 inet_rcv_saddr_equal(sk, sk2, false)) 713 return reuseport_add_sock(sk, sk2, 714 inet_rcv_saddr_any(sk)); 715 } 716 717 return reuseport_alloc(sk, inet_rcv_saddr_any(sk)); 718 } 719 720 int __inet_hash(struct sock *sk, struct sock *osk) 721 { 722 struct inet_hashinfo *hashinfo = tcp_or_dccp_get_hashinfo(sk); 723 struct inet_listen_hashbucket *ilb2; 724 int err = 0; 725 726 if (sk->sk_state != TCP_LISTEN) { 727 local_bh_disable(); 728 inet_ehash_nolisten(sk, osk, NULL); 729 local_bh_enable(); 730 return 0; 731 } 732 WARN_ON(!sk_unhashed(sk)); 733 ilb2 = inet_lhash2_bucket_sk(hashinfo, sk); 734 735 spin_lock(&ilb2->lock); 736 if (sk->sk_reuseport) { 737 err = inet_reuseport_add_sock(sk, ilb2); 738 if (err) 739 goto unlock; 740 } 741 if (IS_ENABLED(CONFIG_IPV6) && sk->sk_reuseport && 742 sk->sk_family == AF_INET6) 743 __sk_nulls_add_node_tail_rcu(sk, &ilb2->nulls_head); 744 else 745 __sk_nulls_add_node_rcu(sk, &ilb2->nulls_head); 746 sock_set_flag(sk, SOCK_RCU_FREE); 747 sock_prot_inuse_add(sock_net(sk), sk->sk_prot, 1); 748 unlock: 749 spin_unlock(&ilb2->lock); 750 751 return err; 752 } 753 EXPORT_SYMBOL(__inet_hash); 754 755 int inet_hash(struct sock *sk) 756 { 757 int err = 0; 758 759 if (sk->sk_state != TCP_CLOSE) 760 err = __inet_hash(sk, NULL); 761 762 return err; 763 } 764 EXPORT_SYMBOL_GPL(inet_hash); 765 766 void inet_unhash(struct sock *sk) 767 { 768 struct inet_hashinfo *hashinfo = tcp_or_dccp_get_hashinfo(sk); 769 770 if (sk_unhashed(sk)) 771 return; 772 773 if (sk->sk_state == TCP_LISTEN) { 774 struct inet_listen_hashbucket *ilb2; 775 776 ilb2 = inet_lhash2_bucket_sk(hashinfo, sk); 777 /* Don't disable bottom halves while acquiring the lock to 778 * avoid circular locking dependency on PREEMPT_RT. 779 */ 780 spin_lock(&ilb2->lock); 781 if (sk_unhashed(sk)) { 782 spin_unlock(&ilb2->lock); 783 return; 784 } 785 786 if (rcu_access_pointer(sk->sk_reuseport_cb)) 787 reuseport_stop_listen_sock(sk); 788 789 __sk_nulls_del_node_init_rcu(sk); 790 sock_prot_inuse_add(sock_net(sk), sk->sk_prot, -1); 791 spin_unlock(&ilb2->lock); 792 } else { 793 spinlock_t *lock = inet_ehash_lockp(hashinfo, sk->sk_hash); 794 795 spin_lock_bh(lock); 796 if (sk_unhashed(sk)) { 797 spin_unlock_bh(lock); 798 return; 799 } 800 __sk_nulls_del_node_init_rcu(sk); 801 sock_prot_inuse_add(sock_net(sk), sk->sk_prot, -1); 802 spin_unlock_bh(lock); 803 } 804 } 805 EXPORT_SYMBOL_GPL(inet_unhash); 806 807 static bool inet_bind2_bucket_match(const struct inet_bind2_bucket *tb, 808 const struct net *net, unsigned short port, 809 int l3mdev, const struct sock *sk) 810 { 811 #if IS_ENABLED(CONFIG_IPV6) 812 if (sk->sk_family != tb->family) 813 return false; 814 815 if (sk->sk_family == AF_INET6) 816 return net_eq(ib2_net(tb), net) && tb->port == port && 817 tb->l3mdev == l3mdev && 818 ipv6_addr_equal(&tb->v6_rcv_saddr, &sk->sk_v6_rcv_saddr); 819 else 820 #endif 821 return net_eq(ib2_net(tb), net) && tb->port == port && 822 tb->l3mdev == l3mdev && tb->rcv_saddr == sk->sk_rcv_saddr; 823 } 824 825 bool inet_bind2_bucket_match_addr_any(const struct inet_bind2_bucket *tb, const struct net *net, 826 unsigned short port, int l3mdev, const struct sock *sk) 827 { 828 #if IS_ENABLED(CONFIG_IPV6) 829 struct in6_addr addr_any = {}; 830 831 if (sk->sk_family != tb->family) { 832 if (sk->sk_family == AF_INET) 833 return net_eq(ib2_net(tb), net) && tb->port == port && 834 tb->l3mdev == l3mdev && 835 ipv6_addr_equal(&tb->v6_rcv_saddr, &addr_any); 836 837 return false; 838 } 839 840 if (sk->sk_family == AF_INET6) 841 return net_eq(ib2_net(tb), net) && tb->port == port && 842 tb->l3mdev == l3mdev && 843 ipv6_addr_equal(&tb->v6_rcv_saddr, &addr_any); 844 else 845 #endif 846 return net_eq(ib2_net(tb), net) && tb->port == port && 847 tb->l3mdev == l3mdev && tb->rcv_saddr == 0; 848 } 849 850 /* The socket's bhash2 hashbucket spinlock must be held when this is called */ 851 struct inet_bind2_bucket * 852 inet_bind2_bucket_find(const struct inet_bind_hashbucket *head, const struct net *net, 853 unsigned short port, int l3mdev, const struct sock *sk) 854 { 855 struct inet_bind2_bucket *bhash2 = NULL; 856 857 inet_bind_bucket_for_each(bhash2, &head->chain) 858 if (inet_bind2_bucket_match(bhash2, net, port, l3mdev, sk)) 859 break; 860 861 return bhash2; 862 } 863 864 struct inet_bind_hashbucket * 865 inet_bhash2_addr_any_hashbucket(const struct sock *sk, const struct net *net, int port) 866 { 867 struct inet_hashinfo *hinfo = tcp_or_dccp_get_hashinfo(sk); 868 u32 hash; 869 #if IS_ENABLED(CONFIG_IPV6) 870 struct in6_addr addr_any = {}; 871 872 if (sk->sk_family == AF_INET6) 873 hash = ipv6_portaddr_hash(net, &addr_any, port); 874 else 875 #endif 876 hash = ipv4_portaddr_hash(net, 0, port); 877 878 return &hinfo->bhash2[hash & (hinfo->bhash_size - 1)]; 879 } 880 881 static void inet_update_saddr(struct sock *sk, void *saddr, int family) 882 { 883 if (family == AF_INET) { 884 inet_sk(sk)->inet_saddr = *(__be32 *)saddr; 885 sk_rcv_saddr_set(sk, inet_sk(sk)->inet_saddr); 886 } 887 #if IS_ENABLED(CONFIG_IPV6) 888 else { 889 sk->sk_v6_rcv_saddr = *(struct in6_addr *)saddr; 890 } 891 #endif 892 } 893 894 static int __inet_bhash2_update_saddr(struct sock *sk, void *saddr, int family, bool reset) 895 { 896 struct inet_hashinfo *hinfo = tcp_or_dccp_get_hashinfo(sk); 897 struct inet_bind_hashbucket *head, *head2; 898 struct inet_bind2_bucket *tb2, *new_tb2; 899 int l3mdev = inet_sk_bound_l3mdev(sk); 900 int port = inet_sk(sk)->inet_num; 901 struct net *net = sock_net(sk); 902 int bhash; 903 904 if (!inet_csk(sk)->icsk_bind2_hash) { 905 /* Not bind()ed before. */ 906 if (reset) 907 inet_reset_saddr(sk); 908 else 909 inet_update_saddr(sk, saddr, family); 910 911 return 0; 912 } 913 914 /* Allocate a bind2 bucket ahead of time to avoid permanently putting 915 * the bhash2 table in an inconsistent state if a new tb2 bucket 916 * allocation fails. 917 */ 918 new_tb2 = kmem_cache_alloc(hinfo->bind2_bucket_cachep, GFP_ATOMIC); 919 if (!new_tb2) { 920 if (reset) { 921 /* The (INADDR_ANY, port) bucket might have already 922 * been freed, then we cannot fixup icsk_bind2_hash, 923 * so we give up and unlink sk from bhash/bhash2 not 924 * to leave inconsistency in bhash2. 925 */ 926 inet_put_port(sk); 927 inet_reset_saddr(sk); 928 } 929 930 return -ENOMEM; 931 } 932 933 bhash = inet_bhashfn(net, port, hinfo->bhash_size); 934 head = &hinfo->bhash[bhash]; 935 head2 = inet_bhashfn_portaddr(hinfo, sk, net, port); 936 937 /* If we change saddr locklessly, another thread 938 * iterating over bhash might see corrupted address. 939 */ 940 spin_lock_bh(&head->lock); 941 942 spin_lock(&head2->lock); 943 __sk_del_bind2_node(sk); 944 inet_bind2_bucket_destroy(hinfo->bind2_bucket_cachep, inet_csk(sk)->icsk_bind2_hash); 945 spin_unlock(&head2->lock); 946 947 if (reset) 948 inet_reset_saddr(sk); 949 else 950 inet_update_saddr(sk, saddr, family); 951 952 head2 = inet_bhashfn_portaddr(hinfo, sk, net, port); 953 954 spin_lock(&head2->lock); 955 tb2 = inet_bind2_bucket_find(head2, net, port, l3mdev, sk); 956 if (!tb2) { 957 tb2 = new_tb2; 958 inet_bind2_bucket_init(tb2, net, head2, port, l3mdev, sk); 959 } 960 sk_add_bind2_node(sk, &tb2->owners); 961 inet_csk(sk)->icsk_bind2_hash = tb2; 962 spin_unlock(&head2->lock); 963 964 spin_unlock_bh(&head->lock); 965 966 if (tb2 != new_tb2) 967 kmem_cache_free(hinfo->bind2_bucket_cachep, new_tb2); 968 969 return 0; 970 } 971 972 int inet_bhash2_update_saddr(struct sock *sk, void *saddr, int family) 973 { 974 return __inet_bhash2_update_saddr(sk, saddr, family, false); 975 } 976 EXPORT_SYMBOL_GPL(inet_bhash2_update_saddr); 977 978 void inet_bhash2_reset_saddr(struct sock *sk) 979 { 980 if (!(sk->sk_userlocks & SOCK_BINDADDR_LOCK)) 981 __inet_bhash2_update_saddr(sk, NULL, 0, true); 982 } 983 EXPORT_SYMBOL_GPL(inet_bhash2_reset_saddr); 984 985 /* RFC 6056 3.3.4. Algorithm 4: Double-Hash Port Selection Algorithm 986 * Note that we use 32bit integers (vs RFC 'short integers') 987 * because 2^16 is not a multiple of num_ephemeral and this 988 * property might be used by clever attacker. 989 * 990 * RFC claims using TABLE_LENGTH=10 buckets gives an improvement, though 991 * attacks were since demonstrated, thus we use 65536 by default instead 992 * to really give more isolation and privacy, at the expense of 256kB 993 * of kernel memory. 994 */ 995 #define INET_TABLE_PERTURB_SIZE (1 << CONFIG_INET_TABLE_PERTURB_ORDER) 996 static u32 *table_perturb; 997 998 int __inet_hash_connect(struct inet_timewait_death_row *death_row, 999 struct sock *sk, u64 port_offset, 1000 int (*check_established)(struct inet_timewait_death_row *, 1001 struct sock *, __u16, struct inet_timewait_sock **)) 1002 { 1003 struct inet_hashinfo *hinfo = death_row->hashinfo; 1004 struct inet_bind_hashbucket *head, *head2; 1005 struct inet_timewait_sock *tw = NULL; 1006 int port = inet_sk(sk)->inet_num; 1007 struct net *net = sock_net(sk); 1008 struct inet_bind2_bucket *tb2; 1009 struct inet_bind_bucket *tb; 1010 bool tb_created = false; 1011 u32 remaining, offset; 1012 int ret, i, low, high; 1013 int l3mdev; 1014 u32 index; 1015 1016 if (port) { 1017 local_bh_disable(); 1018 ret = check_established(death_row, sk, port, NULL); 1019 local_bh_enable(); 1020 return ret; 1021 } 1022 1023 l3mdev = inet_sk_bound_l3mdev(sk); 1024 1025 inet_sk_get_local_port_range(sk, &low, &high); 1026 high++; /* [32768, 60999] -> [32768, 61000[ */ 1027 remaining = high - low; 1028 if (likely(remaining > 1)) 1029 remaining &= ~1U; 1030 1031 get_random_sleepable_once(table_perturb, 1032 INET_TABLE_PERTURB_SIZE * sizeof(*table_perturb)); 1033 index = port_offset & (INET_TABLE_PERTURB_SIZE - 1); 1034 1035 offset = READ_ONCE(table_perturb[index]) + (port_offset >> 32); 1036 offset %= remaining; 1037 1038 /* In first pass we try ports of @low parity. 1039 * inet_csk_get_port() does the opposite choice. 1040 */ 1041 offset &= ~1U; 1042 other_parity_scan: 1043 port = low + offset; 1044 for (i = 0; i < remaining; i += 2, port += 2) { 1045 if (unlikely(port >= high)) 1046 port -= remaining; 1047 if (inet_is_local_reserved_port(net, port)) 1048 continue; 1049 head = &hinfo->bhash[inet_bhashfn(net, port, 1050 hinfo->bhash_size)]; 1051 spin_lock_bh(&head->lock); 1052 1053 /* Does not bother with rcv_saddr checks, because 1054 * the established check is already unique enough. 1055 */ 1056 inet_bind_bucket_for_each(tb, &head->chain) { 1057 if (inet_bind_bucket_match(tb, net, port, l3mdev)) { 1058 if (tb->fastreuse >= 0 || 1059 tb->fastreuseport >= 0) 1060 goto next_port; 1061 WARN_ON(hlist_empty(&tb->owners)); 1062 if (!check_established(death_row, sk, 1063 port, &tw)) 1064 goto ok; 1065 goto next_port; 1066 } 1067 } 1068 1069 tb = inet_bind_bucket_create(hinfo->bind_bucket_cachep, 1070 net, head, port, l3mdev); 1071 if (!tb) { 1072 spin_unlock_bh(&head->lock); 1073 return -ENOMEM; 1074 } 1075 tb_created = true; 1076 tb->fastreuse = -1; 1077 tb->fastreuseport = -1; 1078 goto ok; 1079 next_port: 1080 spin_unlock_bh(&head->lock); 1081 cond_resched(); 1082 } 1083 1084 offset++; 1085 if ((offset & 1) && remaining > 1) 1086 goto other_parity_scan; 1087 1088 return -EADDRNOTAVAIL; 1089 1090 ok: 1091 /* Find the corresponding tb2 bucket since we need to 1092 * add the socket to the bhash2 table as well 1093 */ 1094 head2 = inet_bhashfn_portaddr(hinfo, sk, net, port); 1095 spin_lock(&head2->lock); 1096 1097 tb2 = inet_bind2_bucket_find(head2, net, port, l3mdev, sk); 1098 if (!tb2) { 1099 tb2 = inet_bind2_bucket_create(hinfo->bind2_bucket_cachep, net, 1100 head2, port, l3mdev, sk); 1101 if (!tb2) 1102 goto error; 1103 } 1104 1105 /* Here we want to add a little bit of randomness to the next source 1106 * port that will be chosen. We use a max() with a random here so that 1107 * on low contention the randomness is maximal and on high contention 1108 * it may be inexistent. 1109 */ 1110 i = max_t(int, i, get_random_u32_below(8) * 2); 1111 WRITE_ONCE(table_perturb[index], READ_ONCE(table_perturb[index]) + i + 2); 1112 1113 /* Head lock still held and bh's disabled */ 1114 inet_bind_hash(sk, tb, tb2, port); 1115 1116 if (sk_unhashed(sk)) { 1117 inet_sk(sk)->inet_sport = htons(port); 1118 inet_ehash_nolisten(sk, (struct sock *)tw, NULL); 1119 } 1120 if (tw) 1121 inet_twsk_bind_unhash(tw, hinfo); 1122 1123 spin_unlock(&head2->lock); 1124 spin_unlock(&head->lock); 1125 1126 if (tw) 1127 inet_twsk_deschedule_put(tw); 1128 local_bh_enable(); 1129 return 0; 1130 1131 error: 1132 spin_unlock(&head2->lock); 1133 if (tb_created) 1134 inet_bind_bucket_destroy(hinfo->bind_bucket_cachep, tb); 1135 spin_unlock_bh(&head->lock); 1136 return -ENOMEM; 1137 } 1138 1139 /* 1140 * Bind a port for a connect operation and hash it. 1141 */ 1142 int inet_hash_connect(struct inet_timewait_death_row *death_row, 1143 struct sock *sk) 1144 { 1145 u64 port_offset = 0; 1146 1147 if (!inet_sk(sk)->inet_num) 1148 port_offset = inet_sk_port_offset(sk); 1149 return __inet_hash_connect(death_row, sk, port_offset, 1150 __inet_check_established); 1151 } 1152 EXPORT_SYMBOL_GPL(inet_hash_connect); 1153 1154 static void init_hashinfo_lhash2(struct inet_hashinfo *h) 1155 { 1156 int i; 1157 1158 for (i = 0; i <= h->lhash2_mask; i++) { 1159 spin_lock_init(&h->lhash2[i].lock); 1160 INIT_HLIST_NULLS_HEAD(&h->lhash2[i].nulls_head, 1161 i + LISTENING_NULLS_BASE); 1162 } 1163 } 1164 1165 void __init inet_hashinfo2_init(struct inet_hashinfo *h, const char *name, 1166 unsigned long numentries, int scale, 1167 unsigned long low_limit, 1168 unsigned long high_limit) 1169 { 1170 h->lhash2 = alloc_large_system_hash(name, 1171 sizeof(*h->lhash2), 1172 numentries, 1173 scale, 1174 0, 1175 NULL, 1176 &h->lhash2_mask, 1177 low_limit, 1178 high_limit); 1179 init_hashinfo_lhash2(h); 1180 1181 /* this one is used for source ports of outgoing connections */ 1182 table_perturb = alloc_large_system_hash("Table-perturb", 1183 sizeof(*table_perturb), 1184 INET_TABLE_PERTURB_SIZE, 1185 0, 0, NULL, NULL, 1186 INET_TABLE_PERTURB_SIZE, 1187 INET_TABLE_PERTURB_SIZE); 1188 } 1189 1190 int inet_hashinfo2_init_mod(struct inet_hashinfo *h) 1191 { 1192 h->lhash2 = kmalloc_array(INET_LHTABLE_SIZE, sizeof(*h->lhash2), GFP_KERNEL); 1193 if (!h->lhash2) 1194 return -ENOMEM; 1195 1196 h->lhash2_mask = INET_LHTABLE_SIZE - 1; 1197 /* INET_LHTABLE_SIZE must be a power of 2 */ 1198 BUG_ON(INET_LHTABLE_SIZE & h->lhash2_mask); 1199 1200 init_hashinfo_lhash2(h); 1201 return 0; 1202 } 1203 EXPORT_SYMBOL_GPL(inet_hashinfo2_init_mod); 1204 1205 int inet_ehash_locks_alloc(struct inet_hashinfo *hashinfo) 1206 { 1207 unsigned int locksz = sizeof(spinlock_t); 1208 unsigned int i, nblocks = 1; 1209 1210 if (locksz != 0) { 1211 /* allocate 2 cache lines or at least one spinlock per cpu */ 1212 nblocks = max(2U * L1_CACHE_BYTES / locksz, 1U); 1213 nblocks = roundup_pow_of_two(nblocks * num_possible_cpus()); 1214 1215 /* no more locks than number of hash buckets */ 1216 nblocks = min(nblocks, hashinfo->ehash_mask + 1); 1217 1218 hashinfo->ehash_locks = kvmalloc_array(nblocks, locksz, GFP_KERNEL); 1219 if (!hashinfo->ehash_locks) 1220 return -ENOMEM; 1221 1222 for (i = 0; i < nblocks; i++) 1223 spin_lock_init(&hashinfo->ehash_locks[i]); 1224 } 1225 hashinfo->ehash_locks_mask = nblocks - 1; 1226 return 0; 1227 } 1228 EXPORT_SYMBOL_GPL(inet_ehash_locks_alloc); 1229 1230 struct inet_hashinfo *inet_pernet_hashinfo_alloc(struct inet_hashinfo *hashinfo, 1231 unsigned int ehash_entries) 1232 { 1233 struct inet_hashinfo *new_hashinfo; 1234 int i; 1235 1236 new_hashinfo = kmemdup(hashinfo, sizeof(*hashinfo), GFP_KERNEL); 1237 if (!new_hashinfo) 1238 goto err; 1239 1240 new_hashinfo->ehash = vmalloc_huge(ehash_entries * sizeof(struct inet_ehash_bucket), 1241 GFP_KERNEL_ACCOUNT); 1242 if (!new_hashinfo->ehash) 1243 goto free_hashinfo; 1244 1245 new_hashinfo->ehash_mask = ehash_entries - 1; 1246 1247 if (inet_ehash_locks_alloc(new_hashinfo)) 1248 goto free_ehash; 1249 1250 for (i = 0; i < ehash_entries; i++) 1251 INIT_HLIST_NULLS_HEAD(&new_hashinfo->ehash[i].chain, i); 1252 1253 new_hashinfo->pernet = true; 1254 1255 return new_hashinfo; 1256 1257 free_ehash: 1258 vfree(new_hashinfo->ehash); 1259 free_hashinfo: 1260 kfree(new_hashinfo); 1261 err: 1262 return NULL; 1263 } 1264 EXPORT_SYMBOL_GPL(inet_pernet_hashinfo_alloc); 1265 1266 void inet_pernet_hashinfo_free(struct inet_hashinfo *hashinfo) 1267 { 1268 if (!hashinfo->pernet) 1269 return; 1270 1271 inet_ehash_locks_free(hashinfo); 1272 vfree(hashinfo->ehash); 1273 kfree(hashinfo); 1274 } 1275 EXPORT_SYMBOL_GPL(inet_pernet_hashinfo_free); 1276