1 // SPDX-License-Identifier: GPL-2.0 2 /* Copyright (c) 2017 - 2018 Covalent IO, Inc. http://covalent.io */ 3 4 #include <linux/bpf.h> 5 #include <linux/filter.h> 6 #include <linux/errno.h> 7 #include <linux/file.h> 8 #include <linux/net.h> 9 #include <linux/workqueue.h> 10 #include <linux/skmsg.h> 11 #include <linux/list.h> 12 #include <linux/jhash.h> 13 14 struct bpf_stab { 15 struct bpf_map map; 16 struct sock **sks; 17 struct sk_psock_progs progs; 18 raw_spinlock_t lock; 19 }; 20 21 #define SOCK_CREATE_FLAG_MASK \ 22 (BPF_F_NUMA_NODE | BPF_F_RDONLY | BPF_F_WRONLY) 23 24 static struct bpf_map *sock_map_alloc(union bpf_attr *attr) 25 { 26 struct bpf_stab *stab; 27 u64 cost; 28 int err; 29 30 if (!capable(CAP_NET_ADMIN)) 31 return ERR_PTR(-EPERM); 32 if (attr->max_entries == 0 || 33 attr->key_size != 4 || 34 attr->value_size != 4 || 35 attr->map_flags & ~SOCK_CREATE_FLAG_MASK) 36 return ERR_PTR(-EINVAL); 37 38 stab = kzalloc(sizeof(*stab), GFP_USER); 39 if (!stab) 40 return ERR_PTR(-ENOMEM); 41 42 bpf_map_init_from_attr(&stab->map, attr); 43 raw_spin_lock_init(&stab->lock); 44 45 /* Make sure page count doesn't overflow. */ 46 cost = (u64) stab->map.max_entries * sizeof(struct sock *); 47 err = bpf_map_charge_init(&stab->map.memory, cost); 48 if (err) 49 goto free_stab; 50 51 stab->sks = bpf_map_area_alloc(stab->map.max_entries * 52 sizeof(struct sock *), 53 stab->map.numa_node); 54 if (stab->sks) 55 return &stab->map; 56 err = -ENOMEM; 57 bpf_map_charge_finish(&stab->map.memory); 58 free_stab: 59 kfree(stab); 60 return ERR_PTR(err); 61 } 62 63 int sock_map_get_from_fd(const union bpf_attr *attr, struct bpf_prog *prog) 64 { 65 u32 ufd = attr->target_fd; 66 struct bpf_map *map; 67 struct fd f; 68 int ret; 69 70 f = fdget(ufd); 71 map = __bpf_map_get(f); 72 if (IS_ERR(map)) 73 return PTR_ERR(map); 74 ret = sock_map_prog_update(map, prog, attr->attach_type); 75 fdput(f); 76 return ret; 77 } 78 79 static void sock_map_sk_acquire(struct sock *sk) 80 __acquires(&sk->sk_lock.slock) 81 { 82 lock_sock(sk); 83 preempt_disable(); 84 rcu_read_lock(); 85 } 86 87 static void sock_map_sk_release(struct sock *sk) 88 __releases(&sk->sk_lock.slock) 89 { 90 rcu_read_unlock(); 91 preempt_enable(); 92 release_sock(sk); 93 } 94 95 static void sock_map_add_link(struct sk_psock *psock, 96 struct sk_psock_link *link, 97 struct bpf_map *map, void *link_raw) 98 { 99 link->link_raw = link_raw; 100 link->map = map; 101 spin_lock_bh(&psock->link_lock); 102 list_add_tail(&link->list, &psock->link); 103 spin_unlock_bh(&psock->link_lock); 104 } 105 106 static void sock_map_del_link(struct sock *sk, 107 struct sk_psock *psock, void *link_raw) 108 { 109 struct sk_psock_link *link, *tmp; 110 bool strp_stop = false; 111 112 spin_lock_bh(&psock->link_lock); 113 list_for_each_entry_safe(link, tmp, &psock->link, list) { 114 if (link->link_raw == link_raw) { 115 struct bpf_map *map = link->map; 116 struct bpf_stab *stab = container_of(map, struct bpf_stab, 117 map); 118 if (psock->parser.enabled && stab->progs.skb_parser) 119 strp_stop = true; 120 list_del(&link->list); 121 sk_psock_free_link(link); 122 } 123 } 124 spin_unlock_bh(&psock->link_lock); 125 if (strp_stop) { 126 write_lock_bh(&sk->sk_callback_lock); 127 sk_psock_stop_strp(sk, psock); 128 write_unlock_bh(&sk->sk_callback_lock); 129 } 130 } 131 132 static void sock_map_unref(struct sock *sk, void *link_raw) 133 { 134 struct sk_psock *psock = sk_psock(sk); 135 136 if (likely(psock)) { 137 sock_map_del_link(sk, psock, link_raw); 138 sk_psock_put(sk, psock); 139 } 140 } 141 142 static int sock_map_link(struct bpf_map *map, struct sk_psock_progs *progs, 143 struct sock *sk) 144 { 145 struct bpf_prog *msg_parser, *skb_parser, *skb_verdict; 146 bool skb_progs, sk_psock_is_new = false; 147 struct sk_psock *psock; 148 int ret; 149 150 skb_verdict = READ_ONCE(progs->skb_verdict); 151 skb_parser = READ_ONCE(progs->skb_parser); 152 skb_progs = skb_parser && skb_verdict; 153 if (skb_progs) { 154 skb_verdict = bpf_prog_inc_not_zero(skb_verdict); 155 if (IS_ERR(skb_verdict)) 156 return PTR_ERR(skb_verdict); 157 skb_parser = bpf_prog_inc_not_zero(skb_parser); 158 if (IS_ERR(skb_parser)) { 159 bpf_prog_put(skb_verdict); 160 return PTR_ERR(skb_parser); 161 } 162 } 163 164 msg_parser = READ_ONCE(progs->msg_parser); 165 if (msg_parser) { 166 msg_parser = bpf_prog_inc_not_zero(msg_parser); 167 if (IS_ERR(msg_parser)) { 168 ret = PTR_ERR(msg_parser); 169 goto out; 170 } 171 } 172 173 psock = sk_psock_get_checked(sk); 174 if (IS_ERR(psock)) { 175 ret = PTR_ERR(psock); 176 goto out_progs; 177 } 178 179 if (psock) { 180 if ((msg_parser && READ_ONCE(psock->progs.msg_parser)) || 181 (skb_progs && READ_ONCE(psock->progs.skb_parser))) { 182 sk_psock_put(sk, psock); 183 ret = -EBUSY; 184 goto out_progs; 185 } 186 } else { 187 psock = sk_psock_init(sk, map->numa_node); 188 if (!psock) { 189 ret = -ENOMEM; 190 goto out_progs; 191 } 192 sk_psock_is_new = true; 193 } 194 195 if (msg_parser) 196 psock_set_prog(&psock->progs.msg_parser, msg_parser); 197 if (sk_psock_is_new) { 198 ret = tcp_bpf_init(sk); 199 if (ret < 0) 200 goto out_drop; 201 } else { 202 tcp_bpf_reinit(sk); 203 } 204 205 write_lock_bh(&sk->sk_callback_lock); 206 if (skb_progs && !psock->parser.enabled) { 207 ret = sk_psock_init_strp(sk, psock); 208 if (ret) { 209 write_unlock_bh(&sk->sk_callback_lock); 210 goto out_drop; 211 } 212 psock_set_prog(&psock->progs.skb_verdict, skb_verdict); 213 psock_set_prog(&psock->progs.skb_parser, skb_parser); 214 sk_psock_start_strp(sk, psock); 215 } 216 write_unlock_bh(&sk->sk_callback_lock); 217 return 0; 218 out_drop: 219 sk_psock_put(sk, psock); 220 out_progs: 221 if (msg_parser) 222 bpf_prog_put(msg_parser); 223 out: 224 if (skb_progs) { 225 bpf_prog_put(skb_verdict); 226 bpf_prog_put(skb_parser); 227 } 228 return ret; 229 } 230 231 static void sock_map_free(struct bpf_map *map) 232 { 233 struct bpf_stab *stab = container_of(map, struct bpf_stab, map); 234 int i; 235 236 synchronize_rcu(); 237 rcu_read_lock(); 238 raw_spin_lock_bh(&stab->lock); 239 for (i = 0; i < stab->map.max_entries; i++) { 240 struct sock **psk = &stab->sks[i]; 241 struct sock *sk; 242 243 sk = xchg(psk, NULL); 244 if (sk) { 245 lock_sock(sk); 246 sock_map_unref(sk, psk); 247 release_sock(sk); 248 } 249 } 250 raw_spin_unlock_bh(&stab->lock); 251 rcu_read_unlock(); 252 253 synchronize_rcu(); 254 255 bpf_map_area_free(stab->sks); 256 kfree(stab); 257 } 258 259 static void sock_map_release_progs(struct bpf_map *map) 260 { 261 psock_progs_drop(&container_of(map, struct bpf_stab, map)->progs); 262 } 263 264 static struct sock *__sock_map_lookup_elem(struct bpf_map *map, u32 key) 265 { 266 struct bpf_stab *stab = container_of(map, struct bpf_stab, map); 267 268 WARN_ON_ONCE(!rcu_read_lock_held()); 269 270 if (unlikely(key >= map->max_entries)) 271 return NULL; 272 return READ_ONCE(stab->sks[key]); 273 } 274 275 static void *sock_map_lookup(struct bpf_map *map, void *key) 276 { 277 return ERR_PTR(-EOPNOTSUPP); 278 } 279 280 static int __sock_map_delete(struct bpf_stab *stab, struct sock *sk_test, 281 struct sock **psk) 282 { 283 struct sock *sk; 284 int err = 0; 285 286 raw_spin_lock_bh(&stab->lock); 287 sk = *psk; 288 if (!sk_test || sk_test == sk) 289 sk = xchg(psk, NULL); 290 291 if (likely(sk)) 292 sock_map_unref(sk, psk); 293 else 294 err = -EINVAL; 295 296 raw_spin_unlock_bh(&stab->lock); 297 return err; 298 } 299 300 static void sock_map_delete_from_link(struct bpf_map *map, struct sock *sk, 301 void *link_raw) 302 { 303 struct bpf_stab *stab = container_of(map, struct bpf_stab, map); 304 305 __sock_map_delete(stab, sk, link_raw); 306 } 307 308 static int sock_map_delete_elem(struct bpf_map *map, void *key) 309 { 310 struct bpf_stab *stab = container_of(map, struct bpf_stab, map); 311 u32 i = *(u32 *)key; 312 struct sock **psk; 313 314 if (unlikely(i >= map->max_entries)) 315 return -EINVAL; 316 317 psk = &stab->sks[i]; 318 return __sock_map_delete(stab, NULL, psk); 319 } 320 321 static int sock_map_get_next_key(struct bpf_map *map, void *key, void *next) 322 { 323 struct bpf_stab *stab = container_of(map, struct bpf_stab, map); 324 u32 i = key ? *(u32 *)key : U32_MAX; 325 u32 *key_next = next; 326 327 if (i == stab->map.max_entries - 1) 328 return -ENOENT; 329 if (i >= stab->map.max_entries) 330 *key_next = 0; 331 else 332 *key_next = i + 1; 333 return 0; 334 } 335 336 static int sock_map_update_common(struct bpf_map *map, u32 idx, 337 struct sock *sk, u64 flags) 338 { 339 struct bpf_stab *stab = container_of(map, struct bpf_stab, map); 340 struct inet_connection_sock *icsk = inet_csk(sk); 341 struct sk_psock_link *link; 342 struct sk_psock *psock; 343 struct sock *osk; 344 int ret; 345 346 WARN_ON_ONCE(!rcu_read_lock_held()); 347 if (unlikely(flags > BPF_EXIST)) 348 return -EINVAL; 349 if (unlikely(idx >= map->max_entries)) 350 return -E2BIG; 351 if (unlikely(rcu_access_pointer(icsk->icsk_ulp_data))) 352 return -EINVAL; 353 354 link = sk_psock_init_link(); 355 if (!link) 356 return -ENOMEM; 357 358 ret = sock_map_link(map, &stab->progs, sk); 359 if (ret < 0) 360 goto out_free; 361 362 psock = sk_psock(sk); 363 WARN_ON_ONCE(!psock); 364 365 raw_spin_lock_bh(&stab->lock); 366 osk = stab->sks[idx]; 367 if (osk && flags == BPF_NOEXIST) { 368 ret = -EEXIST; 369 goto out_unlock; 370 } else if (!osk && flags == BPF_EXIST) { 371 ret = -ENOENT; 372 goto out_unlock; 373 } 374 375 sock_map_add_link(psock, link, map, &stab->sks[idx]); 376 stab->sks[idx] = sk; 377 if (osk) 378 sock_map_unref(osk, &stab->sks[idx]); 379 raw_spin_unlock_bh(&stab->lock); 380 return 0; 381 out_unlock: 382 raw_spin_unlock_bh(&stab->lock); 383 if (psock) 384 sk_psock_put(sk, psock); 385 out_free: 386 sk_psock_free_link(link); 387 return ret; 388 } 389 390 static bool sock_map_op_okay(const struct bpf_sock_ops_kern *ops) 391 { 392 return ops->op == BPF_SOCK_OPS_PASSIVE_ESTABLISHED_CB || 393 ops->op == BPF_SOCK_OPS_ACTIVE_ESTABLISHED_CB; 394 } 395 396 static bool sock_map_sk_is_suitable(const struct sock *sk) 397 { 398 return sk->sk_type == SOCK_STREAM && 399 sk->sk_protocol == IPPROTO_TCP; 400 } 401 402 static int sock_map_update_elem(struct bpf_map *map, void *key, 403 void *value, u64 flags) 404 { 405 u32 ufd = *(u32 *)value; 406 u32 idx = *(u32 *)key; 407 struct socket *sock; 408 struct sock *sk; 409 int ret; 410 411 sock = sockfd_lookup(ufd, &ret); 412 if (!sock) 413 return ret; 414 sk = sock->sk; 415 if (!sk) { 416 ret = -EINVAL; 417 goto out; 418 } 419 if (!sock_map_sk_is_suitable(sk) || 420 sk->sk_state != TCP_ESTABLISHED) { 421 ret = -EOPNOTSUPP; 422 goto out; 423 } 424 425 sock_map_sk_acquire(sk); 426 ret = sock_map_update_common(map, idx, sk, flags); 427 sock_map_sk_release(sk); 428 out: 429 fput(sock->file); 430 return ret; 431 } 432 433 BPF_CALL_4(bpf_sock_map_update, struct bpf_sock_ops_kern *, sops, 434 struct bpf_map *, map, void *, key, u64, flags) 435 { 436 WARN_ON_ONCE(!rcu_read_lock_held()); 437 438 if (likely(sock_map_sk_is_suitable(sops->sk) && 439 sock_map_op_okay(sops))) 440 return sock_map_update_common(map, *(u32 *)key, sops->sk, 441 flags); 442 return -EOPNOTSUPP; 443 } 444 445 const struct bpf_func_proto bpf_sock_map_update_proto = { 446 .func = bpf_sock_map_update, 447 .gpl_only = false, 448 .pkt_access = true, 449 .ret_type = RET_INTEGER, 450 .arg1_type = ARG_PTR_TO_CTX, 451 .arg2_type = ARG_CONST_MAP_PTR, 452 .arg3_type = ARG_PTR_TO_MAP_KEY, 453 .arg4_type = ARG_ANYTHING, 454 }; 455 456 BPF_CALL_4(bpf_sk_redirect_map, struct sk_buff *, skb, 457 struct bpf_map *, map, u32, key, u64, flags) 458 { 459 struct tcp_skb_cb *tcb = TCP_SKB_CB(skb); 460 461 if (unlikely(flags & ~(BPF_F_INGRESS))) 462 return SK_DROP; 463 tcb->bpf.flags = flags; 464 tcb->bpf.sk_redir = __sock_map_lookup_elem(map, key); 465 if (!tcb->bpf.sk_redir) 466 return SK_DROP; 467 return SK_PASS; 468 } 469 470 const struct bpf_func_proto bpf_sk_redirect_map_proto = { 471 .func = bpf_sk_redirect_map, 472 .gpl_only = false, 473 .ret_type = RET_INTEGER, 474 .arg1_type = ARG_PTR_TO_CTX, 475 .arg2_type = ARG_CONST_MAP_PTR, 476 .arg3_type = ARG_ANYTHING, 477 .arg4_type = ARG_ANYTHING, 478 }; 479 480 BPF_CALL_4(bpf_msg_redirect_map, struct sk_msg *, msg, 481 struct bpf_map *, map, u32, key, u64, flags) 482 { 483 if (unlikely(flags & ~(BPF_F_INGRESS))) 484 return SK_DROP; 485 msg->flags = flags; 486 msg->sk_redir = __sock_map_lookup_elem(map, key); 487 if (!msg->sk_redir) 488 return SK_DROP; 489 return SK_PASS; 490 } 491 492 const struct bpf_func_proto bpf_msg_redirect_map_proto = { 493 .func = bpf_msg_redirect_map, 494 .gpl_only = false, 495 .ret_type = RET_INTEGER, 496 .arg1_type = ARG_PTR_TO_CTX, 497 .arg2_type = ARG_CONST_MAP_PTR, 498 .arg3_type = ARG_ANYTHING, 499 .arg4_type = ARG_ANYTHING, 500 }; 501 502 const struct bpf_map_ops sock_map_ops = { 503 .map_alloc = sock_map_alloc, 504 .map_free = sock_map_free, 505 .map_get_next_key = sock_map_get_next_key, 506 .map_update_elem = sock_map_update_elem, 507 .map_delete_elem = sock_map_delete_elem, 508 .map_lookup_elem = sock_map_lookup, 509 .map_release_uref = sock_map_release_progs, 510 .map_check_btf = map_check_no_btf, 511 }; 512 513 struct bpf_htab_elem { 514 struct rcu_head rcu; 515 u32 hash; 516 struct sock *sk; 517 struct hlist_node node; 518 u8 key[0]; 519 }; 520 521 struct bpf_htab_bucket { 522 struct hlist_head head; 523 raw_spinlock_t lock; 524 }; 525 526 struct bpf_htab { 527 struct bpf_map map; 528 struct bpf_htab_bucket *buckets; 529 u32 buckets_num; 530 u32 elem_size; 531 struct sk_psock_progs progs; 532 atomic_t count; 533 }; 534 535 static inline u32 sock_hash_bucket_hash(const void *key, u32 len) 536 { 537 return jhash(key, len, 0); 538 } 539 540 static struct bpf_htab_bucket *sock_hash_select_bucket(struct bpf_htab *htab, 541 u32 hash) 542 { 543 return &htab->buckets[hash & (htab->buckets_num - 1)]; 544 } 545 546 static struct bpf_htab_elem * 547 sock_hash_lookup_elem_raw(struct hlist_head *head, u32 hash, void *key, 548 u32 key_size) 549 { 550 struct bpf_htab_elem *elem; 551 552 hlist_for_each_entry_rcu(elem, head, node) { 553 if (elem->hash == hash && 554 !memcmp(&elem->key, key, key_size)) 555 return elem; 556 } 557 558 return NULL; 559 } 560 561 static struct sock *__sock_hash_lookup_elem(struct bpf_map *map, void *key) 562 { 563 struct bpf_htab *htab = container_of(map, struct bpf_htab, map); 564 u32 key_size = map->key_size, hash; 565 struct bpf_htab_bucket *bucket; 566 struct bpf_htab_elem *elem; 567 568 WARN_ON_ONCE(!rcu_read_lock_held()); 569 570 hash = sock_hash_bucket_hash(key, key_size); 571 bucket = sock_hash_select_bucket(htab, hash); 572 elem = sock_hash_lookup_elem_raw(&bucket->head, hash, key, key_size); 573 574 return elem ? elem->sk : NULL; 575 } 576 577 static void sock_hash_free_elem(struct bpf_htab *htab, 578 struct bpf_htab_elem *elem) 579 { 580 atomic_dec(&htab->count); 581 kfree_rcu(elem, rcu); 582 } 583 584 static void sock_hash_delete_from_link(struct bpf_map *map, struct sock *sk, 585 void *link_raw) 586 { 587 struct bpf_htab *htab = container_of(map, struct bpf_htab, map); 588 struct bpf_htab_elem *elem_probe, *elem = link_raw; 589 struct bpf_htab_bucket *bucket; 590 591 WARN_ON_ONCE(!rcu_read_lock_held()); 592 bucket = sock_hash_select_bucket(htab, elem->hash); 593 594 /* elem may be deleted in parallel from the map, but access here 595 * is okay since it's going away only after RCU grace period. 596 * However, we need to check whether it's still present. 597 */ 598 raw_spin_lock_bh(&bucket->lock); 599 elem_probe = sock_hash_lookup_elem_raw(&bucket->head, elem->hash, 600 elem->key, map->key_size); 601 if (elem_probe && elem_probe == elem) { 602 hlist_del_rcu(&elem->node); 603 sock_map_unref(elem->sk, elem); 604 sock_hash_free_elem(htab, elem); 605 } 606 raw_spin_unlock_bh(&bucket->lock); 607 } 608 609 static int sock_hash_delete_elem(struct bpf_map *map, void *key) 610 { 611 struct bpf_htab *htab = container_of(map, struct bpf_htab, map); 612 u32 hash, key_size = map->key_size; 613 struct bpf_htab_bucket *bucket; 614 struct bpf_htab_elem *elem; 615 int ret = -ENOENT; 616 617 hash = sock_hash_bucket_hash(key, key_size); 618 bucket = sock_hash_select_bucket(htab, hash); 619 620 raw_spin_lock_bh(&bucket->lock); 621 elem = sock_hash_lookup_elem_raw(&bucket->head, hash, key, key_size); 622 if (elem) { 623 hlist_del_rcu(&elem->node); 624 sock_map_unref(elem->sk, elem); 625 sock_hash_free_elem(htab, elem); 626 ret = 0; 627 } 628 raw_spin_unlock_bh(&bucket->lock); 629 return ret; 630 } 631 632 static struct bpf_htab_elem *sock_hash_alloc_elem(struct bpf_htab *htab, 633 void *key, u32 key_size, 634 u32 hash, struct sock *sk, 635 struct bpf_htab_elem *old) 636 { 637 struct bpf_htab_elem *new; 638 639 if (atomic_inc_return(&htab->count) > htab->map.max_entries) { 640 if (!old) { 641 atomic_dec(&htab->count); 642 return ERR_PTR(-E2BIG); 643 } 644 } 645 646 new = kmalloc_node(htab->elem_size, GFP_ATOMIC | __GFP_NOWARN, 647 htab->map.numa_node); 648 if (!new) { 649 atomic_dec(&htab->count); 650 return ERR_PTR(-ENOMEM); 651 } 652 memcpy(new->key, key, key_size); 653 new->sk = sk; 654 new->hash = hash; 655 return new; 656 } 657 658 static int sock_hash_update_common(struct bpf_map *map, void *key, 659 struct sock *sk, u64 flags) 660 { 661 struct bpf_htab *htab = container_of(map, struct bpf_htab, map); 662 struct inet_connection_sock *icsk = inet_csk(sk); 663 u32 key_size = map->key_size, hash; 664 struct bpf_htab_elem *elem, *elem_new; 665 struct bpf_htab_bucket *bucket; 666 struct sk_psock_link *link; 667 struct sk_psock *psock; 668 int ret; 669 670 WARN_ON_ONCE(!rcu_read_lock_held()); 671 if (unlikely(flags > BPF_EXIST)) 672 return -EINVAL; 673 if (unlikely(icsk->icsk_ulp_data)) 674 return -EINVAL; 675 676 link = sk_psock_init_link(); 677 if (!link) 678 return -ENOMEM; 679 680 ret = sock_map_link(map, &htab->progs, sk); 681 if (ret < 0) 682 goto out_free; 683 684 psock = sk_psock(sk); 685 WARN_ON_ONCE(!psock); 686 687 hash = sock_hash_bucket_hash(key, key_size); 688 bucket = sock_hash_select_bucket(htab, hash); 689 690 raw_spin_lock_bh(&bucket->lock); 691 elem = sock_hash_lookup_elem_raw(&bucket->head, hash, key, key_size); 692 if (elem && flags == BPF_NOEXIST) { 693 ret = -EEXIST; 694 goto out_unlock; 695 } else if (!elem && flags == BPF_EXIST) { 696 ret = -ENOENT; 697 goto out_unlock; 698 } 699 700 elem_new = sock_hash_alloc_elem(htab, key, key_size, hash, sk, elem); 701 if (IS_ERR(elem_new)) { 702 ret = PTR_ERR(elem_new); 703 goto out_unlock; 704 } 705 706 sock_map_add_link(psock, link, map, elem_new); 707 /* Add new element to the head of the list, so that 708 * concurrent search will find it before old elem. 709 */ 710 hlist_add_head_rcu(&elem_new->node, &bucket->head); 711 if (elem) { 712 hlist_del_rcu(&elem->node); 713 sock_map_unref(elem->sk, elem); 714 sock_hash_free_elem(htab, elem); 715 } 716 raw_spin_unlock_bh(&bucket->lock); 717 return 0; 718 out_unlock: 719 raw_spin_unlock_bh(&bucket->lock); 720 sk_psock_put(sk, psock); 721 out_free: 722 sk_psock_free_link(link); 723 return ret; 724 } 725 726 static int sock_hash_update_elem(struct bpf_map *map, void *key, 727 void *value, u64 flags) 728 { 729 u32 ufd = *(u32 *)value; 730 struct socket *sock; 731 struct sock *sk; 732 int ret; 733 734 sock = sockfd_lookup(ufd, &ret); 735 if (!sock) 736 return ret; 737 sk = sock->sk; 738 if (!sk) { 739 ret = -EINVAL; 740 goto out; 741 } 742 if (!sock_map_sk_is_suitable(sk) || 743 sk->sk_state != TCP_ESTABLISHED) { 744 ret = -EOPNOTSUPP; 745 goto out; 746 } 747 748 sock_map_sk_acquire(sk); 749 ret = sock_hash_update_common(map, key, sk, flags); 750 sock_map_sk_release(sk); 751 out: 752 fput(sock->file); 753 return ret; 754 } 755 756 static int sock_hash_get_next_key(struct bpf_map *map, void *key, 757 void *key_next) 758 { 759 struct bpf_htab *htab = container_of(map, struct bpf_htab, map); 760 struct bpf_htab_elem *elem, *elem_next; 761 u32 hash, key_size = map->key_size; 762 struct hlist_head *head; 763 int i = 0; 764 765 if (!key) 766 goto find_first_elem; 767 hash = sock_hash_bucket_hash(key, key_size); 768 head = &sock_hash_select_bucket(htab, hash)->head; 769 elem = sock_hash_lookup_elem_raw(head, hash, key, key_size); 770 if (!elem) 771 goto find_first_elem; 772 773 elem_next = hlist_entry_safe(rcu_dereference_raw(hlist_next_rcu(&elem->node)), 774 struct bpf_htab_elem, node); 775 if (elem_next) { 776 memcpy(key_next, elem_next->key, key_size); 777 return 0; 778 } 779 780 i = hash & (htab->buckets_num - 1); 781 i++; 782 find_first_elem: 783 for (; i < htab->buckets_num; i++) { 784 head = &sock_hash_select_bucket(htab, i)->head; 785 elem_next = hlist_entry_safe(rcu_dereference_raw(hlist_first_rcu(head)), 786 struct bpf_htab_elem, node); 787 if (elem_next) { 788 memcpy(key_next, elem_next->key, key_size); 789 return 0; 790 } 791 } 792 793 return -ENOENT; 794 } 795 796 static struct bpf_map *sock_hash_alloc(union bpf_attr *attr) 797 { 798 struct bpf_htab *htab; 799 int i, err; 800 u64 cost; 801 802 if (!capable(CAP_NET_ADMIN)) 803 return ERR_PTR(-EPERM); 804 if (attr->max_entries == 0 || 805 attr->key_size == 0 || 806 attr->value_size != 4 || 807 attr->map_flags & ~SOCK_CREATE_FLAG_MASK) 808 return ERR_PTR(-EINVAL); 809 if (attr->key_size > MAX_BPF_STACK) 810 return ERR_PTR(-E2BIG); 811 812 htab = kzalloc(sizeof(*htab), GFP_USER); 813 if (!htab) 814 return ERR_PTR(-ENOMEM); 815 816 bpf_map_init_from_attr(&htab->map, attr); 817 818 htab->buckets_num = roundup_pow_of_two(htab->map.max_entries); 819 htab->elem_size = sizeof(struct bpf_htab_elem) + 820 round_up(htab->map.key_size, 8); 821 if (htab->buckets_num == 0 || 822 htab->buckets_num > U32_MAX / sizeof(struct bpf_htab_bucket)) { 823 err = -EINVAL; 824 goto free_htab; 825 } 826 827 cost = (u64) htab->buckets_num * sizeof(struct bpf_htab_bucket) + 828 (u64) htab->elem_size * htab->map.max_entries; 829 if (cost >= U32_MAX - PAGE_SIZE) { 830 err = -EINVAL; 831 goto free_htab; 832 } 833 834 htab->buckets = bpf_map_area_alloc(htab->buckets_num * 835 sizeof(struct bpf_htab_bucket), 836 htab->map.numa_node); 837 if (!htab->buckets) { 838 err = -ENOMEM; 839 goto free_htab; 840 } 841 842 for (i = 0; i < htab->buckets_num; i++) { 843 INIT_HLIST_HEAD(&htab->buckets[i].head); 844 raw_spin_lock_init(&htab->buckets[i].lock); 845 } 846 847 return &htab->map; 848 free_htab: 849 kfree(htab); 850 return ERR_PTR(err); 851 } 852 853 static void sock_hash_free(struct bpf_map *map) 854 { 855 struct bpf_htab *htab = container_of(map, struct bpf_htab, map); 856 struct bpf_htab_bucket *bucket; 857 struct bpf_htab_elem *elem; 858 struct hlist_node *node; 859 int i; 860 861 synchronize_rcu(); 862 rcu_read_lock(); 863 for (i = 0; i < htab->buckets_num; i++) { 864 bucket = sock_hash_select_bucket(htab, i); 865 raw_spin_lock_bh(&bucket->lock); 866 hlist_for_each_entry_safe(elem, node, &bucket->head, node) { 867 hlist_del_rcu(&elem->node); 868 lock_sock(elem->sk); 869 sock_map_unref(elem->sk, elem); 870 release_sock(elem->sk); 871 } 872 raw_spin_unlock_bh(&bucket->lock); 873 } 874 rcu_read_unlock(); 875 876 bpf_map_area_free(htab->buckets); 877 kfree(htab); 878 } 879 880 static void sock_hash_release_progs(struct bpf_map *map) 881 { 882 psock_progs_drop(&container_of(map, struct bpf_htab, map)->progs); 883 } 884 885 BPF_CALL_4(bpf_sock_hash_update, struct bpf_sock_ops_kern *, sops, 886 struct bpf_map *, map, void *, key, u64, flags) 887 { 888 WARN_ON_ONCE(!rcu_read_lock_held()); 889 890 if (likely(sock_map_sk_is_suitable(sops->sk) && 891 sock_map_op_okay(sops))) 892 return sock_hash_update_common(map, key, sops->sk, flags); 893 return -EOPNOTSUPP; 894 } 895 896 const struct bpf_func_proto bpf_sock_hash_update_proto = { 897 .func = bpf_sock_hash_update, 898 .gpl_only = false, 899 .pkt_access = true, 900 .ret_type = RET_INTEGER, 901 .arg1_type = ARG_PTR_TO_CTX, 902 .arg2_type = ARG_CONST_MAP_PTR, 903 .arg3_type = ARG_PTR_TO_MAP_KEY, 904 .arg4_type = ARG_ANYTHING, 905 }; 906 907 BPF_CALL_4(bpf_sk_redirect_hash, struct sk_buff *, skb, 908 struct bpf_map *, map, void *, key, u64, flags) 909 { 910 struct tcp_skb_cb *tcb = TCP_SKB_CB(skb); 911 912 if (unlikely(flags & ~(BPF_F_INGRESS))) 913 return SK_DROP; 914 tcb->bpf.flags = flags; 915 tcb->bpf.sk_redir = __sock_hash_lookup_elem(map, key); 916 if (!tcb->bpf.sk_redir) 917 return SK_DROP; 918 return SK_PASS; 919 } 920 921 const struct bpf_func_proto bpf_sk_redirect_hash_proto = { 922 .func = bpf_sk_redirect_hash, 923 .gpl_only = false, 924 .ret_type = RET_INTEGER, 925 .arg1_type = ARG_PTR_TO_CTX, 926 .arg2_type = ARG_CONST_MAP_PTR, 927 .arg3_type = ARG_PTR_TO_MAP_KEY, 928 .arg4_type = ARG_ANYTHING, 929 }; 930 931 BPF_CALL_4(bpf_msg_redirect_hash, struct sk_msg *, msg, 932 struct bpf_map *, map, void *, key, u64, flags) 933 { 934 if (unlikely(flags & ~(BPF_F_INGRESS))) 935 return SK_DROP; 936 msg->flags = flags; 937 msg->sk_redir = __sock_hash_lookup_elem(map, key); 938 if (!msg->sk_redir) 939 return SK_DROP; 940 return SK_PASS; 941 } 942 943 const struct bpf_func_proto bpf_msg_redirect_hash_proto = { 944 .func = bpf_msg_redirect_hash, 945 .gpl_only = false, 946 .ret_type = RET_INTEGER, 947 .arg1_type = ARG_PTR_TO_CTX, 948 .arg2_type = ARG_CONST_MAP_PTR, 949 .arg3_type = ARG_PTR_TO_MAP_KEY, 950 .arg4_type = ARG_ANYTHING, 951 }; 952 953 const struct bpf_map_ops sock_hash_ops = { 954 .map_alloc = sock_hash_alloc, 955 .map_free = sock_hash_free, 956 .map_get_next_key = sock_hash_get_next_key, 957 .map_update_elem = sock_hash_update_elem, 958 .map_delete_elem = sock_hash_delete_elem, 959 .map_lookup_elem = sock_map_lookup, 960 .map_release_uref = sock_hash_release_progs, 961 .map_check_btf = map_check_no_btf, 962 }; 963 964 static struct sk_psock_progs *sock_map_progs(struct bpf_map *map) 965 { 966 switch (map->map_type) { 967 case BPF_MAP_TYPE_SOCKMAP: 968 return &container_of(map, struct bpf_stab, map)->progs; 969 case BPF_MAP_TYPE_SOCKHASH: 970 return &container_of(map, struct bpf_htab, map)->progs; 971 default: 972 break; 973 } 974 975 return NULL; 976 } 977 978 int sock_map_prog_update(struct bpf_map *map, struct bpf_prog *prog, 979 u32 which) 980 { 981 struct sk_psock_progs *progs = sock_map_progs(map); 982 983 if (!progs) 984 return -EOPNOTSUPP; 985 986 switch (which) { 987 case BPF_SK_MSG_VERDICT: 988 psock_set_prog(&progs->msg_parser, prog); 989 break; 990 case BPF_SK_SKB_STREAM_PARSER: 991 psock_set_prog(&progs->skb_parser, prog); 992 break; 993 case BPF_SK_SKB_STREAM_VERDICT: 994 psock_set_prog(&progs->skb_verdict, prog); 995 break; 996 default: 997 return -EOPNOTSUPP; 998 } 999 1000 return 0; 1001 } 1002 1003 void sk_psock_unlink(struct sock *sk, struct sk_psock_link *link) 1004 { 1005 switch (link->map->map_type) { 1006 case BPF_MAP_TYPE_SOCKMAP: 1007 return sock_map_delete_from_link(link->map, sk, 1008 link->link_raw); 1009 case BPF_MAP_TYPE_SOCKHASH: 1010 return sock_hash_delete_from_link(link->map, sk, 1011 link->link_raw); 1012 default: 1013 break; 1014 } 1015 } 1016