1 // SPDX-License-Identifier: GPL-2.0 2 /* Copyright (c) 2017 - 2018 Covalent IO, Inc. http://covalent.io */ 3 4 #include <linux/bpf.h> 5 #include <linux/filter.h> 6 #include <linux/errno.h> 7 #include <linux/file.h> 8 #include <linux/net.h> 9 #include <linux/workqueue.h> 10 #include <linux/skmsg.h> 11 #include <linux/list.h> 12 #include <linux/jhash.h> 13 14 struct bpf_stab { 15 struct bpf_map map; 16 struct sock **sks; 17 struct sk_psock_progs progs; 18 raw_spinlock_t lock; 19 }; 20 21 #define SOCK_CREATE_FLAG_MASK \ 22 (BPF_F_NUMA_NODE | BPF_F_RDONLY | BPF_F_WRONLY) 23 24 static struct bpf_map *sock_map_alloc(union bpf_attr *attr) 25 { 26 struct bpf_stab *stab; 27 u64 cost; 28 int err; 29 30 if (!capable(CAP_NET_ADMIN)) 31 return ERR_PTR(-EPERM); 32 if (attr->max_entries == 0 || 33 attr->key_size != 4 || 34 attr->value_size != 4 || 35 attr->map_flags & ~SOCK_CREATE_FLAG_MASK) 36 return ERR_PTR(-EINVAL); 37 38 stab = kzalloc(sizeof(*stab), GFP_USER); 39 if (!stab) 40 return ERR_PTR(-ENOMEM); 41 42 bpf_map_init_from_attr(&stab->map, attr); 43 raw_spin_lock_init(&stab->lock); 44 45 /* Make sure page count doesn't overflow. */ 46 cost = (u64) stab->map.max_entries * sizeof(struct sock *); 47 err = bpf_map_charge_init(&stab->map.memory, cost); 48 if (err) 49 goto free_stab; 50 51 stab->sks = bpf_map_area_alloc(stab->map.max_entries * 52 sizeof(struct sock *), 53 stab->map.numa_node); 54 if (stab->sks) 55 return &stab->map; 56 err = -ENOMEM; 57 bpf_map_charge_finish(&stab->map.memory); 58 free_stab: 59 kfree(stab); 60 return ERR_PTR(err); 61 } 62 63 int sock_map_get_from_fd(const union bpf_attr *attr, struct bpf_prog *prog) 64 { 65 u32 ufd = attr->target_fd; 66 struct bpf_map *map; 67 struct fd f; 68 int ret; 69 70 f = fdget(ufd); 71 map = __bpf_map_get(f); 72 if (IS_ERR(map)) 73 return PTR_ERR(map); 74 ret = sock_map_prog_update(map, prog, attr->attach_type); 75 fdput(f); 76 return ret; 77 } 78 79 static void sock_map_sk_acquire(struct sock *sk) 80 __acquires(&sk->sk_lock.slock) 81 { 82 lock_sock(sk); 83 preempt_disable(); 84 rcu_read_lock(); 85 } 86 87 static void sock_map_sk_release(struct sock *sk) 88 __releases(&sk->sk_lock.slock) 89 { 90 rcu_read_unlock(); 91 preempt_enable(); 92 release_sock(sk); 93 } 94 95 static void sock_map_add_link(struct sk_psock *psock, 96 struct sk_psock_link *link, 97 struct bpf_map *map, void *link_raw) 98 { 99 link->link_raw = link_raw; 100 link->map = map; 101 spin_lock_bh(&psock->link_lock); 102 list_add_tail(&link->list, &psock->link); 103 spin_unlock_bh(&psock->link_lock); 104 } 105 106 static void sock_map_del_link(struct sock *sk, 107 struct sk_psock *psock, void *link_raw) 108 { 109 struct sk_psock_link *link, *tmp; 110 bool strp_stop = false; 111 112 spin_lock_bh(&psock->link_lock); 113 list_for_each_entry_safe(link, tmp, &psock->link, list) { 114 if (link->link_raw == link_raw) { 115 struct bpf_map *map = link->map; 116 struct bpf_stab *stab = container_of(map, struct bpf_stab, 117 map); 118 if (psock->parser.enabled && stab->progs.skb_parser) 119 strp_stop = true; 120 list_del(&link->list); 121 sk_psock_free_link(link); 122 } 123 } 124 spin_unlock_bh(&psock->link_lock); 125 if (strp_stop) { 126 write_lock_bh(&sk->sk_callback_lock); 127 sk_psock_stop_strp(sk, psock); 128 write_unlock_bh(&sk->sk_callback_lock); 129 } 130 } 131 132 static void sock_map_unref(struct sock *sk, void *link_raw) 133 { 134 struct sk_psock *psock = sk_psock(sk); 135 136 if (likely(psock)) { 137 sock_map_del_link(sk, psock, link_raw); 138 sk_psock_put(sk, psock); 139 } 140 } 141 142 static int sock_map_link(struct bpf_map *map, struct sk_psock_progs *progs, 143 struct sock *sk) 144 { 145 struct bpf_prog *msg_parser, *skb_parser, *skb_verdict; 146 bool skb_progs, sk_psock_is_new = false; 147 struct sk_psock *psock; 148 int ret; 149 150 skb_verdict = READ_ONCE(progs->skb_verdict); 151 skb_parser = READ_ONCE(progs->skb_parser); 152 skb_progs = skb_parser && skb_verdict; 153 if (skb_progs) { 154 skb_verdict = bpf_prog_inc_not_zero(skb_verdict); 155 if (IS_ERR(skb_verdict)) 156 return PTR_ERR(skb_verdict); 157 skb_parser = bpf_prog_inc_not_zero(skb_parser); 158 if (IS_ERR(skb_parser)) { 159 bpf_prog_put(skb_verdict); 160 return PTR_ERR(skb_parser); 161 } 162 } 163 164 msg_parser = READ_ONCE(progs->msg_parser); 165 if (msg_parser) { 166 msg_parser = bpf_prog_inc_not_zero(msg_parser); 167 if (IS_ERR(msg_parser)) { 168 ret = PTR_ERR(msg_parser); 169 goto out; 170 } 171 } 172 173 psock = sk_psock_get_checked(sk); 174 if (IS_ERR(psock)) { 175 ret = PTR_ERR(psock); 176 goto out_progs; 177 } 178 179 if (psock) { 180 if ((msg_parser && READ_ONCE(psock->progs.msg_parser)) || 181 (skb_progs && READ_ONCE(psock->progs.skb_parser))) { 182 sk_psock_put(sk, psock); 183 ret = -EBUSY; 184 goto out_progs; 185 } 186 } else { 187 psock = sk_psock_init(sk, map->numa_node); 188 if (!psock) { 189 ret = -ENOMEM; 190 goto out_progs; 191 } 192 sk_psock_is_new = true; 193 } 194 195 if (msg_parser) 196 psock_set_prog(&psock->progs.msg_parser, msg_parser); 197 if (sk_psock_is_new) { 198 ret = tcp_bpf_init(sk); 199 if (ret < 0) 200 goto out_drop; 201 } else { 202 tcp_bpf_reinit(sk); 203 } 204 205 write_lock_bh(&sk->sk_callback_lock); 206 if (skb_progs && !psock->parser.enabled) { 207 ret = sk_psock_init_strp(sk, psock); 208 if (ret) { 209 write_unlock_bh(&sk->sk_callback_lock); 210 goto out_drop; 211 } 212 psock_set_prog(&psock->progs.skb_verdict, skb_verdict); 213 psock_set_prog(&psock->progs.skb_parser, skb_parser); 214 sk_psock_start_strp(sk, psock); 215 } 216 write_unlock_bh(&sk->sk_callback_lock); 217 return 0; 218 out_drop: 219 sk_psock_put(sk, psock); 220 out_progs: 221 if (msg_parser) 222 bpf_prog_put(msg_parser); 223 out: 224 if (skb_progs) { 225 bpf_prog_put(skb_verdict); 226 bpf_prog_put(skb_parser); 227 } 228 return ret; 229 } 230 231 static void sock_map_free(struct bpf_map *map) 232 { 233 struct bpf_stab *stab = container_of(map, struct bpf_stab, map); 234 int i; 235 236 synchronize_rcu(); 237 raw_spin_lock_bh(&stab->lock); 238 for (i = 0; i < stab->map.max_entries; i++) { 239 struct sock **psk = &stab->sks[i]; 240 struct sock *sk; 241 242 sk = xchg(psk, NULL); 243 if (sk) { 244 lock_sock(sk); 245 rcu_read_lock(); 246 sock_map_unref(sk, psk); 247 rcu_read_unlock(); 248 release_sock(sk); 249 } 250 } 251 raw_spin_unlock_bh(&stab->lock); 252 253 /* wait for psock readers accessing its map link */ 254 synchronize_rcu(); 255 256 bpf_map_area_free(stab->sks); 257 kfree(stab); 258 } 259 260 static void sock_map_release_progs(struct bpf_map *map) 261 { 262 psock_progs_drop(&container_of(map, struct bpf_stab, map)->progs); 263 } 264 265 static struct sock *__sock_map_lookup_elem(struct bpf_map *map, u32 key) 266 { 267 struct bpf_stab *stab = container_of(map, struct bpf_stab, map); 268 269 WARN_ON_ONCE(!rcu_read_lock_held()); 270 271 if (unlikely(key >= map->max_entries)) 272 return NULL; 273 return READ_ONCE(stab->sks[key]); 274 } 275 276 static void *sock_map_lookup(struct bpf_map *map, void *key) 277 { 278 return ERR_PTR(-EOPNOTSUPP); 279 } 280 281 static int __sock_map_delete(struct bpf_stab *stab, struct sock *sk_test, 282 struct sock **psk) 283 { 284 struct sock *sk; 285 int err = 0; 286 287 raw_spin_lock_bh(&stab->lock); 288 sk = *psk; 289 if (!sk_test || sk_test == sk) 290 sk = xchg(psk, NULL); 291 292 if (likely(sk)) 293 sock_map_unref(sk, psk); 294 else 295 err = -EINVAL; 296 297 raw_spin_unlock_bh(&stab->lock); 298 return err; 299 } 300 301 static void sock_map_delete_from_link(struct bpf_map *map, struct sock *sk, 302 void *link_raw) 303 { 304 struct bpf_stab *stab = container_of(map, struct bpf_stab, map); 305 306 __sock_map_delete(stab, sk, link_raw); 307 } 308 309 static int sock_map_delete_elem(struct bpf_map *map, void *key) 310 { 311 struct bpf_stab *stab = container_of(map, struct bpf_stab, map); 312 u32 i = *(u32 *)key; 313 struct sock **psk; 314 315 if (unlikely(i >= map->max_entries)) 316 return -EINVAL; 317 318 psk = &stab->sks[i]; 319 return __sock_map_delete(stab, NULL, psk); 320 } 321 322 static int sock_map_get_next_key(struct bpf_map *map, void *key, void *next) 323 { 324 struct bpf_stab *stab = container_of(map, struct bpf_stab, map); 325 u32 i = key ? *(u32 *)key : U32_MAX; 326 u32 *key_next = next; 327 328 if (i == stab->map.max_entries - 1) 329 return -ENOENT; 330 if (i >= stab->map.max_entries) 331 *key_next = 0; 332 else 333 *key_next = i + 1; 334 return 0; 335 } 336 337 static int sock_map_update_common(struct bpf_map *map, u32 idx, 338 struct sock *sk, u64 flags) 339 { 340 struct bpf_stab *stab = container_of(map, struct bpf_stab, map); 341 struct inet_connection_sock *icsk = inet_csk(sk); 342 struct sk_psock_link *link; 343 struct sk_psock *psock; 344 struct sock *osk; 345 int ret; 346 347 WARN_ON_ONCE(!rcu_read_lock_held()); 348 if (unlikely(flags > BPF_EXIST)) 349 return -EINVAL; 350 if (unlikely(idx >= map->max_entries)) 351 return -E2BIG; 352 if (unlikely(rcu_access_pointer(icsk->icsk_ulp_data))) 353 return -EINVAL; 354 355 link = sk_psock_init_link(); 356 if (!link) 357 return -ENOMEM; 358 359 ret = sock_map_link(map, &stab->progs, sk); 360 if (ret < 0) 361 goto out_free; 362 363 psock = sk_psock(sk); 364 WARN_ON_ONCE(!psock); 365 366 raw_spin_lock_bh(&stab->lock); 367 osk = stab->sks[idx]; 368 if (osk && flags == BPF_NOEXIST) { 369 ret = -EEXIST; 370 goto out_unlock; 371 } else if (!osk && flags == BPF_EXIST) { 372 ret = -ENOENT; 373 goto out_unlock; 374 } 375 376 sock_map_add_link(psock, link, map, &stab->sks[idx]); 377 stab->sks[idx] = sk; 378 if (osk) 379 sock_map_unref(osk, &stab->sks[idx]); 380 raw_spin_unlock_bh(&stab->lock); 381 return 0; 382 out_unlock: 383 raw_spin_unlock_bh(&stab->lock); 384 if (psock) 385 sk_psock_put(sk, psock); 386 out_free: 387 sk_psock_free_link(link); 388 return ret; 389 } 390 391 static bool sock_map_op_okay(const struct bpf_sock_ops_kern *ops) 392 { 393 return ops->op == BPF_SOCK_OPS_PASSIVE_ESTABLISHED_CB || 394 ops->op == BPF_SOCK_OPS_ACTIVE_ESTABLISHED_CB; 395 } 396 397 static bool sock_map_sk_is_suitable(const struct sock *sk) 398 { 399 return sk->sk_type == SOCK_STREAM && 400 sk->sk_protocol == IPPROTO_TCP; 401 } 402 403 static int sock_map_update_elem(struct bpf_map *map, void *key, 404 void *value, u64 flags) 405 { 406 u32 ufd = *(u32 *)value; 407 u32 idx = *(u32 *)key; 408 struct socket *sock; 409 struct sock *sk; 410 int ret; 411 412 sock = sockfd_lookup(ufd, &ret); 413 if (!sock) 414 return ret; 415 sk = sock->sk; 416 if (!sk) { 417 ret = -EINVAL; 418 goto out; 419 } 420 if (!sock_map_sk_is_suitable(sk)) { 421 ret = -EOPNOTSUPP; 422 goto out; 423 } 424 425 sock_map_sk_acquire(sk); 426 if (sk->sk_state != TCP_ESTABLISHED) 427 ret = -EOPNOTSUPP; 428 else 429 ret = sock_map_update_common(map, idx, sk, flags); 430 sock_map_sk_release(sk); 431 out: 432 fput(sock->file); 433 return ret; 434 } 435 436 BPF_CALL_4(bpf_sock_map_update, struct bpf_sock_ops_kern *, sops, 437 struct bpf_map *, map, void *, key, u64, flags) 438 { 439 WARN_ON_ONCE(!rcu_read_lock_held()); 440 441 if (likely(sock_map_sk_is_suitable(sops->sk) && 442 sock_map_op_okay(sops))) 443 return sock_map_update_common(map, *(u32 *)key, sops->sk, 444 flags); 445 return -EOPNOTSUPP; 446 } 447 448 const struct bpf_func_proto bpf_sock_map_update_proto = { 449 .func = bpf_sock_map_update, 450 .gpl_only = false, 451 .pkt_access = true, 452 .ret_type = RET_INTEGER, 453 .arg1_type = ARG_PTR_TO_CTX, 454 .arg2_type = ARG_CONST_MAP_PTR, 455 .arg3_type = ARG_PTR_TO_MAP_KEY, 456 .arg4_type = ARG_ANYTHING, 457 }; 458 459 BPF_CALL_4(bpf_sk_redirect_map, struct sk_buff *, skb, 460 struct bpf_map *, map, u32, key, u64, flags) 461 { 462 struct tcp_skb_cb *tcb = TCP_SKB_CB(skb); 463 464 if (unlikely(flags & ~(BPF_F_INGRESS))) 465 return SK_DROP; 466 tcb->bpf.flags = flags; 467 tcb->bpf.sk_redir = __sock_map_lookup_elem(map, key); 468 if (!tcb->bpf.sk_redir) 469 return SK_DROP; 470 return SK_PASS; 471 } 472 473 const struct bpf_func_proto bpf_sk_redirect_map_proto = { 474 .func = bpf_sk_redirect_map, 475 .gpl_only = false, 476 .ret_type = RET_INTEGER, 477 .arg1_type = ARG_PTR_TO_CTX, 478 .arg2_type = ARG_CONST_MAP_PTR, 479 .arg3_type = ARG_ANYTHING, 480 .arg4_type = ARG_ANYTHING, 481 }; 482 483 BPF_CALL_4(bpf_msg_redirect_map, struct sk_msg *, msg, 484 struct bpf_map *, map, u32, key, u64, flags) 485 { 486 if (unlikely(flags & ~(BPF_F_INGRESS))) 487 return SK_DROP; 488 msg->flags = flags; 489 msg->sk_redir = __sock_map_lookup_elem(map, key); 490 if (!msg->sk_redir) 491 return SK_DROP; 492 return SK_PASS; 493 } 494 495 const struct bpf_func_proto bpf_msg_redirect_map_proto = { 496 .func = bpf_msg_redirect_map, 497 .gpl_only = false, 498 .ret_type = RET_INTEGER, 499 .arg1_type = ARG_PTR_TO_CTX, 500 .arg2_type = ARG_CONST_MAP_PTR, 501 .arg3_type = ARG_ANYTHING, 502 .arg4_type = ARG_ANYTHING, 503 }; 504 505 const struct bpf_map_ops sock_map_ops = { 506 .map_alloc = sock_map_alloc, 507 .map_free = sock_map_free, 508 .map_get_next_key = sock_map_get_next_key, 509 .map_update_elem = sock_map_update_elem, 510 .map_delete_elem = sock_map_delete_elem, 511 .map_lookup_elem = sock_map_lookup, 512 .map_release_uref = sock_map_release_progs, 513 .map_check_btf = map_check_no_btf, 514 }; 515 516 struct bpf_htab_elem { 517 struct rcu_head rcu; 518 u32 hash; 519 struct sock *sk; 520 struct hlist_node node; 521 u8 key[0]; 522 }; 523 524 struct bpf_htab_bucket { 525 struct hlist_head head; 526 raw_spinlock_t lock; 527 }; 528 529 struct bpf_htab { 530 struct bpf_map map; 531 struct bpf_htab_bucket *buckets; 532 u32 buckets_num; 533 u32 elem_size; 534 struct sk_psock_progs progs; 535 atomic_t count; 536 }; 537 538 static inline u32 sock_hash_bucket_hash(const void *key, u32 len) 539 { 540 return jhash(key, len, 0); 541 } 542 543 static struct bpf_htab_bucket *sock_hash_select_bucket(struct bpf_htab *htab, 544 u32 hash) 545 { 546 return &htab->buckets[hash & (htab->buckets_num - 1)]; 547 } 548 549 static struct bpf_htab_elem * 550 sock_hash_lookup_elem_raw(struct hlist_head *head, u32 hash, void *key, 551 u32 key_size) 552 { 553 struct bpf_htab_elem *elem; 554 555 hlist_for_each_entry_rcu(elem, head, node) { 556 if (elem->hash == hash && 557 !memcmp(&elem->key, key, key_size)) 558 return elem; 559 } 560 561 return NULL; 562 } 563 564 static struct sock *__sock_hash_lookup_elem(struct bpf_map *map, void *key) 565 { 566 struct bpf_htab *htab = container_of(map, struct bpf_htab, map); 567 u32 key_size = map->key_size, hash; 568 struct bpf_htab_bucket *bucket; 569 struct bpf_htab_elem *elem; 570 571 WARN_ON_ONCE(!rcu_read_lock_held()); 572 573 hash = sock_hash_bucket_hash(key, key_size); 574 bucket = sock_hash_select_bucket(htab, hash); 575 elem = sock_hash_lookup_elem_raw(&bucket->head, hash, key, key_size); 576 577 return elem ? elem->sk : NULL; 578 } 579 580 static void sock_hash_free_elem(struct bpf_htab *htab, 581 struct bpf_htab_elem *elem) 582 { 583 atomic_dec(&htab->count); 584 kfree_rcu(elem, rcu); 585 } 586 587 static void sock_hash_delete_from_link(struct bpf_map *map, struct sock *sk, 588 void *link_raw) 589 { 590 struct bpf_htab *htab = container_of(map, struct bpf_htab, map); 591 struct bpf_htab_elem *elem_probe, *elem = link_raw; 592 struct bpf_htab_bucket *bucket; 593 594 WARN_ON_ONCE(!rcu_read_lock_held()); 595 bucket = sock_hash_select_bucket(htab, elem->hash); 596 597 /* elem may be deleted in parallel from the map, but access here 598 * is okay since it's going away only after RCU grace period. 599 * However, we need to check whether it's still present. 600 */ 601 raw_spin_lock_bh(&bucket->lock); 602 elem_probe = sock_hash_lookup_elem_raw(&bucket->head, elem->hash, 603 elem->key, map->key_size); 604 if (elem_probe && elem_probe == elem) { 605 hlist_del_rcu(&elem->node); 606 sock_map_unref(elem->sk, elem); 607 sock_hash_free_elem(htab, elem); 608 } 609 raw_spin_unlock_bh(&bucket->lock); 610 } 611 612 static int sock_hash_delete_elem(struct bpf_map *map, void *key) 613 { 614 struct bpf_htab *htab = container_of(map, struct bpf_htab, map); 615 u32 hash, key_size = map->key_size; 616 struct bpf_htab_bucket *bucket; 617 struct bpf_htab_elem *elem; 618 int ret = -ENOENT; 619 620 hash = sock_hash_bucket_hash(key, key_size); 621 bucket = sock_hash_select_bucket(htab, hash); 622 623 raw_spin_lock_bh(&bucket->lock); 624 elem = sock_hash_lookup_elem_raw(&bucket->head, hash, key, key_size); 625 if (elem) { 626 hlist_del_rcu(&elem->node); 627 sock_map_unref(elem->sk, elem); 628 sock_hash_free_elem(htab, elem); 629 ret = 0; 630 } 631 raw_spin_unlock_bh(&bucket->lock); 632 return ret; 633 } 634 635 static struct bpf_htab_elem *sock_hash_alloc_elem(struct bpf_htab *htab, 636 void *key, u32 key_size, 637 u32 hash, struct sock *sk, 638 struct bpf_htab_elem *old) 639 { 640 struct bpf_htab_elem *new; 641 642 if (atomic_inc_return(&htab->count) > htab->map.max_entries) { 643 if (!old) { 644 atomic_dec(&htab->count); 645 return ERR_PTR(-E2BIG); 646 } 647 } 648 649 new = kmalloc_node(htab->elem_size, GFP_ATOMIC | __GFP_NOWARN, 650 htab->map.numa_node); 651 if (!new) { 652 atomic_dec(&htab->count); 653 return ERR_PTR(-ENOMEM); 654 } 655 memcpy(new->key, key, key_size); 656 new->sk = sk; 657 new->hash = hash; 658 return new; 659 } 660 661 static int sock_hash_update_common(struct bpf_map *map, void *key, 662 struct sock *sk, u64 flags) 663 { 664 struct bpf_htab *htab = container_of(map, struct bpf_htab, map); 665 struct inet_connection_sock *icsk = inet_csk(sk); 666 u32 key_size = map->key_size, hash; 667 struct bpf_htab_elem *elem, *elem_new; 668 struct bpf_htab_bucket *bucket; 669 struct sk_psock_link *link; 670 struct sk_psock *psock; 671 int ret; 672 673 WARN_ON_ONCE(!rcu_read_lock_held()); 674 if (unlikely(flags > BPF_EXIST)) 675 return -EINVAL; 676 if (unlikely(icsk->icsk_ulp_data)) 677 return -EINVAL; 678 679 link = sk_psock_init_link(); 680 if (!link) 681 return -ENOMEM; 682 683 ret = sock_map_link(map, &htab->progs, sk); 684 if (ret < 0) 685 goto out_free; 686 687 psock = sk_psock(sk); 688 WARN_ON_ONCE(!psock); 689 690 hash = sock_hash_bucket_hash(key, key_size); 691 bucket = sock_hash_select_bucket(htab, hash); 692 693 raw_spin_lock_bh(&bucket->lock); 694 elem = sock_hash_lookup_elem_raw(&bucket->head, hash, key, key_size); 695 if (elem && flags == BPF_NOEXIST) { 696 ret = -EEXIST; 697 goto out_unlock; 698 } else if (!elem && flags == BPF_EXIST) { 699 ret = -ENOENT; 700 goto out_unlock; 701 } 702 703 elem_new = sock_hash_alloc_elem(htab, key, key_size, hash, sk, elem); 704 if (IS_ERR(elem_new)) { 705 ret = PTR_ERR(elem_new); 706 goto out_unlock; 707 } 708 709 sock_map_add_link(psock, link, map, elem_new); 710 /* Add new element to the head of the list, so that 711 * concurrent search will find it before old elem. 712 */ 713 hlist_add_head_rcu(&elem_new->node, &bucket->head); 714 if (elem) { 715 hlist_del_rcu(&elem->node); 716 sock_map_unref(elem->sk, elem); 717 sock_hash_free_elem(htab, elem); 718 } 719 raw_spin_unlock_bh(&bucket->lock); 720 return 0; 721 out_unlock: 722 raw_spin_unlock_bh(&bucket->lock); 723 sk_psock_put(sk, psock); 724 out_free: 725 sk_psock_free_link(link); 726 return ret; 727 } 728 729 static int sock_hash_update_elem(struct bpf_map *map, void *key, 730 void *value, u64 flags) 731 { 732 u32 ufd = *(u32 *)value; 733 struct socket *sock; 734 struct sock *sk; 735 int ret; 736 737 sock = sockfd_lookup(ufd, &ret); 738 if (!sock) 739 return ret; 740 sk = sock->sk; 741 if (!sk) { 742 ret = -EINVAL; 743 goto out; 744 } 745 if (!sock_map_sk_is_suitable(sk)) { 746 ret = -EOPNOTSUPP; 747 goto out; 748 } 749 750 sock_map_sk_acquire(sk); 751 if (sk->sk_state != TCP_ESTABLISHED) 752 ret = -EOPNOTSUPP; 753 else 754 ret = sock_hash_update_common(map, key, sk, flags); 755 sock_map_sk_release(sk); 756 out: 757 fput(sock->file); 758 return ret; 759 } 760 761 static int sock_hash_get_next_key(struct bpf_map *map, void *key, 762 void *key_next) 763 { 764 struct bpf_htab *htab = container_of(map, struct bpf_htab, map); 765 struct bpf_htab_elem *elem, *elem_next; 766 u32 hash, key_size = map->key_size; 767 struct hlist_head *head; 768 int i = 0; 769 770 if (!key) 771 goto find_first_elem; 772 hash = sock_hash_bucket_hash(key, key_size); 773 head = &sock_hash_select_bucket(htab, hash)->head; 774 elem = sock_hash_lookup_elem_raw(head, hash, key, key_size); 775 if (!elem) 776 goto find_first_elem; 777 778 elem_next = hlist_entry_safe(rcu_dereference_raw(hlist_next_rcu(&elem->node)), 779 struct bpf_htab_elem, node); 780 if (elem_next) { 781 memcpy(key_next, elem_next->key, key_size); 782 return 0; 783 } 784 785 i = hash & (htab->buckets_num - 1); 786 i++; 787 find_first_elem: 788 for (; i < htab->buckets_num; i++) { 789 head = &sock_hash_select_bucket(htab, i)->head; 790 elem_next = hlist_entry_safe(rcu_dereference_raw(hlist_first_rcu(head)), 791 struct bpf_htab_elem, node); 792 if (elem_next) { 793 memcpy(key_next, elem_next->key, key_size); 794 return 0; 795 } 796 } 797 798 return -ENOENT; 799 } 800 801 static struct bpf_map *sock_hash_alloc(union bpf_attr *attr) 802 { 803 struct bpf_htab *htab; 804 int i, err; 805 u64 cost; 806 807 if (!capable(CAP_NET_ADMIN)) 808 return ERR_PTR(-EPERM); 809 if (attr->max_entries == 0 || 810 attr->key_size == 0 || 811 attr->value_size != 4 || 812 attr->map_flags & ~SOCK_CREATE_FLAG_MASK) 813 return ERR_PTR(-EINVAL); 814 if (attr->key_size > MAX_BPF_STACK) 815 return ERR_PTR(-E2BIG); 816 817 htab = kzalloc(sizeof(*htab), GFP_USER); 818 if (!htab) 819 return ERR_PTR(-ENOMEM); 820 821 bpf_map_init_from_attr(&htab->map, attr); 822 823 htab->buckets_num = roundup_pow_of_two(htab->map.max_entries); 824 htab->elem_size = sizeof(struct bpf_htab_elem) + 825 round_up(htab->map.key_size, 8); 826 if (htab->buckets_num == 0 || 827 htab->buckets_num > U32_MAX / sizeof(struct bpf_htab_bucket)) { 828 err = -EINVAL; 829 goto free_htab; 830 } 831 832 cost = (u64) htab->buckets_num * sizeof(struct bpf_htab_bucket) + 833 (u64) htab->elem_size * htab->map.max_entries; 834 if (cost >= U32_MAX - PAGE_SIZE) { 835 err = -EINVAL; 836 goto free_htab; 837 } 838 839 htab->buckets = bpf_map_area_alloc(htab->buckets_num * 840 sizeof(struct bpf_htab_bucket), 841 htab->map.numa_node); 842 if (!htab->buckets) { 843 err = -ENOMEM; 844 goto free_htab; 845 } 846 847 for (i = 0; i < htab->buckets_num; i++) { 848 INIT_HLIST_HEAD(&htab->buckets[i].head); 849 raw_spin_lock_init(&htab->buckets[i].lock); 850 } 851 852 return &htab->map; 853 free_htab: 854 kfree(htab); 855 return ERR_PTR(err); 856 } 857 858 static void sock_hash_free(struct bpf_map *map) 859 { 860 struct bpf_htab *htab = container_of(map, struct bpf_htab, map); 861 struct bpf_htab_bucket *bucket; 862 struct bpf_htab_elem *elem; 863 struct hlist_node *node; 864 int i; 865 866 synchronize_rcu(); 867 for (i = 0; i < htab->buckets_num; i++) { 868 bucket = sock_hash_select_bucket(htab, i); 869 raw_spin_lock_bh(&bucket->lock); 870 hlist_for_each_entry_safe(elem, node, &bucket->head, node) { 871 hlist_del_rcu(&elem->node); 872 lock_sock(elem->sk); 873 rcu_read_lock(); 874 sock_map_unref(elem->sk, elem); 875 rcu_read_unlock(); 876 release_sock(elem->sk); 877 } 878 raw_spin_unlock_bh(&bucket->lock); 879 } 880 881 /* wait for psock readers accessing its map link */ 882 synchronize_rcu(); 883 884 bpf_map_area_free(htab->buckets); 885 kfree(htab); 886 } 887 888 static void sock_hash_release_progs(struct bpf_map *map) 889 { 890 psock_progs_drop(&container_of(map, struct bpf_htab, map)->progs); 891 } 892 893 BPF_CALL_4(bpf_sock_hash_update, struct bpf_sock_ops_kern *, sops, 894 struct bpf_map *, map, void *, key, u64, flags) 895 { 896 WARN_ON_ONCE(!rcu_read_lock_held()); 897 898 if (likely(sock_map_sk_is_suitable(sops->sk) && 899 sock_map_op_okay(sops))) 900 return sock_hash_update_common(map, key, sops->sk, flags); 901 return -EOPNOTSUPP; 902 } 903 904 const struct bpf_func_proto bpf_sock_hash_update_proto = { 905 .func = bpf_sock_hash_update, 906 .gpl_only = false, 907 .pkt_access = true, 908 .ret_type = RET_INTEGER, 909 .arg1_type = ARG_PTR_TO_CTX, 910 .arg2_type = ARG_CONST_MAP_PTR, 911 .arg3_type = ARG_PTR_TO_MAP_KEY, 912 .arg4_type = ARG_ANYTHING, 913 }; 914 915 BPF_CALL_4(bpf_sk_redirect_hash, struct sk_buff *, skb, 916 struct bpf_map *, map, void *, key, u64, flags) 917 { 918 struct tcp_skb_cb *tcb = TCP_SKB_CB(skb); 919 920 if (unlikely(flags & ~(BPF_F_INGRESS))) 921 return SK_DROP; 922 tcb->bpf.flags = flags; 923 tcb->bpf.sk_redir = __sock_hash_lookup_elem(map, key); 924 if (!tcb->bpf.sk_redir) 925 return SK_DROP; 926 return SK_PASS; 927 } 928 929 const struct bpf_func_proto bpf_sk_redirect_hash_proto = { 930 .func = bpf_sk_redirect_hash, 931 .gpl_only = false, 932 .ret_type = RET_INTEGER, 933 .arg1_type = ARG_PTR_TO_CTX, 934 .arg2_type = ARG_CONST_MAP_PTR, 935 .arg3_type = ARG_PTR_TO_MAP_KEY, 936 .arg4_type = ARG_ANYTHING, 937 }; 938 939 BPF_CALL_4(bpf_msg_redirect_hash, struct sk_msg *, msg, 940 struct bpf_map *, map, void *, key, u64, flags) 941 { 942 if (unlikely(flags & ~(BPF_F_INGRESS))) 943 return SK_DROP; 944 msg->flags = flags; 945 msg->sk_redir = __sock_hash_lookup_elem(map, key); 946 if (!msg->sk_redir) 947 return SK_DROP; 948 return SK_PASS; 949 } 950 951 const struct bpf_func_proto bpf_msg_redirect_hash_proto = { 952 .func = bpf_msg_redirect_hash, 953 .gpl_only = false, 954 .ret_type = RET_INTEGER, 955 .arg1_type = ARG_PTR_TO_CTX, 956 .arg2_type = ARG_CONST_MAP_PTR, 957 .arg3_type = ARG_PTR_TO_MAP_KEY, 958 .arg4_type = ARG_ANYTHING, 959 }; 960 961 const struct bpf_map_ops sock_hash_ops = { 962 .map_alloc = sock_hash_alloc, 963 .map_free = sock_hash_free, 964 .map_get_next_key = sock_hash_get_next_key, 965 .map_update_elem = sock_hash_update_elem, 966 .map_delete_elem = sock_hash_delete_elem, 967 .map_lookup_elem = sock_map_lookup, 968 .map_release_uref = sock_hash_release_progs, 969 .map_check_btf = map_check_no_btf, 970 }; 971 972 static struct sk_psock_progs *sock_map_progs(struct bpf_map *map) 973 { 974 switch (map->map_type) { 975 case BPF_MAP_TYPE_SOCKMAP: 976 return &container_of(map, struct bpf_stab, map)->progs; 977 case BPF_MAP_TYPE_SOCKHASH: 978 return &container_of(map, struct bpf_htab, map)->progs; 979 default: 980 break; 981 } 982 983 return NULL; 984 } 985 986 int sock_map_prog_update(struct bpf_map *map, struct bpf_prog *prog, 987 u32 which) 988 { 989 struct sk_psock_progs *progs = sock_map_progs(map); 990 991 if (!progs) 992 return -EOPNOTSUPP; 993 994 switch (which) { 995 case BPF_SK_MSG_VERDICT: 996 psock_set_prog(&progs->msg_parser, prog); 997 break; 998 case BPF_SK_SKB_STREAM_PARSER: 999 psock_set_prog(&progs->skb_parser, prog); 1000 break; 1001 case BPF_SK_SKB_STREAM_VERDICT: 1002 psock_set_prog(&progs->skb_verdict, prog); 1003 break; 1004 default: 1005 return -EOPNOTSUPP; 1006 } 1007 1008 return 0; 1009 } 1010 1011 void sk_psock_unlink(struct sock *sk, struct sk_psock_link *link) 1012 { 1013 switch (link->map->map_type) { 1014 case BPF_MAP_TYPE_SOCKMAP: 1015 return sock_map_delete_from_link(link->map, sk, 1016 link->link_raw); 1017 case BPF_MAP_TYPE_SOCKHASH: 1018 return sock_hash_delete_from_link(link->map, sk, 1019 link->link_raw); 1020 default: 1021 break; 1022 } 1023 } 1024