1 // SPDX-License-Identifier: GPL-2.0 2 /* Copyright (c) 2017 - 2018 Covalent IO, Inc. http://covalent.io */ 3 4 #include <linux/bpf.h> 5 #include <linux/filter.h> 6 #include <linux/errno.h> 7 #include <linux/file.h> 8 #include <linux/net.h> 9 #include <linux/workqueue.h> 10 #include <linux/skmsg.h> 11 #include <linux/list.h> 12 #include <linux/jhash.h> 13 14 struct bpf_stab { 15 struct bpf_map map; 16 struct sock **sks; 17 struct sk_psock_progs progs; 18 raw_spinlock_t lock; 19 }; 20 21 #define SOCK_CREATE_FLAG_MASK \ 22 (BPF_F_NUMA_NODE | BPF_F_RDONLY | BPF_F_WRONLY) 23 24 static struct bpf_map *sock_map_alloc(union bpf_attr *attr) 25 { 26 struct bpf_stab *stab; 27 u64 cost; 28 int err; 29 30 if (!capable(CAP_NET_ADMIN)) 31 return ERR_PTR(-EPERM); 32 if (attr->max_entries == 0 || 33 attr->key_size != 4 || 34 attr->value_size != 4 || 35 attr->map_flags & ~SOCK_CREATE_FLAG_MASK) 36 return ERR_PTR(-EINVAL); 37 38 stab = kzalloc(sizeof(*stab), GFP_USER); 39 if (!stab) 40 return ERR_PTR(-ENOMEM); 41 42 bpf_map_init_from_attr(&stab->map, attr); 43 raw_spin_lock_init(&stab->lock); 44 45 /* Make sure page count doesn't overflow. */ 46 cost = (u64) stab->map.max_entries * sizeof(struct sock *); 47 err = bpf_map_charge_init(&stab->map.memory, cost); 48 if (err) 49 goto free_stab; 50 51 stab->sks = bpf_map_area_alloc(stab->map.max_entries * 52 sizeof(struct sock *), 53 stab->map.numa_node); 54 if (stab->sks) 55 return &stab->map; 56 err = -ENOMEM; 57 bpf_map_charge_finish(&stab->map.memory); 58 free_stab: 59 kfree(stab); 60 return ERR_PTR(err); 61 } 62 63 int sock_map_get_from_fd(const union bpf_attr *attr, struct bpf_prog *prog) 64 { 65 u32 ufd = attr->target_fd; 66 struct bpf_map *map; 67 struct fd f; 68 int ret; 69 70 f = fdget(ufd); 71 map = __bpf_map_get(f); 72 if (IS_ERR(map)) 73 return PTR_ERR(map); 74 ret = sock_map_prog_update(map, prog, attr->attach_type); 75 fdput(f); 76 return ret; 77 } 78 79 static void sock_map_sk_acquire(struct sock *sk) 80 __acquires(&sk->sk_lock.slock) 81 { 82 lock_sock(sk); 83 preempt_disable(); 84 rcu_read_lock(); 85 } 86 87 static void sock_map_sk_release(struct sock *sk) 88 __releases(&sk->sk_lock.slock) 89 { 90 rcu_read_unlock(); 91 preempt_enable(); 92 release_sock(sk); 93 } 94 95 static void sock_map_add_link(struct sk_psock *psock, 96 struct sk_psock_link *link, 97 struct bpf_map *map, void *link_raw) 98 { 99 link->link_raw = link_raw; 100 link->map = map; 101 spin_lock_bh(&psock->link_lock); 102 list_add_tail(&link->list, &psock->link); 103 spin_unlock_bh(&psock->link_lock); 104 } 105 106 static void sock_map_del_link(struct sock *sk, 107 struct sk_psock *psock, void *link_raw) 108 { 109 struct sk_psock_link *link, *tmp; 110 bool strp_stop = false; 111 112 spin_lock_bh(&psock->link_lock); 113 list_for_each_entry_safe(link, tmp, &psock->link, list) { 114 if (link->link_raw == link_raw) { 115 struct bpf_map *map = link->map; 116 struct bpf_stab *stab = container_of(map, struct bpf_stab, 117 map); 118 if (psock->parser.enabled && stab->progs.skb_parser) 119 strp_stop = true; 120 list_del(&link->list); 121 sk_psock_free_link(link); 122 } 123 } 124 spin_unlock_bh(&psock->link_lock); 125 if (strp_stop) { 126 write_lock_bh(&sk->sk_callback_lock); 127 sk_psock_stop_strp(sk, psock); 128 write_unlock_bh(&sk->sk_callback_lock); 129 } 130 } 131 132 static void sock_map_unref(struct sock *sk, void *link_raw) 133 { 134 struct sk_psock *psock = sk_psock(sk); 135 136 if (likely(psock)) { 137 sock_map_del_link(sk, psock, link_raw); 138 sk_psock_put(sk, psock); 139 } 140 } 141 142 static int sock_map_link(struct bpf_map *map, struct sk_psock_progs *progs, 143 struct sock *sk) 144 { 145 struct bpf_prog *msg_parser, *skb_parser, *skb_verdict; 146 bool skb_progs, sk_psock_is_new = false; 147 struct sk_psock *psock; 148 int ret; 149 150 skb_verdict = READ_ONCE(progs->skb_verdict); 151 skb_parser = READ_ONCE(progs->skb_parser); 152 skb_progs = skb_parser && skb_verdict; 153 if (skb_progs) { 154 skb_verdict = bpf_prog_inc_not_zero(skb_verdict); 155 if (IS_ERR(skb_verdict)) 156 return PTR_ERR(skb_verdict); 157 skb_parser = bpf_prog_inc_not_zero(skb_parser); 158 if (IS_ERR(skb_parser)) { 159 bpf_prog_put(skb_verdict); 160 return PTR_ERR(skb_parser); 161 } 162 } 163 164 msg_parser = READ_ONCE(progs->msg_parser); 165 if (msg_parser) { 166 msg_parser = bpf_prog_inc_not_zero(msg_parser); 167 if (IS_ERR(msg_parser)) { 168 ret = PTR_ERR(msg_parser); 169 goto out; 170 } 171 } 172 173 psock = sk_psock_get_checked(sk); 174 if (IS_ERR(psock)) { 175 ret = PTR_ERR(psock); 176 goto out_progs; 177 } 178 179 if (psock) { 180 if ((msg_parser && READ_ONCE(psock->progs.msg_parser)) || 181 (skb_progs && READ_ONCE(psock->progs.skb_parser))) { 182 sk_psock_put(sk, psock); 183 ret = -EBUSY; 184 goto out_progs; 185 } 186 } else { 187 psock = sk_psock_init(sk, map->numa_node); 188 if (!psock) { 189 ret = -ENOMEM; 190 goto out_progs; 191 } 192 sk_psock_is_new = true; 193 } 194 195 if (msg_parser) 196 psock_set_prog(&psock->progs.msg_parser, msg_parser); 197 if (sk_psock_is_new) { 198 ret = tcp_bpf_init(sk); 199 if (ret < 0) 200 goto out_drop; 201 } else { 202 tcp_bpf_reinit(sk); 203 } 204 205 write_lock_bh(&sk->sk_callback_lock); 206 if (skb_progs && !psock->parser.enabled) { 207 ret = sk_psock_init_strp(sk, psock); 208 if (ret) { 209 write_unlock_bh(&sk->sk_callback_lock); 210 goto out_drop; 211 } 212 psock_set_prog(&psock->progs.skb_verdict, skb_verdict); 213 psock_set_prog(&psock->progs.skb_parser, skb_parser); 214 sk_psock_start_strp(sk, psock); 215 } 216 write_unlock_bh(&sk->sk_callback_lock); 217 return 0; 218 out_drop: 219 sk_psock_put(sk, psock); 220 out_progs: 221 if (msg_parser) 222 bpf_prog_put(msg_parser); 223 out: 224 if (skb_progs) { 225 bpf_prog_put(skb_verdict); 226 bpf_prog_put(skb_parser); 227 } 228 return ret; 229 } 230 231 static void sock_map_free(struct bpf_map *map) 232 { 233 struct bpf_stab *stab = container_of(map, struct bpf_stab, map); 234 int i; 235 236 synchronize_rcu(); 237 rcu_read_lock(); 238 raw_spin_lock_bh(&stab->lock); 239 for (i = 0; i < stab->map.max_entries; i++) { 240 struct sock **psk = &stab->sks[i]; 241 struct sock *sk; 242 243 sk = xchg(psk, NULL); 244 if (sk) 245 sock_map_unref(sk, psk); 246 } 247 raw_spin_unlock_bh(&stab->lock); 248 rcu_read_unlock(); 249 250 synchronize_rcu(); 251 252 bpf_map_area_free(stab->sks); 253 kfree(stab); 254 } 255 256 static void sock_map_release_progs(struct bpf_map *map) 257 { 258 psock_progs_drop(&container_of(map, struct bpf_stab, map)->progs); 259 } 260 261 static struct sock *__sock_map_lookup_elem(struct bpf_map *map, u32 key) 262 { 263 struct bpf_stab *stab = container_of(map, struct bpf_stab, map); 264 265 WARN_ON_ONCE(!rcu_read_lock_held()); 266 267 if (unlikely(key >= map->max_entries)) 268 return NULL; 269 return READ_ONCE(stab->sks[key]); 270 } 271 272 static void *sock_map_lookup(struct bpf_map *map, void *key) 273 { 274 return ERR_PTR(-EOPNOTSUPP); 275 } 276 277 static int __sock_map_delete(struct bpf_stab *stab, struct sock *sk_test, 278 struct sock **psk) 279 { 280 struct sock *sk; 281 int err = 0; 282 283 raw_spin_lock_bh(&stab->lock); 284 sk = *psk; 285 if (!sk_test || sk_test == sk) 286 sk = xchg(psk, NULL); 287 288 if (likely(sk)) 289 sock_map_unref(sk, psk); 290 else 291 err = -EINVAL; 292 293 raw_spin_unlock_bh(&stab->lock); 294 return err; 295 } 296 297 static void sock_map_delete_from_link(struct bpf_map *map, struct sock *sk, 298 void *link_raw) 299 { 300 struct bpf_stab *stab = container_of(map, struct bpf_stab, map); 301 302 __sock_map_delete(stab, sk, link_raw); 303 } 304 305 static int sock_map_delete_elem(struct bpf_map *map, void *key) 306 { 307 struct bpf_stab *stab = container_of(map, struct bpf_stab, map); 308 u32 i = *(u32 *)key; 309 struct sock **psk; 310 311 if (unlikely(i >= map->max_entries)) 312 return -EINVAL; 313 314 psk = &stab->sks[i]; 315 return __sock_map_delete(stab, NULL, psk); 316 } 317 318 static int sock_map_get_next_key(struct bpf_map *map, void *key, void *next) 319 { 320 struct bpf_stab *stab = container_of(map, struct bpf_stab, map); 321 u32 i = key ? *(u32 *)key : U32_MAX; 322 u32 *key_next = next; 323 324 if (i == stab->map.max_entries - 1) 325 return -ENOENT; 326 if (i >= stab->map.max_entries) 327 *key_next = 0; 328 else 329 *key_next = i + 1; 330 return 0; 331 } 332 333 static int sock_map_update_common(struct bpf_map *map, u32 idx, 334 struct sock *sk, u64 flags) 335 { 336 struct bpf_stab *stab = container_of(map, struct bpf_stab, map); 337 struct inet_connection_sock *icsk = inet_csk(sk); 338 struct sk_psock_link *link; 339 struct sk_psock *psock; 340 struct sock *osk; 341 int ret; 342 343 WARN_ON_ONCE(!rcu_read_lock_held()); 344 if (unlikely(flags > BPF_EXIST)) 345 return -EINVAL; 346 if (unlikely(idx >= map->max_entries)) 347 return -E2BIG; 348 if (unlikely(icsk->icsk_ulp_data)) 349 return -EINVAL; 350 351 link = sk_psock_init_link(); 352 if (!link) 353 return -ENOMEM; 354 355 ret = sock_map_link(map, &stab->progs, sk); 356 if (ret < 0) 357 goto out_free; 358 359 psock = sk_psock(sk); 360 WARN_ON_ONCE(!psock); 361 362 raw_spin_lock_bh(&stab->lock); 363 osk = stab->sks[idx]; 364 if (osk && flags == BPF_NOEXIST) { 365 ret = -EEXIST; 366 goto out_unlock; 367 } else if (!osk && flags == BPF_EXIST) { 368 ret = -ENOENT; 369 goto out_unlock; 370 } 371 372 sock_map_add_link(psock, link, map, &stab->sks[idx]); 373 stab->sks[idx] = sk; 374 if (osk) 375 sock_map_unref(osk, &stab->sks[idx]); 376 raw_spin_unlock_bh(&stab->lock); 377 return 0; 378 out_unlock: 379 raw_spin_unlock_bh(&stab->lock); 380 if (psock) 381 sk_psock_put(sk, psock); 382 out_free: 383 sk_psock_free_link(link); 384 return ret; 385 } 386 387 static bool sock_map_op_okay(const struct bpf_sock_ops_kern *ops) 388 { 389 return ops->op == BPF_SOCK_OPS_PASSIVE_ESTABLISHED_CB || 390 ops->op == BPF_SOCK_OPS_ACTIVE_ESTABLISHED_CB; 391 } 392 393 static bool sock_map_sk_is_suitable(const struct sock *sk) 394 { 395 return sk->sk_type == SOCK_STREAM && 396 sk->sk_protocol == IPPROTO_TCP; 397 } 398 399 static int sock_map_update_elem(struct bpf_map *map, void *key, 400 void *value, u64 flags) 401 { 402 u32 ufd = *(u32 *)value; 403 u32 idx = *(u32 *)key; 404 struct socket *sock; 405 struct sock *sk; 406 int ret; 407 408 sock = sockfd_lookup(ufd, &ret); 409 if (!sock) 410 return ret; 411 sk = sock->sk; 412 if (!sk) { 413 ret = -EINVAL; 414 goto out; 415 } 416 if (!sock_map_sk_is_suitable(sk) || 417 sk->sk_state != TCP_ESTABLISHED) { 418 ret = -EOPNOTSUPP; 419 goto out; 420 } 421 422 sock_map_sk_acquire(sk); 423 ret = sock_map_update_common(map, idx, sk, flags); 424 sock_map_sk_release(sk); 425 out: 426 fput(sock->file); 427 return ret; 428 } 429 430 BPF_CALL_4(bpf_sock_map_update, struct bpf_sock_ops_kern *, sops, 431 struct bpf_map *, map, void *, key, u64, flags) 432 { 433 WARN_ON_ONCE(!rcu_read_lock_held()); 434 435 if (likely(sock_map_sk_is_suitable(sops->sk) && 436 sock_map_op_okay(sops))) 437 return sock_map_update_common(map, *(u32 *)key, sops->sk, 438 flags); 439 return -EOPNOTSUPP; 440 } 441 442 const struct bpf_func_proto bpf_sock_map_update_proto = { 443 .func = bpf_sock_map_update, 444 .gpl_only = false, 445 .pkt_access = true, 446 .ret_type = RET_INTEGER, 447 .arg1_type = ARG_PTR_TO_CTX, 448 .arg2_type = ARG_CONST_MAP_PTR, 449 .arg3_type = ARG_PTR_TO_MAP_KEY, 450 .arg4_type = ARG_ANYTHING, 451 }; 452 453 BPF_CALL_4(bpf_sk_redirect_map, struct sk_buff *, skb, 454 struct bpf_map *, map, u32, key, u64, flags) 455 { 456 struct tcp_skb_cb *tcb = TCP_SKB_CB(skb); 457 458 if (unlikely(flags & ~(BPF_F_INGRESS))) 459 return SK_DROP; 460 tcb->bpf.flags = flags; 461 tcb->bpf.sk_redir = __sock_map_lookup_elem(map, key); 462 if (!tcb->bpf.sk_redir) 463 return SK_DROP; 464 return SK_PASS; 465 } 466 467 const struct bpf_func_proto bpf_sk_redirect_map_proto = { 468 .func = bpf_sk_redirect_map, 469 .gpl_only = false, 470 .ret_type = RET_INTEGER, 471 .arg1_type = ARG_PTR_TO_CTX, 472 .arg2_type = ARG_CONST_MAP_PTR, 473 .arg3_type = ARG_ANYTHING, 474 .arg4_type = ARG_ANYTHING, 475 }; 476 477 BPF_CALL_4(bpf_msg_redirect_map, struct sk_msg *, msg, 478 struct bpf_map *, map, u32, key, u64, flags) 479 { 480 if (unlikely(flags & ~(BPF_F_INGRESS))) 481 return SK_DROP; 482 msg->flags = flags; 483 msg->sk_redir = __sock_map_lookup_elem(map, key); 484 if (!msg->sk_redir) 485 return SK_DROP; 486 return SK_PASS; 487 } 488 489 const struct bpf_func_proto bpf_msg_redirect_map_proto = { 490 .func = bpf_msg_redirect_map, 491 .gpl_only = false, 492 .ret_type = RET_INTEGER, 493 .arg1_type = ARG_PTR_TO_CTX, 494 .arg2_type = ARG_CONST_MAP_PTR, 495 .arg3_type = ARG_ANYTHING, 496 .arg4_type = ARG_ANYTHING, 497 }; 498 499 const struct bpf_map_ops sock_map_ops = { 500 .map_alloc = sock_map_alloc, 501 .map_free = sock_map_free, 502 .map_get_next_key = sock_map_get_next_key, 503 .map_update_elem = sock_map_update_elem, 504 .map_delete_elem = sock_map_delete_elem, 505 .map_lookup_elem = sock_map_lookup, 506 .map_release_uref = sock_map_release_progs, 507 .map_check_btf = map_check_no_btf, 508 }; 509 510 struct bpf_htab_elem { 511 struct rcu_head rcu; 512 u32 hash; 513 struct sock *sk; 514 struct hlist_node node; 515 u8 key[0]; 516 }; 517 518 struct bpf_htab_bucket { 519 struct hlist_head head; 520 raw_spinlock_t lock; 521 }; 522 523 struct bpf_htab { 524 struct bpf_map map; 525 struct bpf_htab_bucket *buckets; 526 u32 buckets_num; 527 u32 elem_size; 528 struct sk_psock_progs progs; 529 atomic_t count; 530 }; 531 532 static inline u32 sock_hash_bucket_hash(const void *key, u32 len) 533 { 534 return jhash(key, len, 0); 535 } 536 537 static struct bpf_htab_bucket *sock_hash_select_bucket(struct bpf_htab *htab, 538 u32 hash) 539 { 540 return &htab->buckets[hash & (htab->buckets_num - 1)]; 541 } 542 543 static struct bpf_htab_elem * 544 sock_hash_lookup_elem_raw(struct hlist_head *head, u32 hash, void *key, 545 u32 key_size) 546 { 547 struct bpf_htab_elem *elem; 548 549 hlist_for_each_entry_rcu(elem, head, node) { 550 if (elem->hash == hash && 551 !memcmp(&elem->key, key, key_size)) 552 return elem; 553 } 554 555 return NULL; 556 } 557 558 static struct sock *__sock_hash_lookup_elem(struct bpf_map *map, void *key) 559 { 560 struct bpf_htab *htab = container_of(map, struct bpf_htab, map); 561 u32 key_size = map->key_size, hash; 562 struct bpf_htab_bucket *bucket; 563 struct bpf_htab_elem *elem; 564 565 WARN_ON_ONCE(!rcu_read_lock_held()); 566 567 hash = sock_hash_bucket_hash(key, key_size); 568 bucket = sock_hash_select_bucket(htab, hash); 569 elem = sock_hash_lookup_elem_raw(&bucket->head, hash, key, key_size); 570 571 return elem ? elem->sk : NULL; 572 } 573 574 static void sock_hash_free_elem(struct bpf_htab *htab, 575 struct bpf_htab_elem *elem) 576 { 577 atomic_dec(&htab->count); 578 kfree_rcu(elem, rcu); 579 } 580 581 static void sock_hash_delete_from_link(struct bpf_map *map, struct sock *sk, 582 void *link_raw) 583 { 584 struct bpf_htab *htab = container_of(map, struct bpf_htab, map); 585 struct bpf_htab_elem *elem_probe, *elem = link_raw; 586 struct bpf_htab_bucket *bucket; 587 588 WARN_ON_ONCE(!rcu_read_lock_held()); 589 bucket = sock_hash_select_bucket(htab, elem->hash); 590 591 /* elem may be deleted in parallel from the map, but access here 592 * is okay since it's going away only after RCU grace period. 593 * However, we need to check whether it's still present. 594 */ 595 raw_spin_lock_bh(&bucket->lock); 596 elem_probe = sock_hash_lookup_elem_raw(&bucket->head, elem->hash, 597 elem->key, map->key_size); 598 if (elem_probe && elem_probe == elem) { 599 hlist_del_rcu(&elem->node); 600 sock_map_unref(elem->sk, elem); 601 sock_hash_free_elem(htab, elem); 602 } 603 raw_spin_unlock_bh(&bucket->lock); 604 } 605 606 static int sock_hash_delete_elem(struct bpf_map *map, void *key) 607 { 608 struct bpf_htab *htab = container_of(map, struct bpf_htab, map); 609 u32 hash, key_size = map->key_size; 610 struct bpf_htab_bucket *bucket; 611 struct bpf_htab_elem *elem; 612 int ret = -ENOENT; 613 614 hash = sock_hash_bucket_hash(key, key_size); 615 bucket = sock_hash_select_bucket(htab, hash); 616 617 raw_spin_lock_bh(&bucket->lock); 618 elem = sock_hash_lookup_elem_raw(&bucket->head, hash, key, key_size); 619 if (elem) { 620 hlist_del_rcu(&elem->node); 621 sock_map_unref(elem->sk, elem); 622 sock_hash_free_elem(htab, elem); 623 ret = 0; 624 } 625 raw_spin_unlock_bh(&bucket->lock); 626 return ret; 627 } 628 629 static struct bpf_htab_elem *sock_hash_alloc_elem(struct bpf_htab *htab, 630 void *key, u32 key_size, 631 u32 hash, struct sock *sk, 632 struct bpf_htab_elem *old) 633 { 634 struct bpf_htab_elem *new; 635 636 if (atomic_inc_return(&htab->count) > htab->map.max_entries) { 637 if (!old) { 638 atomic_dec(&htab->count); 639 return ERR_PTR(-E2BIG); 640 } 641 } 642 643 new = kmalloc_node(htab->elem_size, GFP_ATOMIC | __GFP_NOWARN, 644 htab->map.numa_node); 645 if (!new) { 646 atomic_dec(&htab->count); 647 return ERR_PTR(-ENOMEM); 648 } 649 memcpy(new->key, key, key_size); 650 new->sk = sk; 651 new->hash = hash; 652 return new; 653 } 654 655 static int sock_hash_update_common(struct bpf_map *map, void *key, 656 struct sock *sk, u64 flags) 657 { 658 struct bpf_htab *htab = container_of(map, struct bpf_htab, map); 659 u32 key_size = map->key_size, hash; 660 struct bpf_htab_elem *elem, *elem_new; 661 struct bpf_htab_bucket *bucket; 662 struct sk_psock_link *link; 663 struct sk_psock *psock; 664 int ret; 665 666 WARN_ON_ONCE(!rcu_read_lock_held()); 667 if (unlikely(flags > BPF_EXIST)) 668 return -EINVAL; 669 670 link = sk_psock_init_link(); 671 if (!link) 672 return -ENOMEM; 673 674 ret = sock_map_link(map, &htab->progs, sk); 675 if (ret < 0) 676 goto out_free; 677 678 psock = sk_psock(sk); 679 WARN_ON_ONCE(!psock); 680 681 hash = sock_hash_bucket_hash(key, key_size); 682 bucket = sock_hash_select_bucket(htab, hash); 683 684 raw_spin_lock_bh(&bucket->lock); 685 elem = sock_hash_lookup_elem_raw(&bucket->head, hash, key, key_size); 686 if (elem && flags == BPF_NOEXIST) { 687 ret = -EEXIST; 688 goto out_unlock; 689 } else if (!elem && flags == BPF_EXIST) { 690 ret = -ENOENT; 691 goto out_unlock; 692 } 693 694 elem_new = sock_hash_alloc_elem(htab, key, key_size, hash, sk, elem); 695 if (IS_ERR(elem_new)) { 696 ret = PTR_ERR(elem_new); 697 goto out_unlock; 698 } 699 700 sock_map_add_link(psock, link, map, elem_new); 701 /* Add new element to the head of the list, so that 702 * concurrent search will find it before old elem. 703 */ 704 hlist_add_head_rcu(&elem_new->node, &bucket->head); 705 if (elem) { 706 hlist_del_rcu(&elem->node); 707 sock_map_unref(elem->sk, elem); 708 sock_hash_free_elem(htab, elem); 709 } 710 raw_spin_unlock_bh(&bucket->lock); 711 return 0; 712 out_unlock: 713 raw_spin_unlock_bh(&bucket->lock); 714 sk_psock_put(sk, psock); 715 out_free: 716 sk_psock_free_link(link); 717 return ret; 718 } 719 720 static int sock_hash_update_elem(struct bpf_map *map, void *key, 721 void *value, u64 flags) 722 { 723 u32 ufd = *(u32 *)value; 724 struct socket *sock; 725 struct sock *sk; 726 int ret; 727 728 sock = sockfd_lookup(ufd, &ret); 729 if (!sock) 730 return ret; 731 sk = sock->sk; 732 if (!sk) { 733 ret = -EINVAL; 734 goto out; 735 } 736 if (!sock_map_sk_is_suitable(sk) || 737 sk->sk_state != TCP_ESTABLISHED) { 738 ret = -EOPNOTSUPP; 739 goto out; 740 } 741 742 sock_map_sk_acquire(sk); 743 ret = sock_hash_update_common(map, key, sk, flags); 744 sock_map_sk_release(sk); 745 out: 746 fput(sock->file); 747 return ret; 748 } 749 750 static int sock_hash_get_next_key(struct bpf_map *map, void *key, 751 void *key_next) 752 { 753 struct bpf_htab *htab = container_of(map, struct bpf_htab, map); 754 struct bpf_htab_elem *elem, *elem_next; 755 u32 hash, key_size = map->key_size; 756 struct hlist_head *head; 757 int i = 0; 758 759 if (!key) 760 goto find_first_elem; 761 hash = sock_hash_bucket_hash(key, key_size); 762 head = &sock_hash_select_bucket(htab, hash)->head; 763 elem = sock_hash_lookup_elem_raw(head, hash, key, key_size); 764 if (!elem) 765 goto find_first_elem; 766 767 elem_next = hlist_entry_safe(rcu_dereference_raw(hlist_next_rcu(&elem->node)), 768 struct bpf_htab_elem, node); 769 if (elem_next) { 770 memcpy(key_next, elem_next->key, key_size); 771 return 0; 772 } 773 774 i = hash & (htab->buckets_num - 1); 775 i++; 776 find_first_elem: 777 for (; i < htab->buckets_num; i++) { 778 head = &sock_hash_select_bucket(htab, i)->head; 779 elem_next = hlist_entry_safe(rcu_dereference_raw(hlist_first_rcu(head)), 780 struct bpf_htab_elem, node); 781 if (elem_next) { 782 memcpy(key_next, elem_next->key, key_size); 783 return 0; 784 } 785 } 786 787 return -ENOENT; 788 } 789 790 static struct bpf_map *sock_hash_alloc(union bpf_attr *attr) 791 { 792 struct bpf_htab *htab; 793 int i, err; 794 u64 cost; 795 796 if (!capable(CAP_NET_ADMIN)) 797 return ERR_PTR(-EPERM); 798 if (attr->max_entries == 0 || 799 attr->key_size == 0 || 800 attr->value_size != 4 || 801 attr->map_flags & ~SOCK_CREATE_FLAG_MASK) 802 return ERR_PTR(-EINVAL); 803 if (attr->key_size > MAX_BPF_STACK) 804 return ERR_PTR(-E2BIG); 805 806 htab = kzalloc(sizeof(*htab), GFP_USER); 807 if (!htab) 808 return ERR_PTR(-ENOMEM); 809 810 bpf_map_init_from_attr(&htab->map, attr); 811 812 htab->buckets_num = roundup_pow_of_two(htab->map.max_entries); 813 htab->elem_size = sizeof(struct bpf_htab_elem) + 814 round_up(htab->map.key_size, 8); 815 if (htab->buckets_num == 0 || 816 htab->buckets_num > U32_MAX / sizeof(struct bpf_htab_bucket)) { 817 err = -EINVAL; 818 goto free_htab; 819 } 820 821 cost = (u64) htab->buckets_num * sizeof(struct bpf_htab_bucket) + 822 (u64) htab->elem_size * htab->map.max_entries; 823 if (cost >= U32_MAX - PAGE_SIZE) { 824 err = -EINVAL; 825 goto free_htab; 826 } 827 828 htab->buckets = bpf_map_area_alloc(htab->buckets_num * 829 sizeof(struct bpf_htab_bucket), 830 htab->map.numa_node); 831 if (!htab->buckets) { 832 err = -ENOMEM; 833 goto free_htab; 834 } 835 836 for (i = 0; i < htab->buckets_num; i++) { 837 INIT_HLIST_HEAD(&htab->buckets[i].head); 838 raw_spin_lock_init(&htab->buckets[i].lock); 839 } 840 841 return &htab->map; 842 free_htab: 843 kfree(htab); 844 return ERR_PTR(err); 845 } 846 847 static void sock_hash_free(struct bpf_map *map) 848 { 849 struct bpf_htab *htab = container_of(map, struct bpf_htab, map); 850 struct bpf_htab_bucket *bucket; 851 struct bpf_htab_elem *elem; 852 struct hlist_node *node; 853 int i; 854 855 synchronize_rcu(); 856 rcu_read_lock(); 857 for (i = 0; i < htab->buckets_num; i++) { 858 bucket = sock_hash_select_bucket(htab, i); 859 raw_spin_lock_bh(&bucket->lock); 860 hlist_for_each_entry_safe(elem, node, &bucket->head, node) { 861 hlist_del_rcu(&elem->node); 862 sock_map_unref(elem->sk, elem); 863 } 864 raw_spin_unlock_bh(&bucket->lock); 865 } 866 rcu_read_unlock(); 867 868 bpf_map_area_free(htab->buckets); 869 kfree(htab); 870 } 871 872 static void sock_hash_release_progs(struct bpf_map *map) 873 { 874 psock_progs_drop(&container_of(map, struct bpf_htab, map)->progs); 875 } 876 877 BPF_CALL_4(bpf_sock_hash_update, struct bpf_sock_ops_kern *, sops, 878 struct bpf_map *, map, void *, key, u64, flags) 879 { 880 WARN_ON_ONCE(!rcu_read_lock_held()); 881 882 if (likely(sock_map_sk_is_suitable(sops->sk) && 883 sock_map_op_okay(sops))) 884 return sock_hash_update_common(map, key, sops->sk, flags); 885 return -EOPNOTSUPP; 886 } 887 888 const struct bpf_func_proto bpf_sock_hash_update_proto = { 889 .func = bpf_sock_hash_update, 890 .gpl_only = false, 891 .pkt_access = true, 892 .ret_type = RET_INTEGER, 893 .arg1_type = ARG_PTR_TO_CTX, 894 .arg2_type = ARG_CONST_MAP_PTR, 895 .arg3_type = ARG_PTR_TO_MAP_KEY, 896 .arg4_type = ARG_ANYTHING, 897 }; 898 899 BPF_CALL_4(bpf_sk_redirect_hash, struct sk_buff *, skb, 900 struct bpf_map *, map, void *, key, u64, flags) 901 { 902 struct tcp_skb_cb *tcb = TCP_SKB_CB(skb); 903 904 if (unlikely(flags & ~(BPF_F_INGRESS))) 905 return SK_DROP; 906 tcb->bpf.flags = flags; 907 tcb->bpf.sk_redir = __sock_hash_lookup_elem(map, key); 908 if (!tcb->bpf.sk_redir) 909 return SK_DROP; 910 return SK_PASS; 911 } 912 913 const struct bpf_func_proto bpf_sk_redirect_hash_proto = { 914 .func = bpf_sk_redirect_hash, 915 .gpl_only = false, 916 .ret_type = RET_INTEGER, 917 .arg1_type = ARG_PTR_TO_CTX, 918 .arg2_type = ARG_CONST_MAP_PTR, 919 .arg3_type = ARG_PTR_TO_MAP_KEY, 920 .arg4_type = ARG_ANYTHING, 921 }; 922 923 BPF_CALL_4(bpf_msg_redirect_hash, struct sk_msg *, msg, 924 struct bpf_map *, map, void *, key, u64, flags) 925 { 926 if (unlikely(flags & ~(BPF_F_INGRESS))) 927 return SK_DROP; 928 msg->flags = flags; 929 msg->sk_redir = __sock_hash_lookup_elem(map, key); 930 if (!msg->sk_redir) 931 return SK_DROP; 932 return SK_PASS; 933 } 934 935 const struct bpf_func_proto bpf_msg_redirect_hash_proto = { 936 .func = bpf_msg_redirect_hash, 937 .gpl_only = false, 938 .ret_type = RET_INTEGER, 939 .arg1_type = ARG_PTR_TO_CTX, 940 .arg2_type = ARG_CONST_MAP_PTR, 941 .arg3_type = ARG_PTR_TO_MAP_KEY, 942 .arg4_type = ARG_ANYTHING, 943 }; 944 945 const struct bpf_map_ops sock_hash_ops = { 946 .map_alloc = sock_hash_alloc, 947 .map_free = sock_hash_free, 948 .map_get_next_key = sock_hash_get_next_key, 949 .map_update_elem = sock_hash_update_elem, 950 .map_delete_elem = sock_hash_delete_elem, 951 .map_lookup_elem = sock_map_lookup, 952 .map_release_uref = sock_hash_release_progs, 953 .map_check_btf = map_check_no_btf, 954 }; 955 956 static struct sk_psock_progs *sock_map_progs(struct bpf_map *map) 957 { 958 switch (map->map_type) { 959 case BPF_MAP_TYPE_SOCKMAP: 960 return &container_of(map, struct bpf_stab, map)->progs; 961 case BPF_MAP_TYPE_SOCKHASH: 962 return &container_of(map, struct bpf_htab, map)->progs; 963 default: 964 break; 965 } 966 967 return NULL; 968 } 969 970 int sock_map_prog_update(struct bpf_map *map, struct bpf_prog *prog, 971 u32 which) 972 { 973 struct sk_psock_progs *progs = sock_map_progs(map); 974 975 if (!progs) 976 return -EOPNOTSUPP; 977 978 switch (which) { 979 case BPF_SK_MSG_VERDICT: 980 psock_set_prog(&progs->msg_parser, prog); 981 break; 982 case BPF_SK_SKB_STREAM_PARSER: 983 psock_set_prog(&progs->skb_parser, prog); 984 break; 985 case BPF_SK_SKB_STREAM_VERDICT: 986 psock_set_prog(&progs->skb_verdict, prog); 987 break; 988 default: 989 return -EOPNOTSUPP; 990 } 991 992 return 0; 993 } 994 995 void sk_psock_unlink(struct sock *sk, struct sk_psock_link *link) 996 { 997 switch (link->map->map_type) { 998 case BPF_MAP_TYPE_SOCKMAP: 999 return sock_map_delete_from_link(link->map, sk, 1000 link->link_raw); 1001 case BPF_MAP_TYPE_SOCKHASH: 1002 return sock_hash_delete_from_link(link->map, sk, 1003 link->link_raw); 1004 default: 1005 break; 1006 } 1007 } 1008