1 // SPDX-License-Identifier: GPL-2.0 2 /* Copyright (c) 2017 - 2018 Covalent IO, Inc. http://covalent.io */ 3 4 #include <linux/bpf.h> 5 #include <linux/btf_ids.h> 6 #include <linux/filter.h> 7 #include <linux/errno.h> 8 #include <linux/file.h> 9 #include <linux/net.h> 10 #include <linux/workqueue.h> 11 #include <linux/skmsg.h> 12 #include <linux/list.h> 13 #include <linux/jhash.h> 14 #include <linux/sock_diag.h> 15 #include <net/udp.h> 16 17 struct bpf_stab { 18 struct bpf_map map; 19 struct sock **sks; 20 struct sk_psock_progs progs; 21 raw_spinlock_t lock; 22 }; 23 24 #define SOCK_CREATE_FLAG_MASK \ 25 (BPF_F_NUMA_NODE | BPF_F_RDONLY | BPF_F_WRONLY) 26 27 static int sock_map_prog_update(struct bpf_map *map, struct bpf_prog *prog, 28 struct bpf_prog *old, u32 which); 29 static struct sk_psock_progs *sock_map_progs(struct bpf_map *map); 30 31 static struct bpf_map *sock_map_alloc(union bpf_attr *attr) 32 { 33 struct bpf_stab *stab; 34 35 if (!capable(CAP_NET_ADMIN)) 36 return ERR_PTR(-EPERM); 37 if (attr->max_entries == 0 || 38 attr->key_size != 4 || 39 (attr->value_size != sizeof(u32) && 40 attr->value_size != sizeof(u64)) || 41 attr->map_flags & ~SOCK_CREATE_FLAG_MASK) 42 return ERR_PTR(-EINVAL); 43 44 stab = kzalloc(sizeof(*stab), GFP_USER | __GFP_ACCOUNT); 45 if (!stab) 46 return ERR_PTR(-ENOMEM); 47 48 bpf_map_init_from_attr(&stab->map, attr); 49 raw_spin_lock_init(&stab->lock); 50 51 stab->sks = bpf_map_area_alloc((u64) stab->map.max_entries * 52 sizeof(struct sock *), 53 stab->map.numa_node); 54 if (!stab->sks) { 55 kfree(stab); 56 return ERR_PTR(-ENOMEM); 57 } 58 59 return &stab->map; 60 } 61 62 int sock_map_get_from_fd(const union bpf_attr *attr, struct bpf_prog *prog) 63 { 64 u32 ufd = attr->target_fd; 65 struct bpf_map *map; 66 struct fd f; 67 int ret; 68 69 if (attr->attach_flags || attr->replace_bpf_fd) 70 return -EINVAL; 71 72 f = fdget(ufd); 73 map = __bpf_map_get(f); 74 if (IS_ERR(map)) 75 return PTR_ERR(map); 76 ret = sock_map_prog_update(map, prog, NULL, attr->attach_type); 77 fdput(f); 78 return ret; 79 } 80 81 int sock_map_prog_detach(const union bpf_attr *attr, enum bpf_prog_type ptype) 82 { 83 u32 ufd = attr->target_fd; 84 struct bpf_prog *prog; 85 struct bpf_map *map; 86 struct fd f; 87 int ret; 88 89 if (attr->attach_flags || attr->replace_bpf_fd) 90 return -EINVAL; 91 92 f = fdget(ufd); 93 map = __bpf_map_get(f); 94 if (IS_ERR(map)) 95 return PTR_ERR(map); 96 97 prog = bpf_prog_get(attr->attach_bpf_fd); 98 if (IS_ERR(prog)) { 99 ret = PTR_ERR(prog); 100 goto put_map; 101 } 102 103 if (prog->type != ptype) { 104 ret = -EINVAL; 105 goto put_prog; 106 } 107 108 ret = sock_map_prog_update(map, NULL, prog, attr->attach_type); 109 put_prog: 110 bpf_prog_put(prog); 111 put_map: 112 fdput(f); 113 return ret; 114 } 115 116 static void sock_map_sk_acquire(struct sock *sk) 117 __acquires(&sk->sk_lock.slock) 118 { 119 lock_sock(sk); 120 preempt_disable(); 121 rcu_read_lock(); 122 } 123 124 static void sock_map_sk_release(struct sock *sk) 125 __releases(&sk->sk_lock.slock) 126 { 127 rcu_read_unlock(); 128 preempt_enable(); 129 release_sock(sk); 130 } 131 132 static void sock_map_add_link(struct sk_psock *psock, 133 struct sk_psock_link *link, 134 struct bpf_map *map, void *link_raw) 135 { 136 link->link_raw = link_raw; 137 link->map = map; 138 spin_lock_bh(&psock->link_lock); 139 list_add_tail(&link->list, &psock->link); 140 spin_unlock_bh(&psock->link_lock); 141 } 142 143 static void sock_map_del_link(struct sock *sk, 144 struct sk_psock *psock, void *link_raw) 145 { 146 bool strp_stop = false, verdict_stop = false; 147 struct sk_psock_link *link, *tmp; 148 149 spin_lock_bh(&psock->link_lock); 150 list_for_each_entry_safe(link, tmp, &psock->link, list) { 151 if (link->link_raw == link_raw) { 152 struct bpf_map *map = link->map; 153 struct bpf_stab *stab = container_of(map, struct bpf_stab, 154 map); 155 if (psock->saved_data_ready && stab->progs.stream_parser) 156 strp_stop = true; 157 if (psock->saved_data_ready && stab->progs.stream_verdict) 158 verdict_stop = true; 159 if (psock->saved_data_ready && stab->progs.skb_verdict) 160 verdict_stop = true; 161 list_del(&link->list); 162 sk_psock_free_link(link); 163 } 164 } 165 spin_unlock_bh(&psock->link_lock); 166 if (strp_stop || verdict_stop) { 167 write_lock_bh(&sk->sk_callback_lock); 168 if (strp_stop) 169 sk_psock_stop_strp(sk, psock); 170 else 171 sk_psock_stop_verdict(sk, psock); 172 write_unlock_bh(&sk->sk_callback_lock); 173 } 174 } 175 176 static void sock_map_unref(struct sock *sk, void *link_raw) 177 { 178 struct sk_psock *psock = sk_psock(sk); 179 180 if (likely(psock)) { 181 sock_map_del_link(sk, psock, link_raw); 182 sk_psock_put(sk, psock); 183 } 184 } 185 186 static int sock_map_init_proto(struct sock *sk, struct sk_psock *psock) 187 { 188 if (!sk->sk_prot->psock_update_sk_prot) 189 return -EINVAL; 190 psock->psock_update_sk_prot = sk->sk_prot->psock_update_sk_prot; 191 return sk->sk_prot->psock_update_sk_prot(sk, psock, false); 192 } 193 194 static struct sk_psock *sock_map_psock_get_checked(struct sock *sk) 195 { 196 struct sk_psock *psock; 197 198 rcu_read_lock(); 199 psock = sk_psock(sk); 200 if (psock) { 201 if (sk->sk_prot->close != sock_map_close) { 202 psock = ERR_PTR(-EBUSY); 203 goto out; 204 } 205 206 if (!refcount_inc_not_zero(&psock->refcnt)) 207 psock = ERR_PTR(-EBUSY); 208 } 209 out: 210 rcu_read_unlock(); 211 return psock; 212 } 213 214 static int sock_map_link(struct bpf_map *map, struct sock *sk) 215 { 216 struct sk_psock_progs *progs = sock_map_progs(map); 217 struct bpf_prog *stream_verdict = NULL; 218 struct bpf_prog *stream_parser = NULL; 219 struct bpf_prog *skb_verdict = NULL; 220 struct bpf_prog *msg_parser = NULL; 221 struct sk_psock *psock; 222 int ret; 223 224 stream_verdict = READ_ONCE(progs->stream_verdict); 225 if (stream_verdict) { 226 stream_verdict = bpf_prog_inc_not_zero(stream_verdict); 227 if (IS_ERR(stream_verdict)) 228 return PTR_ERR(stream_verdict); 229 } 230 231 stream_parser = READ_ONCE(progs->stream_parser); 232 if (stream_parser) { 233 stream_parser = bpf_prog_inc_not_zero(stream_parser); 234 if (IS_ERR(stream_parser)) { 235 ret = PTR_ERR(stream_parser); 236 goto out_put_stream_verdict; 237 } 238 } 239 240 msg_parser = READ_ONCE(progs->msg_parser); 241 if (msg_parser) { 242 msg_parser = bpf_prog_inc_not_zero(msg_parser); 243 if (IS_ERR(msg_parser)) { 244 ret = PTR_ERR(msg_parser); 245 goto out_put_stream_parser; 246 } 247 } 248 249 skb_verdict = READ_ONCE(progs->skb_verdict); 250 if (skb_verdict) { 251 skb_verdict = bpf_prog_inc_not_zero(skb_verdict); 252 if (IS_ERR(skb_verdict)) { 253 ret = PTR_ERR(skb_verdict); 254 goto out_put_msg_parser; 255 } 256 } 257 258 psock = sock_map_psock_get_checked(sk); 259 if (IS_ERR(psock)) { 260 ret = PTR_ERR(psock); 261 goto out_progs; 262 } 263 264 if (psock) { 265 if ((msg_parser && READ_ONCE(psock->progs.msg_parser)) || 266 (stream_parser && READ_ONCE(psock->progs.stream_parser)) || 267 (skb_verdict && READ_ONCE(psock->progs.skb_verdict)) || 268 (skb_verdict && READ_ONCE(psock->progs.stream_verdict)) || 269 (stream_verdict && READ_ONCE(psock->progs.skb_verdict)) || 270 (stream_verdict && READ_ONCE(psock->progs.stream_verdict))) { 271 sk_psock_put(sk, psock); 272 ret = -EBUSY; 273 goto out_progs; 274 } 275 } else { 276 psock = sk_psock_init(sk, map->numa_node); 277 if (IS_ERR(psock)) { 278 ret = PTR_ERR(psock); 279 goto out_progs; 280 } 281 } 282 283 if (msg_parser) 284 psock_set_prog(&psock->progs.msg_parser, msg_parser); 285 286 ret = sock_map_init_proto(sk, psock); 287 if (ret < 0) 288 goto out_drop; 289 290 write_lock_bh(&sk->sk_callback_lock); 291 if (stream_parser && stream_verdict && !psock->saved_data_ready) { 292 ret = sk_psock_init_strp(sk, psock); 293 if (ret) 294 goto out_unlock_drop; 295 psock_set_prog(&psock->progs.stream_verdict, stream_verdict); 296 psock_set_prog(&psock->progs.stream_parser, stream_parser); 297 sk_psock_start_strp(sk, psock); 298 } else if (!stream_parser && stream_verdict && !psock->saved_data_ready) { 299 psock_set_prog(&psock->progs.stream_verdict, stream_verdict); 300 sk_psock_start_verdict(sk,psock); 301 } else if (!stream_verdict && skb_verdict && !psock->saved_data_ready) { 302 psock_set_prog(&psock->progs.skb_verdict, skb_verdict); 303 sk_psock_start_verdict(sk, psock); 304 } 305 write_unlock_bh(&sk->sk_callback_lock); 306 return 0; 307 out_unlock_drop: 308 write_unlock_bh(&sk->sk_callback_lock); 309 out_drop: 310 sk_psock_put(sk, psock); 311 out_progs: 312 if (skb_verdict) 313 bpf_prog_put(skb_verdict); 314 out_put_msg_parser: 315 if (msg_parser) 316 bpf_prog_put(msg_parser); 317 out_put_stream_parser: 318 if (stream_parser) 319 bpf_prog_put(stream_parser); 320 out_put_stream_verdict: 321 if (stream_verdict) 322 bpf_prog_put(stream_verdict); 323 return ret; 324 } 325 326 static void sock_map_free(struct bpf_map *map) 327 { 328 struct bpf_stab *stab = container_of(map, struct bpf_stab, map); 329 int i; 330 331 /* After the sync no updates or deletes will be in-flight so it 332 * is safe to walk map and remove entries without risking a race 333 * in EEXIST update case. 334 */ 335 synchronize_rcu(); 336 for (i = 0; i < stab->map.max_entries; i++) { 337 struct sock **psk = &stab->sks[i]; 338 struct sock *sk; 339 340 sk = xchg(psk, NULL); 341 if (sk) { 342 lock_sock(sk); 343 rcu_read_lock(); 344 sock_map_unref(sk, psk); 345 rcu_read_unlock(); 346 release_sock(sk); 347 } 348 } 349 350 /* wait for psock readers accessing its map link */ 351 synchronize_rcu(); 352 353 bpf_map_area_free(stab->sks); 354 kfree(stab); 355 } 356 357 static void sock_map_release_progs(struct bpf_map *map) 358 { 359 psock_progs_drop(&container_of(map, struct bpf_stab, map)->progs); 360 } 361 362 static struct sock *__sock_map_lookup_elem(struct bpf_map *map, u32 key) 363 { 364 struct bpf_stab *stab = container_of(map, struct bpf_stab, map); 365 366 WARN_ON_ONCE(!rcu_read_lock_held()); 367 368 if (unlikely(key >= map->max_entries)) 369 return NULL; 370 return READ_ONCE(stab->sks[key]); 371 } 372 373 static void *sock_map_lookup(struct bpf_map *map, void *key) 374 { 375 struct sock *sk; 376 377 sk = __sock_map_lookup_elem(map, *(u32 *)key); 378 if (!sk) 379 return NULL; 380 if (sk_is_refcounted(sk) && !refcount_inc_not_zero(&sk->sk_refcnt)) 381 return NULL; 382 return sk; 383 } 384 385 static void *sock_map_lookup_sys(struct bpf_map *map, void *key) 386 { 387 struct sock *sk; 388 389 if (map->value_size != sizeof(u64)) 390 return ERR_PTR(-ENOSPC); 391 392 sk = __sock_map_lookup_elem(map, *(u32 *)key); 393 if (!sk) 394 return ERR_PTR(-ENOENT); 395 396 __sock_gen_cookie(sk); 397 return &sk->sk_cookie; 398 } 399 400 static int __sock_map_delete(struct bpf_stab *stab, struct sock *sk_test, 401 struct sock **psk) 402 { 403 struct sock *sk; 404 int err = 0; 405 406 raw_spin_lock_bh(&stab->lock); 407 sk = *psk; 408 if (!sk_test || sk_test == sk) 409 sk = xchg(psk, NULL); 410 411 if (likely(sk)) 412 sock_map_unref(sk, psk); 413 else 414 err = -EINVAL; 415 416 raw_spin_unlock_bh(&stab->lock); 417 return err; 418 } 419 420 static void sock_map_delete_from_link(struct bpf_map *map, struct sock *sk, 421 void *link_raw) 422 { 423 struct bpf_stab *stab = container_of(map, struct bpf_stab, map); 424 425 __sock_map_delete(stab, sk, link_raw); 426 } 427 428 static int sock_map_delete_elem(struct bpf_map *map, void *key) 429 { 430 struct bpf_stab *stab = container_of(map, struct bpf_stab, map); 431 u32 i = *(u32 *)key; 432 struct sock **psk; 433 434 if (unlikely(i >= map->max_entries)) 435 return -EINVAL; 436 437 psk = &stab->sks[i]; 438 return __sock_map_delete(stab, NULL, psk); 439 } 440 441 static int sock_map_get_next_key(struct bpf_map *map, void *key, void *next) 442 { 443 struct bpf_stab *stab = container_of(map, struct bpf_stab, map); 444 u32 i = key ? *(u32 *)key : U32_MAX; 445 u32 *key_next = next; 446 447 if (i == stab->map.max_entries - 1) 448 return -ENOENT; 449 if (i >= stab->map.max_entries) 450 *key_next = 0; 451 else 452 *key_next = i + 1; 453 return 0; 454 } 455 456 static int sock_map_update_common(struct bpf_map *map, u32 idx, 457 struct sock *sk, u64 flags) 458 { 459 struct bpf_stab *stab = container_of(map, struct bpf_stab, map); 460 struct sk_psock_link *link; 461 struct sk_psock *psock; 462 struct sock *osk; 463 int ret; 464 465 WARN_ON_ONCE(!rcu_read_lock_held()); 466 if (unlikely(flags > BPF_EXIST)) 467 return -EINVAL; 468 if (unlikely(idx >= map->max_entries)) 469 return -E2BIG; 470 471 link = sk_psock_init_link(); 472 if (!link) 473 return -ENOMEM; 474 475 ret = sock_map_link(map, sk); 476 if (ret < 0) 477 goto out_free; 478 479 psock = sk_psock(sk); 480 WARN_ON_ONCE(!psock); 481 482 raw_spin_lock_bh(&stab->lock); 483 osk = stab->sks[idx]; 484 if (osk && flags == BPF_NOEXIST) { 485 ret = -EEXIST; 486 goto out_unlock; 487 } else if (!osk && flags == BPF_EXIST) { 488 ret = -ENOENT; 489 goto out_unlock; 490 } 491 492 sock_map_add_link(psock, link, map, &stab->sks[idx]); 493 stab->sks[idx] = sk; 494 if (osk) 495 sock_map_unref(osk, &stab->sks[idx]); 496 raw_spin_unlock_bh(&stab->lock); 497 return 0; 498 out_unlock: 499 raw_spin_unlock_bh(&stab->lock); 500 if (psock) 501 sk_psock_put(sk, psock); 502 out_free: 503 sk_psock_free_link(link); 504 return ret; 505 } 506 507 static bool sock_map_op_okay(const struct bpf_sock_ops_kern *ops) 508 { 509 return ops->op == BPF_SOCK_OPS_PASSIVE_ESTABLISHED_CB || 510 ops->op == BPF_SOCK_OPS_ACTIVE_ESTABLISHED_CB || 511 ops->op == BPF_SOCK_OPS_TCP_LISTEN_CB; 512 } 513 514 static bool sk_is_tcp(const struct sock *sk) 515 { 516 return sk->sk_type == SOCK_STREAM && 517 sk->sk_protocol == IPPROTO_TCP; 518 } 519 520 static bool sock_map_redirect_allowed(const struct sock *sk) 521 { 522 if (sk_is_tcp(sk)) 523 return sk->sk_state != TCP_LISTEN; 524 else 525 return sk->sk_state == TCP_ESTABLISHED; 526 } 527 528 static bool sock_map_sk_is_suitable(const struct sock *sk) 529 { 530 return !!sk->sk_prot->psock_update_sk_prot; 531 } 532 533 static bool sock_map_sk_state_allowed(const struct sock *sk) 534 { 535 if (sk_is_tcp(sk)) 536 return (1 << sk->sk_state) & (TCPF_ESTABLISHED | TCPF_LISTEN); 537 return true; 538 } 539 540 static int sock_hash_update_common(struct bpf_map *map, void *key, 541 struct sock *sk, u64 flags); 542 543 int sock_map_update_elem_sys(struct bpf_map *map, void *key, void *value, 544 u64 flags) 545 { 546 struct socket *sock; 547 struct sock *sk; 548 int ret; 549 u64 ufd; 550 551 if (map->value_size == sizeof(u64)) 552 ufd = *(u64 *)value; 553 else 554 ufd = *(u32 *)value; 555 if (ufd > S32_MAX) 556 return -EINVAL; 557 558 sock = sockfd_lookup(ufd, &ret); 559 if (!sock) 560 return ret; 561 sk = sock->sk; 562 if (!sk) { 563 ret = -EINVAL; 564 goto out; 565 } 566 if (!sock_map_sk_is_suitable(sk)) { 567 ret = -EOPNOTSUPP; 568 goto out; 569 } 570 571 sock_map_sk_acquire(sk); 572 if (!sock_map_sk_state_allowed(sk)) 573 ret = -EOPNOTSUPP; 574 else if (map->map_type == BPF_MAP_TYPE_SOCKMAP) 575 ret = sock_map_update_common(map, *(u32 *)key, sk, flags); 576 else 577 ret = sock_hash_update_common(map, key, sk, flags); 578 sock_map_sk_release(sk); 579 out: 580 sockfd_put(sock); 581 return ret; 582 } 583 584 static int sock_map_update_elem(struct bpf_map *map, void *key, 585 void *value, u64 flags) 586 { 587 struct sock *sk = (struct sock *)value; 588 int ret; 589 590 if (unlikely(!sk || !sk_fullsock(sk))) 591 return -EINVAL; 592 593 if (!sock_map_sk_is_suitable(sk)) 594 return -EOPNOTSUPP; 595 596 local_bh_disable(); 597 bh_lock_sock(sk); 598 if (!sock_map_sk_state_allowed(sk)) 599 ret = -EOPNOTSUPP; 600 else if (map->map_type == BPF_MAP_TYPE_SOCKMAP) 601 ret = sock_map_update_common(map, *(u32 *)key, sk, flags); 602 else 603 ret = sock_hash_update_common(map, key, sk, flags); 604 bh_unlock_sock(sk); 605 local_bh_enable(); 606 return ret; 607 } 608 609 BPF_CALL_4(bpf_sock_map_update, struct bpf_sock_ops_kern *, sops, 610 struct bpf_map *, map, void *, key, u64, flags) 611 { 612 WARN_ON_ONCE(!rcu_read_lock_held()); 613 614 if (likely(sock_map_sk_is_suitable(sops->sk) && 615 sock_map_op_okay(sops))) 616 return sock_map_update_common(map, *(u32 *)key, sops->sk, 617 flags); 618 return -EOPNOTSUPP; 619 } 620 621 const struct bpf_func_proto bpf_sock_map_update_proto = { 622 .func = bpf_sock_map_update, 623 .gpl_only = false, 624 .pkt_access = true, 625 .ret_type = RET_INTEGER, 626 .arg1_type = ARG_PTR_TO_CTX, 627 .arg2_type = ARG_CONST_MAP_PTR, 628 .arg3_type = ARG_PTR_TO_MAP_KEY, 629 .arg4_type = ARG_ANYTHING, 630 }; 631 632 BPF_CALL_4(bpf_sk_redirect_map, struct sk_buff *, skb, 633 struct bpf_map *, map, u32, key, u64, flags) 634 { 635 struct sock *sk; 636 637 if (unlikely(flags & ~(BPF_F_INGRESS))) 638 return SK_DROP; 639 640 sk = __sock_map_lookup_elem(map, key); 641 if (unlikely(!sk || !sock_map_redirect_allowed(sk))) 642 return SK_DROP; 643 644 skb_bpf_set_redir(skb, sk, flags & BPF_F_INGRESS); 645 return SK_PASS; 646 } 647 648 const struct bpf_func_proto bpf_sk_redirect_map_proto = { 649 .func = bpf_sk_redirect_map, 650 .gpl_only = false, 651 .ret_type = RET_INTEGER, 652 .arg1_type = ARG_PTR_TO_CTX, 653 .arg2_type = ARG_CONST_MAP_PTR, 654 .arg3_type = ARG_ANYTHING, 655 .arg4_type = ARG_ANYTHING, 656 }; 657 658 BPF_CALL_4(bpf_msg_redirect_map, struct sk_msg *, msg, 659 struct bpf_map *, map, u32, key, u64, flags) 660 { 661 struct sock *sk; 662 663 if (unlikely(flags & ~(BPF_F_INGRESS))) 664 return SK_DROP; 665 666 sk = __sock_map_lookup_elem(map, key); 667 if (unlikely(!sk || !sock_map_redirect_allowed(sk))) 668 return SK_DROP; 669 670 msg->flags = flags; 671 msg->sk_redir = sk; 672 return SK_PASS; 673 } 674 675 const struct bpf_func_proto bpf_msg_redirect_map_proto = { 676 .func = bpf_msg_redirect_map, 677 .gpl_only = false, 678 .ret_type = RET_INTEGER, 679 .arg1_type = ARG_PTR_TO_CTX, 680 .arg2_type = ARG_CONST_MAP_PTR, 681 .arg3_type = ARG_ANYTHING, 682 .arg4_type = ARG_ANYTHING, 683 }; 684 685 struct sock_map_seq_info { 686 struct bpf_map *map; 687 struct sock *sk; 688 u32 index; 689 }; 690 691 struct bpf_iter__sockmap { 692 __bpf_md_ptr(struct bpf_iter_meta *, meta); 693 __bpf_md_ptr(struct bpf_map *, map); 694 __bpf_md_ptr(void *, key); 695 __bpf_md_ptr(struct sock *, sk); 696 }; 697 698 DEFINE_BPF_ITER_FUNC(sockmap, struct bpf_iter_meta *meta, 699 struct bpf_map *map, void *key, 700 struct sock *sk) 701 702 static void *sock_map_seq_lookup_elem(struct sock_map_seq_info *info) 703 { 704 if (unlikely(info->index >= info->map->max_entries)) 705 return NULL; 706 707 info->sk = __sock_map_lookup_elem(info->map, info->index); 708 709 /* can't return sk directly, since that might be NULL */ 710 return info; 711 } 712 713 static void *sock_map_seq_start(struct seq_file *seq, loff_t *pos) 714 __acquires(rcu) 715 { 716 struct sock_map_seq_info *info = seq->private; 717 718 if (*pos == 0) 719 ++*pos; 720 721 /* pairs with sock_map_seq_stop */ 722 rcu_read_lock(); 723 return sock_map_seq_lookup_elem(info); 724 } 725 726 static void *sock_map_seq_next(struct seq_file *seq, void *v, loff_t *pos) 727 __must_hold(rcu) 728 { 729 struct sock_map_seq_info *info = seq->private; 730 731 ++*pos; 732 ++info->index; 733 734 return sock_map_seq_lookup_elem(info); 735 } 736 737 static int sock_map_seq_show(struct seq_file *seq, void *v) 738 __must_hold(rcu) 739 { 740 struct sock_map_seq_info *info = seq->private; 741 struct bpf_iter__sockmap ctx = {}; 742 struct bpf_iter_meta meta; 743 struct bpf_prog *prog; 744 745 meta.seq = seq; 746 prog = bpf_iter_get_info(&meta, !v); 747 if (!prog) 748 return 0; 749 750 ctx.meta = &meta; 751 ctx.map = info->map; 752 if (v) { 753 ctx.key = &info->index; 754 ctx.sk = info->sk; 755 } 756 757 return bpf_iter_run_prog(prog, &ctx); 758 } 759 760 static void sock_map_seq_stop(struct seq_file *seq, void *v) 761 __releases(rcu) 762 { 763 if (!v) 764 (void)sock_map_seq_show(seq, NULL); 765 766 /* pairs with sock_map_seq_start */ 767 rcu_read_unlock(); 768 } 769 770 static const struct seq_operations sock_map_seq_ops = { 771 .start = sock_map_seq_start, 772 .next = sock_map_seq_next, 773 .stop = sock_map_seq_stop, 774 .show = sock_map_seq_show, 775 }; 776 777 static int sock_map_init_seq_private(void *priv_data, 778 struct bpf_iter_aux_info *aux) 779 { 780 struct sock_map_seq_info *info = priv_data; 781 782 info->map = aux->map; 783 return 0; 784 } 785 786 static const struct bpf_iter_seq_info sock_map_iter_seq_info = { 787 .seq_ops = &sock_map_seq_ops, 788 .init_seq_private = sock_map_init_seq_private, 789 .seq_priv_size = sizeof(struct sock_map_seq_info), 790 }; 791 792 static int sock_map_btf_id; 793 const struct bpf_map_ops sock_map_ops = { 794 .map_meta_equal = bpf_map_meta_equal, 795 .map_alloc = sock_map_alloc, 796 .map_free = sock_map_free, 797 .map_get_next_key = sock_map_get_next_key, 798 .map_lookup_elem_sys_only = sock_map_lookup_sys, 799 .map_update_elem = sock_map_update_elem, 800 .map_delete_elem = sock_map_delete_elem, 801 .map_lookup_elem = sock_map_lookup, 802 .map_release_uref = sock_map_release_progs, 803 .map_check_btf = map_check_no_btf, 804 .map_btf_name = "bpf_stab", 805 .map_btf_id = &sock_map_btf_id, 806 .iter_seq_info = &sock_map_iter_seq_info, 807 }; 808 809 struct bpf_shtab_elem { 810 struct rcu_head rcu; 811 u32 hash; 812 struct sock *sk; 813 struct hlist_node node; 814 u8 key[]; 815 }; 816 817 struct bpf_shtab_bucket { 818 struct hlist_head head; 819 raw_spinlock_t lock; 820 }; 821 822 struct bpf_shtab { 823 struct bpf_map map; 824 struct bpf_shtab_bucket *buckets; 825 u32 buckets_num; 826 u32 elem_size; 827 struct sk_psock_progs progs; 828 atomic_t count; 829 }; 830 831 static inline u32 sock_hash_bucket_hash(const void *key, u32 len) 832 { 833 return jhash(key, len, 0); 834 } 835 836 static struct bpf_shtab_bucket *sock_hash_select_bucket(struct bpf_shtab *htab, 837 u32 hash) 838 { 839 return &htab->buckets[hash & (htab->buckets_num - 1)]; 840 } 841 842 static struct bpf_shtab_elem * 843 sock_hash_lookup_elem_raw(struct hlist_head *head, u32 hash, void *key, 844 u32 key_size) 845 { 846 struct bpf_shtab_elem *elem; 847 848 hlist_for_each_entry_rcu(elem, head, node) { 849 if (elem->hash == hash && 850 !memcmp(&elem->key, key, key_size)) 851 return elem; 852 } 853 854 return NULL; 855 } 856 857 static struct sock *__sock_hash_lookup_elem(struct bpf_map *map, void *key) 858 { 859 struct bpf_shtab *htab = container_of(map, struct bpf_shtab, map); 860 u32 key_size = map->key_size, hash; 861 struct bpf_shtab_bucket *bucket; 862 struct bpf_shtab_elem *elem; 863 864 WARN_ON_ONCE(!rcu_read_lock_held()); 865 866 hash = sock_hash_bucket_hash(key, key_size); 867 bucket = sock_hash_select_bucket(htab, hash); 868 elem = sock_hash_lookup_elem_raw(&bucket->head, hash, key, key_size); 869 870 return elem ? elem->sk : NULL; 871 } 872 873 static void sock_hash_free_elem(struct bpf_shtab *htab, 874 struct bpf_shtab_elem *elem) 875 { 876 atomic_dec(&htab->count); 877 kfree_rcu(elem, rcu); 878 } 879 880 static void sock_hash_delete_from_link(struct bpf_map *map, struct sock *sk, 881 void *link_raw) 882 { 883 struct bpf_shtab *htab = container_of(map, struct bpf_shtab, map); 884 struct bpf_shtab_elem *elem_probe, *elem = link_raw; 885 struct bpf_shtab_bucket *bucket; 886 887 WARN_ON_ONCE(!rcu_read_lock_held()); 888 bucket = sock_hash_select_bucket(htab, elem->hash); 889 890 /* elem may be deleted in parallel from the map, but access here 891 * is okay since it's going away only after RCU grace period. 892 * However, we need to check whether it's still present. 893 */ 894 raw_spin_lock_bh(&bucket->lock); 895 elem_probe = sock_hash_lookup_elem_raw(&bucket->head, elem->hash, 896 elem->key, map->key_size); 897 if (elem_probe && elem_probe == elem) { 898 hlist_del_rcu(&elem->node); 899 sock_map_unref(elem->sk, elem); 900 sock_hash_free_elem(htab, elem); 901 } 902 raw_spin_unlock_bh(&bucket->lock); 903 } 904 905 static int sock_hash_delete_elem(struct bpf_map *map, void *key) 906 { 907 struct bpf_shtab *htab = container_of(map, struct bpf_shtab, map); 908 u32 hash, key_size = map->key_size; 909 struct bpf_shtab_bucket *bucket; 910 struct bpf_shtab_elem *elem; 911 int ret = -ENOENT; 912 913 hash = sock_hash_bucket_hash(key, key_size); 914 bucket = sock_hash_select_bucket(htab, hash); 915 916 raw_spin_lock_bh(&bucket->lock); 917 elem = sock_hash_lookup_elem_raw(&bucket->head, hash, key, key_size); 918 if (elem) { 919 hlist_del_rcu(&elem->node); 920 sock_map_unref(elem->sk, elem); 921 sock_hash_free_elem(htab, elem); 922 ret = 0; 923 } 924 raw_spin_unlock_bh(&bucket->lock); 925 return ret; 926 } 927 928 static struct bpf_shtab_elem *sock_hash_alloc_elem(struct bpf_shtab *htab, 929 void *key, u32 key_size, 930 u32 hash, struct sock *sk, 931 struct bpf_shtab_elem *old) 932 { 933 struct bpf_shtab_elem *new; 934 935 if (atomic_inc_return(&htab->count) > htab->map.max_entries) { 936 if (!old) { 937 atomic_dec(&htab->count); 938 return ERR_PTR(-E2BIG); 939 } 940 } 941 942 new = bpf_map_kmalloc_node(&htab->map, htab->elem_size, 943 GFP_ATOMIC | __GFP_NOWARN, 944 htab->map.numa_node); 945 if (!new) { 946 atomic_dec(&htab->count); 947 return ERR_PTR(-ENOMEM); 948 } 949 memcpy(new->key, key, key_size); 950 new->sk = sk; 951 new->hash = hash; 952 return new; 953 } 954 955 static int sock_hash_update_common(struct bpf_map *map, void *key, 956 struct sock *sk, u64 flags) 957 { 958 struct bpf_shtab *htab = container_of(map, struct bpf_shtab, map); 959 u32 key_size = map->key_size, hash; 960 struct bpf_shtab_elem *elem, *elem_new; 961 struct bpf_shtab_bucket *bucket; 962 struct sk_psock_link *link; 963 struct sk_psock *psock; 964 int ret; 965 966 WARN_ON_ONCE(!rcu_read_lock_held()); 967 if (unlikely(flags > BPF_EXIST)) 968 return -EINVAL; 969 970 link = sk_psock_init_link(); 971 if (!link) 972 return -ENOMEM; 973 974 ret = sock_map_link(map, sk); 975 if (ret < 0) 976 goto out_free; 977 978 psock = sk_psock(sk); 979 WARN_ON_ONCE(!psock); 980 981 hash = sock_hash_bucket_hash(key, key_size); 982 bucket = sock_hash_select_bucket(htab, hash); 983 984 raw_spin_lock_bh(&bucket->lock); 985 elem = sock_hash_lookup_elem_raw(&bucket->head, hash, key, key_size); 986 if (elem && flags == BPF_NOEXIST) { 987 ret = -EEXIST; 988 goto out_unlock; 989 } else if (!elem && flags == BPF_EXIST) { 990 ret = -ENOENT; 991 goto out_unlock; 992 } 993 994 elem_new = sock_hash_alloc_elem(htab, key, key_size, hash, sk, elem); 995 if (IS_ERR(elem_new)) { 996 ret = PTR_ERR(elem_new); 997 goto out_unlock; 998 } 999 1000 sock_map_add_link(psock, link, map, elem_new); 1001 /* Add new element to the head of the list, so that 1002 * concurrent search will find it before old elem. 1003 */ 1004 hlist_add_head_rcu(&elem_new->node, &bucket->head); 1005 if (elem) { 1006 hlist_del_rcu(&elem->node); 1007 sock_map_unref(elem->sk, elem); 1008 sock_hash_free_elem(htab, elem); 1009 } 1010 raw_spin_unlock_bh(&bucket->lock); 1011 return 0; 1012 out_unlock: 1013 raw_spin_unlock_bh(&bucket->lock); 1014 sk_psock_put(sk, psock); 1015 out_free: 1016 sk_psock_free_link(link); 1017 return ret; 1018 } 1019 1020 static int sock_hash_get_next_key(struct bpf_map *map, void *key, 1021 void *key_next) 1022 { 1023 struct bpf_shtab *htab = container_of(map, struct bpf_shtab, map); 1024 struct bpf_shtab_elem *elem, *elem_next; 1025 u32 hash, key_size = map->key_size; 1026 struct hlist_head *head; 1027 int i = 0; 1028 1029 if (!key) 1030 goto find_first_elem; 1031 hash = sock_hash_bucket_hash(key, key_size); 1032 head = &sock_hash_select_bucket(htab, hash)->head; 1033 elem = sock_hash_lookup_elem_raw(head, hash, key, key_size); 1034 if (!elem) 1035 goto find_first_elem; 1036 1037 elem_next = hlist_entry_safe(rcu_dereference(hlist_next_rcu(&elem->node)), 1038 struct bpf_shtab_elem, node); 1039 if (elem_next) { 1040 memcpy(key_next, elem_next->key, key_size); 1041 return 0; 1042 } 1043 1044 i = hash & (htab->buckets_num - 1); 1045 i++; 1046 find_first_elem: 1047 for (; i < htab->buckets_num; i++) { 1048 head = &sock_hash_select_bucket(htab, i)->head; 1049 elem_next = hlist_entry_safe(rcu_dereference(hlist_first_rcu(head)), 1050 struct bpf_shtab_elem, node); 1051 if (elem_next) { 1052 memcpy(key_next, elem_next->key, key_size); 1053 return 0; 1054 } 1055 } 1056 1057 return -ENOENT; 1058 } 1059 1060 static struct bpf_map *sock_hash_alloc(union bpf_attr *attr) 1061 { 1062 struct bpf_shtab *htab; 1063 int i, err; 1064 1065 if (!capable(CAP_NET_ADMIN)) 1066 return ERR_PTR(-EPERM); 1067 if (attr->max_entries == 0 || 1068 attr->key_size == 0 || 1069 (attr->value_size != sizeof(u32) && 1070 attr->value_size != sizeof(u64)) || 1071 attr->map_flags & ~SOCK_CREATE_FLAG_MASK) 1072 return ERR_PTR(-EINVAL); 1073 if (attr->key_size > MAX_BPF_STACK) 1074 return ERR_PTR(-E2BIG); 1075 1076 htab = kzalloc(sizeof(*htab), GFP_USER | __GFP_ACCOUNT); 1077 if (!htab) 1078 return ERR_PTR(-ENOMEM); 1079 1080 bpf_map_init_from_attr(&htab->map, attr); 1081 1082 htab->buckets_num = roundup_pow_of_two(htab->map.max_entries); 1083 htab->elem_size = sizeof(struct bpf_shtab_elem) + 1084 round_up(htab->map.key_size, 8); 1085 if (htab->buckets_num == 0 || 1086 htab->buckets_num > U32_MAX / sizeof(struct bpf_shtab_bucket)) { 1087 err = -EINVAL; 1088 goto free_htab; 1089 } 1090 1091 htab->buckets = bpf_map_area_alloc(htab->buckets_num * 1092 sizeof(struct bpf_shtab_bucket), 1093 htab->map.numa_node); 1094 if (!htab->buckets) { 1095 err = -ENOMEM; 1096 goto free_htab; 1097 } 1098 1099 for (i = 0; i < htab->buckets_num; i++) { 1100 INIT_HLIST_HEAD(&htab->buckets[i].head); 1101 raw_spin_lock_init(&htab->buckets[i].lock); 1102 } 1103 1104 return &htab->map; 1105 free_htab: 1106 kfree(htab); 1107 return ERR_PTR(err); 1108 } 1109 1110 static void sock_hash_free(struct bpf_map *map) 1111 { 1112 struct bpf_shtab *htab = container_of(map, struct bpf_shtab, map); 1113 struct bpf_shtab_bucket *bucket; 1114 struct hlist_head unlink_list; 1115 struct bpf_shtab_elem *elem; 1116 struct hlist_node *node; 1117 int i; 1118 1119 /* After the sync no updates or deletes will be in-flight so it 1120 * is safe to walk map and remove entries without risking a race 1121 * in EEXIST update case. 1122 */ 1123 synchronize_rcu(); 1124 for (i = 0; i < htab->buckets_num; i++) { 1125 bucket = sock_hash_select_bucket(htab, i); 1126 1127 /* We are racing with sock_hash_delete_from_link to 1128 * enter the spin-lock critical section. Every socket on 1129 * the list is still linked to sockhash. Since link 1130 * exists, psock exists and holds a ref to socket. That 1131 * lets us to grab a socket ref too. 1132 */ 1133 raw_spin_lock_bh(&bucket->lock); 1134 hlist_for_each_entry(elem, &bucket->head, node) 1135 sock_hold(elem->sk); 1136 hlist_move_list(&bucket->head, &unlink_list); 1137 raw_spin_unlock_bh(&bucket->lock); 1138 1139 /* Process removed entries out of atomic context to 1140 * block for socket lock before deleting the psock's 1141 * link to sockhash. 1142 */ 1143 hlist_for_each_entry_safe(elem, node, &unlink_list, node) { 1144 hlist_del(&elem->node); 1145 lock_sock(elem->sk); 1146 rcu_read_lock(); 1147 sock_map_unref(elem->sk, elem); 1148 rcu_read_unlock(); 1149 release_sock(elem->sk); 1150 sock_put(elem->sk); 1151 sock_hash_free_elem(htab, elem); 1152 } 1153 } 1154 1155 /* wait for psock readers accessing its map link */ 1156 synchronize_rcu(); 1157 1158 bpf_map_area_free(htab->buckets); 1159 kfree(htab); 1160 } 1161 1162 static void *sock_hash_lookup_sys(struct bpf_map *map, void *key) 1163 { 1164 struct sock *sk; 1165 1166 if (map->value_size != sizeof(u64)) 1167 return ERR_PTR(-ENOSPC); 1168 1169 sk = __sock_hash_lookup_elem(map, key); 1170 if (!sk) 1171 return ERR_PTR(-ENOENT); 1172 1173 __sock_gen_cookie(sk); 1174 return &sk->sk_cookie; 1175 } 1176 1177 static void *sock_hash_lookup(struct bpf_map *map, void *key) 1178 { 1179 struct sock *sk; 1180 1181 sk = __sock_hash_lookup_elem(map, key); 1182 if (!sk) 1183 return NULL; 1184 if (sk_is_refcounted(sk) && !refcount_inc_not_zero(&sk->sk_refcnt)) 1185 return NULL; 1186 return sk; 1187 } 1188 1189 static void sock_hash_release_progs(struct bpf_map *map) 1190 { 1191 psock_progs_drop(&container_of(map, struct bpf_shtab, map)->progs); 1192 } 1193 1194 BPF_CALL_4(bpf_sock_hash_update, struct bpf_sock_ops_kern *, sops, 1195 struct bpf_map *, map, void *, key, u64, flags) 1196 { 1197 WARN_ON_ONCE(!rcu_read_lock_held()); 1198 1199 if (likely(sock_map_sk_is_suitable(sops->sk) && 1200 sock_map_op_okay(sops))) 1201 return sock_hash_update_common(map, key, sops->sk, flags); 1202 return -EOPNOTSUPP; 1203 } 1204 1205 const struct bpf_func_proto bpf_sock_hash_update_proto = { 1206 .func = bpf_sock_hash_update, 1207 .gpl_only = false, 1208 .pkt_access = true, 1209 .ret_type = RET_INTEGER, 1210 .arg1_type = ARG_PTR_TO_CTX, 1211 .arg2_type = ARG_CONST_MAP_PTR, 1212 .arg3_type = ARG_PTR_TO_MAP_KEY, 1213 .arg4_type = ARG_ANYTHING, 1214 }; 1215 1216 BPF_CALL_4(bpf_sk_redirect_hash, struct sk_buff *, skb, 1217 struct bpf_map *, map, void *, key, u64, flags) 1218 { 1219 struct sock *sk; 1220 1221 if (unlikely(flags & ~(BPF_F_INGRESS))) 1222 return SK_DROP; 1223 1224 sk = __sock_hash_lookup_elem(map, key); 1225 if (unlikely(!sk || !sock_map_redirect_allowed(sk))) 1226 return SK_DROP; 1227 1228 skb_bpf_set_redir(skb, sk, flags & BPF_F_INGRESS); 1229 return SK_PASS; 1230 } 1231 1232 const struct bpf_func_proto bpf_sk_redirect_hash_proto = { 1233 .func = bpf_sk_redirect_hash, 1234 .gpl_only = false, 1235 .ret_type = RET_INTEGER, 1236 .arg1_type = ARG_PTR_TO_CTX, 1237 .arg2_type = ARG_CONST_MAP_PTR, 1238 .arg3_type = ARG_PTR_TO_MAP_KEY, 1239 .arg4_type = ARG_ANYTHING, 1240 }; 1241 1242 BPF_CALL_4(bpf_msg_redirect_hash, struct sk_msg *, msg, 1243 struct bpf_map *, map, void *, key, u64, flags) 1244 { 1245 struct sock *sk; 1246 1247 if (unlikely(flags & ~(BPF_F_INGRESS))) 1248 return SK_DROP; 1249 1250 sk = __sock_hash_lookup_elem(map, key); 1251 if (unlikely(!sk || !sock_map_redirect_allowed(sk))) 1252 return SK_DROP; 1253 1254 msg->flags = flags; 1255 msg->sk_redir = sk; 1256 return SK_PASS; 1257 } 1258 1259 const struct bpf_func_proto bpf_msg_redirect_hash_proto = { 1260 .func = bpf_msg_redirect_hash, 1261 .gpl_only = false, 1262 .ret_type = RET_INTEGER, 1263 .arg1_type = ARG_PTR_TO_CTX, 1264 .arg2_type = ARG_CONST_MAP_PTR, 1265 .arg3_type = ARG_PTR_TO_MAP_KEY, 1266 .arg4_type = ARG_ANYTHING, 1267 }; 1268 1269 struct sock_hash_seq_info { 1270 struct bpf_map *map; 1271 struct bpf_shtab *htab; 1272 u32 bucket_id; 1273 }; 1274 1275 static void *sock_hash_seq_find_next(struct sock_hash_seq_info *info, 1276 struct bpf_shtab_elem *prev_elem) 1277 { 1278 const struct bpf_shtab *htab = info->htab; 1279 struct bpf_shtab_bucket *bucket; 1280 struct bpf_shtab_elem *elem; 1281 struct hlist_node *node; 1282 1283 /* try to find next elem in the same bucket */ 1284 if (prev_elem) { 1285 node = rcu_dereference(hlist_next_rcu(&prev_elem->node)); 1286 elem = hlist_entry_safe(node, struct bpf_shtab_elem, node); 1287 if (elem) 1288 return elem; 1289 1290 /* no more elements, continue in the next bucket */ 1291 info->bucket_id++; 1292 } 1293 1294 for (; info->bucket_id < htab->buckets_num; info->bucket_id++) { 1295 bucket = &htab->buckets[info->bucket_id]; 1296 node = rcu_dereference(hlist_first_rcu(&bucket->head)); 1297 elem = hlist_entry_safe(node, struct bpf_shtab_elem, node); 1298 if (elem) 1299 return elem; 1300 } 1301 1302 return NULL; 1303 } 1304 1305 static void *sock_hash_seq_start(struct seq_file *seq, loff_t *pos) 1306 __acquires(rcu) 1307 { 1308 struct sock_hash_seq_info *info = seq->private; 1309 1310 if (*pos == 0) 1311 ++*pos; 1312 1313 /* pairs with sock_hash_seq_stop */ 1314 rcu_read_lock(); 1315 return sock_hash_seq_find_next(info, NULL); 1316 } 1317 1318 static void *sock_hash_seq_next(struct seq_file *seq, void *v, loff_t *pos) 1319 __must_hold(rcu) 1320 { 1321 struct sock_hash_seq_info *info = seq->private; 1322 1323 ++*pos; 1324 return sock_hash_seq_find_next(info, v); 1325 } 1326 1327 static int sock_hash_seq_show(struct seq_file *seq, void *v) 1328 __must_hold(rcu) 1329 { 1330 struct sock_hash_seq_info *info = seq->private; 1331 struct bpf_iter__sockmap ctx = {}; 1332 struct bpf_shtab_elem *elem = v; 1333 struct bpf_iter_meta meta; 1334 struct bpf_prog *prog; 1335 1336 meta.seq = seq; 1337 prog = bpf_iter_get_info(&meta, !elem); 1338 if (!prog) 1339 return 0; 1340 1341 ctx.meta = &meta; 1342 ctx.map = info->map; 1343 if (elem) { 1344 ctx.key = elem->key; 1345 ctx.sk = elem->sk; 1346 } 1347 1348 return bpf_iter_run_prog(prog, &ctx); 1349 } 1350 1351 static void sock_hash_seq_stop(struct seq_file *seq, void *v) 1352 __releases(rcu) 1353 { 1354 if (!v) 1355 (void)sock_hash_seq_show(seq, NULL); 1356 1357 /* pairs with sock_hash_seq_start */ 1358 rcu_read_unlock(); 1359 } 1360 1361 static const struct seq_operations sock_hash_seq_ops = { 1362 .start = sock_hash_seq_start, 1363 .next = sock_hash_seq_next, 1364 .stop = sock_hash_seq_stop, 1365 .show = sock_hash_seq_show, 1366 }; 1367 1368 static int sock_hash_init_seq_private(void *priv_data, 1369 struct bpf_iter_aux_info *aux) 1370 { 1371 struct sock_hash_seq_info *info = priv_data; 1372 1373 info->map = aux->map; 1374 info->htab = container_of(aux->map, struct bpf_shtab, map); 1375 return 0; 1376 } 1377 1378 static const struct bpf_iter_seq_info sock_hash_iter_seq_info = { 1379 .seq_ops = &sock_hash_seq_ops, 1380 .init_seq_private = sock_hash_init_seq_private, 1381 .seq_priv_size = sizeof(struct sock_hash_seq_info), 1382 }; 1383 1384 static int sock_hash_map_btf_id; 1385 const struct bpf_map_ops sock_hash_ops = { 1386 .map_meta_equal = bpf_map_meta_equal, 1387 .map_alloc = sock_hash_alloc, 1388 .map_free = sock_hash_free, 1389 .map_get_next_key = sock_hash_get_next_key, 1390 .map_update_elem = sock_map_update_elem, 1391 .map_delete_elem = sock_hash_delete_elem, 1392 .map_lookup_elem = sock_hash_lookup, 1393 .map_lookup_elem_sys_only = sock_hash_lookup_sys, 1394 .map_release_uref = sock_hash_release_progs, 1395 .map_check_btf = map_check_no_btf, 1396 .map_btf_name = "bpf_shtab", 1397 .map_btf_id = &sock_hash_map_btf_id, 1398 .iter_seq_info = &sock_hash_iter_seq_info, 1399 }; 1400 1401 static struct sk_psock_progs *sock_map_progs(struct bpf_map *map) 1402 { 1403 switch (map->map_type) { 1404 case BPF_MAP_TYPE_SOCKMAP: 1405 return &container_of(map, struct bpf_stab, map)->progs; 1406 case BPF_MAP_TYPE_SOCKHASH: 1407 return &container_of(map, struct bpf_shtab, map)->progs; 1408 default: 1409 break; 1410 } 1411 1412 return NULL; 1413 } 1414 1415 static int sock_map_prog_update(struct bpf_map *map, struct bpf_prog *prog, 1416 struct bpf_prog *old, u32 which) 1417 { 1418 struct sk_psock_progs *progs = sock_map_progs(map); 1419 struct bpf_prog **pprog; 1420 1421 if (!progs) 1422 return -EOPNOTSUPP; 1423 1424 switch (which) { 1425 case BPF_SK_MSG_VERDICT: 1426 pprog = &progs->msg_parser; 1427 break; 1428 #if IS_ENABLED(CONFIG_BPF_STREAM_PARSER) 1429 case BPF_SK_SKB_STREAM_PARSER: 1430 pprog = &progs->stream_parser; 1431 break; 1432 #endif 1433 case BPF_SK_SKB_STREAM_VERDICT: 1434 if (progs->skb_verdict) 1435 return -EBUSY; 1436 pprog = &progs->stream_verdict; 1437 break; 1438 case BPF_SK_SKB_VERDICT: 1439 if (progs->stream_verdict) 1440 return -EBUSY; 1441 pprog = &progs->skb_verdict; 1442 break; 1443 default: 1444 return -EOPNOTSUPP; 1445 } 1446 1447 if (old) 1448 return psock_replace_prog(pprog, prog, old); 1449 1450 psock_set_prog(pprog, prog); 1451 return 0; 1452 } 1453 1454 static void sock_map_unlink(struct sock *sk, struct sk_psock_link *link) 1455 { 1456 switch (link->map->map_type) { 1457 case BPF_MAP_TYPE_SOCKMAP: 1458 return sock_map_delete_from_link(link->map, sk, 1459 link->link_raw); 1460 case BPF_MAP_TYPE_SOCKHASH: 1461 return sock_hash_delete_from_link(link->map, sk, 1462 link->link_raw); 1463 default: 1464 break; 1465 } 1466 } 1467 1468 static void sock_map_remove_links(struct sock *sk, struct sk_psock *psock) 1469 { 1470 struct sk_psock_link *link; 1471 1472 while ((link = sk_psock_link_pop(psock))) { 1473 sock_map_unlink(sk, link); 1474 sk_psock_free_link(link); 1475 } 1476 } 1477 1478 void sock_map_unhash(struct sock *sk) 1479 { 1480 void (*saved_unhash)(struct sock *sk); 1481 struct sk_psock *psock; 1482 1483 rcu_read_lock(); 1484 psock = sk_psock(sk); 1485 if (unlikely(!psock)) { 1486 rcu_read_unlock(); 1487 if (sk->sk_prot->unhash) 1488 sk->sk_prot->unhash(sk); 1489 return; 1490 } 1491 1492 saved_unhash = psock->saved_unhash; 1493 sock_map_remove_links(sk, psock); 1494 rcu_read_unlock(); 1495 saved_unhash(sk); 1496 } 1497 EXPORT_SYMBOL_GPL(sock_map_unhash); 1498 1499 void sock_map_close(struct sock *sk, long timeout) 1500 { 1501 void (*saved_close)(struct sock *sk, long timeout); 1502 struct sk_psock *psock; 1503 1504 lock_sock(sk); 1505 rcu_read_lock(); 1506 psock = sk_psock_get(sk); 1507 if (unlikely(!psock)) { 1508 rcu_read_unlock(); 1509 release_sock(sk); 1510 return sk->sk_prot->close(sk, timeout); 1511 } 1512 1513 saved_close = psock->saved_close; 1514 sock_map_remove_links(sk, psock); 1515 rcu_read_unlock(); 1516 sk_psock_stop(psock, true); 1517 sk_psock_put(sk, psock); 1518 release_sock(sk); 1519 saved_close(sk, timeout); 1520 } 1521 EXPORT_SYMBOL_GPL(sock_map_close); 1522 1523 static int sock_map_iter_attach_target(struct bpf_prog *prog, 1524 union bpf_iter_link_info *linfo, 1525 struct bpf_iter_aux_info *aux) 1526 { 1527 struct bpf_map *map; 1528 int err = -EINVAL; 1529 1530 if (!linfo->map.map_fd) 1531 return -EBADF; 1532 1533 map = bpf_map_get_with_uref(linfo->map.map_fd); 1534 if (IS_ERR(map)) 1535 return PTR_ERR(map); 1536 1537 if (map->map_type != BPF_MAP_TYPE_SOCKMAP && 1538 map->map_type != BPF_MAP_TYPE_SOCKHASH) 1539 goto put_map; 1540 1541 if (prog->aux->max_rdonly_access > map->key_size) { 1542 err = -EACCES; 1543 goto put_map; 1544 } 1545 1546 aux->map = map; 1547 return 0; 1548 1549 put_map: 1550 bpf_map_put_with_uref(map); 1551 return err; 1552 } 1553 1554 static void sock_map_iter_detach_target(struct bpf_iter_aux_info *aux) 1555 { 1556 bpf_map_put_with_uref(aux->map); 1557 } 1558 1559 static struct bpf_iter_reg sock_map_iter_reg = { 1560 .target = "sockmap", 1561 .attach_target = sock_map_iter_attach_target, 1562 .detach_target = sock_map_iter_detach_target, 1563 .show_fdinfo = bpf_iter_map_show_fdinfo, 1564 .fill_link_info = bpf_iter_map_fill_link_info, 1565 .ctx_arg_info_size = 2, 1566 .ctx_arg_info = { 1567 { offsetof(struct bpf_iter__sockmap, key), 1568 PTR_TO_RDONLY_BUF_OR_NULL }, 1569 { offsetof(struct bpf_iter__sockmap, sk), 1570 PTR_TO_BTF_ID_OR_NULL }, 1571 }, 1572 }; 1573 1574 static int __init bpf_sockmap_iter_init(void) 1575 { 1576 sock_map_iter_reg.ctx_arg_info[1].btf_id = 1577 btf_sock_ids[BTF_SOCK_TYPE_SOCK]; 1578 return bpf_iter_reg_target(&sock_map_iter_reg); 1579 } 1580 late_initcall(bpf_sockmap_iter_init); 1581