1 // SPDX-License-Identifier: GPL-2.0
2 #include <vmlinux.h>
3 #include <bpf/bpf_tracing.h>
4 #include <bpf/bpf_helpers.h>
5 
6 struct map_value {
7 	struct prog_test_ref_kfunc __kptr_untrusted *unref_ptr;
8 	struct prog_test_ref_kfunc __kptr *ref_ptr;
9 };
10 
11 struct array_map {
12 	__uint(type, BPF_MAP_TYPE_ARRAY);
13 	__type(key, int);
14 	__type(value, struct map_value);
15 	__uint(max_entries, 1);
16 } array_map SEC(".maps");
17 
18 struct pcpu_array_map {
19 	__uint(type, BPF_MAP_TYPE_PERCPU_ARRAY);
20 	__type(key, int);
21 	__type(value, struct map_value);
22 	__uint(max_entries, 1);
23 } pcpu_array_map SEC(".maps");
24 
25 struct hash_map {
26 	__uint(type, BPF_MAP_TYPE_HASH);
27 	__type(key, int);
28 	__type(value, struct map_value);
29 	__uint(max_entries, 1);
30 } hash_map SEC(".maps");
31 
32 struct pcpu_hash_map {
33 	__uint(type, BPF_MAP_TYPE_PERCPU_HASH);
34 	__type(key, int);
35 	__type(value, struct map_value);
36 	__uint(max_entries, 1);
37 } pcpu_hash_map SEC(".maps");
38 
39 struct hash_malloc_map {
40 	__uint(type, BPF_MAP_TYPE_HASH);
41 	__type(key, int);
42 	__type(value, struct map_value);
43 	__uint(max_entries, 1);
44 	__uint(map_flags, BPF_F_NO_PREALLOC);
45 } hash_malloc_map SEC(".maps");
46 
47 struct pcpu_hash_malloc_map {
48 	__uint(type, BPF_MAP_TYPE_PERCPU_HASH);
49 	__type(key, int);
50 	__type(value, struct map_value);
51 	__uint(max_entries, 1);
52 	__uint(map_flags, BPF_F_NO_PREALLOC);
53 } pcpu_hash_malloc_map SEC(".maps");
54 
55 struct lru_hash_map {
56 	__uint(type, BPF_MAP_TYPE_LRU_HASH);
57 	__type(key, int);
58 	__type(value, struct map_value);
59 	__uint(max_entries, 1);
60 } lru_hash_map SEC(".maps");
61 
62 struct lru_pcpu_hash_map {
63 	__uint(type, BPF_MAP_TYPE_LRU_PERCPU_HASH);
64 	__type(key, int);
65 	__type(value, struct map_value);
66 	__uint(max_entries, 1);
67 } lru_pcpu_hash_map SEC(".maps");
68 
69 struct cgrp_ls_map {
70 	__uint(type, BPF_MAP_TYPE_CGRP_STORAGE);
71 	__uint(map_flags, BPF_F_NO_PREALLOC);
72 	__type(key, int);
73 	__type(value, struct map_value);
74 } cgrp_ls_map SEC(".maps");
75 
76 struct task_ls_map {
77 	__uint(type, BPF_MAP_TYPE_TASK_STORAGE);
78 	__uint(map_flags, BPF_F_NO_PREALLOC);
79 	__type(key, int);
80 	__type(value, struct map_value);
81 } task_ls_map SEC(".maps");
82 
83 struct inode_ls_map {
84 	__uint(type, BPF_MAP_TYPE_INODE_STORAGE);
85 	__uint(map_flags, BPF_F_NO_PREALLOC);
86 	__type(key, int);
87 	__type(value, struct map_value);
88 } inode_ls_map SEC(".maps");
89 
90 struct sk_ls_map {
91 	__uint(type, BPF_MAP_TYPE_SK_STORAGE);
92 	__uint(map_flags, BPF_F_NO_PREALLOC);
93 	__type(key, int);
94 	__type(value, struct map_value);
95 } sk_ls_map SEC(".maps");
96 
97 #define DEFINE_MAP_OF_MAP(map_type, inner_map_type, name)       \
98 	struct {                                                \
99 		__uint(type, map_type);                         \
100 		__uint(max_entries, 1);                         \
101 		__uint(key_size, sizeof(int));                  \
102 		__uint(value_size, sizeof(int));                \
103 		__array(values, struct inner_map_type);         \
104 	} name SEC(".maps") = {                                 \
105 		.values = { [0] = &inner_map_type },            \
106 	}
107 
108 DEFINE_MAP_OF_MAP(BPF_MAP_TYPE_ARRAY_OF_MAPS, array_map, array_of_array_maps);
109 DEFINE_MAP_OF_MAP(BPF_MAP_TYPE_ARRAY_OF_MAPS, hash_map, array_of_hash_maps);
110 DEFINE_MAP_OF_MAP(BPF_MAP_TYPE_ARRAY_OF_MAPS, hash_malloc_map, array_of_hash_malloc_maps);
111 DEFINE_MAP_OF_MAP(BPF_MAP_TYPE_ARRAY_OF_MAPS, lru_hash_map, array_of_lru_hash_maps);
112 DEFINE_MAP_OF_MAP(BPF_MAP_TYPE_HASH_OF_MAPS, array_map, hash_of_array_maps);
113 DEFINE_MAP_OF_MAP(BPF_MAP_TYPE_HASH_OF_MAPS, hash_map, hash_of_hash_maps);
114 DEFINE_MAP_OF_MAP(BPF_MAP_TYPE_HASH_OF_MAPS, hash_malloc_map, hash_of_hash_malloc_maps);
115 DEFINE_MAP_OF_MAP(BPF_MAP_TYPE_HASH_OF_MAPS, lru_hash_map, hash_of_lru_hash_maps);
116 
117 extern struct prog_test_ref_kfunc *bpf_kfunc_call_test_acquire(unsigned long *sp) __ksym;
118 extern struct prog_test_ref_kfunc *
119 bpf_kfunc_call_test_kptr_get(struct prog_test_ref_kfunc **p, int a, int b) __ksym;
120 extern void bpf_kfunc_call_test_release(struct prog_test_ref_kfunc *p) __ksym;
121 void bpf_kfunc_call_test_ref(struct prog_test_ref_kfunc *p) __ksym;
122 
123 #define WRITE_ONCE(x, val) ((*(volatile typeof(x) *) &(x)) = (val))
124 
125 static void test_kptr_unref(struct map_value *v)
126 {
127 	struct prog_test_ref_kfunc *p;
128 
129 	p = v->unref_ptr;
130 	/* store untrusted_ptr_or_null_ */
131 	WRITE_ONCE(v->unref_ptr, p);
132 	if (!p)
133 		return;
134 	if (p->a + p->b > 100)
135 		return;
136 	/* store untrusted_ptr_ */
137 	WRITE_ONCE(v->unref_ptr, p);
138 	/* store NULL */
139 	WRITE_ONCE(v->unref_ptr, NULL);
140 }
141 
142 static void test_kptr_ref(struct map_value *v)
143 {
144 	struct prog_test_ref_kfunc *p;
145 
146 	p = v->ref_ptr;
147 	/* store ptr_or_null_ */
148 	WRITE_ONCE(v->unref_ptr, p);
149 	if (!p)
150 		return;
151 	/*
152 	 * p is rcu_ptr_prog_test_ref_kfunc,
153 	 * because bpf prog is non-sleepable and runs in RCU CS.
154 	 * p can be passed to kfunc that requires KF_RCU.
155 	 */
156 	bpf_kfunc_call_test_ref(p);
157 	if (p->a + p->b > 100)
158 		return;
159 	/* store NULL */
160 	p = bpf_kptr_xchg(&v->ref_ptr, NULL);
161 	if (!p)
162 		return;
163 	/*
164 	 * p is trusted_ptr_prog_test_ref_kfunc.
165 	 * p can be passed to kfunc that requires KF_RCU.
166 	 */
167 	bpf_kfunc_call_test_ref(p);
168 	if (p->a + p->b > 100) {
169 		bpf_kfunc_call_test_release(p);
170 		return;
171 	}
172 	/* store ptr_ */
173 	WRITE_ONCE(v->unref_ptr, p);
174 	bpf_kfunc_call_test_release(p);
175 
176 	p = bpf_kfunc_call_test_acquire(&(unsigned long){0});
177 	if (!p)
178 		return;
179 	/* store ptr_ */
180 	p = bpf_kptr_xchg(&v->ref_ptr, p);
181 	if (!p)
182 		return;
183 	if (p->a + p->b > 100) {
184 		bpf_kfunc_call_test_release(p);
185 		return;
186 	}
187 	bpf_kfunc_call_test_release(p);
188 }
189 
190 static void test_kptr_get(struct map_value *v)
191 {
192 	struct prog_test_ref_kfunc *p;
193 
194 	p = bpf_kfunc_call_test_kptr_get(&v->ref_ptr, 0, 0);
195 	if (!p)
196 		return;
197 	if (p->a + p->b > 100) {
198 		bpf_kfunc_call_test_release(p);
199 		return;
200 	}
201 	bpf_kfunc_call_test_release(p);
202 }
203 
204 static void test_kptr(struct map_value *v)
205 {
206 	test_kptr_unref(v);
207 	test_kptr_ref(v);
208 	test_kptr_get(v);
209 }
210 
211 SEC("tc")
212 int test_map_kptr(struct __sk_buff *ctx)
213 {
214 	struct map_value *v;
215 	int key = 0;
216 
217 #define TEST(map)					\
218 	v = bpf_map_lookup_elem(&map, &key);		\
219 	if (!v)						\
220 		return 0;				\
221 	test_kptr(v)
222 
223 	TEST(array_map);
224 	TEST(hash_map);
225 	TEST(hash_malloc_map);
226 	TEST(lru_hash_map);
227 
228 #undef TEST
229 	return 0;
230 }
231 
232 SEC("tp_btf/cgroup_mkdir")
233 int BPF_PROG(test_cgrp_map_kptr, struct cgroup *cgrp, const char *path)
234 {
235 	struct map_value *v;
236 
237 	v = bpf_cgrp_storage_get(&cgrp_ls_map, cgrp, NULL, BPF_LOCAL_STORAGE_GET_F_CREATE);
238 	if (v)
239 		test_kptr(v);
240 	return 0;
241 }
242 
243 SEC("lsm/inode_unlink")
244 int BPF_PROG(test_task_map_kptr, struct inode *inode, struct dentry *victim)
245 {
246 	struct task_struct *task;
247 	struct map_value *v;
248 
249 	task = bpf_get_current_task_btf();
250 	if (!task)
251 		return 0;
252 	v = bpf_task_storage_get(&task_ls_map, task, NULL, BPF_LOCAL_STORAGE_GET_F_CREATE);
253 	if (v)
254 		test_kptr(v);
255 	return 0;
256 }
257 
258 SEC("lsm/inode_unlink")
259 int BPF_PROG(test_inode_map_kptr, struct inode *inode, struct dentry *victim)
260 {
261 	struct map_value *v;
262 
263 	v = bpf_inode_storage_get(&inode_ls_map, inode, NULL, BPF_LOCAL_STORAGE_GET_F_CREATE);
264 	if (v)
265 		test_kptr(v);
266 	return 0;
267 }
268 
269 SEC("tc")
270 int test_sk_map_kptr(struct __sk_buff *ctx)
271 {
272 	struct map_value *v;
273 	struct bpf_sock *sk;
274 
275 	sk = ctx->sk;
276 	if (!sk)
277 		return 0;
278 	v = bpf_sk_storage_get(&sk_ls_map, sk, NULL, BPF_LOCAL_STORAGE_GET_F_CREATE);
279 	if (v)
280 		test_kptr(v);
281 	return 0;
282 }
283 
284 SEC("tc")
285 int test_map_in_map_kptr(struct __sk_buff *ctx)
286 {
287 	struct map_value *v;
288 	int key = 0;
289 	void *map;
290 
291 #define TEST(map_in_map)                                \
292 	map = bpf_map_lookup_elem(&map_in_map, &key);   \
293 	if (!map)                                       \
294 		return 0;                               \
295 	v = bpf_map_lookup_elem(map, &key);		\
296 	if (!v)						\
297 		return 0;				\
298 	test_kptr(v)
299 
300 	TEST(array_of_array_maps);
301 	TEST(array_of_hash_maps);
302 	TEST(array_of_hash_malloc_maps);
303 	TEST(array_of_lru_hash_maps);
304 	TEST(hash_of_array_maps);
305 	TEST(hash_of_hash_maps);
306 	TEST(hash_of_hash_malloc_maps);
307 	TEST(hash_of_lru_hash_maps);
308 
309 #undef TEST
310 	return 0;
311 }
312 
313 int ref = 1;
314 
315 static __always_inline
316 int test_map_kptr_ref_pre(struct map_value *v)
317 {
318 	struct prog_test_ref_kfunc *p, *p_st;
319 	unsigned long arg = 0;
320 	int ret;
321 
322 	p = bpf_kfunc_call_test_acquire(&arg);
323 	if (!p)
324 		return 1;
325 	ref++;
326 
327 	p_st = p->next;
328 	if (p_st->cnt.refs.counter != ref) {
329 		ret = 2;
330 		goto end;
331 	}
332 
333 	p = bpf_kptr_xchg(&v->ref_ptr, p);
334 	if (p) {
335 		ret = 3;
336 		goto end;
337 	}
338 	if (p_st->cnt.refs.counter != ref)
339 		return 4;
340 
341 	p = bpf_kfunc_call_test_kptr_get(&v->ref_ptr, 0, 0);
342 	if (!p)
343 		return 5;
344 	ref++;
345 	if (p_st->cnt.refs.counter != ref) {
346 		ret = 6;
347 		goto end;
348 	}
349 	bpf_kfunc_call_test_release(p);
350 	ref--;
351 	if (p_st->cnt.refs.counter != ref)
352 		return 7;
353 
354 	p = bpf_kptr_xchg(&v->ref_ptr, NULL);
355 	if (!p)
356 		return 8;
357 	bpf_kfunc_call_test_release(p);
358 	ref--;
359 	if (p_st->cnt.refs.counter != ref)
360 		return 9;
361 
362 	p = bpf_kfunc_call_test_acquire(&arg);
363 	if (!p)
364 		return 10;
365 	ref++;
366 	p = bpf_kptr_xchg(&v->ref_ptr, p);
367 	if (p) {
368 		ret = 11;
369 		goto end;
370 	}
371 	if (p_st->cnt.refs.counter != ref)
372 		return 12;
373 	/* Leave in map */
374 
375 	return 0;
376 end:
377 	ref--;
378 	bpf_kfunc_call_test_release(p);
379 	return ret;
380 }
381 
382 static __always_inline
383 int test_map_kptr_ref_post(struct map_value *v)
384 {
385 	struct prog_test_ref_kfunc *p, *p_st;
386 
387 	p_st = v->ref_ptr;
388 	if (!p_st || p_st->cnt.refs.counter != ref)
389 		return 1;
390 
391 	p = bpf_kptr_xchg(&v->ref_ptr, NULL);
392 	if (!p)
393 		return 2;
394 	if (p_st->cnt.refs.counter != ref) {
395 		bpf_kfunc_call_test_release(p);
396 		return 3;
397 	}
398 
399 	p = bpf_kptr_xchg(&v->ref_ptr, p);
400 	if (p) {
401 		bpf_kfunc_call_test_release(p);
402 		return 4;
403 	}
404 	if (p_st->cnt.refs.counter != ref)
405 		return 5;
406 
407 	return 0;
408 }
409 
410 #define TEST(map)                            \
411 	v = bpf_map_lookup_elem(&map, &key); \
412 	if (!v)                              \
413 		return -1;                   \
414 	ret = test_map_kptr_ref_pre(v);      \
415 	if (ret)                             \
416 		return ret;
417 
418 #define TEST_PCPU(map)                                 \
419 	v = bpf_map_lookup_percpu_elem(&map, &key, 0); \
420 	if (!v)                                        \
421 		return -1;                             \
422 	ret = test_map_kptr_ref_pre(v);                \
423 	if (ret)                                       \
424 		return ret;
425 
426 SEC("tc")
427 int test_map_kptr_ref1(struct __sk_buff *ctx)
428 {
429 	struct map_value *v, val = {};
430 	int key = 0, ret;
431 
432 	bpf_map_update_elem(&hash_map, &key, &val, 0);
433 	bpf_map_update_elem(&hash_malloc_map, &key, &val, 0);
434 	bpf_map_update_elem(&lru_hash_map, &key, &val, 0);
435 
436 	bpf_map_update_elem(&pcpu_hash_map, &key, &val, 0);
437 	bpf_map_update_elem(&pcpu_hash_malloc_map, &key, &val, 0);
438 	bpf_map_update_elem(&lru_pcpu_hash_map, &key, &val, 0);
439 
440 	TEST(array_map);
441 	TEST(hash_map);
442 	TEST(hash_malloc_map);
443 	TEST(lru_hash_map);
444 
445 	TEST_PCPU(pcpu_array_map);
446 	TEST_PCPU(pcpu_hash_map);
447 	TEST_PCPU(pcpu_hash_malloc_map);
448 	TEST_PCPU(lru_pcpu_hash_map);
449 
450 	return 0;
451 }
452 
453 #undef TEST
454 #undef TEST_PCPU
455 
456 #define TEST(map)                            \
457 	v = bpf_map_lookup_elem(&map, &key); \
458 	if (!v)                              \
459 		return -1;                   \
460 	ret = test_map_kptr_ref_post(v);     \
461 	if (ret)                             \
462 		return ret;
463 
464 #define TEST_PCPU(map)                                 \
465 	v = bpf_map_lookup_percpu_elem(&map, &key, 0); \
466 	if (!v)                                        \
467 		return -1;                             \
468 	ret = test_map_kptr_ref_post(v);               \
469 	if (ret)                                       \
470 		return ret;
471 
472 SEC("tc")
473 int test_map_kptr_ref2(struct __sk_buff *ctx)
474 {
475 	struct map_value *v;
476 	int key = 0, ret;
477 
478 	TEST(array_map);
479 	TEST(hash_map);
480 	TEST(hash_malloc_map);
481 	TEST(lru_hash_map);
482 
483 	TEST_PCPU(pcpu_array_map);
484 	TEST_PCPU(pcpu_hash_map);
485 	TEST_PCPU(pcpu_hash_malloc_map);
486 	TEST_PCPU(lru_pcpu_hash_map);
487 
488 	return 0;
489 }
490 
491 #undef TEST
492 #undef TEST_PCPU
493 
494 SEC("tc")
495 int test_map_kptr_ref3(struct __sk_buff *ctx)
496 {
497 	struct prog_test_ref_kfunc *p;
498 	unsigned long sp = 0;
499 
500 	p = bpf_kfunc_call_test_acquire(&sp);
501 	if (!p)
502 		return 1;
503 	ref++;
504 	if (p->cnt.refs.counter != ref) {
505 		bpf_kfunc_call_test_release(p);
506 		return 2;
507 	}
508 	bpf_kfunc_call_test_release(p);
509 	ref--;
510 	return 0;
511 }
512 
513 SEC("syscall")
514 int test_ls_map_kptr_ref1(void *ctx)
515 {
516 	struct task_struct *current;
517 	struct map_value *v;
518 	int ret;
519 
520 	current = bpf_get_current_task_btf();
521 	if (!current)
522 		return 100;
523 	v = bpf_task_storage_get(&task_ls_map, current, NULL, 0);
524 	if (v)
525 		return 150;
526 	v = bpf_task_storage_get(&task_ls_map, current, NULL, BPF_LOCAL_STORAGE_GET_F_CREATE);
527 	if (!v)
528 		return 200;
529 	return test_map_kptr_ref_pre(v);
530 }
531 
532 SEC("syscall")
533 int test_ls_map_kptr_ref2(void *ctx)
534 {
535 	struct task_struct *current;
536 	struct map_value *v;
537 	int ret;
538 
539 	current = bpf_get_current_task_btf();
540 	if (!current)
541 		return 100;
542 	v = bpf_task_storage_get(&task_ls_map, current, NULL, 0);
543 	if (!v)
544 		return 200;
545 	return test_map_kptr_ref_post(v);
546 }
547 
548 SEC("syscall")
549 int test_ls_map_kptr_ref_del(void *ctx)
550 {
551 	struct task_struct *current;
552 	struct map_value *v;
553 	int ret;
554 
555 	current = bpf_get_current_task_btf();
556 	if (!current)
557 		return 100;
558 	v = bpf_task_storage_get(&task_ls_map, current, NULL, 0);
559 	if (!v)
560 		return 200;
561 	if (!v->ref_ptr)
562 		return 300;
563 	return bpf_task_storage_delete(&task_ls_map, current);
564 }
565 
566 char _license[] SEC("license") = "GPL";
567