1 // SPDX-License-Identifier: GPL-2.0 2 /* Copyright (c) 2017 - 2018 Covalent IO, Inc. http://covalent.io */ 3 4 #include <linux/bpf.h> 5 #include <linux/filter.h> 6 #include <linux/errno.h> 7 #include <linux/file.h> 8 #include <linux/net.h> 9 #include <linux/workqueue.h> 10 #include <linux/skmsg.h> 11 #include <linux/list.h> 12 #include <linux/jhash.h> 13 #include <linux/sock_diag.h> 14 #include <net/udp.h> 15 16 struct bpf_stab { 17 struct bpf_map map; 18 struct sock **sks; 19 struct sk_psock_progs progs; 20 raw_spinlock_t lock; 21 }; 22 23 #define SOCK_CREATE_FLAG_MASK \ 24 (BPF_F_NUMA_NODE | BPF_F_RDONLY | BPF_F_WRONLY) 25 26 static struct bpf_map *sock_map_alloc(union bpf_attr *attr) 27 { 28 struct bpf_stab *stab; 29 u64 cost; 30 int err; 31 32 if (!capable(CAP_NET_ADMIN)) 33 return ERR_PTR(-EPERM); 34 if (attr->max_entries == 0 || 35 attr->key_size != 4 || 36 (attr->value_size != sizeof(u32) && 37 attr->value_size != sizeof(u64)) || 38 attr->map_flags & ~SOCK_CREATE_FLAG_MASK) 39 return ERR_PTR(-EINVAL); 40 41 stab = kzalloc(sizeof(*stab), GFP_USER); 42 if (!stab) 43 return ERR_PTR(-ENOMEM); 44 45 bpf_map_init_from_attr(&stab->map, attr); 46 raw_spin_lock_init(&stab->lock); 47 48 /* Make sure page count doesn't overflow. */ 49 cost = (u64) stab->map.max_entries * sizeof(struct sock *); 50 err = bpf_map_charge_init(&stab->map.memory, cost); 51 if (err) 52 goto free_stab; 53 54 stab->sks = bpf_map_area_alloc(stab->map.max_entries * 55 sizeof(struct sock *), 56 stab->map.numa_node); 57 if (stab->sks) 58 return &stab->map; 59 err = -ENOMEM; 60 bpf_map_charge_finish(&stab->map.memory); 61 free_stab: 62 kfree(stab); 63 return ERR_PTR(err); 64 } 65 66 int sock_map_get_from_fd(const union bpf_attr *attr, struct bpf_prog *prog) 67 { 68 u32 ufd = attr->target_fd; 69 struct bpf_map *map; 70 struct fd f; 71 int ret; 72 73 f = fdget(ufd); 74 map = __bpf_map_get(f); 75 if (IS_ERR(map)) 76 return PTR_ERR(map); 77 ret = sock_map_prog_update(map, prog, attr->attach_type); 78 fdput(f); 79 return ret; 80 } 81 82 static void sock_map_sk_acquire(struct sock *sk) 83 __acquires(&sk->sk_lock.slock) 84 { 85 lock_sock(sk); 86 preempt_disable(); 87 rcu_read_lock(); 88 } 89 90 static void sock_map_sk_release(struct sock *sk) 91 __releases(&sk->sk_lock.slock) 92 { 93 rcu_read_unlock(); 94 preempt_enable(); 95 release_sock(sk); 96 } 97 98 static void sock_map_add_link(struct sk_psock *psock, 99 struct sk_psock_link *link, 100 struct bpf_map *map, void *link_raw) 101 { 102 link->link_raw = link_raw; 103 link->map = map; 104 spin_lock_bh(&psock->link_lock); 105 list_add_tail(&link->list, &psock->link); 106 spin_unlock_bh(&psock->link_lock); 107 } 108 109 static void sock_map_del_link(struct sock *sk, 110 struct sk_psock *psock, void *link_raw) 111 { 112 struct sk_psock_link *link, *tmp; 113 bool strp_stop = false; 114 115 spin_lock_bh(&psock->link_lock); 116 list_for_each_entry_safe(link, tmp, &psock->link, list) { 117 if (link->link_raw == link_raw) { 118 struct bpf_map *map = link->map; 119 struct bpf_stab *stab = container_of(map, struct bpf_stab, 120 map); 121 if (psock->parser.enabled && stab->progs.skb_parser) 122 strp_stop = true; 123 list_del(&link->list); 124 sk_psock_free_link(link); 125 } 126 } 127 spin_unlock_bh(&psock->link_lock); 128 if (strp_stop) { 129 write_lock_bh(&sk->sk_callback_lock); 130 sk_psock_stop_strp(sk, psock); 131 write_unlock_bh(&sk->sk_callback_lock); 132 } 133 } 134 135 static void sock_map_unref(struct sock *sk, void *link_raw) 136 { 137 struct sk_psock *psock = sk_psock(sk); 138 139 if (likely(psock)) { 140 sock_map_del_link(sk, psock, link_raw); 141 sk_psock_put(sk, psock); 142 } 143 } 144 145 static int sock_map_init_proto(struct sock *sk, struct sk_psock *psock) 146 { 147 struct proto *prot; 148 149 sock_owned_by_me(sk); 150 151 switch (sk->sk_type) { 152 case SOCK_STREAM: 153 prot = tcp_bpf_get_proto(sk, psock); 154 break; 155 156 case SOCK_DGRAM: 157 prot = udp_bpf_get_proto(sk, psock); 158 break; 159 160 default: 161 return -EINVAL; 162 } 163 164 if (IS_ERR(prot)) 165 return PTR_ERR(prot); 166 167 sk_psock_update_proto(sk, psock, prot); 168 return 0; 169 } 170 171 static struct sk_psock *sock_map_psock_get_checked(struct sock *sk) 172 { 173 struct sk_psock *psock; 174 175 rcu_read_lock(); 176 psock = sk_psock(sk); 177 if (psock) { 178 if (sk->sk_prot->close != sock_map_close) { 179 psock = ERR_PTR(-EBUSY); 180 goto out; 181 } 182 183 if (!refcount_inc_not_zero(&psock->refcnt)) 184 psock = ERR_PTR(-EBUSY); 185 } 186 out: 187 rcu_read_unlock(); 188 return psock; 189 } 190 191 static int sock_map_link(struct bpf_map *map, struct sk_psock_progs *progs, 192 struct sock *sk) 193 { 194 struct bpf_prog *msg_parser, *skb_parser, *skb_verdict; 195 struct sk_psock *psock; 196 bool skb_progs; 197 int ret; 198 199 skb_verdict = READ_ONCE(progs->skb_verdict); 200 skb_parser = READ_ONCE(progs->skb_parser); 201 skb_progs = skb_parser && skb_verdict; 202 if (skb_progs) { 203 skb_verdict = bpf_prog_inc_not_zero(skb_verdict); 204 if (IS_ERR(skb_verdict)) 205 return PTR_ERR(skb_verdict); 206 skb_parser = bpf_prog_inc_not_zero(skb_parser); 207 if (IS_ERR(skb_parser)) { 208 bpf_prog_put(skb_verdict); 209 return PTR_ERR(skb_parser); 210 } 211 } 212 213 msg_parser = READ_ONCE(progs->msg_parser); 214 if (msg_parser) { 215 msg_parser = bpf_prog_inc_not_zero(msg_parser); 216 if (IS_ERR(msg_parser)) { 217 ret = PTR_ERR(msg_parser); 218 goto out; 219 } 220 } 221 222 psock = sock_map_psock_get_checked(sk); 223 if (IS_ERR(psock)) { 224 ret = PTR_ERR(psock); 225 goto out_progs; 226 } 227 228 if (psock) { 229 if ((msg_parser && READ_ONCE(psock->progs.msg_parser)) || 230 (skb_progs && READ_ONCE(psock->progs.skb_parser))) { 231 sk_psock_put(sk, psock); 232 ret = -EBUSY; 233 goto out_progs; 234 } 235 } else { 236 psock = sk_psock_init(sk, map->numa_node); 237 if (!psock) { 238 ret = -ENOMEM; 239 goto out_progs; 240 } 241 } 242 243 if (msg_parser) 244 psock_set_prog(&psock->progs.msg_parser, msg_parser); 245 246 ret = sock_map_init_proto(sk, psock); 247 if (ret < 0) 248 goto out_drop; 249 250 write_lock_bh(&sk->sk_callback_lock); 251 if (skb_progs && !psock->parser.enabled) { 252 ret = sk_psock_init_strp(sk, psock); 253 if (ret) { 254 write_unlock_bh(&sk->sk_callback_lock); 255 goto out_drop; 256 } 257 psock_set_prog(&psock->progs.skb_verdict, skb_verdict); 258 psock_set_prog(&psock->progs.skb_parser, skb_parser); 259 sk_psock_start_strp(sk, psock); 260 } 261 write_unlock_bh(&sk->sk_callback_lock); 262 return 0; 263 out_drop: 264 sk_psock_put(sk, psock); 265 out_progs: 266 if (msg_parser) 267 bpf_prog_put(msg_parser); 268 out: 269 if (skb_progs) { 270 bpf_prog_put(skb_verdict); 271 bpf_prog_put(skb_parser); 272 } 273 return ret; 274 } 275 276 static int sock_map_link_no_progs(struct bpf_map *map, struct sock *sk) 277 { 278 struct sk_psock *psock; 279 int ret; 280 281 psock = sock_map_psock_get_checked(sk); 282 if (IS_ERR(psock)) 283 return PTR_ERR(psock); 284 285 if (!psock) { 286 psock = sk_psock_init(sk, map->numa_node); 287 if (!psock) 288 return -ENOMEM; 289 } 290 291 ret = sock_map_init_proto(sk, psock); 292 if (ret < 0) 293 sk_psock_put(sk, psock); 294 return ret; 295 } 296 297 static void sock_map_free(struct bpf_map *map) 298 { 299 struct bpf_stab *stab = container_of(map, struct bpf_stab, map); 300 int i; 301 302 /* After the sync no updates or deletes will be in-flight so it 303 * is safe to walk map and remove entries without risking a race 304 * in EEXIST update case. 305 */ 306 synchronize_rcu(); 307 for (i = 0; i < stab->map.max_entries; i++) { 308 struct sock **psk = &stab->sks[i]; 309 struct sock *sk; 310 311 sk = xchg(psk, NULL); 312 if (sk) { 313 lock_sock(sk); 314 rcu_read_lock(); 315 sock_map_unref(sk, psk); 316 rcu_read_unlock(); 317 release_sock(sk); 318 } 319 } 320 321 /* wait for psock readers accessing its map link */ 322 synchronize_rcu(); 323 324 bpf_map_area_free(stab->sks); 325 kfree(stab); 326 } 327 328 static void sock_map_release_progs(struct bpf_map *map) 329 { 330 psock_progs_drop(&container_of(map, struct bpf_stab, map)->progs); 331 } 332 333 static struct sock *__sock_map_lookup_elem(struct bpf_map *map, u32 key) 334 { 335 struct bpf_stab *stab = container_of(map, struct bpf_stab, map); 336 337 WARN_ON_ONCE(!rcu_read_lock_held()); 338 339 if (unlikely(key >= map->max_entries)) 340 return NULL; 341 return READ_ONCE(stab->sks[key]); 342 } 343 344 static void *sock_map_lookup(struct bpf_map *map, void *key) 345 { 346 struct sock *sk; 347 348 sk = __sock_map_lookup_elem(map, *(u32 *)key); 349 if (!sk || !sk_fullsock(sk)) 350 return NULL; 351 if (sk_is_refcounted(sk) && !refcount_inc_not_zero(&sk->sk_refcnt)) 352 return NULL; 353 return sk; 354 } 355 356 static void *sock_map_lookup_sys(struct bpf_map *map, void *key) 357 { 358 struct sock *sk; 359 360 if (map->value_size != sizeof(u64)) 361 return ERR_PTR(-ENOSPC); 362 363 sk = __sock_map_lookup_elem(map, *(u32 *)key); 364 if (!sk) 365 return ERR_PTR(-ENOENT); 366 367 sock_gen_cookie(sk); 368 return &sk->sk_cookie; 369 } 370 371 static int __sock_map_delete(struct bpf_stab *stab, struct sock *sk_test, 372 struct sock **psk) 373 { 374 struct sock *sk; 375 int err = 0; 376 377 raw_spin_lock_bh(&stab->lock); 378 sk = *psk; 379 if (!sk_test || sk_test == sk) 380 sk = xchg(psk, NULL); 381 382 if (likely(sk)) 383 sock_map_unref(sk, psk); 384 else 385 err = -EINVAL; 386 387 raw_spin_unlock_bh(&stab->lock); 388 return err; 389 } 390 391 static void sock_map_delete_from_link(struct bpf_map *map, struct sock *sk, 392 void *link_raw) 393 { 394 struct bpf_stab *stab = container_of(map, struct bpf_stab, map); 395 396 __sock_map_delete(stab, sk, link_raw); 397 } 398 399 static int sock_map_delete_elem(struct bpf_map *map, void *key) 400 { 401 struct bpf_stab *stab = container_of(map, struct bpf_stab, map); 402 u32 i = *(u32 *)key; 403 struct sock **psk; 404 405 if (unlikely(i >= map->max_entries)) 406 return -EINVAL; 407 408 psk = &stab->sks[i]; 409 return __sock_map_delete(stab, NULL, psk); 410 } 411 412 static int sock_map_get_next_key(struct bpf_map *map, void *key, void *next) 413 { 414 struct bpf_stab *stab = container_of(map, struct bpf_stab, map); 415 u32 i = key ? *(u32 *)key : U32_MAX; 416 u32 *key_next = next; 417 418 if (i == stab->map.max_entries - 1) 419 return -ENOENT; 420 if (i >= stab->map.max_entries) 421 *key_next = 0; 422 else 423 *key_next = i + 1; 424 return 0; 425 } 426 427 static bool sock_map_redirect_allowed(const struct sock *sk) 428 { 429 return sk->sk_state != TCP_LISTEN; 430 } 431 432 static int sock_map_update_common(struct bpf_map *map, u32 idx, 433 struct sock *sk, u64 flags) 434 { 435 struct bpf_stab *stab = container_of(map, struct bpf_stab, map); 436 struct sk_psock_link *link; 437 struct sk_psock *psock; 438 struct sock *osk; 439 int ret; 440 441 WARN_ON_ONCE(!rcu_read_lock_held()); 442 if (unlikely(flags > BPF_EXIST)) 443 return -EINVAL; 444 if (unlikely(idx >= map->max_entries)) 445 return -E2BIG; 446 if (inet_csk_has_ulp(sk)) 447 return -EINVAL; 448 449 link = sk_psock_init_link(); 450 if (!link) 451 return -ENOMEM; 452 453 /* Only sockets we can redirect into/from in BPF need to hold 454 * refs to parser/verdict progs and have their sk_data_ready 455 * and sk_write_space callbacks overridden. 456 */ 457 if (sock_map_redirect_allowed(sk)) 458 ret = sock_map_link(map, &stab->progs, sk); 459 else 460 ret = sock_map_link_no_progs(map, sk); 461 if (ret < 0) 462 goto out_free; 463 464 psock = sk_psock(sk); 465 WARN_ON_ONCE(!psock); 466 467 raw_spin_lock_bh(&stab->lock); 468 osk = stab->sks[idx]; 469 if (osk && flags == BPF_NOEXIST) { 470 ret = -EEXIST; 471 goto out_unlock; 472 } else if (!osk && flags == BPF_EXIST) { 473 ret = -ENOENT; 474 goto out_unlock; 475 } 476 477 sock_map_add_link(psock, link, map, &stab->sks[idx]); 478 stab->sks[idx] = sk; 479 if (osk) 480 sock_map_unref(osk, &stab->sks[idx]); 481 raw_spin_unlock_bh(&stab->lock); 482 return 0; 483 out_unlock: 484 raw_spin_unlock_bh(&stab->lock); 485 if (psock) 486 sk_psock_put(sk, psock); 487 out_free: 488 sk_psock_free_link(link); 489 return ret; 490 } 491 492 static bool sock_map_op_okay(const struct bpf_sock_ops_kern *ops) 493 { 494 return ops->op == BPF_SOCK_OPS_PASSIVE_ESTABLISHED_CB || 495 ops->op == BPF_SOCK_OPS_ACTIVE_ESTABLISHED_CB || 496 ops->op == BPF_SOCK_OPS_TCP_LISTEN_CB; 497 } 498 499 static bool sk_is_tcp(const struct sock *sk) 500 { 501 return sk->sk_type == SOCK_STREAM && 502 sk->sk_protocol == IPPROTO_TCP; 503 } 504 505 static bool sk_is_udp(const struct sock *sk) 506 { 507 return sk->sk_type == SOCK_DGRAM && 508 sk->sk_protocol == IPPROTO_UDP; 509 } 510 511 static bool sock_map_sk_is_suitable(const struct sock *sk) 512 { 513 return sk_is_tcp(sk) || sk_is_udp(sk); 514 } 515 516 static bool sock_map_sk_state_allowed(const struct sock *sk) 517 { 518 if (sk_is_tcp(sk)) 519 return (1 << sk->sk_state) & (TCPF_ESTABLISHED | TCPF_LISTEN); 520 else if (sk_is_udp(sk)) 521 return sk_hashed(sk); 522 523 return false; 524 } 525 526 static int sock_map_update_elem(struct bpf_map *map, void *key, 527 void *value, u64 flags) 528 { 529 u32 idx = *(u32 *)key; 530 struct socket *sock; 531 struct sock *sk; 532 int ret; 533 u64 ufd; 534 535 if (map->value_size == sizeof(u64)) 536 ufd = *(u64 *)value; 537 else 538 ufd = *(u32 *)value; 539 if (ufd > S32_MAX) 540 return -EINVAL; 541 542 sock = sockfd_lookup(ufd, &ret); 543 if (!sock) 544 return ret; 545 sk = sock->sk; 546 if (!sk) { 547 ret = -EINVAL; 548 goto out; 549 } 550 if (!sock_map_sk_is_suitable(sk)) { 551 ret = -EOPNOTSUPP; 552 goto out; 553 } 554 555 sock_map_sk_acquire(sk); 556 if (!sock_map_sk_state_allowed(sk)) 557 ret = -EOPNOTSUPP; 558 else 559 ret = sock_map_update_common(map, idx, sk, flags); 560 sock_map_sk_release(sk); 561 out: 562 fput(sock->file); 563 return ret; 564 } 565 566 BPF_CALL_4(bpf_sock_map_update, struct bpf_sock_ops_kern *, sops, 567 struct bpf_map *, map, void *, key, u64, flags) 568 { 569 WARN_ON_ONCE(!rcu_read_lock_held()); 570 571 if (likely(sock_map_sk_is_suitable(sops->sk) && 572 sock_map_op_okay(sops))) 573 return sock_map_update_common(map, *(u32 *)key, sops->sk, 574 flags); 575 return -EOPNOTSUPP; 576 } 577 578 const struct bpf_func_proto bpf_sock_map_update_proto = { 579 .func = bpf_sock_map_update, 580 .gpl_only = false, 581 .pkt_access = true, 582 .ret_type = RET_INTEGER, 583 .arg1_type = ARG_PTR_TO_CTX, 584 .arg2_type = ARG_CONST_MAP_PTR, 585 .arg3_type = ARG_PTR_TO_MAP_KEY, 586 .arg4_type = ARG_ANYTHING, 587 }; 588 589 BPF_CALL_4(bpf_sk_redirect_map, struct sk_buff *, skb, 590 struct bpf_map *, map, u32, key, u64, flags) 591 { 592 struct tcp_skb_cb *tcb = TCP_SKB_CB(skb); 593 struct sock *sk; 594 595 if (unlikely(flags & ~(BPF_F_INGRESS))) 596 return SK_DROP; 597 598 sk = __sock_map_lookup_elem(map, key); 599 if (unlikely(!sk || !sock_map_redirect_allowed(sk))) 600 return SK_DROP; 601 602 tcb->bpf.flags = flags; 603 tcb->bpf.sk_redir = sk; 604 return SK_PASS; 605 } 606 607 const struct bpf_func_proto bpf_sk_redirect_map_proto = { 608 .func = bpf_sk_redirect_map, 609 .gpl_only = false, 610 .ret_type = RET_INTEGER, 611 .arg1_type = ARG_PTR_TO_CTX, 612 .arg2_type = ARG_CONST_MAP_PTR, 613 .arg3_type = ARG_ANYTHING, 614 .arg4_type = ARG_ANYTHING, 615 }; 616 617 BPF_CALL_4(bpf_msg_redirect_map, struct sk_msg *, msg, 618 struct bpf_map *, map, u32, key, u64, flags) 619 { 620 struct sock *sk; 621 622 if (unlikely(flags & ~(BPF_F_INGRESS))) 623 return SK_DROP; 624 625 sk = __sock_map_lookup_elem(map, key); 626 if (unlikely(!sk || !sock_map_redirect_allowed(sk))) 627 return SK_DROP; 628 629 msg->flags = flags; 630 msg->sk_redir = sk; 631 return SK_PASS; 632 } 633 634 const struct bpf_func_proto bpf_msg_redirect_map_proto = { 635 .func = bpf_msg_redirect_map, 636 .gpl_only = false, 637 .ret_type = RET_INTEGER, 638 .arg1_type = ARG_PTR_TO_CTX, 639 .arg2_type = ARG_CONST_MAP_PTR, 640 .arg3_type = ARG_ANYTHING, 641 .arg4_type = ARG_ANYTHING, 642 }; 643 644 const struct bpf_map_ops sock_map_ops = { 645 .map_alloc = sock_map_alloc, 646 .map_free = sock_map_free, 647 .map_get_next_key = sock_map_get_next_key, 648 .map_lookup_elem_sys_only = sock_map_lookup_sys, 649 .map_update_elem = sock_map_update_elem, 650 .map_delete_elem = sock_map_delete_elem, 651 .map_lookup_elem = sock_map_lookup, 652 .map_release_uref = sock_map_release_progs, 653 .map_check_btf = map_check_no_btf, 654 }; 655 656 struct bpf_htab_elem { 657 struct rcu_head rcu; 658 u32 hash; 659 struct sock *sk; 660 struct hlist_node node; 661 u8 key[]; 662 }; 663 664 struct bpf_htab_bucket { 665 struct hlist_head head; 666 raw_spinlock_t lock; 667 }; 668 669 struct bpf_htab { 670 struct bpf_map map; 671 struct bpf_htab_bucket *buckets; 672 u32 buckets_num; 673 u32 elem_size; 674 struct sk_psock_progs progs; 675 atomic_t count; 676 }; 677 678 static inline u32 sock_hash_bucket_hash(const void *key, u32 len) 679 { 680 return jhash(key, len, 0); 681 } 682 683 static struct bpf_htab_bucket *sock_hash_select_bucket(struct bpf_htab *htab, 684 u32 hash) 685 { 686 return &htab->buckets[hash & (htab->buckets_num - 1)]; 687 } 688 689 static struct bpf_htab_elem * 690 sock_hash_lookup_elem_raw(struct hlist_head *head, u32 hash, void *key, 691 u32 key_size) 692 { 693 struct bpf_htab_elem *elem; 694 695 hlist_for_each_entry_rcu(elem, head, node) { 696 if (elem->hash == hash && 697 !memcmp(&elem->key, key, key_size)) 698 return elem; 699 } 700 701 return NULL; 702 } 703 704 static struct sock *__sock_hash_lookup_elem(struct bpf_map *map, void *key) 705 { 706 struct bpf_htab *htab = container_of(map, struct bpf_htab, map); 707 u32 key_size = map->key_size, hash; 708 struct bpf_htab_bucket *bucket; 709 struct bpf_htab_elem *elem; 710 711 WARN_ON_ONCE(!rcu_read_lock_held()); 712 713 hash = sock_hash_bucket_hash(key, key_size); 714 bucket = sock_hash_select_bucket(htab, hash); 715 elem = sock_hash_lookup_elem_raw(&bucket->head, hash, key, key_size); 716 717 return elem ? elem->sk : NULL; 718 } 719 720 static void sock_hash_free_elem(struct bpf_htab *htab, 721 struct bpf_htab_elem *elem) 722 { 723 atomic_dec(&htab->count); 724 kfree_rcu(elem, rcu); 725 } 726 727 static void sock_hash_delete_from_link(struct bpf_map *map, struct sock *sk, 728 void *link_raw) 729 { 730 struct bpf_htab *htab = container_of(map, struct bpf_htab, map); 731 struct bpf_htab_elem *elem_probe, *elem = link_raw; 732 struct bpf_htab_bucket *bucket; 733 734 WARN_ON_ONCE(!rcu_read_lock_held()); 735 bucket = sock_hash_select_bucket(htab, elem->hash); 736 737 /* elem may be deleted in parallel from the map, but access here 738 * is okay since it's going away only after RCU grace period. 739 * However, we need to check whether it's still present. 740 */ 741 raw_spin_lock_bh(&bucket->lock); 742 elem_probe = sock_hash_lookup_elem_raw(&bucket->head, elem->hash, 743 elem->key, map->key_size); 744 if (elem_probe && elem_probe == elem) { 745 hlist_del_rcu(&elem->node); 746 sock_map_unref(elem->sk, elem); 747 sock_hash_free_elem(htab, elem); 748 } 749 raw_spin_unlock_bh(&bucket->lock); 750 } 751 752 static int sock_hash_delete_elem(struct bpf_map *map, void *key) 753 { 754 struct bpf_htab *htab = container_of(map, struct bpf_htab, map); 755 u32 hash, key_size = map->key_size; 756 struct bpf_htab_bucket *bucket; 757 struct bpf_htab_elem *elem; 758 int ret = -ENOENT; 759 760 hash = sock_hash_bucket_hash(key, key_size); 761 bucket = sock_hash_select_bucket(htab, hash); 762 763 raw_spin_lock_bh(&bucket->lock); 764 elem = sock_hash_lookup_elem_raw(&bucket->head, hash, key, key_size); 765 if (elem) { 766 hlist_del_rcu(&elem->node); 767 sock_map_unref(elem->sk, elem); 768 sock_hash_free_elem(htab, elem); 769 ret = 0; 770 } 771 raw_spin_unlock_bh(&bucket->lock); 772 return ret; 773 } 774 775 static struct bpf_htab_elem *sock_hash_alloc_elem(struct bpf_htab *htab, 776 void *key, u32 key_size, 777 u32 hash, struct sock *sk, 778 struct bpf_htab_elem *old) 779 { 780 struct bpf_htab_elem *new; 781 782 if (atomic_inc_return(&htab->count) > htab->map.max_entries) { 783 if (!old) { 784 atomic_dec(&htab->count); 785 return ERR_PTR(-E2BIG); 786 } 787 } 788 789 new = kmalloc_node(htab->elem_size, GFP_ATOMIC | __GFP_NOWARN, 790 htab->map.numa_node); 791 if (!new) { 792 atomic_dec(&htab->count); 793 return ERR_PTR(-ENOMEM); 794 } 795 memcpy(new->key, key, key_size); 796 new->sk = sk; 797 new->hash = hash; 798 return new; 799 } 800 801 static int sock_hash_update_common(struct bpf_map *map, void *key, 802 struct sock *sk, u64 flags) 803 { 804 struct bpf_htab *htab = container_of(map, struct bpf_htab, map); 805 u32 key_size = map->key_size, hash; 806 struct bpf_htab_elem *elem, *elem_new; 807 struct bpf_htab_bucket *bucket; 808 struct sk_psock_link *link; 809 struct sk_psock *psock; 810 int ret; 811 812 WARN_ON_ONCE(!rcu_read_lock_held()); 813 if (unlikely(flags > BPF_EXIST)) 814 return -EINVAL; 815 if (inet_csk_has_ulp(sk)) 816 return -EINVAL; 817 818 link = sk_psock_init_link(); 819 if (!link) 820 return -ENOMEM; 821 822 /* Only sockets we can redirect into/from in BPF need to hold 823 * refs to parser/verdict progs and have their sk_data_ready 824 * and sk_write_space callbacks overridden. 825 */ 826 if (sock_map_redirect_allowed(sk)) 827 ret = sock_map_link(map, &htab->progs, sk); 828 else 829 ret = sock_map_link_no_progs(map, sk); 830 if (ret < 0) 831 goto out_free; 832 833 psock = sk_psock(sk); 834 WARN_ON_ONCE(!psock); 835 836 hash = sock_hash_bucket_hash(key, key_size); 837 bucket = sock_hash_select_bucket(htab, hash); 838 839 raw_spin_lock_bh(&bucket->lock); 840 elem = sock_hash_lookup_elem_raw(&bucket->head, hash, key, key_size); 841 if (elem && flags == BPF_NOEXIST) { 842 ret = -EEXIST; 843 goto out_unlock; 844 } else if (!elem && flags == BPF_EXIST) { 845 ret = -ENOENT; 846 goto out_unlock; 847 } 848 849 elem_new = sock_hash_alloc_elem(htab, key, key_size, hash, sk, elem); 850 if (IS_ERR(elem_new)) { 851 ret = PTR_ERR(elem_new); 852 goto out_unlock; 853 } 854 855 sock_map_add_link(psock, link, map, elem_new); 856 /* Add new element to the head of the list, so that 857 * concurrent search will find it before old elem. 858 */ 859 hlist_add_head_rcu(&elem_new->node, &bucket->head); 860 if (elem) { 861 hlist_del_rcu(&elem->node); 862 sock_map_unref(elem->sk, elem); 863 sock_hash_free_elem(htab, elem); 864 } 865 raw_spin_unlock_bh(&bucket->lock); 866 return 0; 867 out_unlock: 868 raw_spin_unlock_bh(&bucket->lock); 869 sk_psock_put(sk, psock); 870 out_free: 871 sk_psock_free_link(link); 872 return ret; 873 } 874 875 static int sock_hash_update_elem(struct bpf_map *map, void *key, 876 void *value, u64 flags) 877 { 878 struct socket *sock; 879 struct sock *sk; 880 int ret; 881 u64 ufd; 882 883 if (map->value_size == sizeof(u64)) 884 ufd = *(u64 *)value; 885 else 886 ufd = *(u32 *)value; 887 if (ufd > S32_MAX) 888 return -EINVAL; 889 890 sock = sockfd_lookup(ufd, &ret); 891 if (!sock) 892 return ret; 893 sk = sock->sk; 894 if (!sk) { 895 ret = -EINVAL; 896 goto out; 897 } 898 if (!sock_map_sk_is_suitable(sk)) { 899 ret = -EOPNOTSUPP; 900 goto out; 901 } 902 903 sock_map_sk_acquire(sk); 904 if (!sock_map_sk_state_allowed(sk)) 905 ret = -EOPNOTSUPP; 906 else 907 ret = sock_hash_update_common(map, key, sk, flags); 908 sock_map_sk_release(sk); 909 out: 910 fput(sock->file); 911 return ret; 912 } 913 914 static int sock_hash_get_next_key(struct bpf_map *map, void *key, 915 void *key_next) 916 { 917 struct bpf_htab *htab = container_of(map, struct bpf_htab, map); 918 struct bpf_htab_elem *elem, *elem_next; 919 u32 hash, key_size = map->key_size; 920 struct hlist_head *head; 921 int i = 0; 922 923 if (!key) 924 goto find_first_elem; 925 hash = sock_hash_bucket_hash(key, key_size); 926 head = &sock_hash_select_bucket(htab, hash)->head; 927 elem = sock_hash_lookup_elem_raw(head, hash, key, key_size); 928 if (!elem) 929 goto find_first_elem; 930 931 elem_next = hlist_entry_safe(rcu_dereference_raw(hlist_next_rcu(&elem->node)), 932 struct bpf_htab_elem, node); 933 if (elem_next) { 934 memcpy(key_next, elem_next->key, key_size); 935 return 0; 936 } 937 938 i = hash & (htab->buckets_num - 1); 939 i++; 940 find_first_elem: 941 for (; i < htab->buckets_num; i++) { 942 head = &sock_hash_select_bucket(htab, i)->head; 943 elem_next = hlist_entry_safe(rcu_dereference_raw(hlist_first_rcu(head)), 944 struct bpf_htab_elem, node); 945 if (elem_next) { 946 memcpy(key_next, elem_next->key, key_size); 947 return 0; 948 } 949 } 950 951 return -ENOENT; 952 } 953 954 static struct bpf_map *sock_hash_alloc(union bpf_attr *attr) 955 { 956 struct bpf_htab *htab; 957 int i, err; 958 u64 cost; 959 960 if (!capable(CAP_NET_ADMIN)) 961 return ERR_PTR(-EPERM); 962 if (attr->max_entries == 0 || 963 attr->key_size == 0 || 964 (attr->value_size != sizeof(u32) && 965 attr->value_size != sizeof(u64)) || 966 attr->map_flags & ~SOCK_CREATE_FLAG_MASK) 967 return ERR_PTR(-EINVAL); 968 if (attr->key_size > MAX_BPF_STACK) 969 return ERR_PTR(-E2BIG); 970 971 htab = kzalloc(sizeof(*htab), GFP_USER); 972 if (!htab) 973 return ERR_PTR(-ENOMEM); 974 975 bpf_map_init_from_attr(&htab->map, attr); 976 977 htab->buckets_num = roundup_pow_of_two(htab->map.max_entries); 978 htab->elem_size = sizeof(struct bpf_htab_elem) + 979 round_up(htab->map.key_size, 8); 980 if (htab->buckets_num == 0 || 981 htab->buckets_num > U32_MAX / sizeof(struct bpf_htab_bucket)) { 982 err = -EINVAL; 983 goto free_htab; 984 } 985 986 cost = (u64) htab->buckets_num * sizeof(struct bpf_htab_bucket) + 987 (u64) htab->elem_size * htab->map.max_entries; 988 if (cost >= U32_MAX - PAGE_SIZE) { 989 err = -EINVAL; 990 goto free_htab; 991 } 992 993 htab->buckets = bpf_map_area_alloc(htab->buckets_num * 994 sizeof(struct bpf_htab_bucket), 995 htab->map.numa_node); 996 if (!htab->buckets) { 997 err = -ENOMEM; 998 goto free_htab; 999 } 1000 1001 for (i = 0; i < htab->buckets_num; i++) { 1002 INIT_HLIST_HEAD(&htab->buckets[i].head); 1003 raw_spin_lock_init(&htab->buckets[i].lock); 1004 } 1005 1006 return &htab->map; 1007 free_htab: 1008 kfree(htab); 1009 return ERR_PTR(err); 1010 } 1011 1012 static void sock_hash_free(struct bpf_map *map) 1013 { 1014 struct bpf_htab *htab = container_of(map, struct bpf_htab, map); 1015 struct bpf_htab_bucket *bucket; 1016 struct bpf_htab_elem *elem; 1017 struct hlist_node *node; 1018 int i; 1019 1020 /* After the sync no updates or deletes will be in-flight so it 1021 * is safe to walk map and remove entries without risking a race 1022 * in EEXIST update case. 1023 */ 1024 synchronize_rcu(); 1025 for (i = 0; i < htab->buckets_num; i++) { 1026 bucket = sock_hash_select_bucket(htab, i); 1027 hlist_for_each_entry_safe(elem, node, &bucket->head, node) { 1028 hlist_del_rcu(&elem->node); 1029 lock_sock(elem->sk); 1030 rcu_read_lock(); 1031 sock_map_unref(elem->sk, elem); 1032 rcu_read_unlock(); 1033 release_sock(elem->sk); 1034 } 1035 } 1036 1037 /* wait for psock readers accessing its map link */ 1038 synchronize_rcu(); 1039 1040 bpf_map_area_free(htab->buckets); 1041 kfree(htab); 1042 } 1043 1044 static void *sock_hash_lookup_sys(struct bpf_map *map, void *key) 1045 { 1046 struct sock *sk; 1047 1048 if (map->value_size != sizeof(u64)) 1049 return ERR_PTR(-ENOSPC); 1050 1051 sk = __sock_hash_lookup_elem(map, key); 1052 if (!sk) 1053 return ERR_PTR(-ENOENT); 1054 1055 sock_gen_cookie(sk); 1056 return &sk->sk_cookie; 1057 } 1058 1059 static void *sock_hash_lookup(struct bpf_map *map, void *key) 1060 { 1061 struct sock *sk; 1062 1063 sk = __sock_hash_lookup_elem(map, key); 1064 if (!sk || !sk_fullsock(sk)) 1065 return NULL; 1066 if (sk_is_refcounted(sk) && !refcount_inc_not_zero(&sk->sk_refcnt)) 1067 return NULL; 1068 return sk; 1069 } 1070 1071 static void sock_hash_release_progs(struct bpf_map *map) 1072 { 1073 psock_progs_drop(&container_of(map, struct bpf_htab, map)->progs); 1074 } 1075 1076 BPF_CALL_4(bpf_sock_hash_update, struct bpf_sock_ops_kern *, sops, 1077 struct bpf_map *, map, void *, key, u64, flags) 1078 { 1079 WARN_ON_ONCE(!rcu_read_lock_held()); 1080 1081 if (likely(sock_map_sk_is_suitable(sops->sk) && 1082 sock_map_op_okay(sops))) 1083 return sock_hash_update_common(map, key, sops->sk, flags); 1084 return -EOPNOTSUPP; 1085 } 1086 1087 const struct bpf_func_proto bpf_sock_hash_update_proto = { 1088 .func = bpf_sock_hash_update, 1089 .gpl_only = false, 1090 .pkt_access = true, 1091 .ret_type = RET_INTEGER, 1092 .arg1_type = ARG_PTR_TO_CTX, 1093 .arg2_type = ARG_CONST_MAP_PTR, 1094 .arg3_type = ARG_PTR_TO_MAP_KEY, 1095 .arg4_type = ARG_ANYTHING, 1096 }; 1097 1098 BPF_CALL_4(bpf_sk_redirect_hash, struct sk_buff *, skb, 1099 struct bpf_map *, map, void *, key, u64, flags) 1100 { 1101 struct tcp_skb_cb *tcb = TCP_SKB_CB(skb); 1102 struct sock *sk; 1103 1104 if (unlikely(flags & ~(BPF_F_INGRESS))) 1105 return SK_DROP; 1106 1107 sk = __sock_hash_lookup_elem(map, key); 1108 if (unlikely(!sk || !sock_map_redirect_allowed(sk))) 1109 return SK_DROP; 1110 1111 tcb->bpf.flags = flags; 1112 tcb->bpf.sk_redir = sk; 1113 return SK_PASS; 1114 } 1115 1116 const struct bpf_func_proto bpf_sk_redirect_hash_proto = { 1117 .func = bpf_sk_redirect_hash, 1118 .gpl_only = false, 1119 .ret_type = RET_INTEGER, 1120 .arg1_type = ARG_PTR_TO_CTX, 1121 .arg2_type = ARG_CONST_MAP_PTR, 1122 .arg3_type = ARG_PTR_TO_MAP_KEY, 1123 .arg4_type = ARG_ANYTHING, 1124 }; 1125 1126 BPF_CALL_4(bpf_msg_redirect_hash, struct sk_msg *, msg, 1127 struct bpf_map *, map, void *, key, u64, flags) 1128 { 1129 struct sock *sk; 1130 1131 if (unlikely(flags & ~(BPF_F_INGRESS))) 1132 return SK_DROP; 1133 1134 sk = __sock_hash_lookup_elem(map, key); 1135 if (unlikely(!sk || !sock_map_redirect_allowed(sk))) 1136 return SK_DROP; 1137 1138 msg->flags = flags; 1139 msg->sk_redir = sk; 1140 return SK_PASS; 1141 } 1142 1143 const struct bpf_func_proto bpf_msg_redirect_hash_proto = { 1144 .func = bpf_msg_redirect_hash, 1145 .gpl_only = false, 1146 .ret_type = RET_INTEGER, 1147 .arg1_type = ARG_PTR_TO_CTX, 1148 .arg2_type = ARG_CONST_MAP_PTR, 1149 .arg3_type = ARG_PTR_TO_MAP_KEY, 1150 .arg4_type = ARG_ANYTHING, 1151 }; 1152 1153 const struct bpf_map_ops sock_hash_ops = { 1154 .map_alloc = sock_hash_alloc, 1155 .map_free = sock_hash_free, 1156 .map_get_next_key = sock_hash_get_next_key, 1157 .map_update_elem = sock_hash_update_elem, 1158 .map_delete_elem = sock_hash_delete_elem, 1159 .map_lookup_elem = sock_hash_lookup, 1160 .map_lookup_elem_sys_only = sock_hash_lookup_sys, 1161 .map_release_uref = sock_hash_release_progs, 1162 .map_check_btf = map_check_no_btf, 1163 }; 1164 1165 static struct sk_psock_progs *sock_map_progs(struct bpf_map *map) 1166 { 1167 switch (map->map_type) { 1168 case BPF_MAP_TYPE_SOCKMAP: 1169 return &container_of(map, struct bpf_stab, map)->progs; 1170 case BPF_MAP_TYPE_SOCKHASH: 1171 return &container_of(map, struct bpf_htab, map)->progs; 1172 default: 1173 break; 1174 } 1175 1176 return NULL; 1177 } 1178 1179 int sock_map_prog_update(struct bpf_map *map, struct bpf_prog *prog, 1180 u32 which) 1181 { 1182 struct sk_psock_progs *progs = sock_map_progs(map); 1183 1184 if (!progs) 1185 return -EOPNOTSUPP; 1186 1187 switch (which) { 1188 case BPF_SK_MSG_VERDICT: 1189 psock_set_prog(&progs->msg_parser, prog); 1190 break; 1191 case BPF_SK_SKB_STREAM_PARSER: 1192 psock_set_prog(&progs->skb_parser, prog); 1193 break; 1194 case BPF_SK_SKB_STREAM_VERDICT: 1195 psock_set_prog(&progs->skb_verdict, prog); 1196 break; 1197 default: 1198 return -EOPNOTSUPP; 1199 } 1200 1201 return 0; 1202 } 1203 1204 static void sock_map_unlink(struct sock *sk, struct sk_psock_link *link) 1205 { 1206 switch (link->map->map_type) { 1207 case BPF_MAP_TYPE_SOCKMAP: 1208 return sock_map_delete_from_link(link->map, sk, 1209 link->link_raw); 1210 case BPF_MAP_TYPE_SOCKHASH: 1211 return sock_hash_delete_from_link(link->map, sk, 1212 link->link_raw); 1213 default: 1214 break; 1215 } 1216 } 1217 1218 static void sock_map_remove_links(struct sock *sk, struct sk_psock *psock) 1219 { 1220 struct sk_psock_link *link; 1221 1222 while ((link = sk_psock_link_pop(psock))) { 1223 sock_map_unlink(sk, link); 1224 sk_psock_free_link(link); 1225 } 1226 } 1227 1228 void sock_map_unhash(struct sock *sk) 1229 { 1230 void (*saved_unhash)(struct sock *sk); 1231 struct sk_psock *psock; 1232 1233 rcu_read_lock(); 1234 psock = sk_psock(sk); 1235 if (unlikely(!psock)) { 1236 rcu_read_unlock(); 1237 if (sk->sk_prot->unhash) 1238 sk->sk_prot->unhash(sk); 1239 return; 1240 } 1241 1242 saved_unhash = psock->saved_unhash; 1243 sock_map_remove_links(sk, psock); 1244 rcu_read_unlock(); 1245 saved_unhash(sk); 1246 } 1247 1248 void sock_map_close(struct sock *sk, long timeout) 1249 { 1250 void (*saved_close)(struct sock *sk, long timeout); 1251 struct sk_psock *psock; 1252 1253 lock_sock(sk); 1254 rcu_read_lock(); 1255 psock = sk_psock(sk); 1256 if (unlikely(!psock)) { 1257 rcu_read_unlock(); 1258 release_sock(sk); 1259 return sk->sk_prot->close(sk, timeout); 1260 } 1261 1262 saved_close = psock->saved_close; 1263 sock_map_remove_links(sk, psock); 1264 rcu_read_unlock(); 1265 release_sock(sk); 1266 saved_close(sk, timeout); 1267 } 1268