1 // SPDX-License-Identifier: GPL-2.0
2 #include <vmlinux.h>
3 #include <bpf/bpf_tracing.h>
4 #include <bpf/bpf_helpers.h>
5 
6 struct map_value {
7 	struct prog_test_ref_kfunc __kptr_untrusted *unref_ptr;
8 	struct prog_test_ref_kfunc __kptr *ref_ptr;
9 };
10 
11 struct array_map {
12 	__uint(type, BPF_MAP_TYPE_ARRAY);
13 	__type(key, int);
14 	__type(value, struct map_value);
15 	__uint(max_entries, 1);
16 } array_map SEC(".maps");
17 
18 struct pcpu_array_map {
19 	__uint(type, BPF_MAP_TYPE_PERCPU_ARRAY);
20 	__type(key, int);
21 	__type(value, struct map_value);
22 	__uint(max_entries, 1);
23 } pcpu_array_map SEC(".maps");
24 
25 struct hash_map {
26 	__uint(type, BPF_MAP_TYPE_HASH);
27 	__type(key, int);
28 	__type(value, struct map_value);
29 	__uint(max_entries, 1);
30 } hash_map SEC(".maps");
31 
32 struct pcpu_hash_map {
33 	__uint(type, BPF_MAP_TYPE_PERCPU_HASH);
34 	__type(key, int);
35 	__type(value, struct map_value);
36 	__uint(max_entries, 1);
37 } pcpu_hash_map SEC(".maps");
38 
39 struct hash_malloc_map {
40 	__uint(type, BPF_MAP_TYPE_HASH);
41 	__type(key, int);
42 	__type(value, struct map_value);
43 	__uint(max_entries, 1);
44 	__uint(map_flags, BPF_F_NO_PREALLOC);
45 } hash_malloc_map SEC(".maps");
46 
47 struct pcpu_hash_malloc_map {
48 	__uint(type, BPF_MAP_TYPE_PERCPU_HASH);
49 	__type(key, int);
50 	__type(value, struct map_value);
51 	__uint(max_entries, 1);
52 	__uint(map_flags, BPF_F_NO_PREALLOC);
53 } pcpu_hash_malloc_map SEC(".maps");
54 
55 struct lru_hash_map {
56 	__uint(type, BPF_MAP_TYPE_LRU_HASH);
57 	__type(key, int);
58 	__type(value, struct map_value);
59 	__uint(max_entries, 1);
60 } lru_hash_map SEC(".maps");
61 
62 struct lru_pcpu_hash_map {
63 	__uint(type, BPF_MAP_TYPE_LRU_PERCPU_HASH);
64 	__type(key, int);
65 	__type(value, struct map_value);
66 	__uint(max_entries, 1);
67 } lru_pcpu_hash_map SEC(".maps");
68 
69 struct cgrp_ls_map {
70 	__uint(type, BPF_MAP_TYPE_CGRP_STORAGE);
71 	__uint(map_flags, BPF_F_NO_PREALLOC);
72 	__type(key, int);
73 	__type(value, struct map_value);
74 } cgrp_ls_map SEC(".maps");
75 
76 struct task_ls_map {
77 	__uint(type, BPF_MAP_TYPE_TASK_STORAGE);
78 	__uint(map_flags, BPF_F_NO_PREALLOC);
79 	__type(key, int);
80 	__type(value, struct map_value);
81 } task_ls_map SEC(".maps");
82 
83 struct inode_ls_map {
84 	__uint(type, BPF_MAP_TYPE_INODE_STORAGE);
85 	__uint(map_flags, BPF_F_NO_PREALLOC);
86 	__type(key, int);
87 	__type(value, struct map_value);
88 } inode_ls_map SEC(".maps");
89 
90 struct sk_ls_map {
91 	__uint(type, BPF_MAP_TYPE_SK_STORAGE);
92 	__uint(map_flags, BPF_F_NO_PREALLOC);
93 	__type(key, int);
94 	__type(value, struct map_value);
95 } sk_ls_map SEC(".maps");
96 
97 #define DEFINE_MAP_OF_MAP(map_type, inner_map_type, name)       \
98 	struct {                                                \
99 		__uint(type, map_type);                         \
100 		__uint(max_entries, 1);                         \
101 		__uint(key_size, sizeof(int));                  \
102 		__uint(value_size, sizeof(int));                \
103 		__array(values, struct inner_map_type);         \
104 	} name SEC(".maps") = {                                 \
105 		.values = { [0] = &inner_map_type },            \
106 	}
107 
108 DEFINE_MAP_OF_MAP(BPF_MAP_TYPE_ARRAY_OF_MAPS, array_map, array_of_array_maps);
109 DEFINE_MAP_OF_MAP(BPF_MAP_TYPE_ARRAY_OF_MAPS, hash_map, array_of_hash_maps);
110 DEFINE_MAP_OF_MAP(BPF_MAP_TYPE_ARRAY_OF_MAPS, hash_malloc_map, array_of_hash_malloc_maps);
111 DEFINE_MAP_OF_MAP(BPF_MAP_TYPE_ARRAY_OF_MAPS, lru_hash_map, array_of_lru_hash_maps);
112 DEFINE_MAP_OF_MAP(BPF_MAP_TYPE_HASH_OF_MAPS, array_map, hash_of_array_maps);
113 DEFINE_MAP_OF_MAP(BPF_MAP_TYPE_HASH_OF_MAPS, hash_map, hash_of_hash_maps);
114 DEFINE_MAP_OF_MAP(BPF_MAP_TYPE_HASH_OF_MAPS, hash_malloc_map, hash_of_hash_malloc_maps);
115 DEFINE_MAP_OF_MAP(BPF_MAP_TYPE_HASH_OF_MAPS, lru_hash_map, hash_of_lru_hash_maps);
116 
117 extern struct prog_test_ref_kfunc *bpf_kfunc_call_test_acquire(unsigned long *sp) __ksym;
118 extern void bpf_kfunc_call_test_release(struct prog_test_ref_kfunc *p) __ksym;
119 void bpf_kfunc_call_test_ref(struct prog_test_ref_kfunc *p) __ksym;
120 
121 #define WRITE_ONCE(x, val) ((*(volatile typeof(x) *) &(x)) = (val))
122 
123 static void test_kptr_unref(struct map_value *v)
124 {
125 	struct prog_test_ref_kfunc *p;
126 
127 	p = v->unref_ptr;
128 	/* store untrusted_ptr_or_null_ */
129 	WRITE_ONCE(v->unref_ptr, p);
130 	if (!p)
131 		return;
132 	if (p->a + p->b > 100)
133 		return;
134 	/* store untrusted_ptr_ */
135 	WRITE_ONCE(v->unref_ptr, p);
136 	/* store NULL */
137 	WRITE_ONCE(v->unref_ptr, NULL);
138 }
139 
140 static void test_kptr_ref(struct map_value *v)
141 {
142 	struct prog_test_ref_kfunc *p;
143 
144 	p = v->ref_ptr;
145 	/* store ptr_or_null_ */
146 	WRITE_ONCE(v->unref_ptr, p);
147 	if (!p)
148 		return;
149 	/*
150 	 * p is rcu_ptr_prog_test_ref_kfunc,
151 	 * because bpf prog is non-sleepable and runs in RCU CS.
152 	 * p can be passed to kfunc that requires KF_RCU.
153 	 */
154 	bpf_kfunc_call_test_ref(p);
155 	if (p->a + p->b > 100)
156 		return;
157 	/* store NULL */
158 	p = bpf_kptr_xchg(&v->ref_ptr, NULL);
159 	if (!p)
160 		return;
161 	/*
162 	 * p is trusted_ptr_prog_test_ref_kfunc.
163 	 * p can be passed to kfunc that requires KF_RCU.
164 	 */
165 	bpf_kfunc_call_test_ref(p);
166 	if (p->a + p->b > 100) {
167 		bpf_kfunc_call_test_release(p);
168 		return;
169 	}
170 	/* store ptr_ */
171 	WRITE_ONCE(v->unref_ptr, p);
172 	bpf_kfunc_call_test_release(p);
173 
174 	p = bpf_kfunc_call_test_acquire(&(unsigned long){0});
175 	if (!p)
176 		return;
177 	/* store ptr_ */
178 	p = bpf_kptr_xchg(&v->ref_ptr, p);
179 	if (!p)
180 		return;
181 	if (p->a + p->b > 100) {
182 		bpf_kfunc_call_test_release(p);
183 		return;
184 	}
185 	bpf_kfunc_call_test_release(p);
186 }
187 
188 static void test_kptr(struct map_value *v)
189 {
190 	test_kptr_unref(v);
191 	test_kptr_ref(v);
192 }
193 
194 SEC("tc")
195 int test_map_kptr(struct __sk_buff *ctx)
196 {
197 	struct map_value *v;
198 	int key = 0;
199 
200 #define TEST(map)					\
201 	v = bpf_map_lookup_elem(&map, &key);		\
202 	if (!v)						\
203 		return 0;				\
204 	test_kptr(v)
205 
206 	TEST(array_map);
207 	TEST(hash_map);
208 	TEST(hash_malloc_map);
209 	TEST(lru_hash_map);
210 
211 #undef TEST
212 	return 0;
213 }
214 
215 SEC("tp_btf/cgroup_mkdir")
216 int BPF_PROG(test_cgrp_map_kptr, struct cgroup *cgrp, const char *path)
217 {
218 	struct map_value *v;
219 
220 	v = bpf_cgrp_storage_get(&cgrp_ls_map, cgrp, NULL, BPF_LOCAL_STORAGE_GET_F_CREATE);
221 	if (v)
222 		test_kptr(v);
223 	return 0;
224 }
225 
226 SEC("lsm/inode_unlink")
227 int BPF_PROG(test_task_map_kptr, struct inode *inode, struct dentry *victim)
228 {
229 	struct task_struct *task;
230 	struct map_value *v;
231 
232 	task = bpf_get_current_task_btf();
233 	if (!task)
234 		return 0;
235 	v = bpf_task_storage_get(&task_ls_map, task, NULL, BPF_LOCAL_STORAGE_GET_F_CREATE);
236 	if (v)
237 		test_kptr(v);
238 	return 0;
239 }
240 
241 SEC("lsm/inode_unlink")
242 int BPF_PROG(test_inode_map_kptr, struct inode *inode, struct dentry *victim)
243 {
244 	struct map_value *v;
245 
246 	v = bpf_inode_storage_get(&inode_ls_map, inode, NULL, BPF_LOCAL_STORAGE_GET_F_CREATE);
247 	if (v)
248 		test_kptr(v);
249 	return 0;
250 }
251 
252 SEC("tc")
253 int test_sk_map_kptr(struct __sk_buff *ctx)
254 {
255 	struct map_value *v;
256 	struct bpf_sock *sk;
257 
258 	sk = ctx->sk;
259 	if (!sk)
260 		return 0;
261 	v = bpf_sk_storage_get(&sk_ls_map, sk, NULL, BPF_LOCAL_STORAGE_GET_F_CREATE);
262 	if (v)
263 		test_kptr(v);
264 	return 0;
265 }
266 
267 SEC("tc")
268 int test_map_in_map_kptr(struct __sk_buff *ctx)
269 {
270 	struct map_value *v;
271 	int key = 0;
272 	void *map;
273 
274 #define TEST(map_in_map)                                \
275 	map = bpf_map_lookup_elem(&map_in_map, &key);   \
276 	if (!map)                                       \
277 		return 0;                               \
278 	v = bpf_map_lookup_elem(map, &key);		\
279 	if (!v)						\
280 		return 0;				\
281 	test_kptr(v)
282 
283 	TEST(array_of_array_maps);
284 	TEST(array_of_hash_maps);
285 	TEST(array_of_hash_malloc_maps);
286 	TEST(array_of_lru_hash_maps);
287 	TEST(hash_of_array_maps);
288 	TEST(hash_of_hash_maps);
289 	TEST(hash_of_hash_malloc_maps);
290 	TEST(hash_of_lru_hash_maps);
291 
292 #undef TEST
293 	return 0;
294 }
295 
296 int ref = 1;
297 
298 static __always_inline
299 int test_map_kptr_ref_pre(struct map_value *v)
300 {
301 	struct prog_test_ref_kfunc *p, *p_st;
302 	unsigned long arg = 0;
303 	int ret;
304 
305 	p = bpf_kfunc_call_test_acquire(&arg);
306 	if (!p)
307 		return 1;
308 	ref++;
309 
310 	p_st = p->next;
311 	if (p_st->cnt.refs.counter != ref) {
312 		ret = 2;
313 		goto end;
314 	}
315 
316 	p = bpf_kptr_xchg(&v->ref_ptr, p);
317 	if (p) {
318 		ret = 3;
319 		goto end;
320 	}
321 	if (p_st->cnt.refs.counter != ref)
322 		return 4;
323 
324 	p = bpf_kptr_xchg(&v->ref_ptr, NULL);
325 	if (!p)
326 		return 5;
327 	bpf_kfunc_call_test_release(p);
328 	ref--;
329 	if (p_st->cnt.refs.counter != ref)
330 		return 6;
331 
332 	p = bpf_kfunc_call_test_acquire(&arg);
333 	if (!p)
334 		return 7;
335 	ref++;
336 	p = bpf_kptr_xchg(&v->ref_ptr, p);
337 	if (p) {
338 		ret = 8;
339 		goto end;
340 	}
341 	if (p_st->cnt.refs.counter != ref)
342 		return 9;
343 	/* Leave in map */
344 
345 	return 0;
346 end:
347 	ref--;
348 	bpf_kfunc_call_test_release(p);
349 	return ret;
350 }
351 
352 static __always_inline
353 int test_map_kptr_ref_post(struct map_value *v)
354 {
355 	struct prog_test_ref_kfunc *p, *p_st;
356 
357 	p_st = v->ref_ptr;
358 	if (!p_st || p_st->cnt.refs.counter != ref)
359 		return 1;
360 
361 	p = bpf_kptr_xchg(&v->ref_ptr, NULL);
362 	if (!p)
363 		return 2;
364 	if (p_st->cnt.refs.counter != ref) {
365 		bpf_kfunc_call_test_release(p);
366 		return 3;
367 	}
368 
369 	p = bpf_kptr_xchg(&v->ref_ptr, p);
370 	if (p) {
371 		bpf_kfunc_call_test_release(p);
372 		return 4;
373 	}
374 	if (p_st->cnt.refs.counter != ref)
375 		return 5;
376 
377 	return 0;
378 }
379 
380 #define TEST(map)                            \
381 	v = bpf_map_lookup_elem(&map, &key); \
382 	if (!v)                              \
383 		return -1;                   \
384 	ret = test_map_kptr_ref_pre(v);      \
385 	if (ret)                             \
386 		return ret;
387 
388 #define TEST_PCPU(map)                                 \
389 	v = bpf_map_lookup_percpu_elem(&map, &key, 0); \
390 	if (!v)                                        \
391 		return -1;                             \
392 	ret = test_map_kptr_ref_pre(v);                \
393 	if (ret)                                       \
394 		return ret;
395 
396 SEC("tc")
397 int test_map_kptr_ref1(struct __sk_buff *ctx)
398 {
399 	struct map_value *v, val = {};
400 	int key = 0, ret;
401 
402 	bpf_map_update_elem(&hash_map, &key, &val, 0);
403 	bpf_map_update_elem(&hash_malloc_map, &key, &val, 0);
404 	bpf_map_update_elem(&lru_hash_map, &key, &val, 0);
405 
406 	bpf_map_update_elem(&pcpu_hash_map, &key, &val, 0);
407 	bpf_map_update_elem(&pcpu_hash_malloc_map, &key, &val, 0);
408 	bpf_map_update_elem(&lru_pcpu_hash_map, &key, &val, 0);
409 
410 	TEST(array_map);
411 	TEST(hash_map);
412 	TEST(hash_malloc_map);
413 	TEST(lru_hash_map);
414 
415 	TEST_PCPU(pcpu_array_map);
416 	TEST_PCPU(pcpu_hash_map);
417 	TEST_PCPU(pcpu_hash_malloc_map);
418 	TEST_PCPU(lru_pcpu_hash_map);
419 
420 	return 0;
421 }
422 
423 #undef TEST
424 #undef TEST_PCPU
425 
426 #define TEST(map)                            \
427 	v = bpf_map_lookup_elem(&map, &key); \
428 	if (!v)                              \
429 		return -1;                   \
430 	ret = test_map_kptr_ref_post(v);     \
431 	if (ret)                             \
432 		return ret;
433 
434 #define TEST_PCPU(map)                                 \
435 	v = bpf_map_lookup_percpu_elem(&map, &key, 0); \
436 	if (!v)                                        \
437 		return -1;                             \
438 	ret = test_map_kptr_ref_post(v);               \
439 	if (ret)                                       \
440 		return ret;
441 
442 SEC("tc")
443 int test_map_kptr_ref2(struct __sk_buff *ctx)
444 {
445 	struct map_value *v;
446 	int key = 0, ret;
447 
448 	TEST(array_map);
449 	TEST(hash_map);
450 	TEST(hash_malloc_map);
451 	TEST(lru_hash_map);
452 
453 	TEST_PCPU(pcpu_array_map);
454 	TEST_PCPU(pcpu_hash_map);
455 	TEST_PCPU(pcpu_hash_malloc_map);
456 	TEST_PCPU(lru_pcpu_hash_map);
457 
458 	return 0;
459 }
460 
461 #undef TEST
462 #undef TEST_PCPU
463 
464 SEC("tc")
465 int test_map_kptr_ref3(struct __sk_buff *ctx)
466 {
467 	struct prog_test_ref_kfunc *p;
468 	unsigned long sp = 0;
469 
470 	p = bpf_kfunc_call_test_acquire(&sp);
471 	if (!p)
472 		return 1;
473 	ref++;
474 	if (p->cnt.refs.counter != ref) {
475 		bpf_kfunc_call_test_release(p);
476 		return 2;
477 	}
478 	bpf_kfunc_call_test_release(p);
479 	ref--;
480 	return 0;
481 }
482 
483 SEC("syscall")
484 int test_ls_map_kptr_ref1(void *ctx)
485 {
486 	struct task_struct *current;
487 	struct map_value *v;
488 
489 	current = bpf_get_current_task_btf();
490 	if (!current)
491 		return 100;
492 	v = bpf_task_storage_get(&task_ls_map, current, NULL, 0);
493 	if (v)
494 		return 150;
495 	v = bpf_task_storage_get(&task_ls_map, current, NULL, BPF_LOCAL_STORAGE_GET_F_CREATE);
496 	if (!v)
497 		return 200;
498 	return test_map_kptr_ref_pre(v);
499 }
500 
501 SEC("syscall")
502 int test_ls_map_kptr_ref2(void *ctx)
503 {
504 	struct task_struct *current;
505 	struct map_value *v;
506 
507 	current = bpf_get_current_task_btf();
508 	if (!current)
509 		return 100;
510 	v = bpf_task_storage_get(&task_ls_map, current, NULL, 0);
511 	if (!v)
512 		return 200;
513 	return test_map_kptr_ref_post(v);
514 }
515 
516 SEC("syscall")
517 int test_ls_map_kptr_ref_del(void *ctx)
518 {
519 	struct task_struct *current;
520 	struct map_value *v;
521 
522 	current = bpf_get_current_task_btf();
523 	if (!current)
524 		return 100;
525 	v = bpf_task_storage_get(&task_ls_map, current, NULL, 0);
526 	if (!v)
527 		return 200;
528 	if (!v->ref_ptr)
529 		return 300;
530 	return bpf_task_storage_delete(&task_ls_map, current);
531 }
532 
533 char _license[] SEC("license") = "GPL";
534