1 // SPDX-License-Identifier: GPL-2.0 2 /* Copyright (c) 2017 - 2018 Covalent IO, Inc. http://covalent.io */ 3 4 #include <linux/bpf.h> 5 #include <linux/filter.h> 6 #include <linux/errno.h> 7 #include <linux/file.h> 8 #include <linux/net.h> 9 #include <linux/workqueue.h> 10 #include <linux/skmsg.h> 11 #include <linux/list.h> 12 #include <linux/jhash.h> 13 14 struct bpf_stab { 15 struct bpf_map map; 16 struct sock **sks; 17 struct sk_psock_progs progs; 18 raw_spinlock_t lock; 19 }; 20 21 #define SOCK_CREATE_FLAG_MASK \ 22 (BPF_F_NUMA_NODE | BPF_F_RDONLY | BPF_F_WRONLY) 23 24 static struct bpf_map *sock_map_alloc(union bpf_attr *attr) 25 { 26 struct bpf_stab *stab; 27 u64 cost; 28 int err; 29 30 if (!capable(CAP_NET_ADMIN)) 31 return ERR_PTR(-EPERM); 32 if (attr->max_entries == 0 || 33 attr->key_size != 4 || 34 attr->value_size != 4 || 35 attr->map_flags & ~SOCK_CREATE_FLAG_MASK) 36 return ERR_PTR(-EINVAL); 37 38 stab = kzalloc(sizeof(*stab), GFP_USER); 39 if (!stab) 40 return ERR_PTR(-ENOMEM); 41 42 bpf_map_init_from_attr(&stab->map, attr); 43 raw_spin_lock_init(&stab->lock); 44 45 /* Make sure page count doesn't overflow. */ 46 cost = (u64) stab->map.max_entries * sizeof(struct sock *); 47 err = bpf_map_charge_init(&stab->map.memory, cost); 48 if (err) 49 goto free_stab; 50 51 stab->sks = bpf_map_area_alloc(stab->map.max_entries * 52 sizeof(struct sock *), 53 stab->map.numa_node); 54 if (stab->sks) 55 return &stab->map; 56 err = -ENOMEM; 57 bpf_map_charge_finish(&stab->map.memory); 58 free_stab: 59 kfree(stab); 60 return ERR_PTR(err); 61 } 62 63 int sock_map_get_from_fd(const union bpf_attr *attr, struct bpf_prog *prog) 64 { 65 u32 ufd = attr->target_fd; 66 struct bpf_map *map; 67 struct fd f; 68 int ret; 69 70 f = fdget(ufd); 71 map = __bpf_map_get(f); 72 if (IS_ERR(map)) 73 return PTR_ERR(map); 74 ret = sock_map_prog_update(map, prog, attr->attach_type); 75 fdput(f); 76 return ret; 77 } 78 79 static void sock_map_sk_acquire(struct sock *sk) 80 __acquires(&sk->sk_lock.slock) 81 { 82 lock_sock(sk); 83 preempt_disable(); 84 rcu_read_lock(); 85 } 86 87 static void sock_map_sk_release(struct sock *sk) 88 __releases(&sk->sk_lock.slock) 89 { 90 rcu_read_unlock(); 91 preempt_enable(); 92 release_sock(sk); 93 } 94 95 static void sock_map_add_link(struct sk_psock *psock, 96 struct sk_psock_link *link, 97 struct bpf_map *map, void *link_raw) 98 { 99 link->link_raw = link_raw; 100 link->map = map; 101 spin_lock_bh(&psock->link_lock); 102 list_add_tail(&link->list, &psock->link); 103 spin_unlock_bh(&psock->link_lock); 104 } 105 106 static void sock_map_del_link(struct sock *sk, 107 struct sk_psock *psock, void *link_raw) 108 { 109 struct sk_psock_link *link, *tmp; 110 bool strp_stop = false; 111 112 spin_lock_bh(&psock->link_lock); 113 list_for_each_entry_safe(link, tmp, &psock->link, list) { 114 if (link->link_raw == link_raw) { 115 struct bpf_map *map = link->map; 116 struct bpf_stab *stab = container_of(map, struct bpf_stab, 117 map); 118 if (psock->parser.enabled && stab->progs.skb_parser) 119 strp_stop = true; 120 list_del(&link->list); 121 sk_psock_free_link(link); 122 } 123 } 124 spin_unlock_bh(&psock->link_lock); 125 if (strp_stop) { 126 write_lock_bh(&sk->sk_callback_lock); 127 sk_psock_stop_strp(sk, psock); 128 write_unlock_bh(&sk->sk_callback_lock); 129 } 130 } 131 132 static void sock_map_unref(struct sock *sk, void *link_raw) 133 { 134 struct sk_psock *psock = sk_psock(sk); 135 136 if (likely(psock)) { 137 sock_map_del_link(sk, psock, link_raw); 138 sk_psock_put(sk, psock); 139 } 140 } 141 142 static int sock_map_link(struct bpf_map *map, struct sk_psock_progs *progs, 143 struct sock *sk) 144 { 145 struct bpf_prog *msg_parser, *skb_parser, *skb_verdict; 146 bool skb_progs, sk_psock_is_new = false; 147 struct sk_psock *psock; 148 int ret; 149 150 skb_verdict = READ_ONCE(progs->skb_verdict); 151 skb_parser = READ_ONCE(progs->skb_parser); 152 skb_progs = skb_parser && skb_verdict; 153 if (skb_progs) { 154 skb_verdict = bpf_prog_inc_not_zero(skb_verdict); 155 if (IS_ERR(skb_verdict)) 156 return PTR_ERR(skb_verdict); 157 skb_parser = bpf_prog_inc_not_zero(skb_parser); 158 if (IS_ERR(skb_parser)) { 159 bpf_prog_put(skb_verdict); 160 return PTR_ERR(skb_parser); 161 } 162 } 163 164 msg_parser = READ_ONCE(progs->msg_parser); 165 if (msg_parser) { 166 msg_parser = bpf_prog_inc_not_zero(msg_parser); 167 if (IS_ERR(msg_parser)) { 168 ret = PTR_ERR(msg_parser); 169 goto out; 170 } 171 } 172 173 psock = sk_psock_get_checked(sk); 174 if (IS_ERR(psock)) { 175 ret = PTR_ERR(psock); 176 goto out_progs; 177 } 178 179 if (psock) { 180 if ((msg_parser && READ_ONCE(psock->progs.msg_parser)) || 181 (skb_progs && READ_ONCE(psock->progs.skb_parser))) { 182 sk_psock_put(sk, psock); 183 ret = -EBUSY; 184 goto out_progs; 185 } 186 } else { 187 psock = sk_psock_init(sk, map->numa_node); 188 if (!psock) { 189 ret = -ENOMEM; 190 goto out_progs; 191 } 192 sk_psock_is_new = true; 193 } 194 195 if (msg_parser) 196 psock_set_prog(&psock->progs.msg_parser, msg_parser); 197 if (sk_psock_is_new) { 198 ret = tcp_bpf_init(sk); 199 if (ret < 0) 200 goto out_drop; 201 } else { 202 tcp_bpf_reinit(sk); 203 } 204 205 write_lock_bh(&sk->sk_callback_lock); 206 if (skb_progs && !psock->parser.enabled) { 207 ret = sk_psock_init_strp(sk, psock); 208 if (ret) { 209 write_unlock_bh(&sk->sk_callback_lock); 210 goto out_drop; 211 } 212 psock_set_prog(&psock->progs.skb_verdict, skb_verdict); 213 psock_set_prog(&psock->progs.skb_parser, skb_parser); 214 sk_psock_start_strp(sk, psock); 215 } 216 write_unlock_bh(&sk->sk_callback_lock); 217 return 0; 218 out_drop: 219 sk_psock_put(sk, psock); 220 out_progs: 221 if (msg_parser) 222 bpf_prog_put(msg_parser); 223 out: 224 if (skb_progs) { 225 bpf_prog_put(skb_verdict); 226 bpf_prog_put(skb_parser); 227 } 228 return ret; 229 } 230 231 static void sock_map_free(struct bpf_map *map) 232 { 233 struct bpf_stab *stab = container_of(map, struct bpf_stab, map); 234 int i; 235 236 /* After the sync no updates or deletes will be in-flight so it 237 * is safe to walk map and remove entries without risking a race 238 * in EEXIST update case. 239 */ 240 synchronize_rcu(); 241 for (i = 0; i < stab->map.max_entries; i++) { 242 struct sock **psk = &stab->sks[i]; 243 struct sock *sk; 244 245 sk = xchg(psk, NULL); 246 if (sk) { 247 lock_sock(sk); 248 rcu_read_lock(); 249 sock_map_unref(sk, psk); 250 rcu_read_unlock(); 251 release_sock(sk); 252 } 253 } 254 255 /* wait for psock readers accessing its map link */ 256 synchronize_rcu(); 257 258 bpf_map_area_free(stab->sks); 259 kfree(stab); 260 } 261 262 static void sock_map_release_progs(struct bpf_map *map) 263 { 264 psock_progs_drop(&container_of(map, struct bpf_stab, map)->progs); 265 } 266 267 static struct sock *__sock_map_lookup_elem(struct bpf_map *map, u32 key) 268 { 269 struct bpf_stab *stab = container_of(map, struct bpf_stab, map); 270 271 WARN_ON_ONCE(!rcu_read_lock_held()); 272 273 if (unlikely(key >= map->max_entries)) 274 return NULL; 275 return READ_ONCE(stab->sks[key]); 276 } 277 278 static void *sock_map_lookup(struct bpf_map *map, void *key) 279 { 280 return ERR_PTR(-EOPNOTSUPP); 281 } 282 283 static int __sock_map_delete(struct bpf_stab *stab, struct sock *sk_test, 284 struct sock **psk) 285 { 286 struct sock *sk; 287 int err = 0; 288 289 raw_spin_lock_bh(&stab->lock); 290 sk = *psk; 291 if (!sk_test || sk_test == sk) 292 sk = xchg(psk, NULL); 293 294 if (likely(sk)) 295 sock_map_unref(sk, psk); 296 else 297 err = -EINVAL; 298 299 raw_spin_unlock_bh(&stab->lock); 300 return err; 301 } 302 303 static void sock_map_delete_from_link(struct bpf_map *map, struct sock *sk, 304 void *link_raw) 305 { 306 struct bpf_stab *stab = container_of(map, struct bpf_stab, map); 307 308 __sock_map_delete(stab, sk, link_raw); 309 } 310 311 static int sock_map_delete_elem(struct bpf_map *map, void *key) 312 { 313 struct bpf_stab *stab = container_of(map, struct bpf_stab, map); 314 u32 i = *(u32 *)key; 315 struct sock **psk; 316 317 if (unlikely(i >= map->max_entries)) 318 return -EINVAL; 319 320 psk = &stab->sks[i]; 321 return __sock_map_delete(stab, NULL, psk); 322 } 323 324 static int sock_map_get_next_key(struct bpf_map *map, void *key, void *next) 325 { 326 struct bpf_stab *stab = container_of(map, struct bpf_stab, map); 327 u32 i = key ? *(u32 *)key : U32_MAX; 328 u32 *key_next = next; 329 330 if (i == stab->map.max_entries - 1) 331 return -ENOENT; 332 if (i >= stab->map.max_entries) 333 *key_next = 0; 334 else 335 *key_next = i + 1; 336 return 0; 337 } 338 339 static int sock_map_update_common(struct bpf_map *map, u32 idx, 340 struct sock *sk, u64 flags) 341 { 342 struct bpf_stab *stab = container_of(map, struct bpf_stab, map); 343 struct inet_connection_sock *icsk = inet_csk(sk); 344 struct sk_psock_link *link; 345 struct sk_psock *psock; 346 struct sock *osk; 347 int ret; 348 349 WARN_ON_ONCE(!rcu_read_lock_held()); 350 if (unlikely(flags > BPF_EXIST)) 351 return -EINVAL; 352 if (unlikely(idx >= map->max_entries)) 353 return -E2BIG; 354 if (unlikely(rcu_access_pointer(icsk->icsk_ulp_data))) 355 return -EINVAL; 356 357 link = sk_psock_init_link(); 358 if (!link) 359 return -ENOMEM; 360 361 ret = sock_map_link(map, &stab->progs, sk); 362 if (ret < 0) 363 goto out_free; 364 365 psock = sk_psock(sk); 366 WARN_ON_ONCE(!psock); 367 368 raw_spin_lock_bh(&stab->lock); 369 osk = stab->sks[idx]; 370 if (osk && flags == BPF_NOEXIST) { 371 ret = -EEXIST; 372 goto out_unlock; 373 } else if (!osk && flags == BPF_EXIST) { 374 ret = -ENOENT; 375 goto out_unlock; 376 } 377 378 sock_map_add_link(psock, link, map, &stab->sks[idx]); 379 stab->sks[idx] = sk; 380 if (osk) 381 sock_map_unref(osk, &stab->sks[idx]); 382 raw_spin_unlock_bh(&stab->lock); 383 return 0; 384 out_unlock: 385 raw_spin_unlock_bh(&stab->lock); 386 if (psock) 387 sk_psock_put(sk, psock); 388 out_free: 389 sk_psock_free_link(link); 390 return ret; 391 } 392 393 static bool sock_map_op_okay(const struct bpf_sock_ops_kern *ops) 394 { 395 return ops->op == BPF_SOCK_OPS_PASSIVE_ESTABLISHED_CB || 396 ops->op == BPF_SOCK_OPS_ACTIVE_ESTABLISHED_CB; 397 } 398 399 static bool sock_map_sk_is_suitable(const struct sock *sk) 400 { 401 return sk->sk_type == SOCK_STREAM && 402 sk->sk_protocol == IPPROTO_TCP; 403 } 404 405 static int sock_map_update_elem(struct bpf_map *map, void *key, 406 void *value, u64 flags) 407 { 408 u32 ufd = *(u32 *)value; 409 u32 idx = *(u32 *)key; 410 struct socket *sock; 411 struct sock *sk; 412 int ret; 413 414 sock = sockfd_lookup(ufd, &ret); 415 if (!sock) 416 return ret; 417 sk = sock->sk; 418 if (!sk) { 419 ret = -EINVAL; 420 goto out; 421 } 422 if (!sock_map_sk_is_suitable(sk)) { 423 ret = -EOPNOTSUPP; 424 goto out; 425 } 426 427 sock_map_sk_acquire(sk); 428 if (sk->sk_state != TCP_ESTABLISHED) 429 ret = -EOPNOTSUPP; 430 else 431 ret = sock_map_update_common(map, idx, sk, flags); 432 sock_map_sk_release(sk); 433 out: 434 fput(sock->file); 435 return ret; 436 } 437 438 BPF_CALL_4(bpf_sock_map_update, struct bpf_sock_ops_kern *, sops, 439 struct bpf_map *, map, void *, key, u64, flags) 440 { 441 WARN_ON_ONCE(!rcu_read_lock_held()); 442 443 if (likely(sock_map_sk_is_suitable(sops->sk) && 444 sock_map_op_okay(sops))) 445 return sock_map_update_common(map, *(u32 *)key, sops->sk, 446 flags); 447 return -EOPNOTSUPP; 448 } 449 450 const struct bpf_func_proto bpf_sock_map_update_proto = { 451 .func = bpf_sock_map_update, 452 .gpl_only = false, 453 .pkt_access = true, 454 .ret_type = RET_INTEGER, 455 .arg1_type = ARG_PTR_TO_CTX, 456 .arg2_type = ARG_CONST_MAP_PTR, 457 .arg3_type = ARG_PTR_TO_MAP_KEY, 458 .arg4_type = ARG_ANYTHING, 459 }; 460 461 BPF_CALL_4(bpf_sk_redirect_map, struct sk_buff *, skb, 462 struct bpf_map *, map, u32, key, u64, flags) 463 { 464 struct tcp_skb_cb *tcb = TCP_SKB_CB(skb); 465 466 if (unlikely(flags & ~(BPF_F_INGRESS))) 467 return SK_DROP; 468 tcb->bpf.flags = flags; 469 tcb->bpf.sk_redir = __sock_map_lookup_elem(map, key); 470 if (!tcb->bpf.sk_redir) 471 return SK_DROP; 472 return SK_PASS; 473 } 474 475 const struct bpf_func_proto bpf_sk_redirect_map_proto = { 476 .func = bpf_sk_redirect_map, 477 .gpl_only = false, 478 .ret_type = RET_INTEGER, 479 .arg1_type = ARG_PTR_TO_CTX, 480 .arg2_type = ARG_CONST_MAP_PTR, 481 .arg3_type = ARG_ANYTHING, 482 .arg4_type = ARG_ANYTHING, 483 }; 484 485 BPF_CALL_4(bpf_msg_redirect_map, struct sk_msg *, msg, 486 struct bpf_map *, map, u32, key, u64, flags) 487 { 488 if (unlikely(flags & ~(BPF_F_INGRESS))) 489 return SK_DROP; 490 msg->flags = flags; 491 msg->sk_redir = __sock_map_lookup_elem(map, key); 492 if (!msg->sk_redir) 493 return SK_DROP; 494 return SK_PASS; 495 } 496 497 const struct bpf_func_proto bpf_msg_redirect_map_proto = { 498 .func = bpf_msg_redirect_map, 499 .gpl_only = false, 500 .ret_type = RET_INTEGER, 501 .arg1_type = ARG_PTR_TO_CTX, 502 .arg2_type = ARG_CONST_MAP_PTR, 503 .arg3_type = ARG_ANYTHING, 504 .arg4_type = ARG_ANYTHING, 505 }; 506 507 const struct bpf_map_ops sock_map_ops = { 508 .map_alloc = sock_map_alloc, 509 .map_free = sock_map_free, 510 .map_get_next_key = sock_map_get_next_key, 511 .map_update_elem = sock_map_update_elem, 512 .map_delete_elem = sock_map_delete_elem, 513 .map_lookup_elem = sock_map_lookup, 514 .map_release_uref = sock_map_release_progs, 515 .map_check_btf = map_check_no_btf, 516 }; 517 518 struct bpf_htab_elem { 519 struct rcu_head rcu; 520 u32 hash; 521 struct sock *sk; 522 struct hlist_node node; 523 u8 key[0]; 524 }; 525 526 struct bpf_htab_bucket { 527 struct hlist_head head; 528 raw_spinlock_t lock; 529 }; 530 531 struct bpf_htab { 532 struct bpf_map map; 533 struct bpf_htab_bucket *buckets; 534 u32 buckets_num; 535 u32 elem_size; 536 struct sk_psock_progs progs; 537 atomic_t count; 538 }; 539 540 static inline u32 sock_hash_bucket_hash(const void *key, u32 len) 541 { 542 return jhash(key, len, 0); 543 } 544 545 static struct bpf_htab_bucket *sock_hash_select_bucket(struct bpf_htab *htab, 546 u32 hash) 547 { 548 return &htab->buckets[hash & (htab->buckets_num - 1)]; 549 } 550 551 static struct bpf_htab_elem * 552 sock_hash_lookup_elem_raw(struct hlist_head *head, u32 hash, void *key, 553 u32 key_size) 554 { 555 struct bpf_htab_elem *elem; 556 557 hlist_for_each_entry_rcu(elem, head, node) { 558 if (elem->hash == hash && 559 !memcmp(&elem->key, key, key_size)) 560 return elem; 561 } 562 563 return NULL; 564 } 565 566 static struct sock *__sock_hash_lookup_elem(struct bpf_map *map, void *key) 567 { 568 struct bpf_htab *htab = container_of(map, struct bpf_htab, map); 569 u32 key_size = map->key_size, hash; 570 struct bpf_htab_bucket *bucket; 571 struct bpf_htab_elem *elem; 572 573 WARN_ON_ONCE(!rcu_read_lock_held()); 574 575 hash = sock_hash_bucket_hash(key, key_size); 576 bucket = sock_hash_select_bucket(htab, hash); 577 elem = sock_hash_lookup_elem_raw(&bucket->head, hash, key, key_size); 578 579 return elem ? elem->sk : NULL; 580 } 581 582 static void sock_hash_free_elem(struct bpf_htab *htab, 583 struct bpf_htab_elem *elem) 584 { 585 atomic_dec(&htab->count); 586 kfree_rcu(elem, rcu); 587 } 588 589 static void sock_hash_delete_from_link(struct bpf_map *map, struct sock *sk, 590 void *link_raw) 591 { 592 struct bpf_htab *htab = container_of(map, struct bpf_htab, map); 593 struct bpf_htab_elem *elem_probe, *elem = link_raw; 594 struct bpf_htab_bucket *bucket; 595 596 WARN_ON_ONCE(!rcu_read_lock_held()); 597 bucket = sock_hash_select_bucket(htab, elem->hash); 598 599 /* elem may be deleted in parallel from the map, but access here 600 * is okay since it's going away only after RCU grace period. 601 * However, we need to check whether it's still present. 602 */ 603 raw_spin_lock_bh(&bucket->lock); 604 elem_probe = sock_hash_lookup_elem_raw(&bucket->head, elem->hash, 605 elem->key, map->key_size); 606 if (elem_probe && elem_probe == elem) { 607 hlist_del_rcu(&elem->node); 608 sock_map_unref(elem->sk, elem); 609 sock_hash_free_elem(htab, elem); 610 } 611 raw_spin_unlock_bh(&bucket->lock); 612 } 613 614 static int sock_hash_delete_elem(struct bpf_map *map, void *key) 615 { 616 struct bpf_htab *htab = container_of(map, struct bpf_htab, map); 617 u32 hash, key_size = map->key_size; 618 struct bpf_htab_bucket *bucket; 619 struct bpf_htab_elem *elem; 620 int ret = -ENOENT; 621 622 hash = sock_hash_bucket_hash(key, key_size); 623 bucket = sock_hash_select_bucket(htab, hash); 624 625 raw_spin_lock_bh(&bucket->lock); 626 elem = sock_hash_lookup_elem_raw(&bucket->head, hash, key, key_size); 627 if (elem) { 628 hlist_del_rcu(&elem->node); 629 sock_map_unref(elem->sk, elem); 630 sock_hash_free_elem(htab, elem); 631 ret = 0; 632 } 633 raw_spin_unlock_bh(&bucket->lock); 634 return ret; 635 } 636 637 static struct bpf_htab_elem *sock_hash_alloc_elem(struct bpf_htab *htab, 638 void *key, u32 key_size, 639 u32 hash, struct sock *sk, 640 struct bpf_htab_elem *old) 641 { 642 struct bpf_htab_elem *new; 643 644 if (atomic_inc_return(&htab->count) > htab->map.max_entries) { 645 if (!old) { 646 atomic_dec(&htab->count); 647 return ERR_PTR(-E2BIG); 648 } 649 } 650 651 new = kmalloc_node(htab->elem_size, GFP_ATOMIC | __GFP_NOWARN, 652 htab->map.numa_node); 653 if (!new) { 654 atomic_dec(&htab->count); 655 return ERR_PTR(-ENOMEM); 656 } 657 memcpy(new->key, key, key_size); 658 new->sk = sk; 659 new->hash = hash; 660 return new; 661 } 662 663 static int sock_hash_update_common(struct bpf_map *map, void *key, 664 struct sock *sk, u64 flags) 665 { 666 struct bpf_htab *htab = container_of(map, struct bpf_htab, map); 667 struct inet_connection_sock *icsk = inet_csk(sk); 668 u32 key_size = map->key_size, hash; 669 struct bpf_htab_elem *elem, *elem_new; 670 struct bpf_htab_bucket *bucket; 671 struct sk_psock_link *link; 672 struct sk_psock *psock; 673 int ret; 674 675 WARN_ON_ONCE(!rcu_read_lock_held()); 676 if (unlikely(flags > BPF_EXIST)) 677 return -EINVAL; 678 if (unlikely(icsk->icsk_ulp_data)) 679 return -EINVAL; 680 681 link = sk_psock_init_link(); 682 if (!link) 683 return -ENOMEM; 684 685 ret = sock_map_link(map, &htab->progs, sk); 686 if (ret < 0) 687 goto out_free; 688 689 psock = sk_psock(sk); 690 WARN_ON_ONCE(!psock); 691 692 hash = sock_hash_bucket_hash(key, key_size); 693 bucket = sock_hash_select_bucket(htab, hash); 694 695 raw_spin_lock_bh(&bucket->lock); 696 elem = sock_hash_lookup_elem_raw(&bucket->head, hash, key, key_size); 697 if (elem && flags == BPF_NOEXIST) { 698 ret = -EEXIST; 699 goto out_unlock; 700 } else if (!elem && flags == BPF_EXIST) { 701 ret = -ENOENT; 702 goto out_unlock; 703 } 704 705 elem_new = sock_hash_alloc_elem(htab, key, key_size, hash, sk, elem); 706 if (IS_ERR(elem_new)) { 707 ret = PTR_ERR(elem_new); 708 goto out_unlock; 709 } 710 711 sock_map_add_link(psock, link, map, elem_new); 712 /* Add new element to the head of the list, so that 713 * concurrent search will find it before old elem. 714 */ 715 hlist_add_head_rcu(&elem_new->node, &bucket->head); 716 if (elem) { 717 hlist_del_rcu(&elem->node); 718 sock_map_unref(elem->sk, elem); 719 sock_hash_free_elem(htab, elem); 720 } 721 raw_spin_unlock_bh(&bucket->lock); 722 return 0; 723 out_unlock: 724 raw_spin_unlock_bh(&bucket->lock); 725 sk_psock_put(sk, psock); 726 out_free: 727 sk_psock_free_link(link); 728 return ret; 729 } 730 731 static int sock_hash_update_elem(struct bpf_map *map, void *key, 732 void *value, u64 flags) 733 { 734 u32 ufd = *(u32 *)value; 735 struct socket *sock; 736 struct sock *sk; 737 int ret; 738 739 sock = sockfd_lookup(ufd, &ret); 740 if (!sock) 741 return ret; 742 sk = sock->sk; 743 if (!sk) { 744 ret = -EINVAL; 745 goto out; 746 } 747 if (!sock_map_sk_is_suitable(sk)) { 748 ret = -EOPNOTSUPP; 749 goto out; 750 } 751 752 sock_map_sk_acquire(sk); 753 if (sk->sk_state != TCP_ESTABLISHED) 754 ret = -EOPNOTSUPP; 755 else 756 ret = sock_hash_update_common(map, key, sk, flags); 757 sock_map_sk_release(sk); 758 out: 759 fput(sock->file); 760 return ret; 761 } 762 763 static int sock_hash_get_next_key(struct bpf_map *map, void *key, 764 void *key_next) 765 { 766 struct bpf_htab *htab = container_of(map, struct bpf_htab, map); 767 struct bpf_htab_elem *elem, *elem_next; 768 u32 hash, key_size = map->key_size; 769 struct hlist_head *head; 770 int i = 0; 771 772 if (!key) 773 goto find_first_elem; 774 hash = sock_hash_bucket_hash(key, key_size); 775 head = &sock_hash_select_bucket(htab, hash)->head; 776 elem = sock_hash_lookup_elem_raw(head, hash, key, key_size); 777 if (!elem) 778 goto find_first_elem; 779 780 elem_next = hlist_entry_safe(rcu_dereference_raw(hlist_next_rcu(&elem->node)), 781 struct bpf_htab_elem, node); 782 if (elem_next) { 783 memcpy(key_next, elem_next->key, key_size); 784 return 0; 785 } 786 787 i = hash & (htab->buckets_num - 1); 788 i++; 789 find_first_elem: 790 for (; i < htab->buckets_num; i++) { 791 head = &sock_hash_select_bucket(htab, i)->head; 792 elem_next = hlist_entry_safe(rcu_dereference_raw(hlist_first_rcu(head)), 793 struct bpf_htab_elem, node); 794 if (elem_next) { 795 memcpy(key_next, elem_next->key, key_size); 796 return 0; 797 } 798 } 799 800 return -ENOENT; 801 } 802 803 static struct bpf_map *sock_hash_alloc(union bpf_attr *attr) 804 { 805 struct bpf_htab *htab; 806 int i, err; 807 u64 cost; 808 809 if (!capable(CAP_NET_ADMIN)) 810 return ERR_PTR(-EPERM); 811 if (attr->max_entries == 0 || 812 attr->key_size == 0 || 813 attr->value_size != 4 || 814 attr->map_flags & ~SOCK_CREATE_FLAG_MASK) 815 return ERR_PTR(-EINVAL); 816 if (attr->key_size > MAX_BPF_STACK) 817 return ERR_PTR(-E2BIG); 818 819 htab = kzalloc(sizeof(*htab), GFP_USER); 820 if (!htab) 821 return ERR_PTR(-ENOMEM); 822 823 bpf_map_init_from_attr(&htab->map, attr); 824 825 htab->buckets_num = roundup_pow_of_two(htab->map.max_entries); 826 htab->elem_size = sizeof(struct bpf_htab_elem) + 827 round_up(htab->map.key_size, 8); 828 if (htab->buckets_num == 0 || 829 htab->buckets_num > U32_MAX / sizeof(struct bpf_htab_bucket)) { 830 err = -EINVAL; 831 goto free_htab; 832 } 833 834 cost = (u64) htab->buckets_num * sizeof(struct bpf_htab_bucket) + 835 (u64) htab->elem_size * htab->map.max_entries; 836 if (cost >= U32_MAX - PAGE_SIZE) { 837 err = -EINVAL; 838 goto free_htab; 839 } 840 841 htab->buckets = bpf_map_area_alloc(htab->buckets_num * 842 sizeof(struct bpf_htab_bucket), 843 htab->map.numa_node); 844 if (!htab->buckets) { 845 err = -ENOMEM; 846 goto free_htab; 847 } 848 849 for (i = 0; i < htab->buckets_num; i++) { 850 INIT_HLIST_HEAD(&htab->buckets[i].head); 851 raw_spin_lock_init(&htab->buckets[i].lock); 852 } 853 854 return &htab->map; 855 free_htab: 856 kfree(htab); 857 return ERR_PTR(err); 858 } 859 860 static void sock_hash_free(struct bpf_map *map) 861 { 862 struct bpf_htab *htab = container_of(map, struct bpf_htab, map); 863 struct bpf_htab_bucket *bucket; 864 struct bpf_htab_elem *elem; 865 struct hlist_node *node; 866 int i; 867 868 /* After the sync no updates or deletes will be in-flight so it 869 * is safe to walk map and remove entries without risking a race 870 * in EEXIST update case. 871 */ 872 synchronize_rcu(); 873 for (i = 0; i < htab->buckets_num; i++) { 874 bucket = sock_hash_select_bucket(htab, i); 875 hlist_for_each_entry_safe(elem, node, &bucket->head, node) { 876 hlist_del_rcu(&elem->node); 877 lock_sock(elem->sk); 878 rcu_read_lock(); 879 sock_map_unref(elem->sk, elem); 880 rcu_read_unlock(); 881 release_sock(elem->sk); 882 } 883 } 884 885 /* wait for psock readers accessing its map link */ 886 synchronize_rcu(); 887 888 bpf_map_area_free(htab->buckets); 889 kfree(htab); 890 } 891 892 static void sock_hash_release_progs(struct bpf_map *map) 893 { 894 psock_progs_drop(&container_of(map, struct bpf_htab, map)->progs); 895 } 896 897 BPF_CALL_4(bpf_sock_hash_update, struct bpf_sock_ops_kern *, sops, 898 struct bpf_map *, map, void *, key, u64, flags) 899 { 900 WARN_ON_ONCE(!rcu_read_lock_held()); 901 902 if (likely(sock_map_sk_is_suitable(sops->sk) && 903 sock_map_op_okay(sops))) 904 return sock_hash_update_common(map, key, sops->sk, flags); 905 return -EOPNOTSUPP; 906 } 907 908 const struct bpf_func_proto bpf_sock_hash_update_proto = { 909 .func = bpf_sock_hash_update, 910 .gpl_only = false, 911 .pkt_access = true, 912 .ret_type = RET_INTEGER, 913 .arg1_type = ARG_PTR_TO_CTX, 914 .arg2_type = ARG_CONST_MAP_PTR, 915 .arg3_type = ARG_PTR_TO_MAP_KEY, 916 .arg4_type = ARG_ANYTHING, 917 }; 918 919 BPF_CALL_4(bpf_sk_redirect_hash, struct sk_buff *, skb, 920 struct bpf_map *, map, void *, key, u64, flags) 921 { 922 struct tcp_skb_cb *tcb = TCP_SKB_CB(skb); 923 924 if (unlikely(flags & ~(BPF_F_INGRESS))) 925 return SK_DROP; 926 tcb->bpf.flags = flags; 927 tcb->bpf.sk_redir = __sock_hash_lookup_elem(map, key); 928 if (!tcb->bpf.sk_redir) 929 return SK_DROP; 930 return SK_PASS; 931 } 932 933 const struct bpf_func_proto bpf_sk_redirect_hash_proto = { 934 .func = bpf_sk_redirect_hash, 935 .gpl_only = false, 936 .ret_type = RET_INTEGER, 937 .arg1_type = ARG_PTR_TO_CTX, 938 .arg2_type = ARG_CONST_MAP_PTR, 939 .arg3_type = ARG_PTR_TO_MAP_KEY, 940 .arg4_type = ARG_ANYTHING, 941 }; 942 943 BPF_CALL_4(bpf_msg_redirect_hash, struct sk_msg *, msg, 944 struct bpf_map *, map, void *, key, u64, flags) 945 { 946 if (unlikely(flags & ~(BPF_F_INGRESS))) 947 return SK_DROP; 948 msg->flags = flags; 949 msg->sk_redir = __sock_hash_lookup_elem(map, key); 950 if (!msg->sk_redir) 951 return SK_DROP; 952 return SK_PASS; 953 } 954 955 const struct bpf_func_proto bpf_msg_redirect_hash_proto = { 956 .func = bpf_msg_redirect_hash, 957 .gpl_only = false, 958 .ret_type = RET_INTEGER, 959 .arg1_type = ARG_PTR_TO_CTX, 960 .arg2_type = ARG_CONST_MAP_PTR, 961 .arg3_type = ARG_PTR_TO_MAP_KEY, 962 .arg4_type = ARG_ANYTHING, 963 }; 964 965 const struct bpf_map_ops sock_hash_ops = { 966 .map_alloc = sock_hash_alloc, 967 .map_free = sock_hash_free, 968 .map_get_next_key = sock_hash_get_next_key, 969 .map_update_elem = sock_hash_update_elem, 970 .map_delete_elem = sock_hash_delete_elem, 971 .map_lookup_elem = sock_map_lookup, 972 .map_release_uref = sock_hash_release_progs, 973 .map_check_btf = map_check_no_btf, 974 }; 975 976 static struct sk_psock_progs *sock_map_progs(struct bpf_map *map) 977 { 978 switch (map->map_type) { 979 case BPF_MAP_TYPE_SOCKMAP: 980 return &container_of(map, struct bpf_stab, map)->progs; 981 case BPF_MAP_TYPE_SOCKHASH: 982 return &container_of(map, struct bpf_htab, map)->progs; 983 default: 984 break; 985 } 986 987 return NULL; 988 } 989 990 int sock_map_prog_update(struct bpf_map *map, struct bpf_prog *prog, 991 u32 which) 992 { 993 struct sk_psock_progs *progs = sock_map_progs(map); 994 995 if (!progs) 996 return -EOPNOTSUPP; 997 998 switch (which) { 999 case BPF_SK_MSG_VERDICT: 1000 psock_set_prog(&progs->msg_parser, prog); 1001 break; 1002 case BPF_SK_SKB_STREAM_PARSER: 1003 psock_set_prog(&progs->skb_parser, prog); 1004 break; 1005 case BPF_SK_SKB_STREAM_VERDICT: 1006 psock_set_prog(&progs->skb_verdict, prog); 1007 break; 1008 default: 1009 return -EOPNOTSUPP; 1010 } 1011 1012 return 0; 1013 } 1014 1015 void sk_psock_unlink(struct sock *sk, struct sk_psock_link *link) 1016 { 1017 switch (link->map->map_type) { 1018 case BPF_MAP_TYPE_SOCKMAP: 1019 return sock_map_delete_from_link(link->map, sk, 1020 link->link_raw); 1021 case BPF_MAP_TYPE_SOCKHASH: 1022 return sock_hash_delete_from_link(link->map, sk, 1023 link->link_raw); 1024 default: 1025 break; 1026 } 1027 } 1028