1 // SPDX-License-Identifier: GPL-2.0 2 /* Copyright (c) 2017 - 2018 Covalent IO, Inc. http://covalent.io */ 3 4 #include <linux/bpf.h> 5 #include <linux/filter.h> 6 #include <linux/errno.h> 7 #include <linux/file.h> 8 #include <linux/net.h> 9 #include <linux/workqueue.h> 10 #include <linux/skmsg.h> 11 #include <linux/list.h> 12 #include <linux/jhash.h> 13 #include <linux/sock_diag.h> 14 #include <net/udp.h> 15 16 struct bpf_stab { 17 struct bpf_map map; 18 struct sock **sks; 19 struct sk_psock_progs progs; 20 raw_spinlock_t lock; 21 }; 22 23 #define SOCK_CREATE_FLAG_MASK \ 24 (BPF_F_NUMA_NODE | BPF_F_RDONLY | BPF_F_WRONLY) 25 26 static struct bpf_map *sock_map_alloc(union bpf_attr *attr) 27 { 28 struct bpf_stab *stab; 29 u64 cost; 30 int err; 31 32 if (!capable(CAP_NET_ADMIN)) 33 return ERR_PTR(-EPERM); 34 if (attr->max_entries == 0 || 35 attr->key_size != 4 || 36 (attr->value_size != sizeof(u32) && 37 attr->value_size != sizeof(u64)) || 38 attr->map_flags & ~SOCK_CREATE_FLAG_MASK) 39 return ERR_PTR(-EINVAL); 40 41 stab = kzalloc(sizeof(*stab), GFP_USER); 42 if (!stab) 43 return ERR_PTR(-ENOMEM); 44 45 bpf_map_init_from_attr(&stab->map, attr); 46 raw_spin_lock_init(&stab->lock); 47 48 /* Make sure page count doesn't overflow. */ 49 cost = (u64) stab->map.max_entries * sizeof(struct sock *); 50 err = bpf_map_charge_init(&stab->map.memory, cost); 51 if (err) 52 goto free_stab; 53 54 stab->sks = bpf_map_area_alloc(stab->map.max_entries * 55 sizeof(struct sock *), 56 stab->map.numa_node); 57 if (stab->sks) 58 return &stab->map; 59 err = -ENOMEM; 60 bpf_map_charge_finish(&stab->map.memory); 61 free_stab: 62 kfree(stab); 63 return ERR_PTR(err); 64 } 65 66 int sock_map_get_from_fd(const union bpf_attr *attr, struct bpf_prog *prog) 67 { 68 u32 ufd = attr->target_fd; 69 struct bpf_map *map; 70 struct fd f; 71 int ret; 72 73 f = fdget(ufd); 74 map = __bpf_map_get(f); 75 if (IS_ERR(map)) 76 return PTR_ERR(map); 77 ret = sock_map_prog_update(map, prog, attr->attach_type); 78 fdput(f); 79 return ret; 80 } 81 82 static void sock_map_sk_acquire(struct sock *sk) 83 __acquires(&sk->sk_lock.slock) 84 { 85 lock_sock(sk); 86 preempt_disable(); 87 rcu_read_lock(); 88 } 89 90 static void sock_map_sk_release(struct sock *sk) 91 __releases(&sk->sk_lock.slock) 92 { 93 rcu_read_unlock(); 94 preempt_enable(); 95 release_sock(sk); 96 } 97 98 static void sock_map_add_link(struct sk_psock *psock, 99 struct sk_psock_link *link, 100 struct bpf_map *map, void *link_raw) 101 { 102 link->link_raw = link_raw; 103 link->map = map; 104 spin_lock_bh(&psock->link_lock); 105 list_add_tail(&link->list, &psock->link); 106 spin_unlock_bh(&psock->link_lock); 107 } 108 109 static void sock_map_del_link(struct sock *sk, 110 struct sk_psock *psock, void *link_raw) 111 { 112 struct sk_psock_link *link, *tmp; 113 bool strp_stop = false; 114 115 spin_lock_bh(&psock->link_lock); 116 list_for_each_entry_safe(link, tmp, &psock->link, list) { 117 if (link->link_raw == link_raw) { 118 struct bpf_map *map = link->map; 119 struct bpf_stab *stab = container_of(map, struct bpf_stab, 120 map); 121 if (psock->parser.enabled && stab->progs.skb_parser) 122 strp_stop = true; 123 list_del(&link->list); 124 sk_psock_free_link(link); 125 } 126 } 127 spin_unlock_bh(&psock->link_lock); 128 if (strp_stop) { 129 write_lock_bh(&sk->sk_callback_lock); 130 sk_psock_stop_strp(sk, psock); 131 write_unlock_bh(&sk->sk_callback_lock); 132 } 133 } 134 135 static void sock_map_unref(struct sock *sk, void *link_raw) 136 { 137 struct sk_psock *psock = sk_psock(sk); 138 139 if (likely(psock)) { 140 sock_map_del_link(sk, psock, link_raw); 141 sk_psock_put(sk, psock); 142 } 143 } 144 145 static int sock_map_init_proto(struct sock *sk, struct sk_psock *psock) 146 { 147 struct proto *prot; 148 149 sock_owned_by_me(sk); 150 151 switch (sk->sk_type) { 152 case SOCK_STREAM: 153 prot = tcp_bpf_get_proto(sk, psock); 154 break; 155 156 case SOCK_DGRAM: 157 prot = udp_bpf_get_proto(sk, psock); 158 break; 159 160 default: 161 return -EINVAL; 162 } 163 164 if (IS_ERR(prot)) 165 return PTR_ERR(prot); 166 167 sk_psock_update_proto(sk, psock, prot); 168 return 0; 169 } 170 171 static struct sk_psock *sock_map_psock_get_checked(struct sock *sk) 172 { 173 struct sk_psock *psock; 174 175 rcu_read_lock(); 176 psock = sk_psock(sk); 177 if (psock) { 178 if (sk->sk_prot->close != sock_map_close) { 179 psock = ERR_PTR(-EBUSY); 180 goto out; 181 } 182 183 if (!refcount_inc_not_zero(&psock->refcnt)) 184 psock = ERR_PTR(-EBUSY); 185 } 186 out: 187 rcu_read_unlock(); 188 return psock; 189 } 190 191 static int sock_map_link(struct bpf_map *map, struct sk_psock_progs *progs, 192 struct sock *sk) 193 { 194 struct bpf_prog *msg_parser, *skb_parser, *skb_verdict; 195 struct sk_psock *psock; 196 bool skb_progs; 197 int ret; 198 199 skb_verdict = READ_ONCE(progs->skb_verdict); 200 skb_parser = READ_ONCE(progs->skb_parser); 201 skb_progs = skb_parser && skb_verdict; 202 if (skb_progs) { 203 skb_verdict = bpf_prog_inc_not_zero(skb_verdict); 204 if (IS_ERR(skb_verdict)) 205 return PTR_ERR(skb_verdict); 206 skb_parser = bpf_prog_inc_not_zero(skb_parser); 207 if (IS_ERR(skb_parser)) { 208 bpf_prog_put(skb_verdict); 209 return PTR_ERR(skb_parser); 210 } 211 } 212 213 msg_parser = READ_ONCE(progs->msg_parser); 214 if (msg_parser) { 215 msg_parser = bpf_prog_inc_not_zero(msg_parser); 216 if (IS_ERR(msg_parser)) { 217 ret = PTR_ERR(msg_parser); 218 goto out; 219 } 220 } 221 222 psock = sock_map_psock_get_checked(sk); 223 if (IS_ERR(psock)) { 224 ret = PTR_ERR(psock); 225 goto out_progs; 226 } 227 228 if (psock) { 229 if ((msg_parser && READ_ONCE(psock->progs.msg_parser)) || 230 (skb_progs && READ_ONCE(psock->progs.skb_parser))) { 231 sk_psock_put(sk, psock); 232 ret = -EBUSY; 233 goto out_progs; 234 } 235 } else { 236 psock = sk_psock_init(sk, map->numa_node); 237 if (!psock) { 238 ret = -ENOMEM; 239 goto out_progs; 240 } 241 } 242 243 if (msg_parser) 244 psock_set_prog(&psock->progs.msg_parser, msg_parser); 245 246 ret = sock_map_init_proto(sk, psock); 247 if (ret < 0) 248 goto out_drop; 249 250 write_lock_bh(&sk->sk_callback_lock); 251 if (skb_progs && !psock->parser.enabled) { 252 ret = sk_psock_init_strp(sk, psock); 253 if (ret) { 254 write_unlock_bh(&sk->sk_callback_lock); 255 goto out_drop; 256 } 257 psock_set_prog(&psock->progs.skb_verdict, skb_verdict); 258 psock_set_prog(&psock->progs.skb_parser, skb_parser); 259 sk_psock_start_strp(sk, psock); 260 } 261 write_unlock_bh(&sk->sk_callback_lock); 262 return 0; 263 out_drop: 264 sk_psock_put(sk, psock); 265 out_progs: 266 if (msg_parser) 267 bpf_prog_put(msg_parser); 268 out: 269 if (skb_progs) { 270 bpf_prog_put(skb_verdict); 271 bpf_prog_put(skb_parser); 272 } 273 return ret; 274 } 275 276 static int sock_map_link_no_progs(struct bpf_map *map, struct sock *sk) 277 { 278 struct sk_psock *psock; 279 int ret; 280 281 psock = sock_map_psock_get_checked(sk); 282 if (IS_ERR(psock)) 283 return PTR_ERR(psock); 284 285 if (!psock) { 286 psock = sk_psock_init(sk, map->numa_node); 287 if (!psock) 288 return -ENOMEM; 289 } 290 291 ret = sock_map_init_proto(sk, psock); 292 if (ret < 0) 293 sk_psock_put(sk, psock); 294 return ret; 295 } 296 297 static void sock_map_free(struct bpf_map *map) 298 { 299 struct bpf_stab *stab = container_of(map, struct bpf_stab, map); 300 int i; 301 302 /* After the sync no updates or deletes will be in-flight so it 303 * is safe to walk map and remove entries without risking a race 304 * in EEXIST update case. 305 */ 306 synchronize_rcu(); 307 for (i = 0; i < stab->map.max_entries; i++) { 308 struct sock **psk = &stab->sks[i]; 309 struct sock *sk; 310 311 sk = xchg(psk, NULL); 312 if (sk) { 313 lock_sock(sk); 314 rcu_read_lock(); 315 sock_map_unref(sk, psk); 316 rcu_read_unlock(); 317 release_sock(sk); 318 } 319 } 320 321 /* wait for psock readers accessing its map link */ 322 synchronize_rcu(); 323 324 bpf_map_area_free(stab->sks); 325 kfree(stab); 326 } 327 328 static void sock_map_release_progs(struct bpf_map *map) 329 { 330 psock_progs_drop(&container_of(map, struct bpf_stab, map)->progs); 331 } 332 333 static struct sock *__sock_map_lookup_elem(struct bpf_map *map, u32 key) 334 { 335 struct bpf_stab *stab = container_of(map, struct bpf_stab, map); 336 337 WARN_ON_ONCE(!rcu_read_lock_held()); 338 339 if (unlikely(key >= map->max_entries)) 340 return NULL; 341 return READ_ONCE(stab->sks[key]); 342 } 343 344 static void *sock_map_lookup(struct bpf_map *map, void *key) 345 { 346 struct sock *sk; 347 348 sk = __sock_map_lookup_elem(map, *(u32 *)key); 349 if (!sk || !sk_fullsock(sk)) 350 return NULL; 351 if (sk_is_refcounted(sk) && !refcount_inc_not_zero(&sk->sk_refcnt)) 352 return NULL; 353 return sk; 354 } 355 356 static void *sock_map_lookup_sys(struct bpf_map *map, void *key) 357 { 358 struct sock *sk; 359 360 if (map->value_size != sizeof(u64)) 361 return ERR_PTR(-ENOSPC); 362 363 sk = __sock_map_lookup_elem(map, *(u32 *)key); 364 if (!sk) 365 return ERR_PTR(-ENOENT); 366 367 sock_gen_cookie(sk); 368 return &sk->sk_cookie; 369 } 370 371 static int __sock_map_delete(struct bpf_stab *stab, struct sock *sk_test, 372 struct sock **psk) 373 { 374 struct sock *sk; 375 int err = 0; 376 377 raw_spin_lock_bh(&stab->lock); 378 sk = *psk; 379 if (!sk_test || sk_test == sk) 380 sk = xchg(psk, NULL); 381 382 if (likely(sk)) 383 sock_map_unref(sk, psk); 384 else 385 err = -EINVAL; 386 387 raw_spin_unlock_bh(&stab->lock); 388 return err; 389 } 390 391 static void sock_map_delete_from_link(struct bpf_map *map, struct sock *sk, 392 void *link_raw) 393 { 394 struct bpf_stab *stab = container_of(map, struct bpf_stab, map); 395 396 __sock_map_delete(stab, sk, link_raw); 397 } 398 399 static int sock_map_delete_elem(struct bpf_map *map, void *key) 400 { 401 struct bpf_stab *stab = container_of(map, struct bpf_stab, map); 402 u32 i = *(u32 *)key; 403 struct sock **psk; 404 405 if (unlikely(i >= map->max_entries)) 406 return -EINVAL; 407 408 psk = &stab->sks[i]; 409 return __sock_map_delete(stab, NULL, psk); 410 } 411 412 static int sock_map_get_next_key(struct bpf_map *map, void *key, void *next) 413 { 414 struct bpf_stab *stab = container_of(map, struct bpf_stab, map); 415 u32 i = key ? *(u32 *)key : U32_MAX; 416 u32 *key_next = next; 417 418 if (i == stab->map.max_entries - 1) 419 return -ENOENT; 420 if (i >= stab->map.max_entries) 421 *key_next = 0; 422 else 423 *key_next = i + 1; 424 return 0; 425 } 426 427 static bool sock_map_redirect_allowed(const struct sock *sk); 428 429 static int sock_map_update_common(struct bpf_map *map, u32 idx, 430 struct sock *sk, u64 flags) 431 { 432 struct bpf_stab *stab = container_of(map, struct bpf_stab, map); 433 struct sk_psock_link *link; 434 struct sk_psock *psock; 435 struct sock *osk; 436 int ret; 437 438 WARN_ON_ONCE(!rcu_read_lock_held()); 439 if (unlikely(flags > BPF_EXIST)) 440 return -EINVAL; 441 if (unlikely(idx >= map->max_entries)) 442 return -E2BIG; 443 if (inet_csk_has_ulp(sk)) 444 return -EINVAL; 445 446 link = sk_psock_init_link(); 447 if (!link) 448 return -ENOMEM; 449 450 /* Only sockets we can redirect into/from in BPF need to hold 451 * refs to parser/verdict progs and have their sk_data_ready 452 * and sk_write_space callbacks overridden. 453 */ 454 if (sock_map_redirect_allowed(sk)) 455 ret = sock_map_link(map, &stab->progs, sk); 456 else 457 ret = sock_map_link_no_progs(map, sk); 458 if (ret < 0) 459 goto out_free; 460 461 psock = sk_psock(sk); 462 WARN_ON_ONCE(!psock); 463 464 raw_spin_lock_bh(&stab->lock); 465 osk = stab->sks[idx]; 466 if (osk && flags == BPF_NOEXIST) { 467 ret = -EEXIST; 468 goto out_unlock; 469 } else if (!osk && flags == BPF_EXIST) { 470 ret = -ENOENT; 471 goto out_unlock; 472 } 473 474 sock_map_add_link(psock, link, map, &stab->sks[idx]); 475 stab->sks[idx] = sk; 476 if (osk) 477 sock_map_unref(osk, &stab->sks[idx]); 478 raw_spin_unlock_bh(&stab->lock); 479 return 0; 480 out_unlock: 481 raw_spin_unlock_bh(&stab->lock); 482 if (psock) 483 sk_psock_put(sk, psock); 484 out_free: 485 sk_psock_free_link(link); 486 return ret; 487 } 488 489 static bool sock_map_op_okay(const struct bpf_sock_ops_kern *ops) 490 { 491 return ops->op == BPF_SOCK_OPS_PASSIVE_ESTABLISHED_CB || 492 ops->op == BPF_SOCK_OPS_ACTIVE_ESTABLISHED_CB || 493 ops->op == BPF_SOCK_OPS_TCP_LISTEN_CB; 494 } 495 496 static bool sk_is_tcp(const struct sock *sk) 497 { 498 return sk->sk_type == SOCK_STREAM && 499 sk->sk_protocol == IPPROTO_TCP; 500 } 501 502 static bool sk_is_udp(const struct sock *sk) 503 { 504 return sk->sk_type == SOCK_DGRAM && 505 sk->sk_protocol == IPPROTO_UDP; 506 } 507 508 static bool sock_map_redirect_allowed(const struct sock *sk) 509 { 510 return sk_is_tcp(sk) && sk->sk_state != TCP_LISTEN; 511 } 512 513 static bool sock_map_sk_is_suitable(const struct sock *sk) 514 { 515 return sk_is_tcp(sk) || sk_is_udp(sk); 516 } 517 518 static bool sock_map_sk_state_allowed(const struct sock *sk) 519 { 520 if (sk_is_tcp(sk)) 521 return (1 << sk->sk_state) & (TCPF_ESTABLISHED | TCPF_LISTEN); 522 else if (sk_is_udp(sk)) 523 return sk_hashed(sk); 524 525 return false; 526 } 527 528 static int sock_map_update_elem(struct bpf_map *map, void *key, 529 void *value, u64 flags) 530 { 531 u32 idx = *(u32 *)key; 532 struct socket *sock; 533 struct sock *sk; 534 int ret; 535 u64 ufd; 536 537 if (map->value_size == sizeof(u64)) 538 ufd = *(u64 *)value; 539 else 540 ufd = *(u32 *)value; 541 if (ufd > S32_MAX) 542 return -EINVAL; 543 544 sock = sockfd_lookup(ufd, &ret); 545 if (!sock) 546 return ret; 547 sk = sock->sk; 548 if (!sk) { 549 ret = -EINVAL; 550 goto out; 551 } 552 if (!sock_map_sk_is_suitable(sk)) { 553 ret = -EOPNOTSUPP; 554 goto out; 555 } 556 557 sock_map_sk_acquire(sk); 558 if (!sock_map_sk_state_allowed(sk)) 559 ret = -EOPNOTSUPP; 560 else 561 ret = sock_map_update_common(map, idx, sk, flags); 562 sock_map_sk_release(sk); 563 out: 564 fput(sock->file); 565 return ret; 566 } 567 568 BPF_CALL_4(bpf_sock_map_update, struct bpf_sock_ops_kern *, sops, 569 struct bpf_map *, map, void *, key, u64, flags) 570 { 571 WARN_ON_ONCE(!rcu_read_lock_held()); 572 573 if (likely(sock_map_sk_is_suitable(sops->sk) && 574 sock_map_op_okay(sops))) 575 return sock_map_update_common(map, *(u32 *)key, sops->sk, 576 flags); 577 return -EOPNOTSUPP; 578 } 579 580 const struct bpf_func_proto bpf_sock_map_update_proto = { 581 .func = bpf_sock_map_update, 582 .gpl_only = false, 583 .pkt_access = true, 584 .ret_type = RET_INTEGER, 585 .arg1_type = ARG_PTR_TO_CTX, 586 .arg2_type = ARG_CONST_MAP_PTR, 587 .arg3_type = ARG_PTR_TO_MAP_KEY, 588 .arg4_type = ARG_ANYTHING, 589 }; 590 591 BPF_CALL_4(bpf_sk_redirect_map, struct sk_buff *, skb, 592 struct bpf_map *, map, u32, key, u64, flags) 593 { 594 struct tcp_skb_cb *tcb = TCP_SKB_CB(skb); 595 struct sock *sk; 596 597 if (unlikely(flags & ~(BPF_F_INGRESS))) 598 return SK_DROP; 599 600 sk = __sock_map_lookup_elem(map, key); 601 if (unlikely(!sk || !sock_map_redirect_allowed(sk))) 602 return SK_DROP; 603 604 tcb->bpf.flags = flags; 605 tcb->bpf.sk_redir = sk; 606 return SK_PASS; 607 } 608 609 const struct bpf_func_proto bpf_sk_redirect_map_proto = { 610 .func = bpf_sk_redirect_map, 611 .gpl_only = false, 612 .ret_type = RET_INTEGER, 613 .arg1_type = ARG_PTR_TO_CTX, 614 .arg2_type = ARG_CONST_MAP_PTR, 615 .arg3_type = ARG_ANYTHING, 616 .arg4_type = ARG_ANYTHING, 617 }; 618 619 BPF_CALL_4(bpf_msg_redirect_map, struct sk_msg *, msg, 620 struct bpf_map *, map, u32, key, u64, flags) 621 { 622 struct sock *sk; 623 624 if (unlikely(flags & ~(BPF_F_INGRESS))) 625 return SK_DROP; 626 627 sk = __sock_map_lookup_elem(map, key); 628 if (unlikely(!sk || !sock_map_redirect_allowed(sk))) 629 return SK_DROP; 630 631 msg->flags = flags; 632 msg->sk_redir = sk; 633 return SK_PASS; 634 } 635 636 const struct bpf_func_proto bpf_msg_redirect_map_proto = { 637 .func = bpf_msg_redirect_map, 638 .gpl_only = false, 639 .ret_type = RET_INTEGER, 640 .arg1_type = ARG_PTR_TO_CTX, 641 .arg2_type = ARG_CONST_MAP_PTR, 642 .arg3_type = ARG_ANYTHING, 643 .arg4_type = ARG_ANYTHING, 644 }; 645 646 const struct bpf_map_ops sock_map_ops = { 647 .map_alloc = sock_map_alloc, 648 .map_free = sock_map_free, 649 .map_get_next_key = sock_map_get_next_key, 650 .map_lookup_elem_sys_only = sock_map_lookup_sys, 651 .map_update_elem = sock_map_update_elem, 652 .map_delete_elem = sock_map_delete_elem, 653 .map_lookup_elem = sock_map_lookup, 654 .map_release_uref = sock_map_release_progs, 655 .map_check_btf = map_check_no_btf, 656 }; 657 658 struct bpf_htab_elem { 659 struct rcu_head rcu; 660 u32 hash; 661 struct sock *sk; 662 struct hlist_node node; 663 u8 key[]; 664 }; 665 666 struct bpf_htab_bucket { 667 struct hlist_head head; 668 raw_spinlock_t lock; 669 }; 670 671 struct bpf_htab { 672 struct bpf_map map; 673 struct bpf_htab_bucket *buckets; 674 u32 buckets_num; 675 u32 elem_size; 676 struct sk_psock_progs progs; 677 atomic_t count; 678 }; 679 680 static inline u32 sock_hash_bucket_hash(const void *key, u32 len) 681 { 682 return jhash(key, len, 0); 683 } 684 685 static struct bpf_htab_bucket *sock_hash_select_bucket(struct bpf_htab *htab, 686 u32 hash) 687 { 688 return &htab->buckets[hash & (htab->buckets_num - 1)]; 689 } 690 691 static struct bpf_htab_elem * 692 sock_hash_lookup_elem_raw(struct hlist_head *head, u32 hash, void *key, 693 u32 key_size) 694 { 695 struct bpf_htab_elem *elem; 696 697 hlist_for_each_entry_rcu(elem, head, node) { 698 if (elem->hash == hash && 699 !memcmp(&elem->key, key, key_size)) 700 return elem; 701 } 702 703 return NULL; 704 } 705 706 static struct sock *__sock_hash_lookup_elem(struct bpf_map *map, void *key) 707 { 708 struct bpf_htab *htab = container_of(map, struct bpf_htab, map); 709 u32 key_size = map->key_size, hash; 710 struct bpf_htab_bucket *bucket; 711 struct bpf_htab_elem *elem; 712 713 WARN_ON_ONCE(!rcu_read_lock_held()); 714 715 hash = sock_hash_bucket_hash(key, key_size); 716 bucket = sock_hash_select_bucket(htab, hash); 717 elem = sock_hash_lookup_elem_raw(&bucket->head, hash, key, key_size); 718 719 return elem ? elem->sk : NULL; 720 } 721 722 static void sock_hash_free_elem(struct bpf_htab *htab, 723 struct bpf_htab_elem *elem) 724 { 725 atomic_dec(&htab->count); 726 kfree_rcu(elem, rcu); 727 } 728 729 static void sock_hash_delete_from_link(struct bpf_map *map, struct sock *sk, 730 void *link_raw) 731 { 732 struct bpf_htab *htab = container_of(map, struct bpf_htab, map); 733 struct bpf_htab_elem *elem_probe, *elem = link_raw; 734 struct bpf_htab_bucket *bucket; 735 736 WARN_ON_ONCE(!rcu_read_lock_held()); 737 bucket = sock_hash_select_bucket(htab, elem->hash); 738 739 /* elem may be deleted in parallel from the map, but access here 740 * is okay since it's going away only after RCU grace period. 741 * However, we need to check whether it's still present. 742 */ 743 raw_spin_lock_bh(&bucket->lock); 744 elem_probe = sock_hash_lookup_elem_raw(&bucket->head, elem->hash, 745 elem->key, map->key_size); 746 if (elem_probe && elem_probe == elem) { 747 hlist_del_rcu(&elem->node); 748 sock_map_unref(elem->sk, elem); 749 sock_hash_free_elem(htab, elem); 750 } 751 raw_spin_unlock_bh(&bucket->lock); 752 } 753 754 static int sock_hash_delete_elem(struct bpf_map *map, void *key) 755 { 756 struct bpf_htab *htab = container_of(map, struct bpf_htab, map); 757 u32 hash, key_size = map->key_size; 758 struct bpf_htab_bucket *bucket; 759 struct bpf_htab_elem *elem; 760 int ret = -ENOENT; 761 762 hash = sock_hash_bucket_hash(key, key_size); 763 bucket = sock_hash_select_bucket(htab, hash); 764 765 raw_spin_lock_bh(&bucket->lock); 766 elem = sock_hash_lookup_elem_raw(&bucket->head, hash, key, key_size); 767 if (elem) { 768 hlist_del_rcu(&elem->node); 769 sock_map_unref(elem->sk, elem); 770 sock_hash_free_elem(htab, elem); 771 ret = 0; 772 } 773 raw_spin_unlock_bh(&bucket->lock); 774 return ret; 775 } 776 777 static struct bpf_htab_elem *sock_hash_alloc_elem(struct bpf_htab *htab, 778 void *key, u32 key_size, 779 u32 hash, struct sock *sk, 780 struct bpf_htab_elem *old) 781 { 782 struct bpf_htab_elem *new; 783 784 if (atomic_inc_return(&htab->count) > htab->map.max_entries) { 785 if (!old) { 786 atomic_dec(&htab->count); 787 return ERR_PTR(-E2BIG); 788 } 789 } 790 791 new = kmalloc_node(htab->elem_size, GFP_ATOMIC | __GFP_NOWARN, 792 htab->map.numa_node); 793 if (!new) { 794 atomic_dec(&htab->count); 795 return ERR_PTR(-ENOMEM); 796 } 797 memcpy(new->key, key, key_size); 798 new->sk = sk; 799 new->hash = hash; 800 return new; 801 } 802 803 static int sock_hash_update_common(struct bpf_map *map, void *key, 804 struct sock *sk, u64 flags) 805 { 806 struct bpf_htab *htab = container_of(map, struct bpf_htab, map); 807 u32 key_size = map->key_size, hash; 808 struct bpf_htab_elem *elem, *elem_new; 809 struct bpf_htab_bucket *bucket; 810 struct sk_psock_link *link; 811 struct sk_psock *psock; 812 int ret; 813 814 WARN_ON_ONCE(!rcu_read_lock_held()); 815 if (unlikely(flags > BPF_EXIST)) 816 return -EINVAL; 817 if (inet_csk_has_ulp(sk)) 818 return -EINVAL; 819 820 link = sk_psock_init_link(); 821 if (!link) 822 return -ENOMEM; 823 824 /* Only sockets we can redirect into/from in BPF need to hold 825 * refs to parser/verdict progs and have their sk_data_ready 826 * and sk_write_space callbacks overridden. 827 */ 828 if (sock_map_redirect_allowed(sk)) 829 ret = sock_map_link(map, &htab->progs, sk); 830 else 831 ret = sock_map_link_no_progs(map, sk); 832 if (ret < 0) 833 goto out_free; 834 835 psock = sk_psock(sk); 836 WARN_ON_ONCE(!psock); 837 838 hash = sock_hash_bucket_hash(key, key_size); 839 bucket = sock_hash_select_bucket(htab, hash); 840 841 raw_spin_lock_bh(&bucket->lock); 842 elem = sock_hash_lookup_elem_raw(&bucket->head, hash, key, key_size); 843 if (elem && flags == BPF_NOEXIST) { 844 ret = -EEXIST; 845 goto out_unlock; 846 } else if (!elem && flags == BPF_EXIST) { 847 ret = -ENOENT; 848 goto out_unlock; 849 } 850 851 elem_new = sock_hash_alloc_elem(htab, key, key_size, hash, sk, elem); 852 if (IS_ERR(elem_new)) { 853 ret = PTR_ERR(elem_new); 854 goto out_unlock; 855 } 856 857 sock_map_add_link(psock, link, map, elem_new); 858 /* Add new element to the head of the list, so that 859 * concurrent search will find it before old elem. 860 */ 861 hlist_add_head_rcu(&elem_new->node, &bucket->head); 862 if (elem) { 863 hlist_del_rcu(&elem->node); 864 sock_map_unref(elem->sk, elem); 865 sock_hash_free_elem(htab, elem); 866 } 867 raw_spin_unlock_bh(&bucket->lock); 868 return 0; 869 out_unlock: 870 raw_spin_unlock_bh(&bucket->lock); 871 sk_psock_put(sk, psock); 872 out_free: 873 sk_psock_free_link(link); 874 return ret; 875 } 876 877 static int sock_hash_update_elem(struct bpf_map *map, void *key, 878 void *value, u64 flags) 879 { 880 struct socket *sock; 881 struct sock *sk; 882 int ret; 883 u64 ufd; 884 885 if (map->value_size == sizeof(u64)) 886 ufd = *(u64 *)value; 887 else 888 ufd = *(u32 *)value; 889 if (ufd > S32_MAX) 890 return -EINVAL; 891 892 sock = sockfd_lookup(ufd, &ret); 893 if (!sock) 894 return ret; 895 sk = sock->sk; 896 if (!sk) { 897 ret = -EINVAL; 898 goto out; 899 } 900 if (!sock_map_sk_is_suitable(sk)) { 901 ret = -EOPNOTSUPP; 902 goto out; 903 } 904 905 sock_map_sk_acquire(sk); 906 if (!sock_map_sk_state_allowed(sk)) 907 ret = -EOPNOTSUPP; 908 else 909 ret = sock_hash_update_common(map, key, sk, flags); 910 sock_map_sk_release(sk); 911 out: 912 fput(sock->file); 913 return ret; 914 } 915 916 static int sock_hash_get_next_key(struct bpf_map *map, void *key, 917 void *key_next) 918 { 919 struct bpf_htab *htab = container_of(map, struct bpf_htab, map); 920 struct bpf_htab_elem *elem, *elem_next; 921 u32 hash, key_size = map->key_size; 922 struct hlist_head *head; 923 int i = 0; 924 925 if (!key) 926 goto find_first_elem; 927 hash = sock_hash_bucket_hash(key, key_size); 928 head = &sock_hash_select_bucket(htab, hash)->head; 929 elem = sock_hash_lookup_elem_raw(head, hash, key, key_size); 930 if (!elem) 931 goto find_first_elem; 932 933 elem_next = hlist_entry_safe(rcu_dereference_raw(hlist_next_rcu(&elem->node)), 934 struct bpf_htab_elem, node); 935 if (elem_next) { 936 memcpy(key_next, elem_next->key, key_size); 937 return 0; 938 } 939 940 i = hash & (htab->buckets_num - 1); 941 i++; 942 find_first_elem: 943 for (; i < htab->buckets_num; i++) { 944 head = &sock_hash_select_bucket(htab, i)->head; 945 elem_next = hlist_entry_safe(rcu_dereference_raw(hlist_first_rcu(head)), 946 struct bpf_htab_elem, node); 947 if (elem_next) { 948 memcpy(key_next, elem_next->key, key_size); 949 return 0; 950 } 951 } 952 953 return -ENOENT; 954 } 955 956 static struct bpf_map *sock_hash_alloc(union bpf_attr *attr) 957 { 958 struct bpf_htab *htab; 959 int i, err; 960 u64 cost; 961 962 if (!capable(CAP_NET_ADMIN)) 963 return ERR_PTR(-EPERM); 964 if (attr->max_entries == 0 || 965 attr->key_size == 0 || 966 (attr->value_size != sizeof(u32) && 967 attr->value_size != sizeof(u64)) || 968 attr->map_flags & ~SOCK_CREATE_FLAG_MASK) 969 return ERR_PTR(-EINVAL); 970 if (attr->key_size > MAX_BPF_STACK) 971 return ERR_PTR(-E2BIG); 972 973 htab = kzalloc(sizeof(*htab), GFP_USER); 974 if (!htab) 975 return ERR_PTR(-ENOMEM); 976 977 bpf_map_init_from_attr(&htab->map, attr); 978 979 htab->buckets_num = roundup_pow_of_two(htab->map.max_entries); 980 htab->elem_size = sizeof(struct bpf_htab_elem) + 981 round_up(htab->map.key_size, 8); 982 if (htab->buckets_num == 0 || 983 htab->buckets_num > U32_MAX / sizeof(struct bpf_htab_bucket)) { 984 err = -EINVAL; 985 goto free_htab; 986 } 987 988 cost = (u64) htab->buckets_num * sizeof(struct bpf_htab_bucket) + 989 (u64) htab->elem_size * htab->map.max_entries; 990 if (cost >= U32_MAX - PAGE_SIZE) { 991 err = -EINVAL; 992 goto free_htab; 993 } 994 err = bpf_map_charge_init(&htab->map.memory, cost); 995 if (err) 996 goto free_htab; 997 998 htab->buckets = bpf_map_area_alloc(htab->buckets_num * 999 sizeof(struct bpf_htab_bucket), 1000 htab->map.numa_node); 1001 if (!htab->buckets) { 1002 bpf_map_charge_finish(&htab->map.memory); 1003 err = -ENOMEM; 1004 goto free_htab; 1005 } 1006 1007 for (i = 0; i < htab->buckets_num; i++) { 1008 INIT_HLIST_HEAD(&htab->buckets[i].head); 1009 raw_spin_lock_init(&htab->buckets[i].lock); 1010 } 1011 1012 return &htab->map; 1013 free_htab: 1014 kfree(htab); 1015 return ERR_PTR(err); 1016 } 1017 1018 static void sock_hash_free(struct bpf_map *map) 1019 { 1020 struct bpf_htab *htab = container_of(map, struct bpf_htab, map); 1021 struct bpf_htab_bucket *bucket; 1022 struct hlist_head unlink_list; 1023 struct bpf_htab_elem *elem; 1024 struct hlist_node *node; 1025 int i; 1026 1027 /* After the sync no updates or deletes will be in-flight so it 1028 * is safe to walk map and remove entries without risking a race 1029 * in EEXIST update case. 1030 */ 1031 synchronize_rcu(); 1032 for (i = 0; i < htab->buckets_num; i++) { 1033 bucket = sock_hash_select_bucket(htab, i); 1034 1035 /* We are racing with sock_hash_delete_from_link to 1036 * enter the spin-lock critical section. Every socket on 1037 * the list is still linked to sockhash. Since link 1038 * exists, psock exists and holds a ref to socket. That 1039 * lets us to grab a socket ref too. 1040 */ 1041 raw_spin_lock_bh(&bucket->lock); 1042 hlist_for_each_entry(elem, &bucket->head, node) 1043 sock_hold(elem->sk); 1044 hlist_move_list(&bucket->head, &unlink_list); 1045 raw_spin_unlock_bh(&bucket->lock); 1046 1047 /* Process removed entries out of atomic context to 1048 * block for socket lock before deleting the psock's 1049 * link to sockhash. 1050 */ 1051 hlist_for_each_entry_safe(elem, node, &unlink_list, node) { 1052 hlist_del(&elem->node); 1053 lock_sock(elem->sk); 1054 rcu_read_lock(); 1055 sock_map_unref(elem->sk, elem); 1056 rcu_read_unlock(); 1057 release_sock(elem->sk); 1058 sock_put(elem->sk); 1059 sock_hash_free_elem(htab, elem); 1060 } 1061 } 1062 1063 /* wait for psock readers accessing its map link */ 1064 synchronize_rcu(); 1065 1066 bpf_map_area_free(htab->buckets); 1067 kfree(htab); 1068 } 1069 1070 static void *sock_hash_lookup_sys(struct bpf_map *map, void *key) 1071 { 1072 struct sock *sk; 1073 1074 if (map->value_size != sizeof(u64)) 1075 return ERR_PTR(-ENOSPC); 1076 1077 sk = __sock_hash_lookup_elem(map, key); 1078 if (!sk) 1079 return ERR_PTR(-ENOENT); 1080 1081 sock_gen_cookie(sk); 1082 return &sk->sk_cookie; 1083 } 1084 1085 static void *sock_hash_lookup(struct bpf_map *map, void *key) 1086 { 1087 struct sock *sk; 1088 1089 sk = __sock_hash_lookup_elem(map, key); 1090 if (!sk || !sk_fullsock(sk)) 1091 return NULL; 1092 if (sk_is_refcounted(sk) && !refcount_inc_not_zero(&sk->sk_refcnt)) 1093 return NULL; 1094 return sk; 1095 } 1096 1097 static void sock_hash_release_progs(struct bpf_map *map) 1098 { 1099 psock_progs_drop(&container_of(map, struct bpf_htab, map)->progs); 1100 } 1101 1102 BPF_CALL_4(bpf_sock_hash_update, struct bpf_sock_ops_kern *, sops, 1103 struct bpf_map *, map, void *, key, u64, flags) 1104 { 1105 WARN_ON_ONCE(!rcu_read_lock_held()); 1106 1107 if (likely(sock_map_sk_is_suitable(sops->sk) && 1108 sock_map_op_okay(sops))) 1109 return sock_hash_update_common(map, key, sops->sk, flags); 1110 return -EOPNOTSUPP; 1111 } 1112 1113 const struct bpf_func_proto bpf_sock_hash_update_proto = { 1114 .func = bpf_sock_hash_update, 1115 .gpl_only = false, 1116 .pkt_access = true, 1117 .ret_type = RET_INTEGER, 1118 .arg1_type = ARG_PTR_TO_CTX, 1119 .arg2_type = ARG_CONST_MAP_PTR, 1120 .arg3_type = ARG_PTR_TO_MAP_KEY, 1121 .arg4_type = ARG_ANYTHING, 1122 }; 1123 1124 BPF_CALL_4(bpf_sk_redirect_hash, struct sk_buff *, skb, 1125 struct bpf_map *, map, void *, key, u64, flags) 1126 { 1127 struct tcp_skb_cb *tcb = TCP_SKB_CB(skb); 1128 struct sock *sk; 1129 1130 if (unlikely(flags & ~(BPF_F_INGRESS))) 1131 return SK_DROP; 1132 1133 sk = __sock_hash_lookup_elem(map, key); 1134 if (unlikely(!sk || !sock_map_redirect_allowed(sk))) 1135 return SK_DROP; 1136 1137 tcb->bpf.flags = flags; 1138 tcb->bpf.sk_redir = sk; 1139 return SK_PASS; 1140 } 1141 1142 const struct bpf_func_proto bpf_sk_redirect_hash_proto = { 1143 .func = bpf_sk_redirect_hash, 1144 .gpl_only = false, 1145 .ret_type = RET_INTEGER, 1146 .arg1_type = ARG_PTR_TO_CTX, 1147 .arg2_type = ARG_CONST_MAP_PTR, 1148 .arg3_type = ARG_PTR_TO_MAP_KEY, 1149 .arg4_type = ARG_ANYTHING, 1150 }; 1151 1152 BPF_CALL_4(bpf_msg_redirect_hash, struct sk_msg *, msg, 1153 struct bpf_map *, map, void *, key, u64, flags) 1154 { 1155 struct sock *sk; 1156 1157 if (unlikely(flags & ~(BPF_F_INGRESS))) 1158 return SK_DROP; 1159 1160 sk = __sock_hash_lookup_elem(map, key); 1161 if (unlikely(!sk || !sock_map_redirect_allowed(sk))) 1162 return SK_DROP; 1163 1164 msg->flags = flags; 1165 msg->sk_redir = sk; 1166 return SK_PASS; 1167 } 1168 1169 const struct bpf_func_proto bpf_msg_redirect_hash_proto = { 1170 .func = bpf_msg_redirect_hash, 1171 .gpl_only = false, 1172 .ret_type = RET_INTEGER, 1173 .arg1_type = ARG_PTR_TO_CTX, 1174 .arg2_type = ARG_CONST_MAP_PTR, 1175 .arg3_type = ARG_PTR_TO_MAP_KEY, 1176 .arg4_type = ARG_ANYTHING, 1177 }; 1178 1179 const struct bpf_map_ops sock_hash_ops = { 1180 .map_alloc = sock_hash_alloc, 1181 .map_free = sock_hash_free, 1182 .map_get_next_key = sock_hash_get_next_key, 1183 .map_update_elem = sock_hash_update_elem, 1184 .map_delete_elem = sock_hash_delete_elem, 1185 .map_lookup_elem = sock_hash_lookup, 1186 .map_lookup_elem_sys_only = sock_hash_lookup_sys, 1187 .map_release_uref = sock_hash_release_progs, 1188 .map_check_btf = map_check_no_btf, 1189 }; 1190 1191 static struct sk_psock_progs *sock_map_progs(struct bpf_map *map) 1192 { 1193 switch (map->map_type) { 1194 case BPF_MAP_TYPE_SOCKMAP: 1195 return &container_of(map, struct bpf_stab, map)->progs; 1196 case BPF_MAP_TYPE_SOCKHASH: 1197 return &container_of(map, struct bpf_htab, map)->progs; 1198 default: 1199 break; 1200 } 1201 1202 return NULL; 1203 } 1204 1205 int sock_map_prog_update(struct bpf_map *map, struct bpf_prog *prog, 1206 u32 which) 1207 { 1208 struct sk_psock_progs *progs = sock_map_progs(map); 1209 1210 if (!progs) 1211 return -EOPNOTSUPP; 1212 1213 switch (which) { 1214 case BPF_SK_MSG_VERDICT: 1215 psock_set_prog(&progs->msg_parser, prog); 1216 break; 1217 case BPF_SK_SKB_STREAM_PARSER: 1218 psock_set_prog(&progs->skb_parser, prog); 1219 break; 1220 case BPF_SK_SKB_STREAM_VERDICT: 1221 psock_set_prog(&progs->skb_verdict, prog); 1222 break; 1223 default: 1224 return -EOPNOTSUPP; 1225 } 1226 1227 return 0; 1228 } 1229 1230 static void sock_map_unlink(struct sock *sk, struct sk_psock_link *link) 1231 { 1232 switch (link->map->map_type) { 1233 case BPF_MAP_TYPE_SOCKMAP: 1234 return sock_map_delete_from_link(link->map, sk, 1235 link->link_raw); 1236 case BPF_MAP_TYPE_SOCKHASH: 1237 return sock_hash_delete_from_link(link->map, sk, 1238 link->link_raw); 1239 default: 1240 break; 1241 } 1242 } 1243 1244 static void sock_map_remove_links(struct sock *sk, struct sk_psock *psock) 1245 { 1246 struct sk_psock_link *link; 1247 1248 while ((link = sk_psock_link_pop(psock))) { 1249 sock_map_unlink(sk, link); 1250 sk_psock_free_link(link); 1251 } 1252 } 1253 1254 void sock_map_unhash(struct sock *sk) 1255 { 1256 void (*saved_unhash)(struct sock *sk); 1257 struct sk_psock *psock; 1258 1259 rcu_read_lock(); 1260 psock = sk_psock(sk); 1261 if (unlikely(!psock)) { 1262 rcu_read_unlock(); 1263 if (sk->sk_prot->unhash) 1264 sk->sk_prot->unhash(sk); 1265 return; 1266 } 1267 1268 saved_unhash = psock->saved_unhash; 1269 sock_map_remove_links(sk, psock); 1270 rcu_read_unlock(); 1271 saved_unhash(sk); 1272 } 1273 1274 void sock_map_close(struct sock *sk, long timeout) 1275 { 1276 void (*saved_close)(struct sock *sk, long timeout); 1277 struct sk_psock *psock; 1278 1279 lock_sock(sk); 1280 rcu_read_lock(); 1281 psock = sk_psock(sk); 1282 if (unlikely(!psock)) { 1283 rcu_read_unlock(); 1284 release_sock(sk); 1285 return sk->sk_prot->close(sk, timeout); 1286 } 1287 1288 saved_close = psock->saved_close; 1289 sock_map_remove_links(sk, psock); 1290 rcu_read_unlock(); 1291 release_sock(sk); 1292 saved_close(sk, timeout); 1293 } 1294