1 // SPDX-License-Identifier: GPL-2.0 2 /* Copyright (c) 2017 - 2018 Covalent IO, Inc. http://covalent.io */ 3 4 #include <linux/bpf.h> 5 #include <linux/btf_ids.h> 6 #include <linux/filter.h> 7 #include <linux/errno.h> 8 #include <linux/file.h> 9 #include <linux/net.h> 10 #include <linux/workqueue.h> 11 #include <linux/skmsg.h> 12 #include <linux/list.h> 13 #include <linux/jhash.h> 14 #include <linux/sock_diag.h> 15 #include <net/udp.h> 16 17 struct bpf_stab { 18 struct bpf_map map; 19 struct sock **sks; 20 struct sk_psock_progs progs; 21 raw_spinlock_t lock; 22 }; 23 24 #define SOCK_CREATE_FLAG_MASK \ 25 (BPF_F_NUMA_NODE | BPF_F_RDONLY | BPF_F_WRONLY) 26 27 static int sock_map_prog_update(struct bpf_map *map, struct bpf_prog *prog, 28 struct bpf_prog *old, u32 which); 29 30 static struct bpf_map *sock_map_alloc(union bpf_attr *attr) 31 { 32 struct bpf_stab *stab; 33 34 if (!capable(CAP_NET_ADMIN)) 35 return ERR_PTR(-EPERM); 36 if (attr->max_entries == 0 || 37 attr->key_size != 4 || 38 (attr->value_size != sizeof(u32) && 39 attr->value_size != sizeof(u64)) || 40 attr->map_flags & ~SOCK_CREATE_FLAG_MASK) 41 return ERR_PTR(-EINVAL); 42 43 stab = kzalloc(sizeof(*stab), GFP_USER | __GFP_ACCOUNT); 44 if (!stab) 45 return ERR_PTR(-ENOMEM); 46 47 bpf_map_init_from_attr(&stab->map, attr); 48 raw_spin_lock_init(&stab->lock); 49 50 stab->sks = bpf_map_area_alloc(stab->map.max_entries * 51 sizeof(struct sock *), 52 stab->map.numa_node); 53 if (!stab->sks) { 54 kfree(stab); 55 return ERR_PTR(-ENOMEM); 56 } 57 58 return &stab->map; 59 } 60 61 int sock_map_get_from_fd(const union bpf_attr *attr, struct bpf_prog *prog) 62 { 63 u32 ufd = attr->target_fd; 64 struct bpf_map *map; 65 struct fd f; 66 int ret; 67 68 if (attr->attach_flags || attr->replace_bpf_fd) 69 return -EINVAL; 70 71 f = fdget(ufd); 72 map = __bpf_map_get(f); 73 if (IS_ERR(map)) 74 return PTR_ERR(map); 75 ret = sock_map_prog_update(map, prog, NULL, attr->attach_type); 76 fdput(f); 77 return ret; 78 } 79 80 int sock_map_prog_detach(const union bpf_attr *attr, enum bpf_prog_type ptype) 81 { 82 u32 ufd = attr->target_fd; 83 struct bpf_prog *prog; 84 struct bpf_map *map; 85 struct fd f; 86 int ret; 87 88 if (attr->attach_flags || attr->replace_bpf_fd) 89 return -EINVAL; 90 91 f = fdget(ufd); 92 map = __bpf_map_get(f); 93 if (IS_ERR(map)) 94 return PTR_ERR(map); 95 96 prog = bpf_prog_get(attr->attach_bpf_fd); 97 if (IS_ERR(prog)) { 98 ret = PTR_ERR(prog); 99 goto put_map; 100 } 101 102 if (prog->type != ptype) { 103 ret = -EINVAL; 104 goto put_prog; 105 } 106 107 ret = sock_map_prog_update(map, NULL, prog, attr->attach_type); 108 put_prog: 109 bpf_prog_put(prog); 110 put_map: 111 fdput(f); 112 return ret; 113 } 114 115 static void sock_map_sk_acquire(struct sock *sk) 116 __acquires(&sk->sk_lock.slock) 117 { 118 lock_sock(sk); 119 preempt_disable(); 120 rcu_read_lock(); 121 } 122 123 static void sock_map_sk_release(struct sock *sk) 124 __releases(&sk->sk_lock.slock) 125 { 126 rcu_read_unlock(); 127 preempt_enable(); 128 release_sock(sk); 129 } 130 131 static void sock_map_add_link(struct sk_psock *psock, 132 struct sk_psock_link *link, 133 struct bpf_map *map, void *link_raw) 134 { 135 link->link_raw = link_raw; 136 link->map = map; 137 spin_lock_bh(&psock->link_lock); 138 list_add_tail(&link->list, &psock->link); 139 spin_unlock_bh(&psock->link_lock); 140 } 141 142 static void sock_map_del_link(struct sock *sk, 143 struct sk_psock *psock, void *link_raw) 144 { 145 bool strp_stop = false, verdict_stop = false; 146 struct sk_psock_link *link, *tmp; 147 148 spin_lock_bh(&psock->link_lock); 149 list_for_each_entry_safe(link, tmp, &psock->link, list) { 150 if (link->link_raw == link_raw) { 151 struct bpf_map *map = link->map; 152 struct bpf_stab *stab = container_of(map, struct bpf_stab, 153 map); 154 if (psock->saved_data_ready && stab->progs.stream_parser) 155 strp_stop = true; 156 if (psock->saved_data_ready && stab->progs.stream_verdict) 157 verdict_stop = true; 158 list_del(&link->list); 159 sk_psock_free_link(link); 160 } 161 } 162 spin_unlock_bh(&psock->link_lock); 163 if (strp_stop || verdict_stop) { 164 write_lock_bh(&sk->sk_callback_lock); 165 if (strp_stop) 166 sk_psock_stop_strp(sk, psock); 167 else 168 sk_psock_stop_verdict(sk, psock); 169 write_unlock_bh(&sk->sk_callback_lock); 170 } 171 } 172 173 static void sock_map_unref(struct sock *sk, void *link_raw) 174 { 175 struct sk_psock *psock = sk_psock(sk); 176 177 if (likely(psock)) { 178 sock_map_del_link(sk, psock, link_raw); 179 sk_psock_put(sk, psock); 180 } 181 } 182 183 static int sock_map_init_proto(struct sock *sk, struct sk_psock *psock) 184 { 185 struct proto *prot; 186 187 switch (sk->sk_type) { 188 case SOCK_STREAM: 189 prot = tcp_bpf_get_proto(sk, psock); 190 break; 191 192 case SOCK_DGRAM: 193 prot = udp_bpf_get_proto(sk, psock); 194 break; 195 196 default: 197 return -EINVAL; 198 } 199 200 if (IS_ERR(prot)) 201 return PTR_ERR(prot); 202 203 sk_psock_update_proto(sk, psock, prot); 204 return 0; 205 } 206 207 static struct sk_psock *sock_map_psock_get_checked(struct sock *sk) 208 { 209 struct sk_psock *psock; 210 211 rcu_read_lock(); 212 psock = sk_psock(sk); 213 if (psock) { 214 if (sk->sk_prot->close != sock_map_close) { 215 psock = ERR_PTR(-EBUSY); 216 goto out; 217 } 218 219 if (!refcount_inc_not_zero(&psock->refcnt)) 220 psock = ERR_PTR(-EBUSY); 221 } 222 out: 223 rcu_read_unlock(); 224 return psock; 225 } 226 227 static int sock_map_link(struct bpf_map *map, struct sk_psock_progs *progs, 228 struct sock *sk) 229 { 230 struct bpf_prog *msg_parser, *stream_parser, *stream_verdict; 231 struct sk_psock *psock; 232 int ret; 233 234 stream_verdict = READ_ONCE(progs->stream_verdict); 235 if (stream_verdict) { 236 stream_verdict = bpf_prog_inc_not_zero(stream_verdict); 237 if (IS_ERR(stream_verdict)) 238 return PTR_ERR(stream_verdict); 239 } 240 241 stream_parser = READ_ONCE(progs->stream_parser); 242 if (stream_parser) { 243 stream_parser = bpf_prog_inc_not_zero(stream_parser); 244 if (IS_ERR(stream_parser)) { 245 ret = PTR_ERR(stream_parser); 246 goto out_put_stream_verdict; 247 } 248 } 249 250 msg_parser = READ_ONCE(progs->msg_parser); 251 if (msg_parser) { 252 msg_parser = bpf_prog_inc_not_zero(msg_parser); 253 if (IS_ERR(msg_parser)) { 254 ret = PTR_ERR(msg_parser); 255 goto out_put_stream_parser; 256 } 257 } 258 259 psock = sock_map_psock_get_checked(sk); 260 if (IS_ERR(psock)) { 261 ret = PTR_ERR(psock); 262 goto out_progs; 263 } 264 265 if (psock) { 266 if ((msg_parser && READ_ONCE(psock->progs.msg_parser)) || 267 (stream_parser && READ_ONCE(psock->progs.stream_parser)) || 268 (stream_verdict && READ_ONCE(psock->progs.stream_verdict))) { 269 sk_psock_put(sk, psock); 270 ret = -EBUSY; 271 goto out_progs; 272 } 273 } else { 274 psock = sk_psock_init(sk, map->numa_node); 275 if (IS_ERR(psock)) { 276 ret = PTR_ERR(psock); 277 goto out_progs; 278 } 279 } 280 281 if (msg_parser) 282 psock_set_prog(&psock->progs.msg_parser, msg_parser); 283 284 ret = sock_map_init_proto(sk, psock); 285 if (ret < 0) 286 goto out_drop; 287 288 write_lock_bh(&sk->sk_callback_lock); 289 if (stream_parser && stream_verdict && !psock->saved_data_ready) { 290 ret = sk_psock_init_strp(sk, psock); 291 if (ret) 292 goto out_unlock_drop; 293 psock_set_prog(&psock->progs.stream_verdict, stream_verdict); 294 psock_set_prog(&psock->progs.stream_parser, stream_parser); 295 sk_psock_start_strp(sk, psock); 296 } else if (!stream_parser && stream_verdict && !psock->saved_data_ready) { 297 psock_set_prog(&psock->progs.stream_verdict, stream_verdict); 298 sk_psock_start_verdict(sk,psock); 299 } 300 write_unlock_bh(&sk->sk_callback_lock); 301 return 0; 302 out_unlock_drop: 303 write_unlock_bh(&sk->sk_callback_lock); 304 out_drop: 305 sk_psock_put(sk, psock); 306 out_progs: 307 if (msg_parser) 308 bpf_prog_put(msg_parser); 309 out_put_stream_parser: 310 if (stream_parser) 311 bpf_prog_put(stream_parser); 312 out_put_stream_verdict: 313 if (stream_verdict) 314 bpf_prog_put(stream_verdict); 315 return ret; 316 } 317 318 static int sock_map_link_no_progs(struct bpf_map *map, struct sock *sk) 319 { 320 struct sk_psock *psock; 321 int ret; 322 323 psock = sock_map_psock_get_checked(sk); 324 if (IS_ERR(psock)) 325 return PTR_ERR(psock); 326 327 if (!psock) { 328 psock = sk_psock_init(sk, map->numa_node); 329 if (IS_ERR(psock)) 330 return PTR_ERR(psock); 331 } 332 333 ret = sock_map_init_proto(sk, psock); 334 if (ret < 0) 335 sk_psock_put(sk, psock); 336 return ret; 337 } 338 339 static void sock_map_free(struct bpf_map *map) 340 { 341 struct bpf_stab *stab = container_of(map, struct bpf_stab, map); 342 int i; 343 344 /* After the sync no updates or deletes will be in-flight so it 345 * is safe to walk map and remove entries without risking a race 346 * in EEXIST update case. 347 */ 348 synchronize_rcu(); 349 for (i = 0; i < stab->map.max_entries; i++) { 350 struct sock **psk = &stab->sks[i]; 351 struct sock *sk; 352 353 sk = xchg(psk, NULL); 354 if (sk) { 355 lock_sock(sk); 356 rcu_read_lock(); 357 sock_map_unref(sk, psk); 358 rcu_read_unlock(); 359 release_sock(sk); 360 } 361 } 362 363 /* wait for psock readers accessing its map link */ 364 synchronize_rcu(); 365 366 bpf_map_area_free(stab->sks); 367 kfree(stab); 368 } 369 370 static void sock_map_release_progs(struct bpf_map *map) 371 { 372 psock_progs_drop(&container_of(map, struct bpf_stab, map)->progs); 373 } 374 375 static struct sock *__sock_map_lookup_elem(struct bpf_map *map, u32 key) 376 { 377 struct bpf_stab *stab = container_of(map, struct bpf_stab, map); 378 379 WARN_ON_ONCE(!rcu_read_lock_held()); 380 381 if (unlikely(key >= map->max_entries)) 382 return NULL; 383 return READ_ONCE(stab->sks[key]); 384 } 385 386 static void *sock_map_lookup(struct bpf_map *map, void *key) 387 { 388 struct sock *sk; 389 390 sk = __sock_map_lookup_elem(map, *(u32 *)key); 391 if (!sk) 392 return NULL; 393 if (sk_is_refcounted(sk) && !refcount_inc_not_zero(&sk->sk_refcnt)) 394 return NULL; 395 return sk; 396 } 397 398 static void *sock_map_lookup_sys(struct bpf_map *map, void *key) 399 { 400 struct sock *sk; 401 402 if (map->value_size != sizeof(u64)) 403 return ERR_PTR(-ENOSPC); 404 405 sk = __sock_map_lookup_elem(map, *(u32 *)key); 406 if (!sk) 407 return ERR_PTR(-ENOENT); 408 409 __sock_gen_cookie(sk); 410 return &sk->sk_cookie; 411 } 412 413 static int __sock_map_delete(struct bpf_stab *stab, struct sock *sk_test, 414 struct sock **psk) 415 { 416 struct sock *sk; 417 int err = 0; 418 419 raw_spin_lock_bh(&stab->lock); 420 sk = *psk; 421 if (!sk_test || sk_test == sk) 422 sk = xchg(psk, NULL); 423 424 if (likely(sk)) 425 sock_map_unref(sk, psk); 426 else 427 err = -EINVAL; 428 429 raw_spin_unlock_bh(&stab->lock); 430 return err; 431 } 432 433 static void sock_map_delete_from_link(struct bpf_map *map, struct sock *sk, 434 void *link_raw) 435 { 436 struct bpf_stab *stab = container_of(map, struct bpf_stab, map); 437 438 __sock_map_delete(stab, sk, link_raw); 439 } 440 441 static int sock_map_delete_elem(struct bpf_map *map, void *key) 442 { 443 struct bpf_stab *stab = container_of(map, struct bpf_stab, map); 444 u32 i = *(u32 *)key; 445 struct sock **psk; 446 447 if (unlikely(i >= map->max_entries)) 448 return -EINVAL; 449 450 psk = &stab->sks[i]; 451 return __sock_map_delete(stab, NULL, psk); 452 } 453 454 static int sock_map_get_next_key(struct bpf_map *map, void *key, void *next) 455 { 456 struct bpf_stab *stab = container_of(map, struct bpf_stab, map); 457 u32 i = key ? *(u32 *)key : U32_MAX; 458 u32 *key_next = next; 459 460 if (i == stab->map.max_entries - 1) 461 return -ENOENT; 462 if (i >= stab->map.max_entries) 463 *key_next = 0; 464 else 465 *key_next = i + 1; 466 return 0; 467 } 468 469 static bool sock_map_redirect_allowed(const struct sock *sk); 470 471 static int sock_map_update_common(struct bpf_map *map, u32 idx, 472 struct sock *sk, u64 flags) 473 { 474 struct bpf_stab *stab = container_of(map, struct bpf_stab, map); 475 struct sk_psock_link *link; 476 struct sk_psock *psock; 477 struct sock *osk; 478 int ret; 479 480 WARN_ON_ONCE(!rcu_read_lock_held()); 481 if (unlikely(flags > BPF_EXIST)) 482 return -EINVAL; 483 if (unlikely(idx >= map->max_entries)) 484 return -E2BIG; 485 486 link = sk_psock_init_link(); 487 if (!link) 488 return -ENOMEM; 489 490 /* Only sockets we can redirect into/from in BPF need to hold 491 * refs to parser/verdict progs and have their sk_data_ready 492 * and sk_write_space callbacks overridden. 493 */ 494 if (sock_map_redirect_allowed(sk)) 495 ret = sock_map_link(map, &stab->progs, sk); 496 else 497 ret = sock_map_link_no_progs(map, sk); 498 if (ret < 0) 499 goto out_free; 500 501 psock = sk_psock(sk); 502 WARN_ON_ONCE(!psock); 503 504 raw_spin_lock_bh(&stab->lock); 505 osk = stab->sks[idx]; 506 if (osk && flags == BPF_NOEXIST) { 507 ret = -EEXIST; 508 goto out_unlock; 509 } else if (!osk && flags == BPF_EXIST) { 510 ret = -ENOENT; 511 goto out_unlock; 512 } 513 514 sock_map_add_link(psock, link, map, &stab->sks[idx]); 515 stab->sks[idx] = sk; 516 if (osk) 517 sock_map_unref(osk, &stab->sks[idx]); 518 raw_spin_unlock_bh(&stab->lock); 519 return 0; 520 out_unlock: 521 raw_spin_unlock_bh(&stab->lock); 522 if (psock) 523 sk_psock_put(sk, psock); 524 out_free: 525 sk_psock_free_link(link); 526 return ret; 527 } 528 529 static bool sock_map_op_okay(const struct bpf_sock_ops_kern *ops) 530 { 531 return ops->op == BPF_SOCK_OPS_PASSIVE_ESTABLISHED_CB || 532 ops->op == BPF_SOCK_OPS_ACTIVE_ESTABLISHED_CB || 533 ops->op == BPF_SOCK_OPS_TCP_LISTEN_CB; 534 } 535 536 static bool sk_is_tcp(const struct sock *sk) 537 { 538 return sk->sk_type == SOCK_STREAM && 539 sk->sk_protocol == IPPROTO_TCP; 540 } 541 542 static bool sk_is_udp(const struct sock *sk) 543 { 544 return sk->sk_type == SOCK_DGRAM && 545 sk->sk_protocol == IPPROTO_UDP; 546 } 547 548 static bool sock_map_redirect_allowed(const struct sock *sk) 549 { 550 return sk_is_tcp(sk) && sk->sk_state != TCP_LISTEN; 551 } 552 553 static bool sock_map_sk_is_suitable(const struct sock *sk) 554 { 555 return sk_is_tcp(sk) || sk_is_udp(sk); 556 } 557 558 static bool sock_map_sk_state_allowed(const struct sock *sk) 559 { 560 if (sk_is_tcp(sk)) 561 return (1 << sk->sk_state) & (TCPF_ESTABLISHED | TCPF_LISTEN); 562 else if (sk_is_udp(sk)) 563 return sk_hashed(sk); 564 565 return false; 566 } 567 568 static int sock_hash_update_common(struct bpf_map *map, void *key, 569 struct sock *sk, u64 flags); 570 571 int sock_map_update_elem_sys(struct bpf_map *map, void *key, void *value, 572 u64 flags) 573 { 574 struct socket *sock; 575 struct sock *sk; 576 int ret; 577 u64 ufd; 578 579 if (map->value_size == sizeof(u64)) 580 ufd = *(u64 *)value; 581 else 582 ufd = *(u32 *)value; 583 if (ufd > S32_MAX) 584 return -EINVAL; 585 586 sock = sockfd_lookup(ufd, &ret); 587 if (!sock) 588 return ret; 589 sk = sock->sk; 590 if (!sk) { 591 ret = -EINVAL; 592 goto out; 593 } 594 if (!sock_map_sk_is_suitable(sk)) { 595 ret = -EOPNOTSUPP; 596 goto out; 597 } 598 599 sock_map_sk_acquire(sk); 600 if (!sock_map_sk_state_allowed(sk)) 601 ret = -EOPNOTSUPP; 602 else if (map->map_type == BPF_MAP_TYPE_SOCKMAP) 603 ret = sock_map_update_common(map, *(u32 *)key, sk, flags); 604 else 605 ret = sock_hash_update_common(map, key, sk, flags); 606 sock_map_sk_release(sk); 607 out: 608 sockfd_put(sock); 609 return ret; 610 } 611 612 static int sock_map_update_elem(struct bpf_map *map, void *key, 613 void *value, u64 flags) 614 { 615 struct sock *sk = (struct sock *)value; 616 int ret; 617 618 if (unlikely(!sk || !sk_fullsock(sk))) 619 return -EINVAL; 620 621 if (!sock_map_sk_is_suitable(sk)) 622 return -EOPNOTSUPP; 623 624 local_bh_disable(); 625 bh_lock_sock(sk); 626 if (!sock_map_sk_state_allowed(sk)) 627 ret = -EOPNOTSUPP; 628 else if (map->map_type == BPF_MAP_TYPE_SOCKMAP) 629 ret = sock_map_update_common(map, *(u32 *)key, sk, flags); 630 else 631 ret = sock_hash_update_common(map, key, sk, flags); 632 bh_unlock_sock(sk); 633 local_bh_enable(); 634 return ret; 635 } 636 637 BPF_CALL_4(bpf_sock_map_update, struct bpf_sock_ops_kern *, sops, 638 struct bpf_map *, map, void *, key, u64, flags) 639 { 640 WARN_ON_ONCE(!rcu_read_lock_held()); 641 642 if (likely(sock_map_sk_is_suitable(sops->sk) && 643 sock_map_op_okay(sops))) 644 return sock_map_update_common(map, *(u32 *)key, sops->sk, 645 flags); 646 return -EOPNOTSUPP; 647 } 648 649 const struct bpf_func_proto bpf_sock_map_update_proto = { 650 .func = bpf_sock_map_update, 651 .gpl_only = false, 652 .pkt_access = true, 653 .ret_type = RET_INTEGER, 654 .arg1_type = ARG_PTR_TO_CTX, 655 .arg2_type = ARG_CONST_MAP_PTR, 656 .arg3_type = ARG_PTR_TO_MAP_KEY, 657 .arg4_type = ARG_ANYTHING, 658 }; 659 660 BPF_CALL_4(bpf_sk_redirect_map, struct sk_buff *, skb, 661 struct bpf_map *, map, u32, key, u64, flags) 662 { 663 struct sock *sk; 664 665 if (unlikely(flags & ~(BPF_F_INGRESS))) 666 return SK_DROP; 667 668 sk = __sock_map_lookup_elem(map, key); 669 if (unlikely(!sk || !sock_map_redirect_allowed(sk))) 670 return SK_DROP; 671 672 skb_bpf_set_redir(skb, sk, flags & BPF_F_INGRESS); 673 return SK_PASS; 674 } 675 676 const struct bpf_func_proto bpf_sk_redirect_map_proto = { 677 .func = bpf_sk_redirect_map, 678 .gpl_only = false, 679 .ret_type = RET_INTEGER, 680 .arg1_type = ARG_PTR_TO_CTX, 681 .arg2_type = ARG_CONST_MAP_PTR, 682 .arg3_type = ARG_ANYTHING, 683 .arg4_type = ARG_ANYTHING, 684 }; 685 686 BPF_CALL_4(bpf_msg_redirect_map, struct sk_msg *, msg, 687 struct bpf_map *, map, u32, key, u64, flags) 688 { 689 struct sock *sk; 690 691 if (unlikely(flags & ~(BPF_F_INGRESS))) 692 return SK_DROP; 693 694 sk = __sock_map_lookup_elem(map, key); 695 if (unlikely(!sk || !sock_map_redirect_allowed(sk))) 696 return SK_DROP; 697 698 msg->flags = flags; 699 msg->sk_redir = sk; 700 return SK_PASS; 701 } 702 703 const struct bpf_func_proto bpf_msg_redirect_map_proto = { 704 .func = bpf_msg_redirect_map, 705 .gpl_only = false, 706 .ret_type = RET_INTEGER, 707 .arg1_type = ARG_PTR_TO_CTX, 708 .arg2_type = ARG_CONST_MAP_PTR, 709 .arg3_type = ARG_ANYTHING, 710 .arg4_type = ARG_ANYTHING, 711 }; 712 713 struct sock_map_seq_info { 714 struct bpf_map *map; 715 struct sock *sk; 716 u32 index; 717 }; 718 719 struct bpf_iter__sockmap { 720 __bpf_md_ptr(struct bpf_iter_meta *, meta); 721 __bpf_md_ptr(struct bpf_map *, map); 722 __bpf_md_ptr(void *, key); 723 __bpf_md_ptr(struct sock *, sk); 724 }; 725 726 DEFINE_BPF_ITER_FUNC(sockmap, struct bpf_iter_meta *meta, 727 struct bpf_map *map, void *key, 728 struct sock *sk) 729 730 static void *sock_map_seq_lookup_elem(struct sock_map_seq_info *info) 731 { 732 if (unlikely(info->index >= info->map->max_entries)) 733 return NULL; 734 735 info->sk = __sock_map_lookup_elem(info->map, info->index); 736 737 /* can't return sk directly, since that might be NULL */ 738 return info; 739 } 740 741 static void *sock_map_seq_start(struct seq_file *seq, loff_t *pos) 742 __acquires(rcu) 743 { 744 struct sock_map_seq_info *info = seq->private; 745 746 if (*pos == 0) 747 ++*pos; 748 749 /* pairs with sock_map_seq_stop */ 750 rcu_read_lock(); 751 return sock_map_seq_lookup_elem(info); 752 } 753 754 static void *sock_map_seq_next(struct seq_file *seq, void *v, loff_t *pos) 755 __must_hold(rcu) 756 { 757 struct sock_map_seq_info *info = seq->private; 758 759 ++*pos; 760 ++info->index; 761 762 return sock_map_seq_lookup_elem(info); 763 } 764 765 static int sock_map_seq_show(struct seq_file *seq, void *v) 766 __must_hold(rcu) 767 { 768 struct sock_map_seq_info *info = seq->private; 769 struct bpf_iter__sockmap ctx = {}; 770 struct bpf_iter_meta meta; 771 struct bpf_prog *prog; 772 773 meta.seq = seq; 774 prog = bpf_iter_get_info(&meta, !v); 775 if (!prog) 776 return 0; 777 778 ctx.meta = &meta; 779 ctx.map = info->map; 780 if (v) { 781 ctx.key = &info->index; 782 ctx.sk = info->sk; 783 } 784 785 return bpf_iter_run_prog(prog, &ctx); 786 } 787 788 static void sock_map_seq_stop(struct seq_file *seq, void *v) 789 __releases(rcu) 790 { 791 if (!v) 792 (void)sock_map_seq_show(seq, NULL); 793 794 /* pairs with sock_map_seq_start */ 795 rcu_read_unlock(); 796 } 797 798 static const struct seq_operations sock_map_seq_ops = { 799 .start = sock_map_seq_start, 800 .next = sock_map_seq_next, 801 .stop = sock_map_seq_stop, 802 .show = sock_map_seq_show, 803 }; 804 805 static int sock_map_init_seq_private(void *priv_data, 806 struct bpf_iter_aux_info *aux) 807 { 808 struct sock_map_seq_info *info = priv_data; 809 810 info->map = aux->map; 811 return 0; 812 } 813 814 static const struct bpf_iter_seq_info sock_map_iter_seq_info = { 815 .seq_ops = &sock_map_seq_ops, 816 .init_seq_private = sock_map_init_seq_private, 817 .seq_priv_size = sizeof(struct sock_map_seq_info), 818 }; 819 820 static int sock_map_btf_id; 821 const struct bpf_map_ops sock_map_ops = { 822 .map_meta_equal = bpf_map_meta_equal, 823 .map_alloc = sock_map_alloc, 824 .map_free = sock_map_free, 825 .map_get_next_key = sock_map_get_next_key, 826 .map_lookup_elem_sys_only = sock_map_lookup_sys, 827 .map_update_elem = sock_map_update_elem, 828 .map_delete_elem = sock_map_delete_elem, 829 .map_lookup_elem = sock_map_lookup, 830 .map_release_uref = sock_map_release_progs, 831 .map_check_btf = map_check_no_btf, 832 .map_btf_name = "bpf_stab", 833 .map_btf_id = &sock_map_btf_id, 834 .iter_seq_info = &sock_map_iter_seq_info, 835 }; 836 837 struct bpf_shtab_elem { 838 struct rcu_head rcu; 839 u32 hash; 840 struct sock *sk; 841 struct hlist_node node; 842 u8 key[]; 843 }; 844 845 struct bpf_shtab_bucket { 846 struct hlist_head head; 847 raw_spinlock_t lock; 848 }; 849 850 struct bpf_shtab { 851 struct bpf_map map; 852 struct bpf_shtab_bucket *buckets; 853 u32 buckets_num; 854 u32 elem_size; 855 struct sk_psock_progs progs; 856 atomic_t count; 857 }; 858 859 static inline u32 sock_hash_bucket_hash(const void *key, u32 len) 860 { 861 return jhash(key, len, 0); 862 } 863 864 static struct bpf_shtab_bucket *sock_hash_select_bucket(struct bpf_shtab *htab, 865 u32 hash) 866 { 867 return &htab->buckets[hash & (htab->buckets_num - 1)]; 868 } 869 870 static struct bpf_shtab_elem * 871 sock_hash_lookup_elem_raw(struct hlist_head *head, u32 hash, void *key, 872 u32 key_size) 873 { 874 struct bpf_shtab_elem *elem; 875 876 hlist_for_each_entry_rcu(elem, head, node) { 877 if (elem->hash == hash && 878 !memcmp(&elem->key, key, key_size)) 879 return elem; 880 } 881 882 return NULL; 883 } 884 885 static struct sock *__sock_hash_lookup_elem(struct bpf_map *map, void *key) 886 { 887 struct bpf_shtab *htab = container_of(map, struct bpf_shtab, map); 888 u32 key_size = map->key_size, hash; 889 struct bpf_shtab_bucket *bucket; 890 struct bpf_shtab_elem *elem; 891 892 WARN_ON_ONCE(!rcu_read_lock_held()); 893 894 hash = sock_hash_bucket_hash(key, key_size); 895 bucket = sock_hash_select_bucket(htab, hash); 896 elem = sock_hash_lookup_elem_raw(&bucket->head, hash, key, key_size); 897 898 return elem ? elem->sk : NULL; 899 } 900 901 static void sock_hash_free_elem(struct bpf_shtab *htab, 902 struct bpf_shtab_elem *elem) 903 { 904 atomic_dec(&htab->count); 905 kfree_rcu(elem, rcu); 906 } 907 908 static void sock_hash_delete_from_link(struct bpf_map *map, struct sock *sk, 909 void *link_raw) 910 { 911 struct bpf_shtab *htab = container_of(map, struct bpf_shtab, map); 912 struct bpf_shtab_elem *elem_probe, *elem = link_raw; 913 struct bpf_shtab_bucket *bucket; 914 915 WARN_ON_ONCE(!rcu_read_lock_held()); 916 bucket = sock_hash_select_bucket(htab, elem->hash); 917 918 /* elem may be deleted in parallel from the map, but access here 919 * is okay since it's going away only after RCU grace period. 920 * However, we need to check whether it's still present. 921 */ 922 raw_spin_lock_bh(&bucket->lock); 923 elem_probe = sock_hash_lookup_elem_raw(&bucket->head, elem->hash, 924 elem->key, map->key_size); 925 if (elem_probe && elem_probe == elem) { 926 hlist_del_rcu(&elem->node); 927 sock_map_unref(elem->sk, elem); 928 sock_hash_free_elem(htab, elem); 929 } 930 raw_spin_unlock_bh(&bucket->lock); 931 } 932 933 static int sock_hash_delete_elem(struct bpf_map *map, void *key) 934 { 935 struct bpf_shtab *htab = container_of(map, struct bpf_shtab, map); 936 u32 hash, key_size = map->key_size; 937 struct bpf_shtab_bucket *bucket; 938 struct bpf_shtab_elem *elem; 939 int ret = -ENOENT; 940 941 hash = sock_hash_bucket_hash(key, key_size); 942 bucket = sock_hash_select_bucket(htab, hash); 943 944 raw_spin_lock_bh(&bucket->lock); 945 elem = sock_hash_lookup_elem_raw(&bucket->head, hash, key, key_size); 946 if (elem) { 947 hlist_del_rcu(&elem->node); 948 sock_map_unref(elem->sk, elem); 949 sock_hash_free_elem(htab, elem); 950 ret = 0; 951 } 952 raw_spin_unlock_bh(&bucket->lock); 953 return ret; 954 } 955 956 static struct bpf_shtab_elem *sock_hash_alloc_elem(struct bpf_shtab *htab, 957 void *key, u32 key_size, 958 u32 hash, struct sock *sk, 959 struct bpf_shtab_elem *old) 960 { 961 struct bpf_shtab_elem *new; 962 963 if (atomic_inc_return(&htab->count) > htab->map.max_entries) { 964 if (!old) { 965 atomic_dec(&htab->count); 966 return ERR_PTR(-E2BIG); 967 } 968 } 969 970 new = bpf_map_kmalloc_node(&htab->map, htab->elem_size, 971 GFP_ATOMIC | __GFP_NOWARN, 972 htab->map.numa_node); 973 if (!new) { 974 atomic_dec(&htab->count); 975 return ERR_PTR(-ENOMEM); 976 } 977 memcpy(new->key, key, key_size); 978 new->sk = sk; 979 new->hash = hash; 980 return new; 981 } 982 983 static int sock_hash_update_common(struct bpf_map *map, void *key, 984 struct sock *sk, u64 flags) 985 { 986 struct bpf_shtab *htab = container_of(map, struct bpf_shtab, map); 987 u32 key_size = map->key_size, hash; 988 struct bpf_shtab_elem *elem, *elem_new; 989 struct bpf_shtab_bucket *bucket; 990 struct sk_psock_link *link; 991 struct sk_psock *psock; 992 int ret; 993 994 WARN_ON_ONCE(!rcu_read_lock_held()); 995 if (unlikely(flags > BPF_EXIST)) 996 return -EINVAL; 997 998 link = sk_psock_init_link(); 999 if (!link) 1000 return -ENOMEM; 1001 1002 /* Only sockets we can redirect into/from in BPF need to hold 1003 * refs to parser/verdict progs and have their sk_data_ready 1004 * and sk_write_space callbacks overridden. 1005 */ 1006 if (sock_map_redirect_allowed(sk)) 1007 ret = sock_map_link(map, &htab->progs, sk); 1008 else 1009 ret = sock_map_link_no_progs(map, sk); 1010 if (ret < 0) 1011 goto out_free; 1012 1013 psock = sk_psock(sk); 1014 WARN_ON_ONCE(!psock); 1015 1016 hash = sock_hash_bucket_hash(key, key_size); 1017 bucket = sock_hash_select_bucket(htab, hash); 1018 1019 raw_spin_lock_bh(&bucket->lock); 1020 elem = sock_hash_lookup_elem_raw(&bucket->head, hash, key, key_size); 1021 if (elem && flags == BPF_NOEXIST) { 1022 ret = -EEXIST; 1023 goto out_unlock; 1024 } else if (!elem && flags == BPF_EXIST) { 1025 ret = -ENOENT; 1026 goto out_unlock; 1027 } 1028 1029 elem_new = sock_hash_alloc_elem(htab, key, key_size, hash, sk, elem); 1030 if (IS_ERR(elem_new)) { 1031 ret = PTR_ERR(elem_new); 1032 goto out_unlock; 1033 } 1034 1035 sock_map_add_link(psock, link, map, elem_new); 1036 /* Add new element to the head of the list, so that 1037 * concurrent search will find it before old elem. 1038 */ 1039 hlist_add_head_rcu(&elem_new->node, &bucket->head); 1040 if (elem) { 1041 hlist_del_rcu(&elem->node); 1042 sock_map_unref(elem->sk, elem); 1043 sock_hash_free_elem(htab, elem); 1044 } 1045 raw_spin_unlock_bh(&bucket->lock); 1046 return 0; 1047 out_unlock: 1048 raw_spin_unlock_bh(&bucket->lock); 1049 sk_psock_put(sk, psock); 1050 out_free: 1051 sk_psock_free_link(link); 1052 return ret; 1053 } 1054 1055 static int sock_hash_get_next_key(struct bpf_map *map, void *key, 1056 void *key_next) 1057 { 1058 struct bpf_shtab *htab = container_of(map, struct bpf_shtab, map); 1059 struct bpf_shtab_elem *elem, *elem_next; 1060 u32 hash, key_size = map->key_size; 1061 struct hlist_head *head; 1062 int i = 0; 1063 1064 if (!key) 1065 goto find_first_elem; 1066 hash = sock_hash_bucket_hash(key, key_size); 1067 head = &sock_hash_select_bucket(htab, hash)->head; 1068 elem = sock_hash_lookup_elem_raw(head, hash, key, key_size); 1069 if (!elem) 1070 goto find_first_elem; 1071 1072 elem_next = hlist_entry_safe(rcu_dereference(hlist_next_rcu(&elem->node)), 1073 struct bpf_shtab_elem, node); 1074 if (elem_next) { 1075 memcpy(key_next, elem_next->key, key_size); 1076 return 0; 1077 } 1078 1079 i = hash & (htab->buckets_num - 1); 1080 i++; 1081 find_first_elem: 1082 for (; i < htab->buckets_num; i++) { 1083 head = &sock_hash_select_bucket(htab, i)->head; 1084 elem_next = hlist_entry_safe(rcu_dereference(hlist_first_rcu(head)), 1085 struct bpf_shtab_elem, node); 1086 if (elem_next) { 1087 memcpy(key_next, elem_next->key, key_size); 1088 return 0; 1089 } 1090 } 1091 1092 return -ENOENT; 1093 } 1094 1095 static struct bpf_map *sock_hash_alloc(union bpf_attr *attr) 1096 { 1097 struct bpf_shtab *htab; 1098 int i, err; 1099 1100 if (!capable(CAP_NET_ADMIN)) 1101 return ERR_PTR(-EPERM); 1102 if (attr->max_entries == 0 || 1103 attr->key_size == 0 || 1104 (attr->value_size != sizeof(u32) && 1105 attr->value_size != sizeof(u64)) || 1106 attr->map_flags & ~SOCK_CREATE_FLAG_MASK) 1107 return ERR_PTR(-EINVAL); 1108 if (attr->key_size > MAX_BPF_STACK) 1109 return ERR_PTR(-E2BIG); 1110 1111 htab = kzalloc(sizeof(*htab), GFP_USER | __GFP_ACCOUNT); 1112 if (!htab) 1113 return ERR_PTR(-ENOMEM); 1114 1115 bpf_map_init_from_attr(&htab->map, attr); 1116 1117 htab->buckets_num = roundup_pow_of_two(htab->map.max_entries); 1118 htab->elem_size = sizeof(struct bpf_shtab_elem) + 1119 round_up(htab->map.key_size, 8); 1120 if (htab->buckets_num == 0 || 1121 htab->buckets_num > U32_MAX / sizeof(struct bpf_shtab_bucket)) { 1122 err = -EINVAL; 1123 goto free_htab; 1124 } 1125 1126 htab->buckets = bpf_map_area_alloc(htab->buckets_num * 1127 sizeof(struct bpf_shtab_bucket), 1128 htab->map.numa_node); 1129 if (!htab->buckets) { 1130 err = -ENOMEM; 1131 goto free_htab; 1132 } 1133 1134 for (i = 0; i < htab->buckets_num; i++) { 1135 INIT_HLIST_HEAD(&htab->buckets[i].head); 1136 raw_spin_lock_init(&htab->buckets[i].lock); 1137 } 1138 1139 return &htab->map; 1140 free_htab: 1141 kfree(htab); 1142 return ERR_PTR(err); 1143 } 1144 1145 static void sock_hash_free(struct bpf_map *map) 1146 { 1147 struct bpf_shtab *htab = container_of(map, struct bpf_shtab, map); 1148 struct bpf_shtab_bucket *bucket; 1149 struct hlist_head unlink_list; 1150 struct bpf_shtab_elem *elem; 1151 struct hlist_node *node; 1152 int i; 1153 1154 /* After the sync no updates or deletes will be in-flight so it 1155 * is safe to walk map and remove entries without risking a race 1156 * in EEXIST update case. 1157 */ 1158 synchronize_rcu(); 1159 for (i = 0; i < htab->buckets_num; i++) { 1160 bucket = sock_hash_select_bucket(htab, i); 1161 1162 /* We are racing with sock_hash_delete_from_link to 1163 * enter the spin-lock critical section. Every socket on 1164 * the list is still linked to sockhash. Since link 1165 * exists, psock exists and holds a ref to socket. That 1166 * lets us to grab a socket ref too. 1167 */ 1168 raw_spin_lock_bh(&bucket->lock); 1169 hlist_for_each_entry(elem, &bucket->head, node) 1170 sock_hold(elem->sk); 1171 hlist_move_list(&bucket->head, &unlink_list); 1172 raw_spin_unlock_bh(&bucket->lock); 1173 1174 /* Process removed entries out of atomic context to 1175 * block for socket lock before deleting the psock's 1176 * link to sockhash. 1177 */ 1178 hlist_for_each_entry_safe(elem, node, &unlink_list, node) { 1179 hlist_del(&elem->node); 1180 lock_sock(elem->sk); 1181 rcu_read_lock(); 1182 sock_map_unref(elem->sk, elem); 1183 rcu_read_unlock(); 1184 release_sock(elem->sk); 1185 sock_put(elem->sk); 1186 sock_hash_free_elem(htab, elem); 1187 } 1188 } 1189 1190 /* wait for psock readers accessing its map link */ 1191 synchronize_rcu(); 1192 1193 bpf_map_area_free(htab->buckets); 1194 kfree(htab); 1195 } 1196 1197 static void *sock_hash_lookup_sys(struct bpf_map *map, void *key) 1198 { 1199 struct sock *sk; 1200 1201 if (map->value_size != sizeof(u64)) 1202 return ERR_PTR(-ENOSPC); 1203 1204 sk = __sock_hash_lookup_elem(map, key); 1205 if (!sk) 1206 return ERR_PTR(-ENOENT); 1207 1208 __sock_gen_cookie(sk); 1209 return &sk->sk_cookie; 1210 } 1211 1212 static void *sock_hash_lookup(struct bpf_map *map, void *key) 1213 { 1214 struct sock *sk; 1215 1216 sk = __sock_hash_lookup_elem(map, key); 1217 if (!sk) 1218 return NULL; 1219 if (sk_is_refcounted(sk) && !refcount_inc_not_zero(&sk->sk_refcnt)) 1220 return NULL; 1221 return sk; 1222 } 1223 1224 static void sock_hash_release_progs(struct bpf_map *map) 1225 { 1226 psock_progs_drop(&container_of(map, struct bpf_shtab, map)->progs); 1227 } 1228 1229 BPF_CALL_4(bpf_sock_hash_update, struct bpf_sock_ops_kern *, sops, 1230 struct bpf_map *, map, void *, key, u64, flags) 1231 { 1232 WARN_ON_ONCE(!rcu_read_lock_held()); 1233 1234 if (likely(sock_map_sk_is_suitable(sops->sk) && 1235 sock_map_op_okay(sops))) 1236 return sock_hash_update_common(map, key, sops->sk, flags); 1237 return -EOPNOTSUPP; 1238 } 1239 1240 const struct bpf_func_proto bpf_sock_hash_update_proto = { 1241 .func = bpf_sock_hash_update, 1242 .gpl_only = false, 1243 .pkt_access = true, 1244 .ret_type = RET_INTEGER, 1245 .arg1_type = ARG_PTR_TO_CTX, 1246 .arg2_type = ARG_CONST_MAP_PTR, 1247 .arg3_type = ARG_PTR_TO_MAP_KEY, 1248 .arg4_type = ARG_ANYTHING, 1249 }; 1250 1251 BPF_CALL_4(bpf_sk_redirect_hash, struct sk_buff *, skb, 1252 struct bpf_map *, map, void *, key, u64, flags) 1253 { 1254 struct sock *sk; 1255 1256 if (unlikely(flags & ~(BPF_F_INGRESS))) 1257 return SK_DROP; 1258 1259 sk = __sock_hash_lookup_elem(map, key); 1260 if (unlikely(!sk || !sock_map_redirect_allowed(sk))) 1261 return SK_DROP; 1262 1263 skb_bpf_set_redir(skb, sk, flags & BPF_F_INGRESS); 1264 return SK_PASS; 1265 } 1266 1267 const struct bpf_func_proto bpf_sk_redirect_hash_proto = { 1268 .func = bpf_sk_redirect_hash, 1269 .gpl_only = false, 1270 .ret_type = RET_INTEGER, 1271 .arg1_type = ARG_PTR_TO_CTX, 1272 .arg2_type = ARG_CONST_MAP_PTR, 1273 .arg3_type = ARG_PTR_TO_MAP_KEY, 1274 .arg4_type = ARG_ANYTHING, 1275 }; 1276 1277 BPF_CALL_4(bpf_msg_redirect_hash, struct sk_msg *, msg, 1278 struct bpf_map *, map, void *, key, u64, flags) 1279 { 1280 struct sock *sk; 1281 1282 if (unlikely(flags & ~(BPF_F_INGRESS))) 1283 return SK_DROP; 1284 1285 sk = __sock_hash_lookup_elem(map, key); 1286 if (unlikely(!sk || !sock_map_redirect_allowed(sk))) 1287 return SK_DROP; 1288 1289 msg->flags = flags; 1290 msg->sk_redir = sk; 1291 return SK_PASS; 1292 } 1293 1294 const struct bpf_func_proto bpf_msg_redirect_hash_proto = { 1295 .func = bpf_msg_redirect_hash, 1296 .gpl_only = false, 1297 .ret_type = RET_INTEGER, 1298 .arg1_type = ARG_PTR_TO_CTX, 1299 .arg2_type = ARG_CONST_MAP_PTR, 1300 .arg3_type = ARG_PTR_TO_MAP_KEY, 1301 .arg4_type = ARG_ANYTHING, 1302 }; 1303 1304 struct sock_hash_seq_info { 1305 struct bpf_map *map; 1306 struct bpf_shtab *htab; 1307 u32 bucket_id; 1308 }; 1309 1310 static void *sock_hash_seq_find_next(struct sock_hash_seq_info *info, 1311 struct bpf_shtab_elem *prev_elem) 1312 { 1313 const struct bpf_shtab *htab = info->htab; 1314 struct bpf_shtab_bucket *bucket; 1315 struct bpf_shtab_elem *elem; 1316 struct hlist_node *node; 1317 1318 /* try to find next elem in the same bucket */ 1319 if (prev_elem) { 1320 node = rcu_dereference(hlist_next_rcu(&prev_elem->node)); 1321 elem = hlist_entry_safe(node, struct bpf_shtab_elem, node); 1322 if (elem) 1323 return elem; 1324 1325 /* no more elements, continue in the next bucket */ 1326 info->bucket_id++; 1327 } 1328 1329 for (; info->bucket_id < htab->buckets_num; info->bucket_id++) { 1330 bucket = &htab->buckets[info->bucket_id]; 1331 node = rcu_dereference(hlist_first_rcu(&bucket->head)); 1332 elem = hlist_entry_safe(node, struct bpf_shtab_elem, node); 1333 if (elem) 1334 return elem; 1335 } 1336 1337 return NULL; 1338 } 1339 1340 static void *sock_hash_seq_start(struct seq_file *seq, loff_t *pos) 1341 __acquires(rcu) 1342 { 1343 struct sock_hash_seq_info *info = seq->private; 1344 1345 if (*pos == 0) 1346 ++*pos; 1347 1348 /* pairs with sock_hash_seq_stop */ 1349 rcu_read_lock(); 1350 return sock_hash_seq_find_next(info, NULL); 1351 } 1352 1353 static void *sock_hash_seq_next(struct seq_file *seq, void *v, loff_t *pos) 1354 __must_hold(rcu) 1355 { 1356 struct sock_hash_seq_info *info = seq->private; 1357 1358 ++*pos; 1359 return sock_hash_seq_find_next(info, v); 1360 } 1361 1362 static int sock_hash_seq_show(struct seq_file *seq, void *v) 1363 __must_hold(rcu) 1364 { 1365 struct sock_hash_seq_info *info = seq->private; 1366 struct bpf_iter__sockmap ctx = {}; 1367 struct bpf_shtab_elem *elem = v; 1368 struct bpf_iter_meta meta; 1369 struct bpf_prog *prog; 1370 1371 meta.seq = seq; 1372 prog = bpf_iter_get_info(&meta, !elem); 1373 if (!prog) 1374 return 0; 1375 1376 ctx.meta = &meta; 1377 ctx.map = info->map; 1378 if (elem) { 1379 ctx.key = elem->key; 1380 ctx.sk = elem->sk; 1381 } 1382 1383 return bpf_iter_run_prog(prog, &ctx); 1384 } 1385 1386 static void sock_hash_seq_stop(struct seq_file *seq, void *v) 1387 __releases(rcu) 1388 { 1389 if (!v) 1390 (void)sock_hash_seq_show(seq, NULL); 1391 1392 /* pairs with sock_hash_seq_start */ 1393 rcu_read_unlock(); 1394 } 1395 1396 static const struct seq_operations sock_hash_seq_ops = { 1397 .start = sock_hash_seq_start, 1398 .next = sock_hash_seq_next, 1399 .stop = sock_hash_seq_stop, 1400 .show = sock_hash_seq_show, 1401 }; 1402 1403 static int sock_hash_init_seq_private(void *priv_data, 1404 struct bpf_iter_aux_info *aux) 1405 { 1406 struct sock_hash_seq_info *info = priv_data; 1407 1408 info->map = aux->map; 1409 info->htab = container_of(aux->map, struct bpf_shtab, map); 1410 return 0; 1411 } 1412 1413 static const struct bpf_iter_seq_info sock_hash_iter_seq_info = { 1414 .seq_ops = &sock_hash_seq_ops, 1415 .init_seq_private = sock_hash_init_seq_private, 1416 .seq_priv_size = sizeof(struct sock_hash_seq_info), 1417 }; 1418 1419 static int sock_hash_map_btf_id; 1420 const struct bpf_map_ops sock_hash_ops = { 1421 .map_meta_equal = bpf_map_meta_equal, 1422 .map_alloc = sock_hash_alloc, 1423 .map_free = sock_hash_free, 1424 .map_get_next_key = sock_hash_get_next_key, 1425 .map_update_elem = sock_map_update_elem, 1426 .map_delete_elem = sock_hash_delete_elem, 1427 .map_lookup_elem = sock_hash_lookup, 1428 .map_lookup_elem_sys_only = sock_hash_lookup_sys, 1429 .map_release_uref = sock_hash_release_progs, 1430 .map_check_btf = map_check_no_btf, 1431 .map_btf_name = "bpf_shtab", 1432 .map_btf_id = &sock_hash_map_btf_id, 1433 .iter_seq_info = &sock_hash_iter_seq_info, 1434 }; 1435 1436 static struct sk_psock_progs *sock_map_progs(struct bpf_map *map) 1437 { 1438 switch (map->map_type) { 1439 case BPF_MAP_TYPE_SOCKMAP: 1440 return &container_of(map, struct bpf_stab, map)->progs; 1441 case BPF_MAP_TYPE_SOCKHASH: 1442 return &container_of(map, struct bpf_shtab, map)->progs; 1443 default: 1444 break; 1445 } 1446 1447 return NULL; 1448 } 1449 1450 static int sock_map_prog_update(struct bpf_map *map, struct bpf_prog *prog, 1451 struct bpf_prog *old, u32 which) 1452 { 1453 struct sk_psock_progs *progs = sock_map_progs(map); 1454 struct bpf_prog **pprog; 1455 1456 if (!progs) 1457 return -EOPNOTSUPP; 1458 1459 switch (which) { 1460 case BPF_SK_MSG_VERDICT: 1461 pprog = &progs->msg_parser; 1462 break; 1463 #if IS_ENABLED(CONFIG_BPF_STREAM_PARSER) 1464 case BPF_SK_SKB_STREAM_PARSER: 1465 pprog = &progs->stream_parser; 1466 break; 1467 #endif 1468 case BPF_SK_SKB_STREAM_VERDICT: 1469 pprog = &progs->stream_verdict; 1470 break; 1471 default: 1472 return -EOPNOTSUPP; 1473 } 1474 1475 if (old) 1476 return psock_replace_prog(pprog, prog, old); 1477 1478 psock_set_prog(pprog, prog); 1479 return 0; 1480 } 1481 1482 static void sock_map_unlink(struct sock *sk, struct sk_psock_link *link) 1483 { 1484 switch (link->map->map_type) { 1485 case BPF_MAP_TYPE_SOCKMAP: 1486 return sock_map_delete_from_link(link->map, sk, 1487 link->link_raw); 1488 case BPF_MAP_TYPE_SOCKHASH: 1489 return sock_hash_delete_from_link(link->map, sk, 1490 link->link_raw); 1491 default: 1492 break; 1493 } 1494 } 1495 1496 static void sock_map_remove_links(struct sock *sk, struct sk_psock *psock) 1497 { 1498 struct sk_psock_link *link; 1499 1500 while ((link = sk_psock_link_pop(psock))) { 1501 sock_map_unlink(sk, link); 1502 sk_psock_free_link(link); 1503 } 1504 } 1505 1506 void sock_map_unhash(struct sock *sk) 1507 { 1508 void (*saved_unhash)(struct sock *sk); 1509 struct sk_psock *psock; 1510 1511 rcu_read_lock(); 1512 psock = sk_psock(sk); 1513 if (unlikely(!psock)) { 1514 rcu_read_unlock(); 1515 if (sk->sk_prot->unhash) 1516 sk->sk_prot->unhash(sk); 1517 return; 1518 } 1519 1520 saved_unhash = psock->saved_unhash; 1521 sock_map_remove_links(sk, psock); 1522 rcu_read_unlock(); 1523 saved_unhash(sk); 1524 } 1525 1526 void sock_map_close(struct sock *sk, long timeout) 1527 { 1528 void (*saved_close)(struct sock *sk, long timeout); 1529 struct sk_psock *psock; 1530 1531 lock_sock(sk); 1532 rcu_read_lock(); 1533 psock = sk_psock(sk); 1534 if (unlikely(!psock)) { 1535 rcu_read_unlock(); 1536 release_sock(sk); 1537 return sk->sk_prot->close(sk, timeout); 1538 } 1539 1540 saved_close = psock->saved_close; 1541 sock_map_remove_links(sk, psock); 1542 rcu_read_unlock(); 1543 release_sock(sk); 1544 saved_close(sk, timeout); 1545 } 1546 1547 static int sock_map_iter_attach_target(struct bpf_prog *prog, 1548 union bpf_iter_link_info *linfo, 1549 struct bpf_iter_aux_info *aux) 1550 { 1551 struct bpf_map *map; 1552 int err = -EINVAL; 1553 1554 if (!linfo->map.map_fd) 1555 return -EBADF; 1556 1557 map = bpf_map_get_with_uref(linfo->map.map_fd); 1558 if (IS_ERR(map)) 1559 return PTR_ERR(map); 1560 1561 if (map->map_type != BPF_MAP_TYPE_SOCKMAP && 1562 map->map_type != BPF_MAP_TYPE_SOCKHASH) 1563 goto put_map; 1564 1565 if (prog->aux->max_rdonly_access > map->key_size) { 1566 err = -EACCES; 1567 goto put_map; 1568 } 1569 1570 aux->map = map; 1571 return 0; 1572 1573 put_map: 1574 bpf_map_put_with_uref(map); 1575 return err; 1576 } 1577 1578 static void sock_map_iter_detach_target(struct bpf_iter_aux_info *aux) 1579 { 1580 bpf_map_put_with_uref(aux->map); 1581 } 1582 1583 static struct bpf_iter_reg sock_map_iter_reg = { 1584 .target = "sockmap", 1585 .attach_target = sock_map_iter_attach_target, 1586 .detach_target = sock_map_iter_detach_target, 1587 .show_fdinfo = bpf_iter_map_show_fdinfo, 1588 .fill_link_info = bpf_iter_map_fill_link_info, 1589 .ctx_arg_info_size = 2, 1590 .ctx_arg_info = { 1591 { offsetof(struct bpf_iter__sockmap, key), 1592 PTR_TO_RDONLY_BUF_OR_NULL }, 1593 { offsetof(struct bpf_iter__sockmap, sk), 1594 PTR_TO_BTF_ID_OR_NULL }, 1595 }, 1596 }; 1597 1598 static int __init bpf_sockmap_iter_init(void) 1599 { 1600 sock_map_iter_reg.ctx_arg_info[1].btf_id = 1601 btf_sock_ids[BTF_SOCK_TYPE_SOCK]; 1602 return bpf_iter_reg_target(&sock_map_iter_reg); 1603 } 1604 late_initcall(bpf_sockmap_iter_init); 1605