1 // SPDX-License-Identifier: GPL-2.0 2 /* Copyright (c) 2017 - 2018 Covalent IO, Inc. http://covalent.io */ 3 4 #include <linux/bpf.h> 5 #include <linux/btf_ids.h> 6 #include <linux/filter.h> 7 #include <linux/errno.h> 8 #include <linux/file.h> 9 #include <linux/net.h> 10 #include <linux/workqueue.h> 11 #include <linux/skmsg.h> 12 #include <linux/list.h> 13 #include <linux/jhash.h> 14 #include <linux/sock_diag.h> 15 #include <net/udp.h> 16 17 struct bpf_stab { 18 struct bpf_map map; 19 struct sock **sks; 20 struct sk_psock_progs progs; 21 raw_spinlock_t lock; 22 }; 23 24 #define SOCK_CREATE_FLAG_MASK \ 25 (BPF_F_NUMA_NODE | BPF_F_RDONLY | BPF_F_WRONLY) 26 27 static struct bpf_map *sock_map_alloc(union bpf_attr *attr) 28 { 29 struct bpf_stab *stab; 30 31 if (!capable(CAP_NET_ADMIN)) 32 return ERR_PTR(-EPERM); 33 if (attr->max_entries == 0 || 34 attr->key_size != 4 || 35 (attr->value_size != sizeof(u32) && 36 attr->value_size != sizeof(u64)) || 37 attr->map_flags & ~SOCK_CREATE_FLAG_MASK) 38 return ERR_PTR(-EINVAL); 39 40 stab = kzalloc(sizeof(*stab), GFP_USER | __GFP_ACCOUNT); 41 if (!stab) 42 return ERR_PTR(-ENOMEM); 43 44 bpf_map_init_from_attr(&stab->map, attr); 45 raw_spin_lock_init(&stab->lock); 46 47 stab->sks = bpf_map_area_alloc(stab->map.max_entries * 48 sizeof(struct sock *), 49 stab->map.numa_node); 50 if (!stab->sks) { 51 kfree(stab); 52 return ERR_PTR(-ENOMEM); 53 } 54 55 return &stab->map; 56 } 57 58 int sock_map_get_from_fd(const union bpf_attr *attr, struct bpf_prog *prog) 59 { 60 u32 ufd = attr->target_fd; 61 struct bpf_map *map; 62 struct fd f; 63 int ret; 64 65 if (attr->attach_flags || attr->replace_bpf_fd) 66 return -EINVAL; 67 68 f = fdget(ufd); 69 map = __bpf_map_get(f); 70 if (IS_ERR(map)) 71 return PTR_ERR(map); 72 ret = sock_map_prog_update(map, prog, NULL, attr->attach_type); 73 fdput(f); 74 return ret; 75 } 76 77 int sock_map_prog_detach(const union bpf_attr *attr, enum bpf_prog_type ptype) 78 { 79 u32 ufd = attr->target_fd; 80 struct bpf_prog *prog; 81 struct bpf_map *map; 82 struct fd f; 83 int ret; 84 85 if (attr->attach_flags || attr->replace_bpf_fd) 86 return -EINVAL; 87 88 f = fdget(ufd); 89 map = __bpf_map_get(f); 90 if (IS_ERR(map)) 91 return PTR_ERR(map); 92 93 prog = bpf_prog_get(attr->attach_bpf_fd); 94 if (IS_ERR(prog)) { 95 ret = PTR_ERR(prog); 96 goto put_map; 97 } 98 99 if (prog->type != ptype) { 100 ret = -EINVAL; 101 goto put_prog; 102 } 103 104 ret = sock_map_prog_update(map, NULL, prog, attr->attach_type); 105 put_prog: 106 bpf_prog_put(prog); 107 put_map: 108 fdput(f); 109 return ret; 110 } 111 112 static void sock_map_sk_acquire(struct sock *sk) 113 __acquires(&sk->sk_lock.slock) 114 { 115 lock_sock(sk); 116 preempt_disable(); 117 rcu_read_lock(); 118 } 119 120 static void sock_map_sk_release(struct sock *sk) 121 __releases(&sk->sk_lock.slock) 122 { 123 rcu_read_unlock(); 124 preempt_enable(); 125 release_sock(sk); 126 } 127 128 static void sock_map_add_link(struct sk_psock *psock, 129 struct sk_psock_link *link, 130 struct bpf_map *map, void *link_raw) 131 { 132 link->link_raw = link_raw; 133 link->map = map; 134 spin_lock_bh(&psock->link_lock); 135 list_add_tail(&link->list, &psock->link); 136 spin_unlock_bh(&psock->link_lock); 137 } 138 139 static void sock_map_del_link(struct sock *sk, 140 struct sk_psock *psock, void *link_raw) 141 { 142 bool strp_stop = false, verdict_stop = false; 143 struct sk_psock_link *link, *tmp; 144 145 spin_lock_bh(&psock->link_lock); 146 list_for_each_entry_safe(link, tmp, &psock->link, list) { 147 if (link->link_raw == link_raw) { 148 struct bpf_map *map = link->map; 149 struct bpf_stab *stab = container_of(map, struct bpf_stab, 150 map); 151 if (psock->parser.enabled && stab->progs.skb_parser) 152 strp_stop = true; 153 if (psock->parser.enabled && stab->progs.skb_verdict) 154 verdict_stop = true; 155 list_del(&link->list); 156 sk_psock_free_link(link); 157 } 158 } 159 spin_unlock_bh(&psock->link_lock); 160 if (strp_stop || verdict_stop) { 161 write_lock_bh(&sk->sk_callback_lock); 162 if (strp_stop) 163 sk_psock_stop_strp(sk, psock); 164 else 165 sk_psock_stop_verdict(sk, psock); 166 write_unlock_bh(&sk->sk_callback_lock); 167 } 168 } 169 170 static void sock_map_unref(struct sock *sk, void *link_raw) 171 { 172 struct sk_psock *psock = sk_psock(sk); 173 174 if (likely(psock)) { 175 sock_map_del_link(sk, psock, link_raw); 176 sk_psock_put(sk, psock); 177 } 178 } 179 180 static int sock_map_init_proto(struct sock *sk, struct sk_psock *psock) 181 { 182 struct proto *prot; 183 184 switch (sk->sk_type) { 185 case SOCK_STREAM: 186 prot = tcp_bpf_get_proto(sk, psock); 187 break; 188 189 case SOCK_DGRAM: 190 prot = udp_bpf_get_proto(sk, psock); 191 break; 192 193 default: 194 return -EINVAL; 195 } 196 197 if (IS_ERR(prot)) 198 return PTR_ERR(prot); 199 200 sk_psock_update_proto(sk, psock, prot); 201 return 0; 202 } 203 204 static struct sk_psock *sock_map_psock_get_checked(struct sock *sk) 205 { 206 struct sk_psock *psock; 207 208 rcu_read_lock(); 209 psock = sk_psock(sk); 210 if (psock) { 211 if (sk->sk_prot->close != sock_map_close) { 212 psock = ERR_PTR(-EBUSY); 213 goto out; 214 } 215 216 if (!refcount_inc_not_zero(&psock->refcnt)) 217 psock = ERR_PTR(-EBUSY); 218 } 219 out: 220 rcu_read_unlock(); 221 return psock; 222 } 223 224 static int sock_map_link(struct bpf_map *map, struct sk_psock_progs *progs, 225 struct sock *sk) 226 { 227 struct bpf_prog *msg_parser, *skb_parser, *skb_verdict; 228 struct sk_psock *psock; 229 int ret; 230 231 skb_verdict = READ_ONCE(progs->skb_verdict); 232 if (skb_verdict) { 233 skb_verdict = bpf_prog_inc_not_zero(skb_verdict); 234 if (IS_ERR(skb_verdict)) 235 return PTR_ERR(skb_verdict); 236 } 237 238 skb_parser = READ_ONCE(progs->skb_parser); 239 if (skb_parser) { 240 skb_parser = bpf_prog_inc_not_zero(skb_parser); 241 if (IS_ERR(skb_parser)) { 242 ret = PTR_ERR(skb_parser); 243 goto out_put_skb_verdict; 244 } 245 } 246 247 msg_parser = READ_ONCE(progs->msg_parser); 248 if (msg_parser) { 249 msg_parser = bpf_prog_inc_not_zero(msg_parser); 250 if (IS_ERR(msg_parser)) { 251 ret = PTR_ERR(msg_parser); 252 goto out_put_skb_parser; 253 } 254 } 255 256 psock = sock_map_psock_get_checked(sk); 257 if (IS_ERR(psock)) { 258 ret = PTR_ERR(psock); 259 goto out_progs; 260 } 261 262 if (psock) { 263 if ((msg_parser && READ_ONCE(psock->progs.msg_parser)) || 264 (skb_parser && READ_ONCE(psock->progs.skb_parser)) || 265 (skb_verdict && READ_ONCE(psock->progs.skb_verdict))) { 266 sk_psock_put(sk, psock); 267 ret = -EBUSY; 268 goto out_progs; 269 } 270 } else { 271 psock = sk_psock_init(sk, map->numa_node); 272 if (IS_ERR(psock)) { 273 ret = PTR_ERR(psock); 274 goto out_progs; 275 } 276 } 277 278 if (msg_parser) 279 psock_set_prog(&psock->progs.msg_parser, msg_parser); 280 281 ret = sock_map_init_proto(sk, psock); 282 if (ret < 0) 283 goto out_drop; 284 285 write_lock_bh(&sk->sk_callback_lock); 286 if (skb_parser && skb_verdict && !psock->parser.enabled) { 287 ret = sk_psock_init_strp(sk, psock); 288 if (ret) 289 goto out_unlock_drop; 290 psock_set_prog(&psock->progs.skb_verdict, skb_verdict); 291 psock_set_prog(&psock->progs.skb_parser, skb_parser); 292 sk_psock_start_strp(sk, psock); 293 } else if (!skb_parser && skb_verdict && !psock->parser.enabled) { 294 psock_set_prog(&psock->progs.skb_verdict, skb_verdict); 295 sk_psock_start_verdict(sk,psock); 296 } 297 write_unlock_bh(&sk->sk_callback_lock); 298 return 0; 299 out_unlock_drop: 300 write_unlock_bh(&sk->sk_callback_lock); 301 out_drop: 302 sk_psock_put(sk, psock); 303 out_progs: 304 if (msg_parser) 305 bpf_prog_put(msg_parser); 306 out_put_skb_parser: 307 if (skb_parser) 308 bpf_prog_put(skb_parser); 309 out_put_skb_verdict: 310 if (skb_verdict) 311 bpf_prog_put(skb_verdict); 312 return ret; 313 } 314 315 static int sock_map_link_no_progs(struct bpf_map *map, struct sock *sk) 316 { 317 struct sk_psock *psock; 318 int ret; 319 320 psock = sock_map_psock_get_checked(sk); 321 if (IS_ERR(psock)) 322 return PTR_ERR(psock); 323 324 if (!psock) { 325 psock = sk_psock_init(sk, map->numa_node); 326 if (IS_ERR(psock)) 327 return PTR_ERR(psock); 328 } 329 330 ret = sock_map_init_proto(sk, psock); 331 if (ret < 0) 332 sk_psock_put(sk, psock); 333 return ret; 334 } 335 336 static void sock_map_free(struct bpf_map *map) 337 { 338 struct bpf_stab *stab = container_of(map, struct bpf_stab, map); 339 int i; 340 341 /* After the sync no updates or deletes will be in-flight so it 342 * is safe to walk map and remove entries without risking a race 343 * in EEXIST update case. 344 */ 345 synchronize_rcu(); 346 for (i = 0; i < stab->map.max_entries; i++) { 347 struct sock **psk = &stab->sks[i]; 348 struct sock *sk; 349 350 sk = xchg(psk, NULL); 351 if (sk) { 352 lock_sock(sk); 353 rcu_read_lock(); 354 sock_map_unref(sk, psk); 355 rcu_read_unlock(); 356 release_sock(sk); 357 } 358 } 359 360 /* wait for psock readers accessing its map link */ 361 synchronize_rcu(); 362 363 bpf_map_area_free(stab->sks); 364 kfree(stab); 365 } 366 367 static void sock_map_release_progs(struct bpf_map *map) 368 { 369 psock_progs_drop(&container_of(map, struct bpf_stab, map)->progs); 370 } 371 372 static struct sock *__sock_map_lookup_elem(struct bpf_map *map, u32 key) 373 { 374 struct bpf_stab *stab = container_of(map, struct bpf_stab, map); 375 376 WARN_ON_ONCE(!rcu_read_lock_held()); 377 378 if (unlikely(key >= map->max_entries)) 379 return NULL; 380 return READ_ONCE(stab->sks[key]); 381 } 382 383 static void *sock_map_lookup(struct bpf_map *map, void *key) 384 { 385 struct sock *sk; 386 387 sk = __sock_map_lookup_elem(map, *(u32 *)key); 388 if (!sk) 389 return NULL; 390 if (sk_is_refcounted(sk) && !refcount_inc_not_zero(&sk->sk_refcnt)) 391 return NULL; 392 return sk; 393 } 394 395 static void *sock_map_lookup_sys(struct bpf_map *map, void *key) 396 { 397 struct sock *sk; 398 399 if (map->value_size != sizeof(u64)) 400 return ERR_PTR(-ENOSPC); 401 402 sk = __sock_map_lookup_elem(map, *(u32 *)key); 403 if (!sk) 404 return ERR_PTR(-ENOENT); 405 406 __sock_gen_cookie(sk); 407 return &sk->sk_cookie; 408 } 409 410 static int __sock_map_delete(struct bpf_stab *stab, struct sock *sk_test, 411 struct sock **psk) 412 { 413 struct sock *sk; 414 int err = 0; 415 416 raw_spin_lock_bh(&stab->lock); 417 sk = *psk; 418 if (!sk_test || sk_test == sk) 419 sk = xchg(psk, NULL); 420 421 if (likely(sk)) 422 sock_map_unref(sk, psk); 423 else 424 err = -EINVAL; 425 426 raw_spin_unlock_bh(&stab->lock); 427 return err; 428 } 429 430 static void sock_map_delete_from_link(struct bpf_map *map, struct sock *sk, 431 void *link_raw) 432 { 433 struct bpf_stab *stab = container_of(map, struct bpf_stab, map); 434 435 __sock_map_delete(stab, sk, link_raw); 436 } 437 438 static int sock_map_delete_elem(struct bpf_map *map, void *key) 439 { 440 struct bpf_stab *stab = container_of(map, struct bpf_stab, map); 441 u32 i = *(u32 *)key; 442 struct sock **psk; 443 444 if (unlikely(i >= map->max_entries)) 445 return -EINVAL; 446 447 psk = &stab->sks[i]; 448 return __sock_map_delete(stab, NULL, psk); 449 } 450 451 static int sock_map_get_next_key(struct bpf_map *map, void *key, void *next) 452 { 453 struct bpf_stab *stab = container_of(map, struct bpf_stab, map); 454 u32 i = key ? *(u32 *)key : U32_MAX; 455 u32 *key_next = next; 456 457 if (i == stab->map.max_entries - 1) 458 return -ENOENT; 459 if (i >= stab->map.max_entries) 460 *key_next = 0; 461 else 462 *key_next = i + 1; 463 return 0; 464 } 465 466 static bool sock_map_redirect_allowed(const struct sock *sk); 467 468 static int sock_map_update_common(struct bpf_map *map, u32 idx, 469 struct sock *sk, u64 flags) 470 { 471 struct bpf_stab *stab = container_of(map, struct bpf_stab, map); 472 struct sk_psock_link *link; 473 struct sk_psock *psock; 474 struct sock *osk; 475 int ret; 476 477 WARN_ON_ONCE(!rcu_read_lock_held()); 478 if (unlikely(flags > BPF_EXIST)) 479 return -EINVAL; 480 if (unlikely(idx >= map->max_entries)) 481 return -E2BIG; 482 483 link = sk_psock_init_link(); 484 if (!link) 485 return -ENOMEM; 486 487 /* Only sockets we can redirect into/from in BPF need to hold 488 * refs to parser/verdict progs and have their sk_data_ready 489 * and sk_write_space callbacks overridden. 490 */ 491 if (sock_map_redirect_allowed(sk)) 492 ret = sock_map_link(map, &stab->progs, sk); 493 else 494 ret = sock_map_link_no_progs(map, sk); 495 if (ret < 0) 496 goto out_free; 497 498 psock = sk_psock(sk); 499 WARN_ON_ONCE(!psock); 500 501 raw_spin_lock_bh(&stab->lock); 502 osk = stab->sks[idx]; 503 if (osk && flags == BPF_NOEXIST) { 504 ret = -EEXIST; 505 goto out_unlock; 506 } else if (!osk && flags == BPF_EXIST) { 507 ret = -ENOENT; 508 goto out_unlock; 509 } 510 511 sock_map_add_link(psock, link, map, &stab->sks[idx]); 512 stab->sks[idx] = sk; 513 if (osk) 514 sock_map_unref(osk, &stab->sks[idx]); 515 raw_spin_unlock_bh(&stab->lock); 516 return 0; 517 out_unlock: 518 raw_spin_unlock_bh(&stab->lock); 519 if (psock) 520 sk_psock_put(sk, psock); 521 out_free: 522 sk_psock_free_link(link); 523 return ret; 524 } 525 526 static bool sock_map_op_okay(const struct bpf_sock_ops_kern *ops) 527 { 528 return ops->op == BPF_SOCK_OPS_PASSIVE_ESTABLISHED_CB || 529 ops->op == BPF_SOCK_OPS_ACTIVE_ESTABLISHED_CB || 530 ops->op == BPF_SOCK_OPS_TCP_LISTEN_CB; 531 } 532 533 static bool sk_is_tcp(const struct sock *sk) 534 { 535 return sk->sk_type == SOCK_STREAM && 536 sk->sk_protocol == IPPROTO_TCP; 537 } 538 539 static bool sk_is_udp(const struct sock *sk) 540 { 541 return sk->sk_type == SOCK_DGRAM && 542 sk->sk_protocol == IPPROTO_UDP; 543 } 544 545 static bool sock_map_redirect_allowed(const struct sock *sk) 546 { 547 return sk_is_tcp(sk) && sk->sk_state != TCP_LISTEN; 548 } 549 550 static bool sock_map_sk_is_suitable(const struct sock *sk) 551 { 552 return sk_is_tcp(sk) || sk_is_udp(sk); 553 } 554 555 static bool sock_map_sk_state_allowed(const struct sock *sk) 556 { 557 if (sk_is_tcp(sk)) 558 return (1 << sk->sk_state) & (TCPF_ESTABLISHED | TCPF_LISTEN); 559 else if (sk_is_udp(sk)) 560 return sk_hashed(sk); 561 562 return false; 563 } 564 565 static int sock_hash_update_common(struct bpf_map *map, void *key, 566 struct sock *sk, u64 flags); 567 568 int sock_map_update_elem_sys(struct bpf_map *map, void *key, void *value, 569 u64 flags) 570 { 571 struct socket *sock; 572 struct sock *sk; 573 int ret; 574 u64 ufd; 575 576 if (map->value_size == sizeof(u64)) 577 ufd = *(u64 *)value; 578 else 579 ufd = *(u32 *)value; 580 if (ufd > S32_MAX) 581 return -EINVAL; 582 583 sock = sockfd_lookup(ufd, &ret); 584 if (!sock) 585 return ret; 586 sk = sock->sk; 587 if (!sk) { 588 ret = -EINVAL; 589 goto out; 590 } 591 if (!sock_map_sk_is_suitable(sk)) { 592 ret = -EOPNOTSUPP; 593 goto out; 594 } 595 596 sock_map_sk_acquire(sk); 597 if (!sock_map_sk_state_allowed(sk)) 598 ret = -EOPNOTSUPP; 599 else if (map->map_type == BPF_MAP_TYPE_SOCKMAP) 600 ret = sock_map_update_common(map, *(u32 *)key, sk, flags); 601 else 602 ret = sock_hash_update_common(map, key, sk, flags); 603 sock_map_sk_release(sk); 604 out: 605 sockfd_put(sock); 606 return ret; 607 } 608 609 static int sock_map_update_elem(struct bpf_map *map, void *key, 610 void *value, u64 flags) 611 { 612 struct sock *sk = (struct sock *)value; 613 int ret; 614 615 if (unlikely(!sk || !sk_fullsock(sk))) 616 return -EINVAL; 617 618 if (!sock_map_sk_is_suitable(sk)) 619 return -EOPNOTSUPP; 620 621 local_bh_disable(); 622 bh_lock_sock(sk); 623 if (!sock_map_sk_state_allowed(sk)) 624 ret = -EOPNOTSUPP; 625 else if (map->map_type == BPF_MAP_TYPE_SOCKMAP) 626 ret = sock_map_update_common(map, *(u32 *)key, sk, flags); 627 else 628 ret = sock_hash_update_common(map, key, sk, flags); 629 bh_unlock_sock(sk); 630 local_bh_enable(); 631 return ret; 632 } 633 634 BPF_CALL_4(bpf_sock_map_update, struct bpf_sock_ops_kern *, sops, 635 struct bpf_map *, map, void *, key, u64, flags) 636 { 637 WARN_ON_ONCE(!rcu_read_lock_held()); 638 639 if (likely(sock_map_sk_is_suitable(sops->sk) && 640 sock_map_op_okay(sops))) 641 return sock_map_update_common(map, *(u32 *)key, sops->sk, 642 flags); 643 return -EOPNOTSUPP; 644 } 645 646 const struct bpf_func_proto bpf_sock_map_update_proto = { 647 .func = bpf_sock_map_update, 648 .gpl_only = false, 649 .pkt_access = true, 650 .ret_type = RET_INTEGER, 651 .arg1_type = ARG_PTR_TO_CTX, 652 .arg2_type = ARG_CONST_MAP_PTR, 653 .arg3_type = ARG_PTR_TO_MAP_KEY, 654 .arg4_type = ARG_ANYTHING, 655 }; 656 657 BPF_CALL_4(bpf_sk_redirect_map, struct sk_buff *, skb, 658 struct bpf_map *, map, u32, key, u64, flags) 659 { 660 struct tcp_skb_cb *tcb = TCP_SKB_CB(skb); 661 struct sock *sk; 662 663 if (unlikely(flags & ~(BPF_F_INGRESS))) 664 return SK_DROP; 665 666 sk = __sock_map_lookup_elem(map, key); 667 if (unlikely(!sk || !sock_map_redirect_allowed(sk))) 668 return SK_DROP; 669 670 tcb->bpf.flags = flags; 671 tcb->bpf.sk_redir = sk; 672 return SK_PASS; 673 } 674 675 const struct bpf_func_proto bpf_sk_redirect_map_proto = { 676 .func = bpf_sk_redirect_map, 677 .gpl_only = false, 678 .ret_type = RET_INTEGER, 679 .arg1_type = ARG_PTR_TO_CTX, 680 .arg2_type = ARG_CONST_MAP_PTR, 681 .arg3_type = ARG_ANYTHING, 682 .arg4_type = ARG_ANYTHING, 683 }; 684 685 BPF_CALL_4(bpf_msg_redirect_map, struct sk_msg *, msg, 686 struct bpf_map *, map, u32, key, u64, flags) 687 { 688 struct sock *sk; 689 690 if (unlikely(flags & ~(BPF_F_INGRESS))) 691 return SK_DROP; 692 693 sk = __sock_map_lookup_elem(map, key); 694 if (unlikely(!sk || !sock_map_redirect_allowed(sk))) 695 return SK_DROP; 696 697 msg->flags = flags; 698 msg->sk_redir = sk; 699 return SK_PASS; 700 } 701 702 const struct bpf_func_proto bpf_msg_redirect_map_proto = { 703 .func = bpf_msg_redirect_map, 704 .gpl_only = false, 705 .ret_type = RET_INTEGER, 706 .arg1_type = ARG_PTR_TO_CTX, 707 .arg2_type = ARG_CONST_MAP_PTR, 708 .arg3_type = ARG_ANYTHING, 709 .arg4_type = ARG_ANYTHING, 710 }; 711 712 struct sock_map_seq_info { 713 struct bpf_map *map; 714 struct sock *sk; 715 u32 index; 716 }; 717 718 struct bpf_iter__sockmap { 719 __bpf_md_ptr(struct bpf_iter_meta *, meta); 720 __bpf_md_ptr(struct bpf_map *, map); 721 __bpf_md_ptr(void *, key); 722 __bpf_md_ptr(struct sock *, sk); 723 }; 724 725 DEFINE_BPF_ITER_FUNC(sockmap, struct bpf_iter_meta *meta, 726 struct bpf_map *map, void *key, 727 struct sock *sk) 728 729 static void *sock_map_seq_lookup_elem(struct sock_map_seq_info *info) 730 { 731 if (unlikely(info->index >= info->map->max_entries)) 732 return NULL; 733 734 info->sk = __sock_map_lookup_elem(info->map, info->index); 735 736 /* can't return sk directly, since that might be NULL */ 737 return info; 738 } 739 740 static void *sock_map_seq_start(struct seq_file *seq, loff_t *pos) 741 __acquires(rcu) 742 { 743 struct sock_map_seq_info *info = seq->private; 744 745 if (*pos == 0) 746 ++*pos; 747 748 /* pairs with sock_map_seq_stop */ 749 rcu_read_lock(); 750 return sock_map_seq_lookup_elem(info); 751 } 752 753 static void *sock_map_seq_next(struct seq_file *seq, void *v, loff_t *pos) 754 __must_hold(rcu) 755 { 756 struct sock_map_seq_info *info = seq->private; 757 758 ++*pos; 759 ++info->index; 760 761 return sock_map_seq_lookup_elem(info); 762 } 763 764 static int sock_map_seq_show(struct seq_file *seq, void *v) 765 __must_hold(rcu) 766 { 767 struct sock_map_seq_info *info = seq->private; 768 struct bpf_iter__sockmap ctx = {}; 769 struct bpf_iter_meta meta; 770 struct bpf_prog *prog; 771 772 meta.seq = seq; 773 prog = bpf_iter_get_info(&meta, !v); 774 if (!prog) 775 return 0; 776 777 ctx.meta = &meta; 778 ctx.map = info->map; 779 if (v) { 780 ctx.key = &info->index; 781 ctx.sk = info->sk; 782 } 783 784 return bpf_iter_run_prog(prog, &ctx); 785 } 786 787 static void sock_map_seq_stop(struct seq_file *seq, void *v) 788 __releases(rcu) 789 { 790 if (!v) 791 (void)sock_map_seq_show(seq, NULL); 792 793 /* pairs with sock_map_seq_start */ 794 rcu_read_unlock(); 795 } 796 797 static const struct seq_operations sock_map_seq_ops = { 798 .start = sock_map_seq_start, 799 .next = sock_map_seq_next, 800 .stop = sock_map_seq_stop, 801 .show = sock_map_seq_show, 802 }; 803 804 static int sock_map_init_seq_private(void *priv_data, 805 struct bpf_iter_aux_info *aux) 806 { 807 struct sock_map_seq_info *info = priv_data; 808 809 info->map = aux->map; 810 return 0; 811 } 812 813 static const struct bpf_iter_seq_info sock_map_iter_seq_info = { 814 .seq_ops = &sock_map_seq_ops, 815 .init_seq_private = sock_map_init_seq_private, 816 .seq_priv_size = sizeof(struct sock_map_seq_info), 817 }; 818 819 static int sock_map_btf_id; 820 const struct bpf_map_ops sock_map_ops = { 821 .map_meta_equal = bpf_map_meta_equal, 822 .map_alloc = sock_map_alloc, 823 .map_free = sock_map_free, 824 .map_get_next_key = sock_map_get_next_key, 825 .map_lookup_elem_sys_only = sock_map_lookup_sys, 826 .map_update_elem = sock_map_update_elem, 827 .map_delete_elem = sock_map_delete_elem, 828 .map_lookup_elem = sock_map_lookup, 829 .map_release_uref = sock_map_release_progs, 830 .map_check_btf = map_check_no_btf, 831 .map_btf_name = "bpf_stab", 832 .map_btf_id = &sock_map_btf_id, 833 .iter_seq_info = &sock_map_iter_seq_info, 834 }; 835 836 struct bpf_shtab_elem { 837 struct rcu_head rcu; 838 u32 hash; 839 struct sock *sk; 840 struct hlist_node node; 841 u8 key[]; 842 }; 843 844 struct bpf_shtab_bucket { 845 struct hlist_head head; 846 raw_spinlock_t lock; 847 }; 848 849 struct bpf_shtab { 850 struct bpf_map map; 851 struct bpf_shtab_bucket *buckets; 852 u32 buckets_num; 853 u32 elem_size; 854 struct sk_psock_progs progs; 855 atomic_t count; 856 }; 857 858 static inline u32 sock_hash_bucket_hash(const void *key, u32 len) 859 { 860 return jhash(key, len, 0); 861 } 862 863 static struct bpf_shtab_bucket *sock_hash_select_bucket(struct bpf_shtab *htab, 864 u32 hash) 865 { 866 return &htab->buckets[hash & (htab->buckets_num - 1)]; 867 } 868 869 static struct bpf_shtab_elem * 870 sock_hash_lookup_elem_raw(struct hlist_head *head, u32 hash, void *key, 871 u32 key_size) 872 { 873 struct bpf_shtab_elem *elem; 874 875 hlist_for_each_entry_rcu(elem, head, node) { 876 if (elem->hash == hash && 877 !memcmp(&elem->key, key, key_size)) 878 return elem; 879 } 880 881 return NULL; 882 } 883 884 static struct sock *__sock_hash_lookup_elem(struct bpf_map *map, void *key) 885 { 886 struct bpf_shtab *htab = container_of(map, struct bpf_shtab, map); 887 u32 key_size = map->key_size, hash; 888 struct bpf_shtab_bucket *bucket; 889 struct bpf_shtab_elem *elem; 890 891 WARN_ON_ONCE(!rcu_read_lock_held()); 892 893 hash = sock_hash_bucket_hash(key, key_size); 894 bucket = sock_hash_select_bucket(htab, hash); 895 elem = sock_hash_lookup_elem_raw(&bucket->head, hash, key, key_size); 896 897 return elem ? elem->sk : NULL; 898 } 899 900 static void sock_hash_free_elem(struct bpf_shtab *htab, 901 struct bpf_shtab_elem *elem) 902 { 903 atomic_dec(&htab->count); 904 kfree_rcu(elem, rcu); 905 } 906 907 static void sock_hash_delete_from_link(struct bpf_map *map, struct sock *sk, 908 void *link_raw) 909 { 910 struct bpf_shtab *htab = container_of(map, struct bpf_shtab, map); 911 struct bpf_shtab_elem *elem_probe, *elem = link_raw; 912 struct bpf_shtab_bucket *bucket; 913 914 WARN_ON_ONCE(!rcu_read_lock_held()); 915 bucket = sock_hash_select_bucket(htab, elem->hash); 916 917 /* elem may be deleted in parallel from the map, but access here 918 * is okay since it's going away only after RCU grace period. 919 * However, we need to check whether it's still present. 920 */ 921 raw_spin_lock_bh(&bucket->lock); 922 elem_probe = sock_hash_lookup_elem_raw(&bucket->head, elem->hash, 923 elem->key, map->key_size); 924 if (elem_probe && elem_probe == elem) { 925 hlist_del_rcu(&elem->node); 926 sock_map_unref(elem->sk, elem); 927 sock_hash_free_elem(htab, elem); 928 } 929 raw_spin_unlock_bh(&bucket->lock); 930 } 931 932 static int sock_hash_delete_elem(struct bpf_map *map, void *key) 933 { 934 struct bpf_shtab *htab = container_of(map, struct bpf_shtab, map); 935 u32 hash, key_size = map->key_size; 936 struct bpf_shtab_bucket *bucket; 937 struct bpf_shtab_elem *elem; 938 int ret = -ENOENT; 939 940 hash = sock_hash_bucket_hash(key, key_size); 941 bucket = sock_hash_select_bucket(htab, hash); 942 943 raw_spin_lock_bh(&bucket->lock); 944 elem = sock_hash_lookup_elem_raw(&bucket->head, hash, key, key_size); 945 if (elem) { 946 hlist_del_rcu(&elem->node); 947 sock_map_unref(elem->sk, elem); 948 sock_hash_free_elem(htab, elem); 949 ret = 0; 950 } 951 raw_spin_unlock_bh(&bucket->lock); 952 return ret; 953 } 954 955 static struct bpf_shtab_elem *sock_hash_alloc_elem(struct bpf_shtab *htab, 956 void *key, u32 key_size, 957 u32 hash, struct sock *sk, 958 struct bpf_shtab_elem *old) 959 { 960 struct bpf_shtab_elem *new; 961 962 if (atomic_inc_return(&htab->count) > htab->map.max_entries) { 963 if (!old) { 964 atomic_dec(&htab->count); 965 return ERR_PTR(-E2BIG); 966 } 967 } 968 969 new = bpf_map_kmalloc_node(&htab->map, htab->elem_size, 970 GFP_ATOMIC | __GFP_NOWARN, 971 htab->map.numa_node); 972 if (!new) { 973 atomic_dec(&htab->count); 974 return ERR_PTR(-ENOMEM); 975 } 976 memcpy(new->key, key, key_size); 977 new->sk = sk; 978 new->hash = hash; 979 return new; 980 } 981 982 static int sock_hash_update_common(struct bpf_map *map, void *key, 983 struct sock *sk, u64 flags) 984 { 985 struct bpf_shtab *htab = container_of(map, struct bpf_shtab, map); 986 u32 key_size = map->key_size, hash; 987 struct bpf_shtab_elem *elem, *elem_new; 988 struct bpf_shtab_bucket *bucket; 989 struct sk_psock_link *link; 990 struct sk_psock *psock; 991 int ret; 992 993 WARN_ON_ONCE(!rcu_read_lock_held()); 994 if (unlikely(flags > BPF_EXIST)) 995 return -EINVAL; 996 997 link = sk_psock_init_link(); 998 if (!link) 999 return -ENOMEM; 1000 1001 /* Only sockets we can redirect into/from in BPF need to hold 1002 * refs to parser/verdict progs and have their sk_data_ready 1003 * and sk_write_space callbacks overridden. 1004 */ 1005 if (sock_map_redirect_allowed(sk)) 1006 ret = sock_map_link(map, &htab->progs, sk); 1007 else 1008 ret = sock_map_link_no_progs(map, sk); 1009 if (ret < 0) 1010 goto out_free; 1011 1012 psock = sk_psock(sk); 1013 WARN_ON_ONCE(!psock); 1014 1015 hash = sock_hash_bucket_hash(key, key_size); 1016 bucket = sock_hash_select_bucket(htab, hash); 1017 1018 raw_spin_lock_bh(&bucket->lock); 1019 elem = sock_hash_lookup_elem_raw(&bucket->head, hash, key, key_size); 1020 if (elem && flags == BPF_NOEXIST) { 1021 ret = -EEXIST; 1022 goto out_unlock; 1023 } else if (!elem && flags == BPF_EXIST) { 1024 ret = -ENOENT; 1025 goto out_unlock; 1026 } 1027 1028 elem_new = sock_hash_alloc_elem(htab, key, key_size, hash, sk, elem); 1029 if (IS_ERR(elem_new)) { 1030 ret = PTR_ERR(elem_new); 1031 goto out_unlock; 1032 } 1033 1034 sock_map_add_link(psock, link, map, elem_new); 1035 /* Add new element to the head of the list, so that 1036 * concurrent search will find it before old elem. 1037 */ 1038 hlist_add_head_rcu(&elem_new->node, &bucket->head); 1039 if (elem) { 1040 hlist_del_rcu(&elem->node); 1041 sock_map_unref(elem->sk, elem); 1042 sock_hash_free_elem(htab, elem); 1043 } 1044 raw_spin_unlock_bh(&bucket->lock); 1045 return 0; 1046 out_unlock: 1047 raw_spin_unlock_bh(&bucket->lock); 1048 sk_psock_put(sk, psock); 1049 out_free: 1050 sk_psock_free_link(link); 1051 return ret; 1052 } 1053 1054 static int sock_hash_get_next_key(struct bpf_map *map, void *key, 1055 void *key_next) 1056 { 1057 struct bpf_shtab *htab = container_of(map, struct bpf_shtab, map); 1058 struct bpf_shtab_elem *elem, *elem_next; 1059 u32 hash, key_size = map->key_size; 1060 struct hlist_head *head; 1061 int i = 0; 1062 1063 if (!key) 1064 goto find_first_elem; 1065 hash = sock_hash_bucket_hash(key, key_size); 1066 head = &sock_hash_select_bucket(htab, hash)->head; 1067 elem = sock_hash_lookup_elem_raw(head, hash, key, key_size); 1068 if (!elem) 1069 goto find_first_elem; 1070 1071 elem_next = hlist_entry_safe(rcu_dereference(hlist_next_rcu(&elem->node)), 1072 struct bpf_shtab_elem, node); 1073 if (elem_next) { 1074 memcpy(key_next, elem_next->key, key_size); 1075 return 0; 1076 } 1077 1078 i = hash & (htab->buckets_num - 1); 1079 i++; 1080 find_first_elem: 1081 for (; i < htab->buckets_num; i++) { 1082 head = &sock_hash_select_bucket(htab, i)->head; 1083 elem_next = hlist_entry_safe(rcu_dereference(hlist_first_rcu(head)), 1084 struct bpf_shtab_elem, node); 1085 if (elem_next) { 1086 memcpy(key_next, elem_next->key, key_size); 1087 return 0; 1088 } 1089 } 1090 1091 return -ENOENT; 1092 } 1093 1094 static struct bpf_map *sock_hash_alloc(union bpf_attr *attr) 1095 { 1096 struct bpf_shtab *htab; 1097 int i, err; 1098 1099 if (!capable(CAP_NET_ADMIN)) 1100 return ERR_PTR(-EPERM); 1101 if (attr->max_entries == 0 || 1102 attr->key_size == 0 || 1103 (attr->value_size != sizeof(u32) && 1104 attr->value_size != sizeof(u64)) || 1105 attr->map_flags & ~SOCK_CREATE_FLAG_MASK) 1106 return ERR_PTR(-EINVAL); 1107 if (attr->key_size > MAX_BPF_STACK) 1108 return ERR_PTR(-E2BIG); 1109 1110 htab = kzalloc(sizeof(*htab), GFP_USER | __GFP_ACCOUNT); 1111 if (!htab) 1112 return ERR_PTR(-ENOMEM); 1113 1114 bpf_map_init_from_attr(&htab->map, attr); 1115 1116 htab->buckets_num = roundup_pow_of_two(htab->map.max_entries); 1117 htab->elem_size = sizeof(struct bpf_shtab_elem) + 1118 round_up(htab->map.key_size, 8); 1119 if (htab->buckets_num == 0 || 1120 htab->buckets_num > U32_MAX / sizeof(struct bpf_shtab_bucket)) { 1121 err = -EINVAL; 1122 goto free_htab; 1123 } 1124 1125 htab->buckets = bpf_map_area_alloc(htab->buckets_num * 1126 sizeof(struct bpf_shtab_bucket), 1127 htab->map.numa_node); 1128 if (!htab->buckets) { 1129 err = -ENOMEM; 1130 goto free_htab; 1131 } 1132 1133 for (i = 0; i < htab->buckets_num; i++) { 1134 INIT_HLIST_HEAD(&htab->buckets[i].head); 1135 raw_spin_lock_init(&htab->buckets[i].lock); 1136 } 1137 1138 return &htab->map; 1139 free_htab: 1140 kfree(htab); 1141 return ERR_PTR(err); 1142 } 1143 1144 static void sock_hash_free(struct bpf_map *map) 1145 { 1146 struct bpf_shtab *htab = container_of(map, struct bpf_shtab, map); 1147 struct bpf_shtab_bucket *bucket; 1148 struct hlist_head unlink_list; 1149 struct bpf_shtab_elem *elem; 1150 struct hlist_node *node; 1151 int i; 1152 1153 /* After the sync no updates or deletes will be in-flight so it 1154 * is safe to walk map and remove entries without risking a race 1155 * in EEXIST update case. 1156 */ 1157 synchronize_rcu(); 1158 for (i = 0; i < htab->buckets_num; i++) { 1159 bucket = sock_hash_select_bucket(htab, i); 1160 1161 /* We are racing with sock_hash_delete_from_link to 1162 * enter the spin-lock critical section. Every socket on 1163 * the list is still linked to sockhash. Since link 1164 * exists, psock exists and holds a ref to socket. That 1165 * lets us to grab a socket ref too. 1166 */ 1167 raw_spin_lock_bh(&bucket->lock); 1168 hlist_for_each_entry(elem, &bucket->head, node) 1169 sock_hold(elem->sk); 1170 hlist_move_list(&bucket->head, &unlink_list); 1171 raw_spin_unlock_bh(&bucket->lock); 1172 1173 /* Process removed entries out of atomic context to 1174 * block for socket lock before deleting the psock's 1175 * link to sockhash. 1176 */ 1177 hlist_for_each_entry_safe(elem, node, &unlink_list, node) { 1178 hlist_del(&elem->node); 1179 lock_sock(elem->sk); 1180 rcu_read_lock(); 1181 sock_map_unref(elem->sk, elem); 1182 rcu_read_unlock(); 1183 release_sock(elem->sk); 1184 sock_put(elem->sk); 1185 sock_hash_free_elem(htab, elem); 1186 } 1187 } 1188 1189 /* wait for psock readers accessing its map link */ 1190 synchronize_rcu(); 1191 1192 bpf_map_area_free(htab->buckets); 1193 kfree(htab); 1194 } 1195 1196 static void *sock_hash_lookup_sys(struct bpf_map *map, void *key) 1197 { 1198 struct sock *sk; 1199 1200 if (map->value_size != sizeof(u64)) 1201 return ERR_PTR(-ENOSPC); 1202 1203 sk = __sock_hash_lookup_elem(map, key); 1204 if (!sk) 1205 return ERR_PTR(-ENOENT); 1206 1207 __sock_gen_cookie(sk); 1208 return &sk->sk_cookie; 1209 } 1210 1211 static void *sock_hash_lookup(struct bpf_map *map, void *key) 1212 { 1213 struct sock *sk; 1214 1215 sk = __sock_hash_lookup_elem(map, key); 1216 if (!sk) 1217 return NULL; 1218 if (sk_is_refcounted(sk) && !refcount_inc_not_zero(&sk->sk_refcnt)) 1219 return NULL; 1220 return sk; 1221 } 1222 1223 static void sock_hash_release_progs(struct bpf_map *map) 1224 { 1225 psock_progs_drop(&container_of(map, struct bpf_shtab, map)->progs); 1226 } 1227 1228 BPF_CALL_4(bpf_sock_hash_update, struct bpf_sock_ops_kern *, sops, 1229 struct bpf_map *, map, void *, key, u64, flags) 1230 { 1231 WARN_ON_ONCE(!rcu_read_lock_held()); 1232 1233 if (likely(sock_map_sk_is_suitable(sops->sk) && 1234 sock_map_op_okay(sops))) 1235 return sock_hash_update_common(map, key, sops->sk, flags); 1236 return -EOPNOTSUPP; 1237 } 1238 1239 const struct bpf_func_proto bpf_sock_hash_update_proto = { 1240 .func = bpf_sock_hash_update, 1241 .gpl_only = false, 1242 .pkt_access = true, 1243 .ret_type = RET_INTEGER, 1244 .arg1_type = ARG_PTR_TO_CTX, 1245 .arg2_type = ARG_CONST_MAP_PTR, 1246 .arg3_type = ARG_PTR_TO_MAP_KEY, 1247 .arg4_type = ARG_ANYTHING, 1248 }; 1249 1250 BPF_CALL_4(bpf_sk_redirect_hash, struct sk_buff *, skb, 1251 struct bpf_map *, map, void *, key, u64, flags) 1252 { 1253 struct tcp_skb_cb *tcb = TCP_SKB_CB(skb); 1254 struct sock *sk; 1255 1256 if (unlikely(flags & ~(BPF_F_INGRESS))) 1257 return SK_DROP; 1258 1259 sk = __sock_hash_lookup_elem(map, key); 1260 if (unlikely(!sk || !sock_map_redirect_allowed(sk))) 1261 return SK_DROP; 1262 1263 tcb->bpf.flags = flags; 1264 tcb->bpf.sk_redir = sk; 1265 return SK_PASS; 1266 } 1267 1268 const struct bpf_func_proto bpf_sk_redirect_hash_proto = { 1269 .func = bpf_sk_redirect_hash, 1270 .gpl_only = false, 1271 .ret_type = RET_INTEGER, 1272 .arg1_type = ARG_PTR_TO_CTX, 1273 .arg2_type = ARG_CONST_MAP_PTR, 1274 .arg3_type = ARG_PTR_TO_MAP_KEY, 1275 .arg4_type = ARG_ANYTHING, 1276 }; 1277 1278 BPF_CALL_4(bpf_msg_redirect_hash, struct sk_msg *, msg, 1279 struct bpf_map *, map, void *, key, u64, flags) 1280 { 1281 struct sock *sk; 1282 1283 if (unlikely(flags & ~(BPF_F_INGRESS))) 1284 return SK_DROP; 1285 1286 sk = __sock_hash_lookup_elem(map, key); 1287 if (unlikely(!sk || !sock_map_redirect_allowed(sk))) 1288 return SK_DROP; 1289 1290 msg->flags = flags; 1291 msg->sk_redir = sk; 1292 return SK_PASS; 1293 } 1294 1295 const struct bpf_func_proto bpf_msg_redirect_hash_proto = { 1296 .func = bpf_msg_redirect_hash, 1297 .gpl_only = false, 1298 .ret_type = RET_INTEGER, 1299 .arg1_type = ARG_PTR_TO_CTX, 1300 .arg2_type = ARG_CONST_MAP_PTR, 1301 .arg3_type = ARG_PTR_TO_MAP_KEY, 1302 .arg4_type = ARG_ANYTHING, 1303 }; 1304 1305 struct sock_hash_seq_info { 1306 struct bpf_map *map; 1307 struct bpf_shtab *htab; 1308 u32 bucket_id; 1309 }; 1310 1311 static void *sock_hash_seq_find_next(struct sock_hash_seq_info *info, 1312 struct bpf_shtab_elem *prev_elem) 1313 { 1314 const struct bpf_shtab *htab = info->htab; 1315 struct bpf_shtab_bucket *bucket; 1316 struct bpf_shtab_elem *elem; 1317 struct hlist_node *node; 1318 1319 /* try to find next elem in the same bucket */ 1320 if (prev_elem) { 1321 node = rcu_dereference(hlist_next_rcu(&prev_elem->node)); 1322 elem = hlist_entry_safe(node, struct bpf_shtab_elem, node); 1323 if (elem) 1324 return elem; 1325 1326 /* no more elements, continue in the next bucket */ 1327 info->bucket_id++; 1328 } 1329 1330 for (; info->bucket_id < htab->buckets_num; info->bucket_id++) { 1331 bucket = &htab->buckets[info->bucket_id]; 1332 node = rcu_dereference(hlist_first_rcu(&bucket->head)); 1333 elem = hlist_entry_safe(node, struct bpf_shtab_elem, node); 1334 if (elem) 1335 return elem; 1336 } 1337 1338 return NULL; 1339 } 1340 1341 static void *sock_hash_seq_start(struct seq_file *seq, loff_t *pos) 1342 __acquires(rcu) 1343 { 1344 struct sock_hash_seq_info *info = seq->private; 1345 1346 if (*pos == 0) 1347 ++*pos; 1348 1349 /* pairs with sock_hash_seq_stop */ 1350 rcu_read_lock(); 1351 return sock_hash_seq_find_next(info, NULL); 1352 } 1353 1354 static void *sock_hash_seq_next(struct seq_file *seq, void *v, loff_t *pos) 1355 __must_hold(rcu) 1356 { 1357 struct sock_hash_seq_info *info = seq->private; 1358 1359 ++*pos; 1360 return sock_hash_seq_find_next(info, v); 1361 } 1362 1363 static int sock_hash_seq_show(struct seq_file *seq, void *v) 1364 __must_hold(rcu) 1365 { 1366 struct sock_hash_seq_info *info = seq->private; 1367 struct bpf_iter__sockmap ctx = {}; 1368 struct bpf_shtab_elem *elem = v; 1369 struct bpf_iter_meta meta; 1370 struct bpf_prog *prog; 1371 1372 meta.seq = seq; 1373 prog = bpf_iter_get_info(&meta, !elem); 1374 if (!prog) 1375 return 0; 1376 1377 ctx.meta = &meta; 1378 ctx.map = info->map; 1379 if (elem) { 1380 ctx.key = elem->key; 1381 ctx.sk = elem->sk; 1382 } 1383 1384 return bpf_iter_run_prog(prog, &ctx); 1385 } 1386 1387 static void sock_hash_seq_stop(struct seq_file *seq, void *v) 1388 __releases(rcu) 1389 { 1390 if (!v) 1391 (void)sock_hash_seq_show(seq, NULL); 1392 1393 /* pairs with sock_hash_seq_start */ 1394 rcu_read_unlock(); 1395 } 1396 1397 static const struct seq_operations sock_hash_seq_ops = { 1398 .start = sock_hash_seq_start, 1399 .next = sock_hash_seq_next, 1400 .stop = sock_hash_seq_stop, 1401 .show = sock_hash_seq_show, 1402 }; 1403 1404 static int sock_hash_init_seq_private(void *priv_data, 1405 struct bpf_iter_aux_info *aux) 1406 { 1407 struct sock_hash_seq_info *info = priv_data; 1408 1409 info->map = aux->map; 1410 info->htab = container_of(aux->map, struct bpf_shtab, map); 1411 return 0; 1412 } 1413 1414 static const struct bpf_iter_seq_info sock_hash_iter_seq_info = { 1415 .seq_ops = &sock_hash_seq_ops, 1416 .init_seq_private = sock_hash_init_seq_private, 1417 .seq_priv_size = sizeof(struct sock_hash_seq_info), 1418 }; 1419 1420 static int sock_hash_map_btf_id; 1421 const struct bpf_map_ops sock_hash_ops = { 1422 .map_meta_equal = bpf_map_meta_equal, 1423 .map_alloc = sock_hash_alloc, 1424 .map_free = sock_hash_free, 1425 .map_get_next_key = sock_hash_get_next_key, 1426 .map_update_elem = sock_map_update_elem, 1427 .map_delete_elem = sock_hash_delete_elem, 1428 .map_lookup_elem = sock_hash_lookup, 1429 .map_lookup_elem_sys_only = sock_hash_lookup_sys, 1430 .map_release_uref = sock_hash_release_progs, 1431 .map_check_btf = map_check_no_btf, 1432 .map_btf_name = "bpf_shtab", 1433 .map_btf_id = &sock_hash_map_btf_id, 1434 .iter_seq_info = &sock_hash_iter_seq_info, 1435 }; 1436 1437 static struct sk_psock_progs *sock_map_progs(struct bpf_map *map) 1438 { 1439 switch (map->map_type) { 1440 case BPF_MAP_TYPE_SOCKMAP: 1441 return &container_of(map, struct bpf_stab, map)->progs; 1442 case BPF_MAP_TYPE_SOCKHASH: 1443 return &container_of(map, struct bpf_shtab, map)->progs; 1444 default: 1445 break; 1446 } 1447 1448 return NULL; 1449 } 1450 1451 int sock_map_prog_update(struct bpf_map *map, struct bpf_prog *prog, 1452 struct bpf_prog *old, u32 which) 1453 { 1454 struct sk_psock_progs *progs = sock_map_progs(map); 1455 struct bpf_prog **pprog; 1456 1457 if (!progs) 1458 return -EOPNOTSUPP; 1459 1460 switch (which) { 1461 case BPF_SK_MSG_VERDICT: 1462 pprog = &progs->msg_parser; 1463 break; 1464 case BPF_SK_SKB_STREAM_PARSER: 1465 pprog = &progs->skb_parser; 1466 break; 1467 case BPF_SK_SKB_STREAM_VERDICT: 1468 pprog = &progs->skb_verdict; 1469 break; 1470 default: 1471 return -EOPNOTSUPP; 1472 } 1473 1474 if (old) 1475 return psock_replace_prog(pprog, prog, old); 1476 1477 psock_set_prog(pprog, prog); 1478 return 0; 1479 } 1480 1481 static void sock_map_unlink(struct sock *sk, struct sk_psock_link *link) 1482 { 1483 switch (link->map->map_type) { 1484 case BPF_MAP_TYPE_SOCKMAP: 1485 return sock_map_delete_from_link(link->map, sk, 1486 link->link_raw); 1487 case BPF_MAP_TYPE_SOCKHASH: 1488 return sock_hash_delete_from_link(link->map, sk, 1489 link->link_raw); 1490 default: 1491 break; 1492 } 1493 } 1494 1495 static void sock_map_remove_links(struct sock *sk, struct sk_psock *psock) 1496 { 1497 struct sk_psock_link *link; 1498 1499 while ((link = sk_psock_link_pop(psock))) { 1500 sock_map_unlink(sk, link); 1501 sk_psock_free_link(link); 1502 } 1503 } 1504 1505 void sock_map_unhash(struct sock *sk) 1506 { 1507 void (*saved_unhash)(struct sock *sk); 1508 struct sk_psock *psock; 1509 1510 rcu_read_lock(); 1511 psock = sk_psock(sk); 1512 if (unlikely(!psock)) { 1513 rcu_read_unlock(); 1514 if (sk->sk_prot->unhash) 1515 sk->sk_prot->unhash(sk); 1516 return; 1517 } 1518 1519 saved_unhash = psock->saved_unhash; 1520 sock_map_remove_links(sk, psock); 1521 rcu_read_unlock(); 1522 saved_unhash(sk); 1523 } 1524 1525 void sock_map_close(struct sock *sk, long timeout) 1526 { 1527 void (*saved_close)(struct sock *sk, long timeout); 1528 struct sk_psock *psock; 1529 1530 lock_sock(sk); 1531 rcu_read_lock(); 1532 psock = sk_psock(sk); 1533 if (unlikely(!psock)) { 1534 rcu_read_unlock(); 1535 release_sock(sk); 1536 return sk->sk_prot->close(sk, timeout); 1537 } 1538 1539 saved_close = psock->saved_close; 1540 sock_map_remove_links(sk, psock); 1541 rcu_read_unlock(); 1542 release_sock(sk); 1543 saved_close(sk, timeout); 1544 } 1545 1546 static int sock_map_iter_attach_target(struct bpf_prog *prog, 1547 union bpf_iter_link_info *linfo, 1548 struct bpf_iter_aux_info *aux) 1549 { 1550 struct bpf_map *map; 1551 int err = -EINVAL; 1552 1553 if (!linfo->map.map_fd) 1554 return -EBADF; 1555 1556 map = bpf_map_get_with_uref(linfo->map.map_fd); 1557 if (IS_ERR(map)) 1558 return PTR_ERR(map); 1559 1560 if (map->map_type != BPF_MAP_TYPE_SOCKMAP && 1561 map->map_type != BPF_MAP_TYPE_SOCKHASH) 1562 goto put_map; 1563 1564 if (prog->aux->max_rdonly_access > map->key_size) { 1565 err = -EACCES; 1566 goto put_map; 1567 } 1568 1569 aux->map = map; 1570 return 0; 1571 1572 put_map: 1573 bpf_map_put_with_uref(map); 1574 return err; 1575 } 1576 1577 static void sock_map_iter_detach_target(struct bpf_iter_aux_info *aux) 1578 { 1579 bpf_map_put_with_uref(aux->map); 1580 } 1581 1582 static struct bpf_iter_reg sock_map_iter_reg = { 1583 .target = "sockmap", 1584 .attach_target = sock_map_iter_attach_target, 1585 .detach_target = sock_map_iter_detach_target, 1586 .show_fdinfo = bpf_iter_map_show_fdinfo, 1587 .fill_link_info = bpf_iter_map_fill_link_info, 1588 .ctx_arg_info_size = 2, 1589 .ctx_arg_info = { 1590 { offsetof(struct bpf_iter__sockmap, key), 1591 PTR_TO_RDONLY_BUF_OR_NULL }, 1592 { offsetof(struct bpf_iter__sockmap, sk), 1593 PTR_TO_BTF_ID_OR_NULL }, 1594 }, 1595 }; 1596 1597 static int __init bpf_sockmap_iter_init(void) 1598 { 1599 sock_map_iter_reg.ctx_arg_info[1].btf_id = 1600 btf_sock_ids[BTF_SOCK_TYPE_SOCK]; 1601 return bpf_iter_reg_target(&sock_map_iter_reg); 1602 } 1603 late_initcall(bpf_sockmap_iter_init); 1604