xref: /openbmc/linux/net/core/bpf_sk_storage.c (revision 4bb1eb3c)
1 // SPDX-License-Identifier: GPL-2.0
2 /* Copyright (c) 2019 Facebook  */
3 #include <linux/rculist.h>
4 #include <linux/list.h>
5 #include <linux/hash.h>
6 #include <linux/types.h>
7 #include <linux/spinlock.h>
8 #include <linux/bpf.h>
9 #include <linux/btf_ids.h>
10 #include <net/bpf_sk_storage.h>
11 #include <net/sock.h>
12 #include <uapi/linux/sock_diag.h>
13 #include <uapi/linux/btf.h>
14 
15 #define SK_STORAGE_CREATE_FLAG_MASK					\
16 	(BPF_F_NO_PREALLOC | BPF_F_CLONE)
17 
18 struct bucket {
19 	struct hlist_head list;
20 	raw_spinlock_t lock;
21 };
22 
23 /* Thp map is not the primary owner of a bpf_sk_storage_elem.
24  * Instead, the sk->sk_bpf_storage is.
25  *
26  * The map (bpf_sk_storage_map) is for two purposes
27  * 1. Define the size of the "sk local storage".  It is
28  *    the map's value_size.
29  *
30  * 2. Maintain a list to keep track of all elems such
31  *    that they can be cleaned up during the map destruction.
32  *
33  * When a bpf local storage is being looked up for a
34  * particular sk,  the "bpf_map" pointer is actually used
35  * as the "key" to search in the list of elem in
36  * sk->sk_bpf_storage.
37  *
38  * Hence, consider sk->sk_bpf_storage is the mini-map
39  * with the "bpf_map" pointer as the searching key.
40  */
41 struct bpf_sk_storage_map {
42 	struct bpf_map map;
43 	/* Lookup elem does not require accessing the map.
44 	 *
45 	 * Updating/Deleting requires a bucket lock to
46 	 * link/unlink the elem from the map.  Having
47 	 * multiple buckets to improve contention.
48 	 */
49 	struct bucket *buckets;
50 	u32 bucket_log;
51 	u16 elem_size;
52 	u16 cache_idx;
53 };
54 
55 struct bpf_sk_storage_data {
56 	/* smap is used as the searching key when looking up
57 	 * from sk->sk_bpf_storage.
58 	 *
59 	 * Put it in the same cacheline as the data to minimize
60 	 * the number of cachelines access during the cache hit case.
61 	 */
62 	struct bpf_sk_storage_map __rcu *smap;
63 	u8 data[] __aligned(8);
64 };
65 
66 /* Linked to bpf_sk_storage and bpf_sk_storage_map */
67 struct bpf_sk_storage_elem {
68 	struct hlist_node map_node;	/* Linked to bpf_sk_storage_map */
69 	struct hlist_node snode;	/* Linked to bpf_sk_storage */
70 	struct bpf_sk_storage __rcu *sk_storage;
71 	struct rcu_head rcu;
72 	/* 8 bytes hole */
73 	/* The data is stored in aother cacheline to minimize
74 	 * the number of cachelines access during a cache hit.
75 	 */
76 	struct bpf_sk_storage_data sdata ____cacheline_aligned;
77 };
78 
79 #define SELEM(_SDATA) container_of((_SDATA), struct bpf_sk_storage_elem, sdata)
80 #define SDATA(_SELEM) (&(_SELEM)->sdata)
81 #define BPF_SK_STORAGE_CACHE_SIZE	16
82 
83 static DEFINE_SPINLOCK(cache_idx_lock);
84 static u64 cache_idx_usage_counts[BPF_SK_STORAGE_CACHE_SIZE];
85 
86 struct bpf_sk_storage {
87 	struct bpf_sk_storage_data __rcu *cache[BPF_SK_STORAGE_CACHE_SIZE];
88 	struct hlist_head list;	/* List of bpf_sk_storage_elem */
89 	struct sock *sk;	/* The sk that owns the the above "list" of
90 				 * bpf_sk_storage_elem.
91 				 */
92 	struct rcu_head rcu;
93 	raw_spinlock_t lock;	/* Protect adding/removing from the "list" */
94 };
95 
96 static struct bucket *select_bucket(struct bpf_sk_storage_map *smap,
97 				    struct bpf_sk_storage_elem *selem)
98 {
99 	return &smap->buckets[hash_ptr(selem, smap->bucket_log)];
100 }
101 
102 static int omem_charge(struct sock *sk, unsigned int size)
103 {
104 	/* same check as in sock_kmalloc() */
105 	if (size <= sysctl_optmem_max &&
106 	    atomic_read(&sk->sk_omem_alloc) + size < sysctl_optmem_max) {
107 		atomic_add(size, &sk->sk_omem_alloc);
108 		return 0;
109 	}
110 
111 	return -ENOMEM;
112 }
113 
114 static bool selem_linked_to_sk(const struct bpf_sk_storage_elem *selem)
115 {
116 	return !hlist_unhashed(&selem->snode);
117 }
118 
119 static bool selem_linked_to_map(const struct bpf_sk_storage_elem *selem)
120 {
121 	return !hlist_unhashed(&selem->map_node);
122 }
123 
124 static struct bpf_sk_storage_elem *selem_alloc(struct bpf_sk_storage_map *smap,
125 					       struct sock *sk, void *value,
126 					       bool charge_omem)
127 {
128 	struct bpf_sk_storage_elem *selem;
129 
130 	if (charge_omem && omem_charge(sk, smap->elem_size))
131 		return NULL;
132 
133 	selem = kzalloc(smap->elem_size, GFP_ATOMIC | __GFP_NOWARN);
134 	if (selem) {
135 		if (value)
136 			memcpy(SDATA(selem)->data, value, smap->map.value_size);
137 		return selem;
138 	}
139 
140 	if (charge_omem)
141 		atomic_sub(smap->elem_size, &sk->sk_omem_alloc);
142 
143 	return NULL;
144 }
145 
146 /* sk_storage->lock must be held and selem->sk_storage == sk_storage.
147  * The caller must ensure selem->smap is still valid to be
148  * dereferenced for its smap->elem_size and smap->cache_idx.
149  */
150 static bool __selem_unlink_sk(struct bpf_sk_storage *sk_storage,
151 			      struct bpf_sk_storage_elem *selem,
152 			      bool uncharge_omem)
153 {
154 	struct bpf_sk_storage_map *smap;
155 	bool free_sk_storage;
156 	struct sock *sk;
157 
158 	smap = rcu_dereference(SDATA(selem)->smap);
159 	sk = sk_storage->sk;
160 
161 	/* All uncharging on sk->sk_omem_alloc must be done first.
162 	 * sk may be freed once the last selem is unlinked from sk_storage.
163 	 */
164 	if (uncharge_omem)
165 		atomic_sub(smap->elem_size, &sk->sk_omem_alloc);
166 
167 	free_sk_storage = hlist_is_singular_node(&selem->snode,
168 						 &sk_storage->list);
169 	if (free_sk_storage) {
170 		atomic_sub(sizeof(struct bpf_sk_storage), &sk->sk_omem_alloc);
171 		sk_storage->sk = NULL;
172 		/* After this RCU_INIT, sk may be freed and cannot be used */
173 		RCU_INIT_POINTER(sk->sk_bpf_storage, NULL);
174 
175 		/* sk_storage is not freed now.  sk_storage->lock is
176 		 * still held and raw_spin_unlock_bh(&sk_storage->lock)
177 		 * will be done by the caller.
178 		 *
179 		 * Although the unlock will be done under
180 		 * rcu_read_lock(),  it is more intutivie to
181 		 * read if kfree_rcu(sk_storage, rcu) is done
182 		 * after the raw_spin_unlock_bh(&sk_storage->lock).
183 		 *
184 		 * Hence, a "bool free_sk_storage" is returned
185 		 * to the caller which then calls the kfree_rcu()
186 		 * after unlock.
187 		 */
188 	}
189 	hlist_del_init_rcu(&selem->snode);
190 	if (rcu_access_pointer(sk_storage->cache[smap->cache_idx]) ==
191 	    SDATA(selem))
192 		RCU_INIT_POINTER(sk_storage->cache[smap->cache_idx], NULL);
193 
194 	kfree_rcu(selem, rcu);
195 
196 	return free_sk_storage;
197 }
198 
199 static void selem_unlink_sk(struct bpf_sk_storage_elem *selem)
200 {
201 	struct bpf_sk_storage *sk_storage;
202 	bool free_sk_storage = false;
203 
204 	if (unlikely(!selem_linked_to_sk(selem)))
205 		/* selem has already been unlinked from sk */
206 		return;
207 
208 	sk_storage = rcu_dereference(selem->sk_storage);
209 	raw_spin_lock_bh(&sk_storage->lock);
210 	if (likely(selem_linked_to_sk(selem)))
211 		free_sk_storage = __selem_unlink_sk(sk_storage, selem, true);
212 	raw_spin_unlock_bh(&sk_storage->lock);
213 
214 	if (free_sk_storage)
215 		kfree_rcu(sk_storage, rcu);
216 }
217 
218 static void __selem_link_sk(struct bpf_sk_storage *sk_storage,
219 			    struct bpf_sk_storage_elem *selem)
220 {
221 	RCU_INIT_POINTER(selem->sk_storage, sk_storage);
222 	hlist_add_head(&selem->snode, &sk_storage->list);
223 }
224 
225 static void selem_unlink_map(struct bpf_sk_storage_elem *selem)
226 {
227 	struct bpf_sk_storage_map *smap;
228 	struct bucket *b;
229 
230 	if (unlikely(!selem_linked_to_map(selem)))
231 		/* selem has already be unlinked from smap */
232 		return;
233 
234 	smap = rcu_dereference(SDATA(selem)->smap);
235 	b = select_bucket(smap, selem);
236 	raw_spin_lock_bh(&b->lock);
237 	if (likely(selem_linked_to_map(selem)))
238 		hlist_del_init_rcu(&selem->map_node);
239 	raw_spin_unlock_bh(&b->lock);
240 }
241 
242 static void selem_link_map(struct bpf_sk_storage_map *smap,
243 			   struct bpf_sk_storage_elem *selem)
244 {
245 	struct bucket *b = select_bucket(smap, selem);
246 
247 	raw_spin_lock_bh(&b->lock);
248 	RCU_INIT_POINTER(SDATA(selem)->smap, smap);
249 	hlist_add_head_rcu(&selem->map_node, &b->list);
250 	raw_spin_unlock_bh(&b->lock);
251 }
252 
253 static void selem_unlink(struct bpf_sk_storage_elem *selem)
254 {
255 	/* Always unlink from map before unlinking from sk_storage
256 	 * because selem will be freed after successfully unlinked from
257 	 * the sk_storage.
258 	 */
259 	selem_unlink_map(selem);
260 	selem_unlink_sk(selem);
261 }
262 
263 static struct bpf_sk_storage_data *
264 __sk_storage_lookup(struct bpf_sk_storage *sk_storage,
265 		    struct bpf_sk_storage_map *smap,
266 		    bool cacheit_lockit)
267 {
268 	struct bpf_sk_storage_data *sdata;
269 	struct bpf_sk_storage_elem *selem;
270 
271 	/* Fast path (cache hit) */
272 	sdata = rcu_dereference(sk_storage->cache[smap->cache_idx]);
273 	if (sdata && rcu_access_pointer(sdata->smap) == smap)
274 		return sdata;
275 
276 	/* Slow path (cache miss) */
277 	hlist_for_each_entry_rcu(selem, &sk_storage->list, snode)
278 		if (rcu_access_pointer(SDATA(selem)->smap) == smap)
279 			break;
280 
281 	if (!selem)
282 		return NULL;
283 
284 	sdata = SDATA(selem);
285 	if (cacheit_lockit) {
286 		/* spinlock is needed to avoid racing with the
287 		 * parallel delete.  Otherwise, publishing an already
288 		 * deleted sdata to the cache will become a use-after-free
289 		 * problem in the next __sk_storage_lookup().
290 		 */
291 		raw_spin_lock_bh(&sk_storage->lock);
292 		if (selem_linked_to_sk(selem))
293 			rcu_assign_pointer(sk_storage->cache[smap->cache_idx],
294 					   sdata);
295 		raw_spin_unlock_bh(&sk_storage->lock);
296 	}
297 
298 	return sdata;
299 }
300 
301 static struct bpf_sk_storage_data *
302 sk_storage_lookup(struct sock *sk, struct bpf_map *map, bool cacheit_lockit)
303 {
304 	struct bpf_sk_storage *sk_storage;
305 	struct bpf_sk_storage_map *smap;
306 
307 	sk_storage = rcu_dereference(sk->sk_bpf_storage);
308 	if (!sk_storage)
309 		return NULL;
310 
311 	smap = (struct bpf_sk_storage_map *)map;
312 	return __sk_storage_lookup(sk_storage, smap, cacheit_lockit);
313 }
314 
315 static int check_flags(const struct bpf_sk_storage_data *old_sdata,
316 		       u64 map_flags)
317 {
318 	if (old_sdata && (map_flags & ~BPF_F_LOCK) == BPF_NOEXIST)
319 		/* elem already exists */
320 		return -EEXIST;
321 
322 	if (!old_sdata && (map_flags & ~BPF_F_LOCK) == BPF_EXIST)
323 		/* elem doesn't exist, cannot update it */
324 		return -ENOENT;
325 
326 	return 0;
327 }
328 
329 static int sk_storage_alloc(struct sock *sk,
330 			    struct bpf_sk_storage_map *smap,
331 			    struct bpf_sk_storage_elem *first_selem)
332 {
333 	struct bpf_sk_storage *prev_sk_storage, *sk_storage;
334 	int err;
335 
336 	err = omem_charge(sk, sizeof(*sk_storage));
337 	if (err)
338 		return err;
339 
340 	sk_storage = kzalloc(sizeof(*sk_storage), GFP_ATOMIC | __GFP_NOWARN);
341 	if (!sk_storage) {
342 		err = -ENOMEM;
343 		goto uncharge;
344 	}
345 	INIT_HLIST_HEAD(&sk_storage->list);
346 	raw_spin_lock_init(&sk_storage->lock);
347 	sk_storage->sk = sk;
348 
349 	__selem_link_sk(sk_storage, first_selem);
350 	selem_link_map(smap, first_selem);
351 	/* Publish sk_storage to sk.  sk->sk_lock cannot be acquired.
352 	 * Hence, atomic ops is used to set sk->sk_bpf_storage
353 	 * from NULL to the newly allocated sk_storage ptr.
354 	 *
355 	 * From now on, the sk->sk_bpf_storage pointer is protected
356 	 * by the sk_storage->lock.  Hence,  when freeing
357 	 * the sk->sk_bpf_storage, the sk_storage->lock must
358 	 * be held before setting sk->sk_bpf_storage to NULL.
359 	 */
360 	prev_sk_storage = cmpxchg((struct bpf_sk_storage **)&sk->sk_bpf_storage,
361 				  NULL, sk_storage);
362 	if (unlikely(prev_sk_storage)) {
363 		selem_unlink_map(first_selem);
364 		err = -EAGAIN;
365 		goto uncharge;
366 
367 		/* Note that even first_selem was linked to smap's
368 		 * bucket->list, first_selem can be freed immediately
369 		 * (instead of kfree_rcu) because
370 		 * bpf_sk_storage_map_free() does a
371 		 * synchronize_rcu() before walking the bucket->list.
372 		 * Hence, no one is accessing selem from the
373 		 * bucket->list under rcu_read_lock().
374 		 */
375 	}
376 
377 	return 0;
378 
379 uncharge:
380 	kfree(sk_storage);
381 	atomic_sub(sizeof(*sk_storage), &sk->sk_omem_alloc);
382 	return err;
383 }
384 
385 /* sk cannot be going away because it is linking new elem
386  * to sk->sk_bpf_storage. (i.e. sk->sk_refcnt cannot be 0).
387  * Otherwise, it will become a leak (and other memory issues
388  * during map destruction).
389  */
390 static struct bpf_sk_storage_data *sk_storage_update(struct sock *sk,
391 						     struct bpf_map *map,
392 						     void *value,
393 						     u64 map_flags)
394 {
395 	struct bpf_sk_storage_data *old_sdata = NULL;
396 	struct bpf_sk_storage_elem *selem;
397 	struct bpf_sk_storage *sk_storage;
398 	struct bpf_sk_storage_map *smap;
399 	int err;
400 
401 	/* BPF_EXIST and BPF_NOEXIST cannot be both set */
402 	if (unlikely((map_flags & ~BPF_F_LOCK) > BPF_EXIST) ||
403 	    /* BPF_F_LOCK can only be used in a value with spin_lock */
404 	    unlikely((map_flags & BPF_F_LOCK) && !map_value_has_spin_lock(map)))
405 		return ERR_PTR(-EINVAL);
406 
407 	smap = (struct bpf_sk_storage_map *)map;
408 	sk_storage = rcu_dereference(sk->sk_bpf_storage);
409 	if (!sk_storage || hlist_empty(&sk_storage->list)) {
410 		/* Very first elem for this sk */
411 		err = check_flags(NULL, map_flags);
412 		if (err)
413 			return ERR_PTR(err);
414 
415 		selem = selem_alloc(smap, sk, value, true);
416 		if (!selem)
417 			return ERR_PTR(-ENOMEM);
418 
419 		err = sk_storage_alloc(sk, smap, selem);
420 		if (err) {
421 			kfree(selem);
422 			atomic_sub(smap->elem_size, &sk->sk_omem_alloc);
423 			return ERR_PTR(err);
424 		}
425 
426 		return SDATA(selem);
427 	}
428 
429 	if ((map_flags & BPF_F_LOCK) && !(map_flags & BPF_NOEXIST)) {
430 		/* Hoping to find an old_sdata to do inline update
431 		 * such that it can avoid taking the sk_storage->lock
432 		 * and changing the lists.
433 		 */
434 		old_sdata = __sk_storage_lookup(sk_storage, smap, false);
435 		err = check_flags(old_sdata, map_flags);
436 		if (err)
437 			return ERR_PTR(err);
438 		if (old_sdata && selem_linked_to_sk(SELEM(old_sdata))) {
439 			copy_map_value_locked(map, old_sdata->data,
440 					      value, false);
441 			return old_sdata;
442 		}
443 	}
444 
445 	raw_spin_lock_bh(&sk_storage->lock);
446 
447 	/* Recheck sk_storage->list under sk_storage->lock */
448 	if (unlikely(hlist_empty(&sk_storage->list))) {
449 		/* A parallel del is happening and sk_storage is going
450 		 * away.  It has just been checked before, so very
451 		 * unlikely.  Return instead of retry to keep things
452 		 * simple.
453 		 */
454 		err = -EAGAIN;
455 		goto unlock_err;
456 	}
457 
458 	old_sdata = __sk_storage_lookup(sk_storage, smap, false);
459 	err = check_flags(old_sdata, map_flags);
460 	if (err)
461 		goto unlock_err;
462 
463 	if (old_sdata && (map_flags & BPF_F_LOCK)) {
464 		copy_map_value_locked(map, old_sdata->data, value, false);
465 		selem = SELEM(old_sdata);
466 		goto unlock;
467 	}
468 
469 	/* sk_storage->lock is held.  Hence, we are sure
470 	 * we can unlink and uncharge the old_sdata successfully
471 	 * later.  Hence, instead of charging the new selem now
472 	 * and then uncharge the old selem later (which may cause
473 	 * a potential but unnecessary charge failure),  avoid taking
474 	 * a charge at all here (the "!old_sdata" check) and the
475 	 * old_sdata will not be uncharged later during __selem_unlink_sk().
476 	 */
477 	selem = selem_alloc(smap, sk, value, !old_sdata);
478 	if (!selem) {
479 		err = -ENOMEM;
480 		goto unlock_err;
481 	}
482 
483 	/* First, link the new selem to the map */
484 	selem_link_map(smap, selem);
485 
486 	/* Second, link (and publish) the new selem to sk_storage */
487 	__selem_link_sk(sk_storage, selem);
488 
489 	/* Third, remove old selem, SELEM(old_sdata) */
490 	if (old_sdata) {
491 		selem_unlink_map(SELEM(old_sdata));
492 		__selem_unlink_sk(sk_storage, SELEM(old_sdata), false);
493 	}
494 
495 unlock:
496 	raw_spin_unlock_bh(&sk_storage->lock);
497 	return SDATA(selem);
498 
499 unlock_err:
500 	raw_spin_unlock_bh(&sk_storage->lock);
501 	return ERR_PTR(err);
502 }
503 
504 static int sk_storage_delete(struct sock *sk, struct bpf_map *map)
505 {
506 	struct bpf_sk_storage_data *sdata;
507 
508 	sdata = sk_storage_lookup(sk, map, false);
509 	if (!sdata)
510 		return -ENOENT;
511 
512 	selem_unlink(SELEM(sdata));
513 
514 	return 0;
515 }
516 
517 static u16 cache_idx_get(void)
518 {
519 	u64 min_usage = U64_MAX;
520 	u16 i, res = 0;
521 
522 	spin_lock(&cache_idx_lock);
523 
524 	for (i = 0; i < BPF_SK_STORAGE_CACHE_SIZE; i++) {
525 		if (cache_idx_usage_counts[i] < min_usage) {
526 			min_usage = cache_idx_usage_counts[i];
527 			res = i;
528 
529 			/* Found a free cache_idx */
530 			if (!min_usage)
531 				break;
532 		}
533 	}
534 	cache_idx_usage_counts[res]++;
535 
536 	spin_unlock(&cache_idx_lock);
537 
538 	return res;
539 }
540 
541 static void cache_idx_free(u16 idx)
542 {
543 	spin_lock(&cache_idx_lock);
544 	cache_idx_usage_counts[idx]--;
545 	spin_unlock(&cache_idx_lock);
546 }
547 
548 /* Called by __sk_destruct() & bpf_sk_storage_clone() */
549 void bpf_sk_storage_free(struct sock *sk)
550 {
551 	struct bpf_sk_storage_elem *selem;
552 	struct bpf_sk_storage *sk_storage;
553 	bool free_sk_storage = false;
554 	struct hlist_node *n;
555 
556 	rcu_read_lock();
557 	sk_storage = rcu_dereference(sk->sk_bpf_storage);
558 	if (!sk_storage) {
559 		rcu_read_unlock();
560 		return;
561 	}
562 
563 	/* Netiher the bpf_prog nor the bpf-map's syscall
564 	 * could be modifying the sk_storage->list now.
565 	 * Thus, no elem can be added-to or deleted-from the
566 	 * sk_storage->list by the bpf_prog or by the bpf-map's syscall.
567 	 *
568 	 * It is racing with bpf_sk_storage_map_free() alone
569 	 * when unlinking elem from the sk_storage->list and
570 	 * the map's bucket->list.
571 	 */
572 	raw_spin_lock_bh(&sk_storage->lock);
573 	hlist_for_each_entry_safe(selem, n, &sk_storage->list, snode) {
574 		/* Always unlink from map before unlinking from
575 		 * sk_storage.
576 		 */
577 		selem_unlink_map(selem);
578 		free_sk_storage = __selem_unlink_sk(sk_storage, selem, true);
579 	}
580 	raw_spin_unlock_bh(&sk_storage->lock);
581 	rcu_read_unlock();
582 
583 	if (free_sk_storage)
584 		kfree_rcu(sk_storage, rcu);
585 }
586 
587 static void bpf_sk_storage_map_free(struct bpf_map *map)
588 {
589 	struct bpf_sk_storage_elem *selem;
590 	struct bpf_sk_storage_map *smap;
591 	struct bucket *b;
592 	unsigned int i;
593 
594 	smap = (struct bpf_sk_storage_map *)map;
595 
596 	cache_idx_free(smap->cache_idx);
597 
598 	/* Note that this map might be concurrently cloned from
599 	 * bpf_sk_storage_clone. Wait for any existing bpf_sk_storage_clone
600 	 * RCU read section to finish before proceeding. New RCU
601 	 * read sections should be prevented via bpf_map_inc_not_zero.
602 	 */
603 	synchronize_rcu();
604 
605 	/* bpf prog and the userspace can no longer access this map
606 	 * now.  No new selem (of this map) can be added
607 	 * to the sk->sk_bpf_storage or to the map bucket's list.
608 	 *
609 	 * The elem of this map can be cleaned up here
610 	 * or
611 	 * by bpf_sk_storage_free() during __sk_destruct().
612 	 */
613 	for (i = 0; i < (1U << smap->bucket_log); i++) {
614 		b = &smap->buckets[i];
615 
616 		rcu_read_lock();
617 		/* No one is adding to b->list now */
618 		while ((selem = hlist_entry_safe(rcu_dereference_raw(hlist_first_rcu(&b->list)),
619 						 struct bpf_sk_storage_elem,
620 						 map_node))) {
621 			selem_unlink(selem);
622 			cond_resched_rcu();
623 		}
624 		rcu_read_unlock();
625 	}
626 
627 	/* bpf_sk_storage_free() may still need to access the map.
628 	 * e.g. bpf_sk_storage_free() has unlinked selem from the map
629 	 * which then made the above while((selem = ...)) loop
630 	 * exited immediately.
631 	 *
632 	 * However, the bpf_sk_storage_free() still needs to access
633 	 * the smap->elem_size to do the uncharging in
634 	 * __selem_unlink_sk().
635 	 *
636 	 * Hence, wait another rcu grace period for the
637 	 * bpf_sk_storage_free() to finish.
638 	 */
639 	synchronize_rcu();
640 
641 	kvfree(smap->buckets);
642 	kfree(map);
643 }
644 
645 /* U16_MAX is much more than enough for sk local storage
646  * considering a tcp_sock is ~2k.
647  */
648 #define MAX_VALUE_SIZE							\
649 	min_t(u32,							\
650 	      (KMALLOC_MAX_SIZE - MAX_BPF_STACK - sizeof(struct bpf_sk_storage_elem)), \
651 	      (U16_MAX - sizeof(struct bpf_sk_storage_elem)))
652 
653 static int bpf_sk_storage_map_alloc_check(union bpf_attr *attr)
654 {
655 	if (attr->map_flags & ~SK_STORAGE_CREATE_FLAG_MASK ||
656 	    !(attr->map_flags & BPF_F_NO_PREALLOC) ||
657 	    attr->max_entries ||
658 	    attr->key_size != sizeof(int) || !attr->value_size ||
659 	    /* Enforce BTF for userspace sk dumping */
660 	    !attr->btf_key_type_id || !attr->btf_value_type_id)
661 		return -EINVAL;
662 
663 	if (!bpf_capable())
664 		return -EPERM;
665 
666 	if (attr->value_size > MAX_VALUE_SIZE)
667 		return -E2BIG;
668 
669 	return 0;
670 }
671 
672 static struct bpf_map *bpf_sk_storage_map_alloc(union bpf_attr *attr)
673 {
674 	struct bpf_sk_storage_map *smap;
675 	unsigned int i;
676 	u32 nbuckets;
677 	u64 cost;
678 	int ret;
679 
680 	smap = kzalloc(sizeof(*smap), GFP_USER | __GFP_NOWARN);
681 	if (!smap)
682 		return ERR_PTR(-ENOMEM);
683 	bpf_map_init_from_attr(&smap->map, attr);
684 
685 	nbuckets = roundup_pow_of_two(num_possible_cpus());
686 	/* Use at least 2 buckets, select_bucket() is undefined behavior with 1 bucket */
687 	nbuckets = max_t(u32, 2, nbuckets);
688 	smap->bucket_log = ilog2(nbuckets);
689 	cost = sizeof(*smap->buckets) * nbuckets + sizeof(*smap);
690 
691 	ret = bpf_map_charge_init(&smap->map.memory, cost);
692 	if (ret < 0) {
693 		kfree(smap);
694 		return ERR_PTR(ret);
695 	}
696 
697 	smap->buckets = kvcalloc(sizeof(*smap->buckets), nbuckets,
698 				 GFP_USER | __GFP_NOWARN);
699 	if (!smap->buckets) {
700 		bpf_map_charge_finish(&smap->map.memory);
701 		kfree(smap);
702 		return ERR_PTR(-ENOMEM);
703 	}
704 
705 	for (i = 0; i < nbuckets; i++) {
706 		INIT_HLIST_HEAD(&smap->buckets[i].list);
707 		raw_spin_lock_init(&smap->buckets[i].lock);
708 	}
709 
710 	smap->elem_size = sizeof(struct bpf_sk_storage_elem) + attr->value_size;
711 	smap->cache_idx = cache_idx_get();
712 
713 	return &smap->map;
714 }
715 
716 static int notsupp_get_next_key(struct bpf_map *map, void *key,
717 				void *next_key)
718 {
719 	return -ENOTSUPP;
720 }
721 
722 static int bpf_sk_storage_map_check_btf(const struct bpf_map *map,
723 					const struct btf *btf,
724 					const struct btf_type *key_type,
725 					const struct btf_type *value_type)
726 {
727 	u32 int_data;
728 
729 	if (BTF_INFO_KIND(key_type->info) != BTF_KIND_INT)
730 		return -EINVAL;
731 
732 	int_data = *(u32 *)(key_type + 1);
733 	if (BTF_INT_BITS(int_data) != 32 || BTF_INT_OFFSET(int_data))
734 		return -EINVAL;
735 
736 	return 0;
737 }
738 
739 static void *bpf_fd_sk_storage_lookup_elem(struct bpf_map *map, void *key)
740 {
741 	struct bpf_sk_storage_data *sdata;
742 	struct socket *sock;
743 	int fd, err;
744 
745 	fd = *(int *)key;
746 	sock = sockfd_lookup(fd, &err);
747 	if (sock) {
748 		sdata = sk_storage_lookup(sock->sk, map, true);
749 		sockfd_put(sock);
750 		return sdata ? sdata->data : NULL;
751 	}
752 
753 	return ERR_PTR(err);
754 }
755 
756 static int bpf_fd_sk_storage_update_elem(struct bpf_map *map, void *key,
757 					 void *value, u64 map_flags)
758 {
759 	struct bpf_sk_storage_data *sdata;
760 	struct socket *sock;
761 	int fd, err;
762 
763 	fd = *(int *)key;
764 	sock = sockfd_lookup(fd, &err);
765 	if (sock) {
766 		sdata = sk_storage_update(sock->sk, map, value, map_flags);
767 		sockfd_put(sock);
768 		return PTR_ERR_OR_ZERO(sdata);
769 	}
770 
771 	return err;
772 }
773 
774 static int bpf_fd_sk_storage_delete_elem(struct bpf_map *map, void *key)
775 {
776 	struct socket *sock;
777 	int fd, err;
778 
779 	fd = *(int *)key;
780 	sock = sockfd_lookup(fd, &err);
781 	if (sock) {
782 		err = sk_storage_delete(sock->sk, map);
783 		sockfd_put(sock);
784 		return err;
785 	}
786 
787 	return err;
788 }
789 
790 static struct bpf_sk_storage_elem *
791 bpf_sk_storage_clone_elem(struct sock *newsk,
792 			  struct bpf_sk_storage_map *smap,
793 			  struct bpf_sk_storage_elem *selem)
794 {
795 	struct bpf_sk_storage_elem *copy_selem;
796 
797 	copy_selem = selem_alloc(smap, newsk, NULL, true);
798 	if (!copy_selem)
799 		return NULL;
800 
801 	if (map_value_has_spin_lock(&smap->map))
802 		copy_map_value_locked(&smap->map, SDATA(copy_selem)->data,
803 				      SDATA(selem)->data, true);
804 	else
805 		copy_map_value(&smap->map, SDATA(copy_selem)->data,
806 			       SDATA(selem)->data);
807 
808 	return copy_selem;
809 }
810 
811 int bpf_sk_storage_clone(const struct sock *sk, struct sock *newsk)
812 {
813 	struct bpf_sk_storage *new_sk_storage = NULL;
814 	struct bpf_sk_storage *sk_storage;
815 	struct bpf_sk_storage_elem *selem;
816 	int ret = 0;
817 
818 	RCU_INIT_POINTER(newsk->sk_bpf_storage, NULL);
819 
820 	rcu_read_lock();
821 	sk_storage = rcu_dereference(sk->sk_bpf_storage);
822 
823 	if (!sk_storage || hlist_empty(&sk_storage->list))
824 		goto out;
825 
826 	hlist_for_each_entry_rcu(selem, &sk_storage->list, snode) {
827 		struct bpf_sk_storage_elem *copy_selem;
828 		struct bpf_sk_storage_map *smap;
829 		struct bpf_map *map;
830 
831 		smap = rcu_dereference(SDATA(selem)->smap);
832 		if (!(smap->map.map_flags & BPF_F_CLONE))
833 			continue;
834 
835 		/* Note that for lockless listeners adding new element
836 		 * here can race with cleanup in bpf_sk_storage_map_free.
837 		 * Try to grab map refcnt to make sure that it's still
838 		 * alive and prevent concurrent removal.
839 		 */
840 		map = bpf_map_inc_not_zero(&smap->map);
841 		if (IS_ERR(map))
842 			continue;
843 
844 		copy_selem = bpf_sk_storage_clone_elem(newsk, smap, selem);
845 		if (!copy_selem) {
846 			ret = -ENOMEM;
847 			bpf_map_put(map);
848 			goto out;
849 		}
850 
851 		if (new_sk_storage) {
852 			selem_link_map(smap, copy_selem);
853 			__selem_link_sk(new_sk_storage, copy_selem);
854 		} else {
855 			ret = sk_storage_alloc(newsk, smap, copy_selem);
856 			if (ret) {
857 				kfree(copy_selem);
858 				atomic_sub(smap->elem_size,
859 					   &newsk->sk_omem_alloc);
860 				bpf_map_put(map);
861 				goto out;
862 			}
863 
864 			new_sk_storage = rcu_dereference(copy_selem->sk_storage);
865 		}
866 		bpf_map_put(map);
867 	}
868 
869 out:
870 	rcu_read_unlock();
871 
872 	/* In case of an error, don't free anything explicitly here, the
873 	 * caller is responsible to call bpf_sk_storage_free.
874 	 */
875 
876 	return ret;
877 }
878 
879 BPF_CALL_4(bpf_sk_storage_get, struct bpf_map *, map, struct sock *, sk,
880 	   void *, value, u64, flags)
881 {
882 	struct bpf_sk_storage_data *sdata;
883 
884 	if (flags > BPF_SK_STORAGE_GET_F_CREATE)
885 		return (unsigned long)NULL;
886 
887 	sdata = sk_storage_lookup(sk, map, true);
888 	if (sdata)
889 		return (unsigned long)sdata->data;
890 
891 	if (flags == BPF_SK_STORAGE_GET_F_CREATE &&
892 	    /* Cannot add new elem to a going away sk.
893 	     * Otherwise, the new elem may become a leak
894 	     * (and also other memory issues during map
895 	     *  destruction).
896 	     */
897 	    refcount_inc_not_zero(&sk->sk_refcnt)) {
898 		sdata = sk_storage_update(sk, map, value, BPF_NOEXIST);
899 		/* sk must be a fullsock (guaranteed by verifier),
900 		 * so sock_gen_put() is unnecessary.
901 		 */
902 		sock_put(sk);
903 		return IS_ERR(sdata) ?
904 			(unsigned long)NULL : (unsigned long)sdata->data;
905 	}
906 
907 	return (unsigned long)NULL;
908 }
909 
910 BPF_CALL_2(bpf_sk_storage_delete, struct bpf_map *, map, struct sock *, sk)
911 {
912 	if (refcount_inc_not_zero(&sk->sk_refcnt)) {
913 		int err;
914 
915 		err = sk_storage_delete(sk, map);
916 		sock_put(sk);
917 		return err;
918 	}
919 
920 	return -ENOENT;
921 }
922 
923 static int sk_storage_map_btf_id;
924 const struct bpf_map_ops sk_storage_map_ops = {
925 	.map_alloc_check = bpf_sk_storage_map_alloc_check,
926 	.map_alloc = bpf_sk_storage_map_alloc,
927 	.map_free = bpf_sk_storage_map_free,
928 	.map_get_next_key = notsupp_get_next_key,
929 	.map_lookup_elem = bpf_fd_sk_storage_lookup_elem,
930 	.map_update_elem = bpf_fd_sk_storage_update_elem,
931 	.map_delete_elem = bpf_fd_sk_storage_delete_elem,
932 	.map_check_btf = bpf_sk_storage_map_check_btf,
933 	.map_btf_name = "bpf_sk_storage_map",
934 	.map_btf_id = &sk_storage_map_btf_id,
935 };
936 
937 const struct bpf_func_proto bpf_sk_storage_get_proto = {
938 	.func		= bpf_sk_storage_get,
939 	.gpl_only	= false,
940 	.ret_type	= RET_PTR_TO_MAP_VALUE_OR_NULL,
941 	.arg1_type	= ARG_CONST_MAP_PTR,
942 	.arg2_type	= ARG_PTR_TO_SOCKET,
943 	.arg3_type	= ARG_PTR_TO_MAP_VALUE_OR_NULL,
944 	.arg4_type	= ARG_ANYTHING,
945 };
946 
947 const struct bpf_func_proto bpf_sk_storage_get_cg_sock_proto = {
948 	.func		= bpf_sk_storage_get,
949 	.gpl_only	= false,
950 	.ret_type	= RET_PTR_TO_MAP_VALUE_OR_NULL,
951 	.arg1_type	= ARG_CONST_MAP_PTR,
952 	.arg2_type	= ARG_PTR_TO_CTX, /* context is 'struct sock' */
953 	.arg3_type	= ARG_PTR_TO_MAP_VALUE_OR_NULL,
954 	.arg4_type	= ARG_ANYTHING,
955 };
956 
957 const struct bpf_func_proto bpf_sk_storage_delete_proto = {
958 	.func		= bpf_sk_storage_delete,
959 	.gpl_only	= false,
960 	.ret_type	= RET_INTEGER,
961 	.arg1_type	= ARG_CONST_MAP_PTR,
962 	.arg2_type	= ARG_PTR_TO_SOCKET,
963 };
964 
965 struct bpf_sk_storage_diag {
966 	u32 nr_maps;
967 	struct bpf_map *maps[];
968 };
969 
970 /* The reply will be like:
971  * INET_DIAG_BPF_SK_STORAGES (nla_nest)
972  *	SK_DIAG_BPF_STORAGE (nla_nest)
973  *		SK_DIAG_BPF_STORAGE_MAP_ID (nla_put_u32)
974  *		SK_DIAG_BPF_STORAGE_MAP_VALUE (nla_reserve_64bit)
975  *	SK_DIAG_BPF_STORAGE (nla_nest)
976  *		SK_DIAG_BPF_STORAGE_MAP_ID (nla_put_u32)
977  *		SK_DIAG_BPF_STORAGE_MAP_VALUE (nla_reserve_64bit)
978  *	....
979  */
980 static int nla_value_size(u32 value_size)
981 {
982 	/* SK_DIAG_BPF_STORAGE (nla_nest)
983 	 *	SK_DIAG_BPF_STORAGE_MAP_ID (nla_put_u32)
984 	 *	SK_DIAG_BPF_STORAGE_MAP_VALUE (nla_reserve_64bit)
985 	 */
986 	return nla_total_size(0) + nla_total_size(sizeof(u32)) +
987 		nla_total_size_64bit(value_size);
988 }
989 
990 void bpf_sk_storage_diag_free(struct bpf_sk_storage_diag *diag)
991 {
992 	u32 i;
993 
994 	if (!diag)
995 		return;
996 
997 	for (i = 0; i < diag->nr_maps; i++)
998 		bpf_map_put(diag->maps[i]);
999 
1000 	kfree(diag);
1001 }
1002 EXPORT_SYMBOL_GPL(bpf_sk_storage_diag_free);
1003 
1004 static bool diag_check_dup(const struct bpf_sk_storage_diag *diag,
1005 			   const struct bpf_map *map)
1006 {
1007 	u32 i;
1008 
1009 	for (i = 0; i < diag->nr_maps; i++) {
1010 		if (diag->maps[i] == map)
1011 			return true;
1012 	}
1013 
1014 	return false;
1015 }
1016 
1017 struct bpf_sk_storage_diag *
1018 bpf_sk_storage_diag_alloc(const struct nlattr *nla_stgs)
1019 {
1020 	struct bpf_sk_storage_diag *diag;
1021 	struct nlattr *nla;
1022 	u32 nr_maps = 0;
1023 	int rem, err;
1024 
1025 	/* bpf_sk_storage_map is currently limited to CAP_SYS_ADMIN as
1026 	 * the map_alloc_check() side also does.
1027 	 */
1028 	if (!bpf_capable())
1029 		return ERR_PTR(-EPERM);
1030 
1031 	nla_for_each_nested(nla, nla_stgs, rem) {
1032 		if (nla_type(nla) == SK_DIAG_BPF_STORAGE_REQ_MAP_FD)
1033 			nr_maps++;
1034 	}
1035 
1036 	diag = kzalloc(sizeof(*diag) + sizeof(diag->maps[0]) * nr_maps,
1037 		       GFP_KERNEL);
1038 	if (!diag)
1039 		return ERR_PTR(-ENOMEM);
1040 
1041 	nla_for_each_nested(nla, nla_stgs, rem) {
1042 		struct bpf_map *map;
1043 		int map_fd;
1044 
1045 		if (nla_type(nla) != SK_DIAG_BPF_STORAGE_REQ_MAP_FD)
1046 			continue;
1047 
1048 		map_fd = nla_get_u32(nla);
1049 		map = bpf_map_get(map_fd);
1050 		if (IS_ERR(map)) {
1051 			err = PTR_ERR(map);
1052 			goto err_free;
1053 		}
1054 		if (map->map_type != BPF_MAP_TYPE_SK_STORAGE) {
1055 			bpf_map_put(map);
1056 			err = -EINVAL;
1057 			goto err_free;
1058 		}
1059 		if (diag_check_dup(diag, map)) {
1060 			bpf_map_put(map);
1061 			err = -EEXIST;
1062 			goto err_free;
1063 		}
1064 		diag->maps[diag->nr_maps++] = map;
1065 	}
1066 
1067 	return diag;
1068 
1069 err_free:
1070 	bpf_sk_storage_diag_free(diag);
1071 	return ERR_PTR(err);
1072 }
1073 EXPORT_SYMBOL_GPL(bpf_sk_storage_diag_alloc);
1074 
1075 static int diag_get(struct bpf_sk_storage_data *sdata, struct sk_buff *skb)
1076 {
1077 	struct nlattr *nla_stg, *nla_value;
1078 	struct bpf_sk_storage_map *smap;
1079 
1080 	/* It cannot exceed max nlattr's payload */
1081 	BUILD_BUG_ON(U16_MAX - NLA_HDRLEN < MAX_VALUE_SIZE);
1082 
1083 	nla_stg = nla_nest_start(skb, SK_DIAG_BPF_STORAGE);
1084 	if (!nla_stg)
1085 		return -EMSGSIZE;
1086 
1087 	smap = rcu_dereference(sdata->smap);
1088 	if (nla_put_u32(skb, SK_DIAG_BPF_STORAGE_MAP_ID, smap->map.id))
1089 		goto errout;
1090 
1091 	nla_value = nla_reserve_64bit(skb, SK_DIAG_BPF_STORAGE_MAP_VALUE,
1092 				      smap->map.value_size,
1093 				      SK_DIAG_BPF_STORAGE_PAD);
1094 	if (!nla_value)
1095 		goto errout;
1096 
1097 	if (map_value_has_spin_lock(&smap->map))
1098 		copy_map_value_locked(&smap->map, nla_data(nla_value),
1099 				      sdata->data, true);
1100 	else
1101 		copy_map_value(&smap->map, nla_data(nla_value), sdata->data);
1102 
1103 	nla_nest_end(skb, nla_stg);
1104 	return 0;
1105 
1106 errout:
1107 	nla_nest_cancel(skb, nla_stg);
1108 	return -EMSGSIZE;
1109 }
1110 
1111 static int bpf_sk_storage_diag_put_all(struct sock *sk, struct sk_buff *skb,
1112 				       int stg_array_type,
1113 				       unsigned int *res_diag_size)
1114 {
1115 	/* stg_array_type (e.g. INET_DIAG_BPF_SK_STORAGES) */
1116 	unsigned int diag_size = nla_total_size(0);
1117 	struct bpf_sk_storage *sk_storage;
1118 	struct bpf_sk_storage_elem *selem;
1119 	struct bpf_sk_storage_map *smap;
1120 	struct nlattr *nla_stgs;
1121 	unsigned int saved_len;
1122 	int err = 0;
1123 
1124 	rcu_read_lock();
1125 
1126 	sk_storage = rcu_dereference(sk->sk_bpf_storage);
1127 	if (!sk_storage || hlist_empty(&sk_storage->list)) {
1128 		rcu_read_unlock();
1129 		return 0;
1130 	}
1131 
1132 	nla_stgs = nla_nest_start(skb, stg_array_type);
1133 	if (!nla_stgs)
1134 		/* Continue to learn diag_size */
1135 		err = -EMSGSIZE;
1136 
1137 	saved_len = skb->len;
1138 	hlist_for_each_entry_rcu(selem, &sk_storage->list, snode) {
1139 		smap = rcu_dereference(SDATA(selem)->smap);
1140 		diag_size += nla_value_size(smap->map.value_size);
1141 
1142 		if (nla_stgs && diag_get(SDATA(selem), skb))
1143 			/* Continue to learn diag_size */
1144 			err = -EMSGSIZE;
1145 	}
1146 
1147 	rcu_read_unlock();
1148 
1149 	if (nla_stgs) {
1150 		if (saved_len == skb->len)
1151 			nla_nest_cancel(skb, nla_stgs);
1152 		else
1153 			nla_nest_end(skb, nla_stgs);
1154 	}
1155 
1156 	if (diag_size == nla_total_size(0)) {
1157 		*res_diag_size = 0;
1158 		return 0;
1159 	}
1160 
1161 	*res_diag_size = diag_size;
1162 	return err;
1163 }
1164 
1165 int bpf_sk_storage_diag_put(struct bpf_sk_storage_diag *diag,
1166 			    struct sock *sk, struct sk_buff *skb,
1167 			    int stg_array_type,
1168 			    unsigned int *res_diag_size)
1169 {
1170 	/* stg_array_type (e.g. INET_DIAG_BPF_SK_STORAGES) */
1171 	unsigned int diag_size = nla_total_size(0);
1172 	struct bpf_sk_storage *sk_storage;
1173 	struct bpf_sk_storage_data *sdata;
1174 	struct nlattr *nla_stgs;
1175 	unsigned int saved_len;
1176 	int err = 0;
1177 	u32 i;
1178 
1179 	*res_diag_size = 0;
1180 
1181 	/* No map has been specified.  Dump all. */
1182 	if (!diag->nr_maps)
1183 		return bpf_sk_storage_diag_put_all(sk, skb, stg_array_type,
1184 						   res_diag_size);
1185 
1186 	rcu_read_lock();
1187 	sk_storage = rcu_dereference(sk->sk_bpf_storage);
1188 	if (!sk_storage || hlist_empty(&sk_storage->list)) {
1189 		rcu_read_unlock();
1190 		return 0;
1191 	}
1192 
1193 	nla_stgs = nla_nest_start(skb, stg_array_type);
1194 	if (!nla_stgs)
1195 		/* Continue to learn diag_size */
1196 		err = -EMSGSIZE;
1197 
1198 	saved_len = skb->len;
1199 	for (i = 0; i < diag->nr_maps; i++) {
1200 		sdata = __sk_storage_lookup(sk_storage,
1201 				(struct bpf_sk_storage_map *)diag->maps[i],
1202 				false);
1203 
1204 		if (!sdata)
1205 			continue;
1206 
1207 		diag_size += nla_value_size(diag->maps[i]->value_size);
1208 
1209 		if (nla_stgs && diag_get(sdata, skb))
1210 			/* Continue to learn diag_size */
1211 			err = -EMSGSIZE;
1212 	}
1213 	rcu_read_unlock();
1214 
1215 	if (nla_stgs) {
1216 		if (saved_len == skb->len)
1217 			nla_nest_cancel(skb, nla_stgs);
1218 		else
1219 			nla_nest_end(skb, nla_stgs);
1220 	}
1221 
1222 	if (diag_size == nla_total_size(0)) {
1223 		*res_diag_size = 0;
1224 		return 0;
1225 	}
1226 
1227 	*res_diag_size = diag_size;
1228 	return err;
1229 }
1230 EXPORT_SYMBOL_GPL(bpf_sk_storage_diag_put);
1231 
1232 struct bpf_iter_seq_sk_storage_map_info {
1233 	struct bpf_map *map;
1234 	unsigned int bucket_id;
1235 	unsigned skip_elems;
1236 };
1237 
1238 static struct bpf_sk_storage_elem *
1239 bpf_sk_storage_map_seq_find_next(struct bpf_iter_seq_sk_storage_map_info *info,
1240 				 struct bpf_sk_storage_elem *prev_selem)
1241 {
1242 	struct bpf_sk_storage *sk_storage;
1243 	struct bpf_sk_storage_elem *selem;
1244 	u32 skip_elems = info->skip_elems;
1245 	struct bpf_sk_storage_map *smap;
1246 	u32 bucket_id = info->bucket_id;
1247 	u32 i, count, n_buckets;
1248 	struct bucket *b;
1249 
1250 	smap = (struct bpf_sk_storage_map *)info->map;
1251 	n_buckets = 1U << smap->bucket_log;
1252 	if (bucket_id >= n_buckets)
1253 		return NULL;
1254 
1255 	/* try to find next selem in the same bucket */
1256 	selem = prev_selem;
1257 	count = 0;
1258 	while (selem) {
1259 		selem = hlist_entry_safe(selem->map_node.next,
1260 					 struct bpf_sk_storage_elem, map_node);
1261 		if (!selem) {
1262 			/* not found, unlock and go to the next bucket */
1263 			b = &smap->buckets[bucket_id++];
1264 			raw_spin_unlock_bh(&b->lock);
1265 			skip_elems = 0;
1266 			break;
1267 		}
1268 		sk_storage = rcu_dereference_raw(selem->sk_storage);
1269 		if (sk_storage) {
1270 			info->skip_elems = skip_elems + count;
1271 			return selem;
1272 		}
1273 		count++;
1274 	}
1275 
1276 	for (i = bucket_id; i < (1U << smap->bucket_log); i++) {
1277 		b = &smap->buckets[i];
1278 		raw_spin_lock_bh(&b->lock);
1279 		count = 0;
1280 		hlist_for_each_entry(selem, &b->list, map_node) {
1281 			sk_storage = rcu_dereference_raw(selem->sk_storage);
1282 			if (sk_storage && count >= skip_elems) {
1283 				info->bucket_id = i;
1284 				info->skip_elems = count;
1285 				return selem;
1286 			}
1287 			count++;
1288 		}
1289 		raw_spin_unlock_bh(&b->lock);
1290 		skip_elems = 0;
1291 	}
1292 
1293 	info->bucket_id = i;
1294 	info->skip_elems = 0;
1295 	return NULL;
1296 }
1297 
1298 static void *bpf_sk_storage_map_seq_start(struct seq_file *seq, loff_t *pos)
1299 {
1300 	struct bpf_sk_storage_elem *selem;
1301 
1302 	selem = bpf_sk_storage_map_seq_find_next(seq->private, NULL);
1303 	if (!selem)
1304 		return NULL;
1305 
1306 	if (*pos == 0)
1307 		++*pos;
1308 	return selem;
1309 }
1310 
1311 static void *bpf_sk_storage_map_seq_next(struct seq_file *seq, void *v,
1312 					 loff_t *pos)
1313 {
1314 	struct bpf_iter_seq_sk_storage_map_info *info = seq->private;
1315 
1316 	++*pos;
1317 	++info->skip_elems;
1318 	return bpf_sk_storage_map_seq_find_next(seq->private, v);
1319 }
1320 
1321 struct bpf_iter__bpf_sk_storage_map {
1322 	__bpf_md_ptr(struct bpf_iter_meta *, meta);
1323 	__bpf_md_ptr(struct bpf_map *, map);
1324 	__bpf_md_ptr(struct sock *, sk);
1325 	__bpf_md_ptr(void *, value);
1326 };
1327 
1328 DEFINE_BPF_ITER_FUNC(bpf_sk_storage_map, struct bpf_iter_meta *meta,
1329 		     struct bpf_map *map, struct sock *sk,
1330 		     void *value)
1331 
1332 static int __bpf_sk_storage_map_seq_show(struct seq_file *seq,
1333 					 struct bpf_sk_storage_elem *selem)
1334 {
1335 	struct bpf_iter_seq_sk_storage_map_info *info = seq->private;
1336 	struct bpf_iter__bpf_sk_storage_map ctx = {};
1337 	struct bpf_sk_storage *sk_storage;
1338 	struct bpf_iter_meta meta;
1339 	struct bpf_prog *prog;
1340 	int ret = 0;
1341 
1342 	meta.seq = seq;
1343 	prog = bpf_iter_get_info(&meta, selem == NULL);
1344 	if (prog) {
1345 		ctx.meta = &meta;
1346 		ctx.map = info->map;
1347 		if (selem) {
1348 			sk_storage = rcu_dereference_raw(selem->sk_storage);
1349 			ctx.sk = sk_storage->sk;
1350 			ctx.value = SDATA(selem)->data;
1351 		}
1352 		ret = bpf_iter_run_prog(prog, &ctx);
1353 	}
1354 
1355 	return ret;
1356 }
1357 
1358 static int bpf_sk_storage_map_seq_show(struct seq_file *seq, void *v)
1359 {
1360 	return __bpf_sk_storage_map_seq_show(seq, v);
1361 }
1362 
1363 static void bpf_sk_storage_map_seq_stop(struct seq_file *seq, void *v)
1364 {
1365 	struct bpf_iter_seq_sk_storage_map_info *info = seq->private;
1366 	struct bpf_sk_storage_map *smap;
1367 	struct bucket *b;
1368 
1369 	if (!v) {
1370 		(void)__bpf_sk_storage_map_seq_show(seq, v);
1371 	} else {
1372 		smap = (struct bpf_sk_storage_map *)info->map;
1373 		b = &smap->buckets[info->bucket_id];
1374 		raw_spin_unlock_bh(&b->lock);
1375 	}
1376 }
1377 
1378 static int bpf_iter_init_sk_storage_map(void *priv_data,
1379 					struct bpf_iter_aux_info *aux)
1380 {
1381 	struct bpf_iter_seq_sk_storage_map_info *seq_info = priv_data;
1382 
1383 	seq_info->map = aux->map;
1384 	return 0;
1385 }
1386 
1387 static int bpf_iter_attach_map(struct bpf_prog *prog,
1388 			       union bpf_iter_link_info *linfo,
1389 			       struct bpf_iter_aux_info *aux)
1390 {
1391 	struct bpf_map *map;
1392 	int err = -EINVAL;
1393 
1394 	if (!linfo->map.map_fd)
1395 		return -EBADF;
1396 
1397 	map = bpf_map_get_with_uref(linfo->map.map_fd);
1398 	if (IS_ERR(map))
1399 		return PTR_ERR(map);
1400 
1401 	if (map->map_type != BPF_MAP_TYPE_SK_STORAGE)
1402 		goto put_map;
1403 
1404 	if (prog->aux->max_rdonly_access > map->value_size) {
1405 		err = -EACCES;
1406 		goto put_map;
1407 	}
1408 
1409 	aux->map = map;
1410 	return 0;
1411 
1412 put_map:
1413 	bpf_map_put_with_uref(map);
1414 	return err;
1415 }
1416 
1417 static void bpf_iter_detach_map(struct bpf_iter_aux_info *aux)
1418 {
1419 	bpf_map_put_with_uref(aux->map);
1420 }
1421 
1422 static const struct seq_operations bpf_sk_storage_map_seq_ops = {
1423 	.start  = bpf_sk_storage_map_seq_start,
1424 	.next   = bpf_sk_storage_map_seq_next,
1425 	.stop   = bpf_sk_storage_map_seq_stop,
1426 	.show   = bpf_sk_storage_map_seq_show,
1427 };
1428 
1429 static const struct bpf_iter_seq_info iter_seq_info = {
1430 	.seq_ops		= &bpf_sk_storage_map_seq_ops,
1431 	.init_seq_private	= bpf_iter_init_sk_storage_map,
1432 	.fini_seq_private	= NULL,
1433 	.seq_priv_size		= sizeof(struct bpf_iter_seq_sk_storage_map_info),
1434 };
1435 
1436 static struct bpf_iter_reg bpf_sk_storage_map_reg_info = {
1437 	.target			= "bpf_sk_storage_map",
1438 	.attach_target		= bpf_iter_attach_map,
1439 	.detach_target		= bpf_iter_detach_map,
1440 	.ctx_arg_info_size	= 2,
1441 	.ctx_arg_info		= {
1442 		{ offsetof(struct bpf_iter__bpf_sk_storage_map, sk),
1443 		  PTR_TO_BTF_ID_OR_NULL },
1444 		{ offsetof(struct bpf_iter__bpf_sk_storage_map, value),
1445 		  PTR_TO_RDWR_BUF_OR_NULL },
1446 	},
1447 	.seq_info		= &iter_seq_info,
1448 };
1449 
1450 static int __init bpf_sk_storage_map_iter_init(void)
1451 {
1452 	bpf_sk_storage_map_reg_info.ctx_arg_info[0].btf_id =
1453 		btf_sock_ids[BTF_SOCK_TYPE_SOCK];
1454 	return bpf_iter_reg_target(&bpf_sk_storage_map_reg_info);
1455 }
1456 late_initcall(bpf_sk_storage_map_iter_init);
1457