xref: /openbmc/linux/net/core/sock_map.c (revision 6197e5b7)
1 // SPDX-License-Identifier: GPL-2.0
2 /* Copyright (c) 2017 - 2018 Covalent IO, Inc. http://covalent.io */
3 
4 #include <linux/bpf.h>
5 #include <linux/btf_ids.h>
6 #include <linux/filter.h>
7 #include <linux/errno.h>
8 #include <linux/file.h>
9 #include <linux/net.h>
10 #include <linux/workqueue.h>
11 #include <linux/skmsg.h>
12 #include <linux/list.h>
13 #include <linux/jhash.h>
14 #include <linux/sock_diag.h>
15 #include <net/udp.h>
16 
17 struct bpf_stab {
18 	struct bpf_map map;
19 	struct sock **sks;
20 	struct sk_psock_progs progs;
21 	raw_spinlock_t lock;
22 };
23 
24 #define SOCK_CREATE_FLAG_MASK				\
25 	(BPF_F_NUMA_NODE | BPF_F_RDONLY | BPF_F_WRONLY)
26 
27 static int sock_map_prog_update(struct bpf_map *map, struct bpf_prog *prog,
28 				struct bpf_prog *old, u32 which);
29 
30 static struct bpf_map *sock_map_alloc(union bpf_attr *attr)
31 {
32 	struct bpf_stab *stab;
33 
34 	if (!capable(CAP_NET_ADMIN))
35 		return ERR_PTR(-EPERM);
36 	if (attr->max_entries == 0 ||
37 	    attr->key_size    != 4 ||
38 	    (attr->value_size != sizeof(u32) &&
39 	     attr->value_size != sizeof(u64)) ||
40 	    attr->map_flags & ~SOCK_CREATE_FLAG_MASK)
41 		return ERR_PTR(-EINVAL);
42 
43 	stab = kzalloc(sizeof(*stab), GFP_USER | __GFP_ACCOUNT);
44 	if (!stab)
45 		return ERR_PTR(-ENOMEM);
46 
47 	bpf_map_init_from_attr(&stab->map, attr);
48 	raw_spin_lock_init(&stab->lock);
49 
50 	stab->sks = bpf_map_area_alloc(stab->map.max_entries *
51 				       sizeof(struct sock *),
52 				       stab->map.numa_node);
53 	if (!stab->sks) {
54 		kfree(stab);
55 		return ERR_PTR(-ENOMEM);
56 	}
57 
58 	return &stab->map;
59 }
60 
61 int sock_map_get_from_fd(const union bpf_attr *attr, struct bpf_prog *prog)
62 {
63 	u32 ufd = attr->target_fd;
64 	struct bpf_map *map;
65 	struct fd f;
66 	int ret;
67 
68 	if (attr->attach_flags || attr->replace_bpf_fd)
69 		return -EINVAL;
70 
71 	f = fdget(ufd);
72 	map = __bpf_map_get(f);
73 	if (IS_ERR(map))
74 		return PTR_ERR(map);
75 	ret = sock_map_prog_update(map, prog, NULL, attr->attach_type);
76 	fdput(f);
77 	return ret;
78 }
79 
80 int sock_map_prog_detach(const union bpf_attr *attr, enum bpf_prog_type ptype)
81 {
82 	u32 ufd = attr->target_fd;
83 	struct bpf_prog *prog;
84 	struct bpf_map *map;
85 	struct fd f;
86 	int ret;
87 
88 	if (attr->attach_flags || attr->replace_bpf_fd)
89 		return -EINVAL;
90 
91 	f = fdget(ufd);
92 	map = __bpf_map_get(f);
93 	if (IS_ERR(map))
94 		return PTR_ERR(map);
95 
96 	prog = bpf_prog_get(attr->attach_bpf_fd);
97 	if (IS_ERR(prog)) {
98 		ret = PTR_ERR(prog);
99 		goto put_map;
100 	}
101 
102 	if (prog->type != ptype) {
103 		ret = -EINVAL;
104 		goto put_prog;
105 	}
106 
107 	ret = sock_map_prog_update(map, NULL, prog, attr->attach_type);
108 put_prog:
109 	bpf_prog_put(prog);
110 put_map:
111 	fdput(f);
112 	return ret;
113 }
114 
115 static void sock_map_sk_acquire(struct sock *sk)
116 	__acquires(&sk->sk_lock.slock)
117 {
118 	lock_sock(sk);
119 	preempt_disable();
120 	rcu_read_lock();
121 }
122 
123 static void sock_map_sk_release(struct sock *sk)
124 	__releases(&sk->sk_lock.slock)
125 {
126 	rcu_read_unlock();
127 	preempt_enable();
128 	release_sock(sk);
129 }
130 
131 static void sock_map_add_link(struct sk_psock *psock,
132 			      struct sk_psock_link *link,
133 			      struct bpf_map *map, void *link_raw)
134 {
135 	link->link_raw = link_raw;
136 	link->map = map;
137 	spin_lock_bh(&psock->link_lock);
138 	list_add_tail(&link->list, &psock->link);
139 	spin_unlock_bh(&psock->link_lock);
140 }
141 
142 static void sock_map_del_link(struct sock *sk,
143 			      struct sk_psock *psock, void *link_raw)
144 {
145 	bool strp_stop = false, verdict_stop = false;
146 	struct sk_psock_link *link, *tmp;
147 
148 	spin_lock_bh(&psock->link_lock);
149 	list_for_each_entry_safe(link, tmp, &psock->link, list) {
150 		if (link->link_raw == link_raw) {
151 			struct bpf_map *map = link->map;
152 			struct bpf_stab *stab = container_of(map, struct bpf_stab,
153 							     map);
154 			if (psock->saved_data_ready && stab->progs.stream_parser)
155 				strp_stop = true;
156 			if (psock->saved_data_ready && stab->progs.stream_verdict)
157 				verdict_stop = true;
158 			list_del(&link->list);
159 			sk_psock_free_link(link);
160 		}
161 	}
162 	spin_unlock_bh(&psock->link_lock);
163 	if (strp_stop || verdict_stop) {
164 		write_lock_bh(&sk->sk_callback_lock);
165 		if (strp_stop)
166 			sk_psock_stop_strp(sk, psock);
167 		else
168 			sk_psock_stop_verdict(sk, psock);
169 		write_unlock_bh(&sk->sk_callback_lock);
170 	}
171 }
172 
173 static void sock_map_unref(struct sock *sk, void *link_raw)
174 {
175 	struct sk_psock *psock = sk_psock(sk);
176 
177 	if (likely(psock)) {
178 		sock_map_del_link(sk, psock, link_raw);
179 		sk_psock_put(sk, psock);
180 	}
181 }
182 
183 static int sock_map_init_proto(struct sock *sk, struct sk_psock *psock)
184 {
185 	struct proto *prot;
186 
187 	switch (sk->sk_type) {
188 	case SOCK_STREAM:
189 		prot = tcp_bpf_get_proto(sk, psock);
190 		break;
191 
192 	case SOCK_DGRAM:
193 		prot = udp_bpf_get_proto(sk, psock);
194 		break;
195 
196 	default:
197 		return -EINVAL;
198 	}
199 
200 	if (IS_ERR(prot))
201 		return PTR_ERR(prot);
202 
203 	sk_psock_update_proto(sk, psock, prot);
204 	return 0;
205 }
206 
207 static struct sk_psock *sock_map_psock_get_checked(struct sock *sk)
208 {
209 	struct sk_psock *psock;
210 
211 	rcu_read_lock();
212 	psock = sk_psock(sk);
213 	if (psock) {
214 		if (sk->sk_prot->close != sock_map_close) {
215 			psock = ERR_PTR(-EBUSY);
216 			goto out;
217 		}
218 
219 		if (!refcount_inc_not_zero(&psock->refcnt))
220 			psock = ERR_PTR(-EBUSY);
221 	}
222 out:
223 	rcu_read_unlock();
224 	return psock;
225 }
226 
227 static int sock_map_link(struct bpf_map *map, struct sk_psock_progs *progs,
228 			 struct sock *sk)
229 {
230 	struct bpf_prog *msg_parser, *stream_parser, *stream_verdict;
231 	struct sk_psock *psock;
232 	int ret;
233 
234 	stream_verdict = READ_ONCE(progs->stream_verdict);
235 	if (stream_verdict) {
236 		stream_verdict = bpf_prog_inc_not_zero(stream_verdict);
237 		if (IS_ERR(stream_verdict))
238 			return PTR_ERR(stream_verdict);
239 	}
240 
241 	stream_parser = READ_ONCE(progs->stream_parser);
242 	if (stream_parser) {
243 		stream_parser = bpf_prog_inc_not_zero(stream_parser);
244 		if (IS_ERR(stream_parser)) {
245 			ret = PTR_ERR(stream_parser);
246 			goto out_put_stream_verdict;
247 		}
248 	}
249 
250 	msg_parser = READ_ONCE(progs->msg_parser);
251 	if (msg_parser) {
252 		msg_parser = bpf_prog_inc_not_zero(msg_parser);
253 		if (IS_ERR(msg_parser)) {
254 			ret = PTR_ERR(msg_parser);
255 			goto out_put_stream_parser;
256 		}
257 	}
258 
259 	psock = sock_map_psock_get_checked(sk);
260 	if (IS_ERR(psock)) {
261 		ret = PTR_ERR(psock);
262 		goto out_progs;
263 	}
264 
265 	if (psock) {
266 		if ((msg_parser && READ_ONCE(psock->progs.msg_parser)) ||
267 		    (stream_parser  && READ_ONCE(psock->progs.stream_parser)) ||
268 		    (stream_verdict && READ_ONCE(psock->progs.stream_verdict))) {
269 			sk_psock_put(sk, psock);
270 			ret = -EBUSY;
271 			goto out_progs;
272 		}
273 	} else {
274 		psock = sk_psock_init(sk, map->numa_node);
275 		if (IS_ERR(psock)) {
276 			ret = PTR_ERR(psock);
277 			goto out_progs;
278 		}
279 	}
280 
281 	if (msg_parser)
282 		psock_set_prog(&psock->progs.msg_parser, msg_parser);
283 
284 	ret = sock_map_init_proto(sk, psock);
285 	if (ret < 0)
286 		goto out_drop;
287 
288 	write_lock_bh(&sk->sk_callback_lock);
289 	if (stream_parser && stream_verdict && !psock->saved_data_ready) {
290 		ret = sk_psock_init_strp(sk, psock);
291 		if (ret)
292 			goto out_unlock_drop;
293 		psock_set_prog(&psock->progs.stream_verdict, stream_verdict);
294 		psock_set_prog(&psock->progs.stream_parser, stream_parser);
295 		sk_psock_start_strp(sk, psock);
296 	} else if (!stream_parser && stream_verdict && !psock->saved_data_ready) {
297 		psock_set_prog(&psock->progs.stream_verdict, stream_verdict);
298 		sk_psock_start_verdict(sk,psock);
299 	}
300 	write_unlock_bh(&sk->sk_callback_lock);
301 	return 0;
302 out_unlock_drop:
303 	write_unlock_bh(&sk->sk_callback_lock);
304 out_drop:
305 	sk_psock_put(sk, psock);
306 out_progs:
307 	if (msg_parser)
308 		bpf_prog_put(msg_parser);
309 out_put_stream_parser:
310 	if (stream_parser)
311 		bpf_prog_put(stream_parser);
312 out_put_stream_verdict:
313 	if (stream_verdict)
314 		bpf_prog_put(stream_verdict);
315 	return ret;
316 }
317 
318 static int sock_map_link_no_progs(struct bpf_map *map, struct sock *sk)
319 {
320 	struct sk_psock *psock;
321 	int ret;
322 
323 	psock = sock_map_psock_get_checked(sk);
324 	if (IS_ERR(psock))
325 		return PTR_ERR(psock);
326 
327 	if (!psock) {
328 		psock = sk_psock_init(sk, map->numa_node);
329 		if (IS_ERR(psock))
330 			return PTR_ERR(psock);
331 	}
332 
333 	ret = sock_map_init_proto(sk, psock);
334 	if (ret < 0)
335 		sk_psock_put(sk, psock);
336 	return ret;
337 }
338 
339 static void sock_map_free(struct bpf_map *map)
340 {
341 	struct bpf_stab *stab = container_of(map, struct bpf_stab, map);
342 	int i;
343 
344 	/* After the sync no updates or deletes will be in-flight so it
345 	 * is safe to walk map and remove entries without risking a race
346 	 * in EEXIST update case.
347 	 */
348 	synchronize_rcu();
349 	for (i = 0; i < stab->map.max_entries; i++) {
350 		struct sock **psk = &stab->sks[i];
351 		struct sock *sk;
352 
353 		sk = xchg(psk, NULL);
354 		if (sk) {
355 			lock_sock(sk);
356 			rcu_read_lock();
357 			sock_map_unref(sk, psk);
358 			rcu_read_unlock();
359 			release_sock(sk);
360 		}
361 	}
362 
363 	/* wait for psock readers accessing its map link */
364 	synchronize_rcu();
365 
366 	bpf_map_area_free(stab->sks);
367 	kfree(stab);
368 }
369 
370 static void sock_map_release_progs(struct bpf_map *map)
371 {
372 	psock_progs_drop(&container_of(map, struct bpf_stab, map)->progs);
373 }
374 
375 static struct sock *__sock_map_lookup_elem(struct bpf_map *map, u32 key)
376 {
377 	struct bpf_stab *stab = container_of(map, struct bpf_stab, map);
378 
379 	WARN_ON_ONCE(!rcu_read_lock_held());
380 
381 	if (unlikely(key >= map->max_entries))
382 		return NULL;
383 	return READ_ONCE(stab->sks[key]);
384 }
385 
386 static void *sock_map_lookup(struct bpf_map *map, void *key)
387 {
388 	struct sock *sk;
389 
390 	sk = __sock_map_lookup_elem(map, *(u32 *)key);
391 	if (!sk)
392 		return NULL;
393 	if (sk_is_refcounted(sk) && !refcount_inc_not_zero(&sk->sk_refcnt))
394 		return NULL;
395 	return sk;
396 }
397 
398 static void *sock_map_lookup_sys(struct bpf_map *map, void *key)
399 {
400 	struct sock *sk;
401 
402 	if (map->value_size != sizeof(u64))
403 		return ERR_PTR(-ENOSPC);
404 
405 	sk = __sock_map_lookup_elem(map, *(u32 *)key);
406 	if (!sk)
407 		return ERR_PTR(-ENOENT);
408 
409 	__sock_gen_cookie(sk);
410 	return &sk->sk_cookie;
411 }
412 
413 static int __sock_map_delete(struct bpf_stab *stab, struct sock *sk_test,
414 			     struct sock **psk)
415 {
416 	struct sock *sk;
417 	int err = 0;
418 
419 	raw_spin_lock_bh(&stab->lock);
420 	sk = *psk;
421 	if (!sk_test || sk_test == sk)
422 		sk = xchg(psk, NULL);
423 
424 	if (likely(sk))
425 		sock_map_unref(sk, psk);
426 	else
427 		err = -EINVAL;
428 
429 	raw_spin_unlock_bh(&stab->lock);
430 	return err;
431 }
432 
433 static void sock_map_delete_from_link(struct bpf_map *map, struct sock *sk,
434 				      void *link_raw)
435 {
436 	struct bpf_stab *stab = container_of(map, struct bpf_stab, map);
437 
438 	__sock_map_delete(stab, sk, link_raw);
439 }
440 
441 static int sock_map_delete_elem(struct bpf_map *map, void *key)
442 {
443 	struct bpf_stab *stab = container_of(map, struct bpf_stab, map);
444 	u32 i = *(u32 *)key;
445 	struct sock **psk;
446 
447 	if (unlikely(i >= map->max_entries))
448 		return -EINVAL;
449 
450 	psk = &stab->sks[i];
451 	return __sock_map_delete(stab, NULL, psk);
452 }
453 
454 static int sock_map_get_next_key(struct bpf_map *map, void *key, void *next)
455 {
456 	struct bpf_stab *stab = container_of(map, struct bpf_stab, map);
457 	u32 i = key ? *(u32 *)key : U32_MAX;
458 	u32 *key_next = next;
459 
460 	if (i == stab->map.max_entries - 1)
461 		return -ENOENT;
462 	if (i >= stab->map.max_entries)
463 		*key_next = 0;
464 	else
465 		*key_next = i + 1;
466 	return 0;
467 }
468 
469 static bool sock_map_redirect_allowed(const struct sock *sk);
470 
471 static int sock_map_update_common(struct bpf_map *map, u32 idx,
472 				  struct sock *sk, u64 flags)
473 {
474 	struct bpf_stab *stab = container_of(map, struct bpf_stab, map);
475 	struct sk_psock_link *link;
476 	struct sk_psock *psock;
477 	struct sock *osk;
478 	int ret;
479 
480 	WARN_ON_ONCE(!rcu_read_lock_held());
481 	if (unlikely(flags > BPF_EXIST))
482 		return -EINVAL;
483 	if (unlikely(idx >= map->max_entries))
484 		return -E2BIG;
485 
486 	link = sk_psock_init_link();
487 	if (!link)
488 		return -ENOMEM;
489 
490 	/* Only sockets we can redirect into/from in BPF need to hold
491 	 * refs to parser/verdict progs and have their sk_data_ready
492 	 * and sk_write_space callbacks overridden.
493 	 */
494 	if (sock_map_redirect_allowed(sk))
495 		ret = sock_map_link(map, &stab->progs, sk);
496 	else
497 		ret = sock_map_link_no_progs(map, sk);
498 	if (ret < 0)
499 		goto out_free;
500 
501 	psock = sk_psock(sk);
502 	WARN_ON_ONCE(!psock);
503 
504 	raw_spin_lock_bh(&stab->lock);
505 	osk = stab->sks[idx];
506 	if (osk && flags == BPF_NOEXIST) {
507 		ret = -EEXIST;
508 		goto out_unlock;
509 	} else if (!osk && flags == BPF_EXIST) {
510 		ret = -ENOENT;
511 		goto out_unlock;
512 	}
513 
514 	sock_map_add_link(psock, link, map, &stab->sks[idx]);
515 	stab->sks[idx] = sk;
516 	if (osk)
517 		sock_map_unref(osk, &stab->sks[idx]);
518 	raw_spin_unlock_bh(&stab->lock);
519 	return 0;
520 out_unlock:
521 	raw_spin_unlock_bh(&stab->lock);
522 	if (psock)
523 		sk_psock_put(sk, psock);
524 out_free:
525 	sk_psock_free_link(link);
526 	return ret;
527 }
528 
529 static bool sock_map_op_okay(const struct bpf_sock_ops_kern *ops)
530 {
531 	return ops->op == BPF_SOCK_OPS_PASSIVE_ESTABLISHED_CB ||
532 	       ops->op == BPF_SOCK_OPS_ACTIVE_ESTABLISHED_CB ||
533 	       ops->op == BPF_SOCK_OPS_TCP_LISTEN_CB;
534 }
535 
536 static bool sk_is_tcp(const struct sock *sk)
537 {
538 	return sk->sk_type == SOCK_STREAM &&
539 	       sk->sk_protocol == IPPROTO_TCP;
540 }
541 
542 static bool sk_is_udp(const struct sock *sk)
543 {
544 	return sk->sk_type == SOCK_DGRAM &&
545 	       sk->sk_protocol == IPPROTO_UDP;
546 }
547 
548 static bool sock_map_redirect_allowed(const struct sock *sk)
549 {
550 	return sk_is_tcp(sk) && sk->sk_state != TCP_LISTEN;
551 }
552 
553 static bool sock_map_sk_is_suitable(const struct sock *sk)
554 {
555 	return sk_is_tcp(sk) || sk_is_udp(sk);
556 }
557 
558 static bool sock_map_sk_state_allowed(const struct sock *sk)
559 {
560 	if (sk_is_tcp(sk))
561 		return (1 << sk->sk_state) & (TCPF_ESTABLISHED | TCPF_LISTEN);
562 	else if (sk_is_udp(sk))
563 		return sk_hashed(sk);
564 
565 	return false;
566 }
567 
568 static int sock_hash_update_common(struct bpf_map *map, void *key,
569 				   struct sock *sk, u64 flags);
570 
571 int sock_map_update_elem_sys(struct bpf_map *map, void *key, void *value,
572 			     u64 flags)
573 {
574 	struct socket *sock;
575 	struct sock *sk;
576 	int ret;
577 	u64 ufd;
578 
579 	if (map->value_size == sizeof(u64))
580 		ufd = *(u64 *)value;
581 	else
582 		ufd = *(u32 *)value;
583 	if (ufd > S32_MAX)
584 		return -EINVAL;
585 
586 	sock = sockfd_lookup(ufd, &ret);
587 	if (!sock)
588 		return ret;
589 	sk = sock->sk;
590 	if (!sk) {
591 		ret = -EINVAL;
592 		goto out;
593 	}
594 	if (!sock_map_sk_is_suitable(sk)) {
595 		ret = -EOPNOTSUPP;
596 		goto out;
597 	}
598 
599 	sock_map_sk_acquire(sk);
600 	if (!sock_map_sk_state_allowed(sk))
601 		ret = -EOPNOTSUPP;
602 	else if (map->map_type == BPF_MAP_TYPE_SOCKMAP)
603 		ret = sock_map_update_common(map, *(u32 *)key, sk, flags);
604 	else
605 		ret = sock_hash_update_common(map, key, sk, flags);
606 	sock_map_sk_release(sk);
607 out:
608 	sockfd_put(sock);
609 	return ret;
610 }
611 
612 static int sock_map_update_elem(struct bpf_map *map, void *key,
613 				void *value, u64 flags)
614 {
615 	struct sock *sk = (struct sock *)value;
616 	int ret;
617 
618 	if (unlikely(!sk || !sk_fullsock(sk)))
619 		return -EINVAL;
620 
621 	if (!sock_map_sk_is_suitable(sk))
622 		return -EOPNOTSUPP;
623 
624 	local_bh_disable();
625 	bh_lock_sock(sk);
626 	if (!sock_map_sk_state_allowed(sk))
627 		ret = -EOPNOTSUPP;
628 	else if (map->map_type == BPF_MAP_TYPE_SOCKMAP)
629 		ret = sock_map_update_common(map, *(u32 *)key, sk, flags);
630 	else
631 		ret = sock_hash_update_common(map, key, sk, flags);
632 	bh_unlock_sock(sk);
633 	local_bh_enable();
634 	return ret;
635 }
636 
637 BPF_CALL_4(bpf_sock_map_update, struct bpf_sock_ops_kern *, sops,
638 	   struct bpf_map *, map, void *, key, u64, flags)
639 {
640 	WARN_ON_ONCE(!rcu_read_lock_held());
641 
642 	if (likely(sock_map_sk_is_suitable(sops->sk) &&
643 		   sock_map_op_okay(sops)))
644 		return sock_map_update_common(map, *(u32 *)key, sops->sk,
645 					      flags);
646 	return -EOPNOTSUPP;
647 }
648 
649 const struct bpf_func_proto bpf_sock_map_update_proto = {
650 	.func		= bpf_sock_map_update,
651 	.gpl_only	= false,
652 	.pkt_access	= true,
653 	.ret_type	= RET_INTEGER,
654 	.arg1_type	= ARG_PTR_TO_CTX,
655 	.arg2_type	= ARG_CONST_MAP_PTR,
656 	.arg3_type	= ARG_PTR_TO_MAP_KEY,
657 	.arg4_type	= ARG_ANYTHING,
658 };
659 
660 BPF_CALL_4(bpf_sk_redirect_map, struct sk_buff *, skb,
661 	   struct bpf_map *, map, u32, key, u64, flags)
662 {
663 	struct sock *sk;
664 
665 	if (unlikely(flags & ~(BPF_F_INGRESS)))
666 		return SK_DROP;
667 
668 	sk = __sock_map_lookup_elem(map, key);
669 	if (unlikely(!sk || !sock_map_redirect_allowed(sk)))
670 		return SK_DROP;
671 
672 	skb_bpf_set_redir(skb, sk, flags & BPF_F_INGRESS);
673 	return SK_PASS;
674 }
675 
676 const struct bpf_func_proto bpf_sk_redirect_map_proto = {
677 	.func           = bpf_sk_redirect_map,
678 	.gpl_only       = false,
679 	.ret_type       = RET_INTEGER,
680 	.arg1_type	= ARG_PTR_TO_CTX,
681 	.arg2_type      = ARG_CONST_MAP_PTR,
682 	.arg3_type      = ARG_ANYTHING,
683 	.arg4_type      = ARG_ANYTHING,
684 };
685 
686 BPF_CALL_4(bpf_msg_redirect_map, struct sk_msg *, msg,
687 	   struct bpf_map *, map, u32, key, u64, flags)
688 {
689 	struct sock *sk;
690 
691 	if (unlikely(flags & ~(BPF_F_INGRESS)))
692 		return SK_DROP;
693 
694 	sk = __sock_map_lookup_elem(map, key);
695 	if (unlikely(!sk || !sock_map_redirect_allowed(sk)))
696 		return SK_DROP;
697 
698 	msg->flags = flags;
699 	msg->sk_redir = sk;
700 	return SK_PASS;
701 }
702 
703 const struct bpf_func_proto bpf_msg_redirect_map_proto = {
704 	.func           = bpf_msg_redirect_map,
705 	.gpl_only       = false,
706 	.ret_type       = RET_INTEGER,
707 	.arg1_type	= ARG_PTR_TO_CTX,
708 	.arg2_type      = ARG_CONST_MAP_PTR,
709 	.arg3_type      = ARG_ANYTHING,
710 	.arg4_type      = ARG_ANYTHING,
711 };
712 
713 struct sock_map_seq_info {
714 	struct bpf_map *map;
715 	struct sock *sk;
716 	u32 index;
717 };
718 
719 struct bpf_iter__sockmap {
720 	__bpf_md_ptr(struct bpf_iter_meta *, meta);
721 	__bpf_md_ptr(struct bpf_map *, map);
722 	__bpf_md_ptr(void *, key);
723 	__bpf_md_ptr(struct sock *, sk);
724 };
725 
726 DEFINE_BPF_ITER_FUNC(sockmap, struct bpf_iter_meta *meta,
727 		     struct bpf_map *map, void *key,
728 		     struct sock *sk)
729 
730 static void *sock_map_seq_lookup_elem(struct sock_map_seq_info *info)
731 {
732 	if (unlikely(info->index >= info->map->max_entries))
733 		return NULL;
734 
735 	info->sk = __sock_map_lookup_elem(info->map, info->index);
736 
737 	/* can't return sk directly, since that might be NULL */
738 	return info;
739 }
740 
741 static void *sock_map_seq_start(struct seq_file *seq, loff_t *pos)
742 	__acquires(rcu)
743 {
744 	struct sock_map_seq_info *info = seq->private;
745 
746 	if (*pos == 0)
747 		++*pos;
748 
749 	/* pairs with sock_map_seq_stop */
750 	rcu_read_lock();
751 	return sock_map_seq_lookup_elem(info);
752 }
753 
754 static void *sock_map_seq_next(struct seq_file *seq, void *v, loff_t *pos)
755 	__must_hold(rcu)
756 {
757 	struct sock_map_seq_info *info = seq->private;
758 
759 	++*pos;
760 	++info->index;
761 
762 	return sock_map_seq_lookup_elem(info);
763 }
764 
765 static int sock_map_seq_show(struct seq_file *seq, void *v)
766 	__must_hold(rcu)
767 {
768 	struct sock_map_seq_info *info = seq->private;
769 	struct bpf_iter__sockmap ctx = {};
770 	struct bpf_iter_meta meta;
771 	struct bpf_prog *prog;
772 
773 	meta.seq = seq;
774 	prog = bpf_iter_get_info(&meta, !v);
775 	if (!prog)
776 		return 0;
777 
778 	ctx.meta = &meta;
779 	ctx.map = info->map;
780 	if (v) {
781 		ctx.key = &info->index;
782 		ctx.sk = info->sk;
783 	}
784 
785 	return bpf_iter_run_prog(prog, &ctx);
786 }
787 
788 static void sock_map_seq_stop(struct seq_file *seq, void *v)
789 	__releases(rcu)
790 {
791 	if (!v)
792 		(void)sock_map_seq_show(seq, NULL);
793 
794 	/* pairs with sock_map_seq_start */
795 	rcu_read_unlock();
796 }
797 
798 static const struct seq_operations sock_map_seq_ops = {
799 	.start	= sock_map_seq_start,
800 	.next	= sock_map_seq_next,
801 	.stop	= sock_map_seq_stop,
802 	.show	= sock_map_seq_show,
803 };
804 
805 static int sock_map_init_seq_private(void *priv_data,
806 				     struct bpf_iter_aux_info *aux)
807 {
808 	struct sock_map_seq_info *info = priv_data;
809 
810 	info->map = aux->map;
811 	return 0;
812 }
813 
814 static const struct bpf_iter_seq_info sock_map_iter_seq_info = {
815 	.seq_ops		= &sock_map_seq_ops,
816 	.init_seq_private	= sock_map_init_seq_private,
817 	.seq_priv_size		= sizeof(struct sock_map_seq_info),
818 };
819 
820 static int sock_map_btf_id;
821 const struct bpf_map_ops sock_map_ops = {
822 	.map_meta_equal		= bpf_map_meta_equal,
823 	.map_alloc		= sock_map_alloc,
824 	.map_free		= sock_map_free,
825 	.map_get_next_key	= sock_map_get_next_key,
826 	.map_lookup_elem_sys_only = sock_map_lookup_sys,
827 	.map_update_elem	= sock_map_update_elem,
828 	.map_delete_elem	= sock_map_delete_elem,
829 	.map_lookup_elem	= sock_map_lookup,
830 	.map_release_uref	= sock_map_release_progs,
831 	.map_check_btf		= map_check_no_btf,
832 	.map_btf_name		= "bpf_stab",
833 	.map_btf_id		= &sock_map_btf_id,
834 	.iter_seq_info		= &sock_map_iter_seq_info,
835 };
836 
837 struct bpf_shtab_elem {
838 	struct rcu_head rcu;
839 	u32 hash;
840 	struct sock *sk;
841 	struct hlist_node node;
842 	u8 key[];
843 };
844 
845 struct bpf_shtab_bucket {
846 	struct hlist_head head;
847 	raw_spinlock_t lock;
848 };
849 
850 struct bpf_shtab {
851 	struct bpf_map map;
852 	struct bpf_shtab_bucket *buckets;
853 	u32 buckets_num;
854 	u32 elem_size;
855 	struct sk_psock_progs progs;
856 	atomic_t count;
857 };
858 
859 static inline u32 sock_hash_bucket_hash(const void *key, u32 len)
860 {
861 	return jhash(key, len, 0);
862 }
863 
864 static struct bpf_shtab_bucket *sock_hash_select_bucket(struct bpf_shtab *htab,
865 							u32 hash)
866 {
867 	return &htab->buckets[hash & (htab->buckets_num - 1)];
868 }
869 
870 static struct bpf_shtab_elem *
871 sock_hash_lookup_elem_raw(struct hlist_head *head, u32 hash, void *key,
872 			  u32 key_size)
873 {
874 	struct bpf_shtab_elem *elem;
875 
876 	hlist_for_each_entry_rcu(elem, head, node) {
877 		if (elem->hash == hash &&
878 		    !memcmp(&elem->key, key, key_size))
879 			return elem;
880 	}
881 
882 	return NULL;
883 }
884 
885 static struct sock *__sock_hash_lookup_elem(struct bpf_map *map, void *key)
886 {
887 	struct bpf_shtab *htab = container_of(map, struct bpf_shtab, map);
888 	u32 key_size = map->key_size, hash;
889 	struct bpf_shtab_bucket *bucket;
890 	struct bpf_shtab_elem *elem;
891 
892 	WARN_ON_ONCE(!rcu_read_lock_held());
893 
894 	hash = sock_hash_bucket_hash(key, key_size);
895 	bucket = sock_hash_select_bucket(htab, hash);
896 	elem = sock_hash_lookup_elem_raw(&bucket->head, hash, key, key_size);
897 
898 	return elem ? elem->sk : NULL;
899 }
900 
901 static void sock_hash_free_elem(struct bpf_shtab *htab,
902 				struct bpf_shtab_elem *elem)
903 {
904 	atomic_dec(&htab->count);
905 	kfree_rcu(elem, rcu);
906 }
907 
908 static void sock_hash_delete_from_link(struct bpf_map *map, struct sock *sk,
909 				       void *link_raw)
910 {
911 	struct bpf_shtab *htab = container_of(map, struct bpf_shtab, map);
912 	struct bpf_shtab_elem *elem_probe, *elem = link_raw;
913 	struct bpf_shtab_bucket *bucket;
914 
915 	WARN_ON_ONCE(!rcu_read_lock_held());
916 	bucket = sock_hash_select_bucket(htab, elem->hash);
917 
918 	/* elem may be deleted in parallel from the map, but access here
919 	 * is okay since it's going away only after RCU grace period.
920 	 * However, we need to check whether it's still present.
921 	 */
922 	raw_spin_lock_bh(&bucket->lock);
923 	elem_probe = sock_hash_lookup_elem_raw(&bucket->head, elem->hash,
924 					       elem->key, map->key_size);
925 	if (elem_probe && elem_probe == elem) {
926 		hlist_del_rcu(&elem->node);
927 		sock_map_unref(elem->sk, elem);
928 		sock_hash_free_elem(htab, elem);
929 	}
930 	raw_spin_unlock_bh(&bucket->lock);
931 }
932 
933 static int sock_hash_delete_elem(struct bpf_map *map, void *key)
934 {
935 	struct bpf_shtab *htab = container_of(map, struct bpf_shtab, map);
936 	u32 hash, key_size = map->key_size;
937 	struct bpf_shtab_bucket *bucket;
938 	struct bpf_shtab_elem *elem;
939 	int ret = -ENOENT;
940 
941 	hash = sock_hash_bucket_hash(key, key_size);
942 	bucket = sock_hash_select_bucket(htab, hash);
943 
944 	raw_spin_lock_bh(&bucket->lock);
945 	elem = sock_hash_lookup_elem_raw(&bucket->head, hash, key, key_size);
946 	if (elem) {
947 		hlist_del_rcu(&elem->node);
948 		sock_map_unref(elem->sk, elem);
949 		sock_hash_free_elem(htab, elem);
950 		ret = 0;
951 	}
952 	raw_spin_unlock_bh(&bucket->lock);
953 	return ret;
954 }
955 
956 static struct bpf_shtab_elem *sock_hash_alloc_elem(struct bpf_shtab *htab,
957 						   void *key, u32 key_size,
958 						   u32 hash, struct sock *sk,
959 						   struct bpf_shtab_elem *old)
960 {
961 	struct bpf_shtab_elem *new;
962 
963 	if (atomic_inc_return(&htab->count) > htab->map.max_entries) {
964 		if (!old) {
965 			atomic_dec(&htab->count);
966 			return ERR_PTR(-E2BIG);
967 		}
968 	}
969 
970 	new = bpf_map_kmalloc_node(&htab->map, htab->elem_size,
971 				   GFP_ATOMIC | __GFP_NOWARN,
972 				   htab->map.numa_node);
973 	if (!new) {
974 		atomic_dec(&htab->count);
975 		return ERR_PTR(-ENOMEM);
976 	}
977 	memcpy(new->key, key, key_size);
978 	new->sk = sk;
979 	new->hash = hash;
980 	return new;
981 }
982 
983 static int sock_hash_update_common(struct bpf_map *map, void *key,
984 				   struct sock *sk, u64 flags)
985 {
986 	struct bpf_shtab *htab = container_of(map, struct bpf_shtab, map);
987 	u32 key_size = map->key_size, hash;
988 	struct bpf_shtab_elem *elem, *elem_new;
989 	struct bpf_shtab_bucket *bucket;
990 	struct sk_psock_link *link;
991 	struct sk_psock *psock;
992 	int ret;
993 
994 	WARN_ON_ONCE(!rcu_read_lock_held());
995 	if (unlikely(flags > BPF_EXIST))
996 		return -EINVAL;
997 
998 	link = sk_psock_init_link();
999 	if (!link)
1000 		return -ENOMEM;
1001 
1002 	/* Only sockets we can redirect into/from in BPF need to hold
1003 	 * refs to parser/verdict progs and have their sk_data_ready
1004 	 * and sk_write_space callbacks overridden.
1005 	 */
1006 	if (sock_map_redirect_allowed(sk))
1007 		ret = sock_map_link(map, &htab->progs, sk);
1008 	else
1009 		ret = sock_map_link_no_progs(map, sk);
1010 	if (ret < 0)
1011 		goto out_free;
1012 
1013 	psock = sk_psock(sk);
1014 	WARN_ON_ONCE(!psock);
1015 
1016 	hash = sock_hash_bucket_hash(key, key_size);
1017 	bucket = sock_hash_select_bucket(htab, hash);
1018 
1019 	raw_spin_lock_bh(&bucket->lock);
1020 	elem = sock_hash_lookup_elem_raw(&bucket->head, hash, key, key_size);
1021 	if (elem && flags == BPF_NOEXIST) {
1022 		ret = -EEXIST;
1023 		goto out_unlock;
1024 	} else if (!elem && flags == BPF_EXIST) {
1025 		ret = -ENOENT;
1026 		goto out_unlock;
1027 	}
1028 
1029 	elem_new = sock_hash_alloc_elem(htab, key, key_size, hash, sk, elem);
1030 	if (IS_ERR(elem_new)) {
1031 		ret = PTR_ERR(elem_new);
1032 		goto out_unlock;
1033 	}
1034 
1035 	sock_map_add_link(psock, link, map, elem_new);
1036 	/* Add new element to the head of the list, so that
1037 	 * concurrent search will find it before old elem.
1038 	 */
1039 	hlist_add_head_rcu(&elem_new->node, &bucket->head);
1040 	if (elem) {
1041 		hlist_del_rcu(&elem->node);
1042 		sock_map_unref(elem->sk, elem);
1043 		sock_hash_free_elem(htab, elem);
1044 	}
1045 	raw_spin_unlock_bh(&bucket->lock);
1046 	return 0;
1047 out_unlock:
1048 	raw_spin_unlock_bh(&bucket->lock);
1049 	sk_psock_put(sk, psock);
1050 out_free:
1051 	sk_psock_free_link(link);
1052 	return ret;
1053 }
1054 
1055 static int sock_hash_get_next_key(struct bpf_map *map, void *key,
1056 				  void *key_next)
1057 {
1058 	struct bpf_shtab *htab = container_of(map, struct bpf_shtab, map);
1059 	struct bpf_shtab_elem *elem, *elem_next;
1060 	u32 hash, key_size = map->key_size;
1061 	struct hlist_head *head;
1062 	int i = 0;
1063 
1064 	if (!key)
1065 		goto find_first_elem;
1066 	hash = sock_hash_bucket_hash(key, key_size);
1067 	head = &sock_hash_select_bucket(htab, hash)->head;
1068 	elem = sock_hash_lookup_elem_raw(head, hash, key, key_size);
1069 	if (!elem)
1070 		goto find_first_elem;
1071 
1072 	elem_next = hlist_entry_safe(rcu_dereference(hlist_next_rcu(&elem->node)),
1073 				     struct bpf_shtab_elem, node);
1074 	if (elem_next) {
1075 		memcpy(key_next, elem_next->key, key_size);
1076 		return 0;
1077 	}
1078 
1079 	i = hash & (htab->buckets_num - 1);
1080 	i++;
1081 find_first_elem:
1082 	for (; i < htab->buckets_num; i++) {
1083 		head = &sock_hash_select_bucket(htab, i)->head;
1084 		elem_next = hlist_entry_safe(rcu_dereference(hlist_first_rcu(head)),
1085 					     struct bpf_shtab_elem, node);
1086 		if (elem_next) {
1087 			memcpy(key_next, elem_next->key, key_size);
1088 			return 0;
1089 		}
1090 	}
1091 
1092 	return -ENOENT;
1093 }
1094 
1095 static struct bpf_map *sock_hash_alloc(union bpf_attr *attr)
1096 {
1097 	struct bpf_shtab *htab;
1098 	int i, err;
1099 
1100 	if (!capable(CAP_NET_ADMIN))
1101 		return ERR_PTR(-EPERM);
1102 	if (attr->max_entries == 0 ||
1103 	    attr->key_size    == 0 ||
1104 	    (attr->value_size != sizeof(u32) &&
1105 	     attr->value_size != sizeof(u64)) ||
1106 	    attr->map_flags & ~SOCK_CREATE_FLAG_MASK)
1107 		return ERR_PTR(-EINVAL);
1108 	if (attr->key_size > MAX_BPF_STACK)
1109 		return ERR_PTR(-E2BIG);
1110 
1111 	htab = kzalloc(sizeof(*htab), GFP_USER | __GFP_ACCOUNT);
1112 	if (!htab)
1113 		return ERR_PTR(-ENOMEM);
1114 
1115 	bpf_map_init_from_attr(&htab->map, attr);
1116 
1117 	htab->buckets_num = roundup_pow_of_two(htab->map.max_entries);
1118 	htab->elem_size = sizeof(struct bpf_shtab_elem) +
1119 			  round_up(htab->map.key_size, 8);
1120 	if (htab->buckets_num == 0 ||
1121 	    htab->buckets_num > U32_MAX / sizeof(struct bpf_shtab_bucket)) {
1122 		err = -EINVAL;
1123 		goto free_htab;
1124 	}
1125 
1126 	htab->buckets = bpf_map_area_alloc(htab->buckets_num *
1127 					   sizeof(struct bpf_shtab_bucket),
1128 					   htab->map.numa_node);
1129 	if (!htab->buckets) {
1130 		err = -ENOMEM;
1131 		goto free_htab;
1132 	}
1133 
1134 	for (i = 0; i < htab->buckets_num; i++) {
1135 		INIT_HLIST_HEAD(&htab->buckets[i].head);
1136 		raw_spin_lock_init(&htab->buckets[i].lock);
1137 	}
1138 
1139 	return &htab->map;
1140 free_htab:
1141 	kfree(htab);
1142 	return ERR_PTR(err);
1143 }
1144 
1145 static void sock_hash_free(struct bpf_map *map)
1146 {
1147 	struct bpf_shtab *htab = container_of(map, struct bpf_shtab, map);
1148 	struct bpf_shtab_bucket *bucket;
1149 	struct hlist_head unlink_list;
1150 	struct bpf_shtab_elem *elem;
1151 	struct hlist_node *node;
1152 	int i;
1153 
1154 	/* After the sync no updates or deletes will be in-flight so it
1155 	 * is safe to walk map and remove entries without risking a race
1156 	 * in EEXIST update case.
1157 	 */
1158 	synchronize_rcu();
1159 	for (i = 0; i < htab->buckets_num; i++) {
1160 		bucket = sock_hash_select_bucket(htab, i);
1161 
1162 		/* We are racing with sock_hash_delete_from_link to
1163 		 * enter the spin-lock critical section. Every socket on
1164 		 * the list is still linked to sockhash. Since link
1165 		 * exists, psock exists and holds a ref to socket. That
1166 		 * lets us to grab a socket ref too.
1167 		 */
1168 		raw_spin_lock_bh(&bucket->lock);
1169 		hlist_for_each_entry(elem, &bucket->head, node)
1170 			sock_hold(elem->sk);
1171 		hlist_move_list(&bucket->head, &unlink_list);
1172 		raw_spin_unlock_bh(&bucket->lock);
1173 
1174 		/* Process removed entries out of atomic context to
1175 		 * block for socket lock before deleting the psock's
1176 		 * link to sockhash.
1177 		 */
1178 		hlist_for_each_entry_safe(elem, node, &unlink_list, node) {
1179 			hlist_del(&elem->node);
1180 			lock_sock(elem->sk);
1181 			rcu_read_lock();
1182 			sock_map_unref(elem->sk, elem);
1183 			rcu_read_unlock();
1184 			release_sock(elem->sk);
1185 			sock_put(elem->sk);
1186 			sock_hash_free_elem(htab, elem);
1187 		}
1188 	}
1189 
1190 	/* wait for psock readers accessing its map link */
1191 	synchronize_rcu();
1192 
1193 	bpf_map_area_free(htab->buckets);
1194 	kfree(htab);
1195 }
1196 
1197 static void *sock_hash_lookup_sys(struct bpf_map *map, void *key)
1198 {
1199 	struct sock *sk;
1200 
1201 	if (map->value_size != sizeof(u64))
1202 		return ERR_PTR(-ENOSPC);
1203 
1204 	sk = __sock_hash_lookup_elem(map, key);
1205 	if (!sk)
1206 		return ERR_PTR(-ENOENT);
1207 
1208 	__sock_gen_cookie(sk);
1209 	return &sk->sk_cookie;
1210 }
1211 
1212 static void *sock_hash_lookup(struct bpf_map *map, void *key)
1213 {
1214 	struct sock *sk;
1215 
1216 	sk = __sock_hash_lookup_elem(map, key);
1217 	if (!sk)
1218 		return NULL;
1219 	if (sk_is_refcounted(sk) && !refcount_inc_not_zero(&sk->sk_refcnt))
1220 		return NULL;
1221 	return sk;
1222 }
1223 
1224 static void sock_hash_release_progs(struct bpf_map *map)
1225 {
1226 	psock_progs_drop(&container_of(map, struct bpf_shtab, map)->progs);
1227 }
1228 
1229 BPF_CALL_4(bpf_sock_hash_update, struct bpf_sock_ops_kern *, sops,
1230 	   struct bpf_map *, map, void *, key, u64, flags)
1231 {
1232 	WARN_ON_ONCE(!rcu_read_lock_held());
1233 
1234 	if (likely(sock_map_sk_is_suitable(sops->sk) &&
1235 		   sock_map_op_okay(sops)))
1236 		return sock_hash_update_common(map, key, sops->sk, flags);
1237 	return -EOPNOTSUPP;
1238 }
1239 
1240 const struct bpf_func_proto bpf_sock_hash_update_proto = {
1241 	.func		= bpf_sock_hash_update,
1242 	.gpl_only	= false,
1243 	.pkt_access	= true,
1244 	.ret_type	= RET_INTEGER,
1245 	.arg1_type	= ARG_PTR_TO_CTX,
1246 	.arg2_type	= ARG_CONST_MAP_PTR,
1247 	.arg3_type	= ARG_PTR_TO_MAP_KEY,
1248 	.arg4_type	= ARG_ANYTHING,
1249 };
1250 
1251 BPF_CALL_4(bpf_sk_redirect_hash, struct sk_buff *, skb,
1252 	   struct bpf_map *, map, void *, key, u64, flags)
1253 {
1254 	struct sock *sk;
1255 
1256 	if (unlikely(flags & ~(BPF_F_INGRESS)))
1257 		return SK_DROP;
1258 
1259 	sk = __sock_hash_lookup_elem(map, key);
1260 	if (unlikely(!sk || !sock_map_redirect_allowed(sk)))
1261 		return SK_DROP;
1262 
1263 	skb_bpf_set_redir(skb, sk, flags & BPF_F_INGRESS);
1264 	return SK_PASS;
1265 }
1266 
1267 const struct bpf_func_proto bpf_sk_redirect_hash_proto = {
1268 	.func           = bpf_sk_redirect_hash,
1269 	.gpl_only       = false,
1270 	.ret_type       = RET_INTEGER,
1271 	.arg1_type	= ARG_PTR_TO_CTX,
1272 	.arg2_type      = ARG_CONST_MAP_PTR,
1273 	.arg3_type      = ARG_PTR_TO_MAP_KEY,
1274 	.arg4_type      = ARG_ANYTHING,
1275 };
1276 
1277 BPF_CALL_4(bpf_msg_redirect_hash, struct sk_msg *, msg,
1278 	   struct bpf_map *, map, void *, key, u64, flags)
1279 {
1280 	struct sock *sk;
1281 
1282 	if (unlikely(flags & ~(BPF_F_INGRESS)))
1283 		return SK_DROP;
1284 
1285 	sk = __sock_hash_lookup_elem(map, key);
1286 	if (unlikely(!sk || !sock_map_redirect_allowed(sk)))
1287 		return SK_DROP;
1288 
1289 	msg->flags = flags;
1290 	msg->sk_redir = sk;
1291 	return SK_PASS;
1292 }
1293 
1294 const struct bpf_func_proto bpf_msg_redirect_hash_proto = {
1295 	.func           = bpf_msg_redirect_hash,
1296 	.gpl_only       = false,
1297 	.ret_type       = RET_INTEGER,
1298 	.arg1_type	= ARG_PTR_TO_CTX,
1299 	.arg2_type      = ARG_CONST_MAP_PTR,
1300 	.arg3_type      = ARG_PTR_TO_MAP_KEY,
1301 	.arg4_type      = ARG_ANYTHING,
1302 };
1303 
1304 struct sock_hash_seq_info {
1305 	struct bpf_map *map;
1306 	struct bpf_shtab *htab;
1307 	u32 bucket_id;
1308 };
1309 
1310 static void *sock_hash_seq_find_next(struct sock_hash_seq_info *info,
1311 				     struct bpf_shtab_elem *prev_elem)
1312 {
1313 	const struct bpf_shtab *htab = info->htab;
1314 	struct bpf_shtab_bucket *bucket;
1315 	struct bpf_shtab_elem *elem;
1316 	struct hlist_node *node;
1317 
1318 	/* try to find next elem in the same bucket */
1319 	if (prev_elem) {
1320 		node = rcu_dereference(hlist_next_rcu(&prev_elem->node));
1321 		elem = hlist_entry_safe(node, struct bpf_shtab_elem, node);
1322 		if (elem)
1323 			return elem;
1324 
1325 		/* no more elements, continue in the next bucket */
1326 		info->bucket_id++;
1327 	}
1328 
1329 	for (; info->bucket_id < htab->buckets_num; info->bucket_id++) {
1330 		bucket = &htab->buckets[info->bucket_id];
1331 		node = rcu_dereference(hlist_first_rcu(&bucket->head));
1332 		elem = hlist_entry_safe(node, struct bpf_shtab_elem, node);
1333 		if (elem)
1334 			return elem;
1335 	}
1336 
1337 	return NULL;
1338 }
1339 
1340 static void *sock_hash_seq_start(struct seq_file *seq, loff_t *pos)
1341 	__acquires(rcu)
1342 {
1343 	struct sock_hash_seq_info *info = seq->private;
1344 
1345 	if (*pos == 0)
1346 		++*pos;
1347 
1348 	/* pairs with sock_hash_seq_stop */
1349 	rcu_read_lock();
1350 	return sock_hash_seq_find_next(info, NULL);
1351 }
1352 
1353 static void *sock_hash_seq_next(struct seq_file *seq, void *v, loff_t *pos)
1354 	__must_hold(rcu)
1355 {
1356 	struct sock_hash_seq_info *info = seq->private;
1357 
1358 	++*pos;
1359 	return sock_hash_seq_find_next(info, v);
1360 }
1361 
1362 static int sock_hash_seq_show(struct seq_file *seq, void *v)
1363 	__must_hold(rcu)
1364 {
1365 	struct sock_hash_seq_info *info = seq->private;
1366 	struct bpf_iter__sockmap ctx = {};
1367 	struct bpf_shtab_elem *elem = v;
1368 	struct bpf_iter_meta meta;
1369 	struct bpf_prog *prog;
1370 
1371 	meta.seq = seq;
1372 	prog = bpf_iter_get_info(&meta, !elem);
1373 	if (!prog)
1374 		return 0;
1375 
1376 	ctx.meta = &meta;
1377 	ctx.map = info->map;
1378 	if (elem) {
1379 		ctx.key = elem->key;
1380 		ctx.sk = elem->sk;
1381 	}
1382 
1383 	return bpf_iter_run_prog(prog, &ctx);
1384 }
1385 
1386 static void sock_hash_seq_stop(struct seq_file *seq, void *v)
1387 	__releases(rcu)
1388 {
1389 	if (!v)
1390 		(void)sock_hash_seq_show(seq, NULL);
1391 
1392 	/* pairs with sock_hash_seq_start */
1393 	rcu_read_unlock();
1394 }
1395 
1396 static const struct seq_operations sock_hash_seq_ops = {
1397 	.start	= sock_hash_seq_start,
1398 	.next	= sock_hash_seq_next,
1399 	.stop	= sock_hash_seq_stop,
1400 	.show	= sock_hash_seq_show,
1401 };
1402 
1403 static int sock_hash_init_seq_private(void *priv_data,
1404 				     struct bpf_iter_aux_info *aux)
1405 {
1406 	struct sock_hash_seq_info *info = priv_data;
1407 
1408 	info->map = aux->map;
1409 	info->htab = container_of(aux->map, struct bpf_shtab, map);
1410 	return 0;
1411 }
1412 
1413 static const struct bpf_iter_seq_info sock_hash_iter_seq_info = {
1414 	.seq_ops		= &sock_hash_seq_ops,
1415 	.init_seq_private	= sock_hash_init_seq_private,
1416 	.seq_priv_size		= sizeof(struct sock_hash_seq_info),
1417 };
1418 
1419 static int sock_hash_map_btf_id;
1420 const struct bpf_map_ops sock_hash_ops = {
1421 	.map_meta_equal		= bpf_map_meta_equal,
1422 	.map_alloc		= sock_hash_alloc,
1423 	.map_free		= sock_hash_free,
1424 	.map_get_next_key	= sock_hash_get_next_key,
1425 	.map_update_elem	= sock_map_update_elem,
1426 	.map_delete_elem	= sock_hash_delete_elem,
1427 	.map_lookup_elem	= sock_hash_lookup,
1428 	.map_lookup_elem_sys_only = sock_hash_lookup_sys,
1429 	.map_release_uref	= sock_hash_release_progs,
1430 	.map_check_btf		= map_check_no_btf,
1431 	.map_btf_name		= "bpf_shtab",
1432 	.map_btf_id		= &sock_hash_map_btf_id,
1433 	.iter_seq_info		= &sock_hash_iter_seq_info,
1434 };
1435 
1436 static struct sk_psock_progs *sock_map_progs(struct bpf_map *map)
1437 {
1438 	switch (map->map_type) {
1439 	case BPF_MAP_TYPE_SOCKMAP:
1440 		return &container_of(map, struct bpf_stab, map)->progs;
1441 	case BPF_MAP_TYPE_SOCKHASH:
1442 		return &container_of(map, struct bpf_shtab, map)->progs;
1443 	default:
1444 		break;
1445 	}
1446 
1447 	return NULL;
1448 }
1449 
1450 static int sock_map_prog_update(struct bpf_map *map, struct bpf_prog *prog,
1451 				struct bpf_prog *old, u32 which)
1452 {
1453 	struct sk_psock_progs *progs = sock_map_progs(map);
1454 	struct bpf_prog **pprog;
1455 
1456 	if (!progs)
1457 		return -EOPNOTSUPP;
1458 
1459 	switch (which) {
1460 	case BPF_SK_MSG_VERDICT:
1461 		pprog = &progs->msg_parser;
1462 		break;
1463 #if IS_ENABLED(CONFIG_BPF_STREAM_PARSER)
1464 	case BPF_SK_SKB_STREAM_PARSER:
1465 		pprog = &progs->stream_parser;
1466 		break;
1467 #endif
1468 	case BPF_SK_SKB_STREAM_VERDICT:
1469 		pprog = &progs->stream_verdict;
1470 		break;
1471 	default:
1472 		return -EOPNOTSUPP;
1473 	}
1474 
1475 	if (old)
1476 		return psock_replace_prog(pprog, prog, old);
1477 
1478 	psock_set_prog(pprog, prog);
1479 	return 0;
1480 }
1481 
1482 static void sock_map_unlink(struct sock *sk, struct sk_psock_link *link)
1483 {
1484 	switch (link->map->map_type) {
1485 	case BPF_MAP_TYPE_SOCKMAP:
1486 		return sock_map_delete_from_link(link->map, sk,
1487 						 link->link_raw);
1488 	case BPF_MAP_TYPE_SOCKHASH:
1489 		return sock_hash_delete_from_link(link->map, sk,
1490 						  link->link_raw);
1491 	default:
1492 		break;
1493 	}
1494 }
1495 
1496 static void sock_map_remove_links(struct sock *sk, struct sk_psock *psock)
1497 {
1498 	struct sk_psock_link *link;
1499 
1500 	while ((link = sk_psock_link_pop(psock))) {
1501 		sock_map_unlink(sk, link);
1502 		sk_psock_free_link(link);
1503 	}
1504 }
1505 
1506 void sock_map_unhash(struct sock *sk)
1507 {
1508 	void (*saved_unhash)(struct sock *sk);
1509 	struct sk_psock *psock;
1510 
1511 	rcu_read_lock();
1512 	psock = sk_psock(sk);
1513 	if (unlikely(!psock)) {
1514 		rcu_read_unlock();
1515 		if (sk->sk_prot->unhash)
1516 			sk->sk_prot->unhash(sk);
1517 		return;
1518 	}
1519 
1520 	saved_unhash = psock->saved_unhash;
1521 	sock_map_remove_links(sk, psock);
1522 	rcu_read_unlock();
1523 	saved_unhash(sk);
1524 }
1525 
1526 void sock_map_close(struct sock *sk, long timeout)
1527 {
1528 	void (*saved_close)(struct sock *sk, long timeout);
1529 	struct sk_psock *psock;
1530 
1531 	lock_sock(sk);
1532 	rcu_read_lock();
1533 	psock = sk_psock(sk);
1534 	if (unlikely(!psock)) {
1535 		rcu_read_unlock();
1536 		release_sock(sk);
1537 		return sk->sk_prot->close(sk, timeout);
1538 	}
1539 
1540 	saved_close = psock->saved_close;
1541 	sock_map_remove_links(sk, psock);
1542 	rcu_read_unlock();
1543 	release_sock(sk);
1544 	saved_close(sk, timeout);
1545 }
1546 
1547 static int sock_map_iter_attach_target(struct bpf_prog *prog,
1548 				       union bpf_iter_link_info *linfo,
1549 				       struct bpf_iter_aux_info *aux)
1550 {
1551 	struct bpf_map *map;
1552 	int err = -EINVAL;
1553 
1554 	if (!linfo->map.map_fd)
1555 		return -EBADF;
1556 
1557 	map = bpf_map_get_with_uref(linfo->map.map_fd);
1558 	if (IS_ERR(map))
1559 		return PTR_ERR(map);
1560 
1561 	if (map->map_type != BPF_MAP_TYPE_SOCKMAP &&
1562 	    map->map_type != BPF_MAP_TYPE_SOCKHASH)
1563 		goto put_map;
1564 
1565 	if (prog->aux->max_rdonly_access > map->key_size) {
1566 		err = -EACCES;
1567 		goto put_map;
1568 	}
1569 
1570 	aux->map = map;
1571 	return 0;
1572 
1573 put_map:
1574 	bpf_map_put_with_uref(map);
1575 	return err;
1576 }
1577 
1578 static void sock_map_iter_detach_target(struct bpf_iter_aux_info *aux)
1579 {
1580 	bpf_map_put_with_uref(aux->map);
1581 }
1582 
1583 static struct bpf_iter_reg sock_map_iter_reg = {
1584 	.target			= "sockmap",
1585 	.attach_target		= sock_map_iter_attach_target,
1586 	.detach_target		= sock_map_iter_detach_target,
1587 	.show_fdinfo		= bpf_iter_map_show_fdinfo,
1588 	.fill_link_info		= bpf_iter_map_fill_link_info,
1589 	.ctx_arg_info_size	= 2,
1590 	.ctx_arg_info		= {
1591 		{ offsetof(struct bpf_iter__sockmap, key),
1592 		  PTR_TO_RDONLY_BUF_OR_NULL },
1593 		{ offsetof(struct bpf_iter__sockmap, sk),
1594 		  PTR_TO_BTF_ID_OR_NULL },
1595 	},
1596 };
1597 
1598 static int __init bpf_sockmap_iter_init(void)
1599 {
1600 	sock_map_iter_reg.ctx_arg_info[1].btf_id =
1601 		btf_sock_ids[BTF_SOCK_TYPE_SOCK];
1602 	return bpf_iter_reg_target(&sock_map_iter_reg);
1603 }
1604 late_initcall(bpf_sockmap_iter_init);
1605