1 // SPDX-License-Identifier: GPL-2.0
2 #include <vmlinux.h>
3 #include <bpf/bpf_tracing.h>
4 #include <bpf/bpf_helpers.h>
5 
6 struct map_value {
7 	struct prog_test_ref_kfunc __kptr *unref_ptr;
8 	struct prog_test_ref_kfunc __kptr_ref *ref_ptr;
9 };
10 
11 struct array_map {
12 	__uint(type, BPF_MAP_TYPE_ARRAY);
13 	__type(key, int);
14 	__type(value, struct map_value);
15 	__uint(max_entries, 1);
16 } array_map SEC(".maps");
17 
18 struct hash_map {
19 	__uint(type, BPF_MAP_TYPE_HASH);
20 	__type(key, int);
21 	__type(value, struct map_value);
22 	__uint(max_entries, 1);
23 } hash_map SEC(".maps");
24 
25 struct hash_malloc_map {
26 	__uint(type, BPF_MAP_TYPE_HASH);
27 	__type(key, int);
28 	__type(value, struct map_value);
29 	__uint(max_entries, 1);
30 	__uint(map_flags, BPF_F_NO_PREALLOC);
31 } hash_malloc_map SEC(".maps");
32 
33 struct lru_hash_map {
34 	__uint(type, BPF_MAP_TYPE_LRU_HASH);
35 	__type(key, int);
36 	__type(value, struct map_value);
37 	__uint(max_entries, 1);
38 } lru_hash_map SEC(".maps");
39 
40 #define DEFINE_MAP_OF_MAP(map_type, inner_map_type, name)       \
41 	struct {                                                \
42 		__uint(type, map_type);                         \
43 		__uint(max_entries, 1);                         \
44 		__uint(key_size, sizeof(int));                  \
45 		__uint(value_size, sizeof(int));                \
46 		__array(values, struct inner_map_type);         \
47 	} name SEC(".maps") = {                                 \
48 		.values = { [0] = &inner_map_type },            \
49 	}
50 
51 DEFINE_MAP_OF_MAP(BPF_MAP_TYPE_ARRAY_OF_MAPS, array_map, array_of_array_maps);
52 DEFINE_MAP_OF_MAP(BPF_MAP_TYPE_ARRAY_OF_MAPS, hash_map, array_of_hash_maps);
53 DEFINE_MAP_OF_MAP(BPF_MAP_TYPE_ARRAY_OF_MAPS, hash_malloc_map, array_of_hash_malloc_maps);
54 DEFINE_MAP_OF_MAP(BPF_MAP_TYPE_ARRAY_OF_MAPS, lru_hash_map, array_of_lru_hash_maps);
55 DEFINE_MAP_OF_MAP(BPF_MAP_TYPE_HASH_OF_MAPS, array_map, hash_of_array_maps);
56 DEFINE_MAP_OF_MAP(BPF_MAP_TYPE_HASH_OF_MAPS, hash_map, hash_of_hash_maps);
57 DEFINE_MAP_OF_MAP(BPF_MAP_TYPE_HASH_OF_MAPS, hash_malloc_map, hash_of_hash_malloc_maps);
58 DEFINE_MAP_OF_MAP(BPF_MAP_TYPE_HASH_OF_MAPS, lru_hash_map, hash_of_lru_hash_maps);
59 
60 extern struct prog_test_ref_kfunc *bpf_kfunc_call_test_acquire(unsigned long *sp) __ksym;
61 extern struct prog_test_ref_kfunc *
62 bpf_kfunc_call_test_kptr_get(struct prog_test_ref_kfunc **p, int a, int b) __ksym;
63 extern void bpf_kfunc_call_test_release(struct prog_test_ref_kfunc *p) __ksym;
64 
65 #define WRITE_ONCE(x, val) ((*(volatile typeof(x) *) &(x)) = (val))
66 
67 static void test_kptr_unref(struct map_value *v)
68 {
69 	struct prog_test_ref_kfunc *p;
70 
71 	p = v->unref_ptr;
72 	/* store untrusted_ptr_or_null_ */
73 	WRITE_ONCE(v->unref_ptr, p);
74 	if (!p)
75 		return;
76 	if (p->a + p->b > 100)
77 		return;
78 	/* store untrusted_ptr_ */
79 	WRITE_ONCE(v->unref_ptr, p);
80 	/* store NULL */
81 	WRITE_ONCE(v->unref_ptr, NULL);
82 }
83 
84 static void test_kptr_ref(struct map_value *v)
85 {
86 	struct prog_test_ref_kfunc *p;
87 
88 	p = v->ref_ptr;
89 	/* store ptr_or_null_ */
90 	WRITE_ONCE(v->unref_ptr, p);
91 	if (!p)
92 		return;
93 	if (p->a + p->b > 100)
94 		return;
95 	/* store NULL */
96 	p = bpf_kptr_xchg(&v->ref_ptr, NULL);
97 	if (!p)
98 		return;
99 	if (p->a + p->b > 100) {
100 		bpf_kfunc_call_test_release(p);
101 		return;
102 	}
103 	/* store ptr_ */
104 	WRITE_ONCE(v->unref_ptr, p);
105 	bpf_kfunc_call_test_release(p);
106 
107 	p = bpf_kfunc_call_test_acquire(&(unsigned long){0});
108 	if (!p)
109 		return;
110 	/* store ptr_ */
111 	p = bpf_kptr_xchg(&v->ref_ptr, p);
112 	if (!p)
113 		return;
114 	if (p->a + p->b > 100) {
115 		bpf_kfunc_call_test_release(p);
116 		return;
117 	}
118 	bpf_kfunc_call_test_release(p);
119 }
120 
121 static void test_kptr_get(struct map_value *v)
122 {
123 	struct prog_test_ref_kfunc *p;
124 
125 	p = bpf_kfunc_call_test_kptr_get(&v->ref_ptr, 0, 0);
126 	if (!p)
127 		return;
128 	if (p->a + p->b > 100) {
129 		bpf_kfunc_call_test_release(p);
130 		return;
131 	}
132 	bpf_kfunc_call_test_release(p);
133 }
134 
135 static void test_kptr(struct map_value *v)
136 {
137 	test_kptr_unref(v);
138 	test_kptr_ref(v);
139 	test_kptr_get(v);
140 }
141 
142 SEC("tc")
143 int test_map_kptr(struct __sk_buff *ctx)
144 {
145 	struct map_value *v;
146 	int key = 0;
147 
148 #define TEST(map)					\
149 	v = bpf_map_lookup_elem(&map, &key);		\
150 	if (!v)						\
151 		return 0;				\
152 	test_kptr(v)
153 
154 	TEST(array_map);
155 	TEST(hash_map);
156 	TEST(hash_malloc_map);
157 	TEST(lru_hash_map);
158 
159 #undef TEST
160 	return 0;
161 }
162 
163 SEC("tc")
164 int test_map_in_map_kptr(struct __sk_buff *ctx)
165 {
166 	struct map_value *v;
167 	int key = 0;
168 	void *map;
169 
170 #define TEST(map_in_map)                                \
171 	map = bpf_map_lookup_elem(&map_in_map, &key);   \
172 	if (!map)                                       \
173 		return 0;                               \
174 	v = bpf_map_lookup_elem(map, &key);		\
175 	if (!v)						\
176 		return 0;				\
177 	test_kptr(v)
178 
179 	TEST(array_of_array_maps);
180 	TEST(array_of_hash_maps);
181 	TEST(array_of_hash_malloc_maps);
182 	TEST(array_of_lru_hash_maps);
183 	TEST(hash_of_array_maps);
184 	TEST(hash_of_hash_maps);
185 	TEST(hash_of_hash_malloc_maps);
186 	TEST(hash_of_lru_hash_maps);
187 
188 #undef TEST
189 	return 0;
190 }
191 
192 SEC("tc")
193 int test_map_kptr_ref(struct __sk_buff *ctx)
194 {
195 	struct prog_test_ref_kfunc *p, *p_st;
196 	unsigned long arg = 0;
197 	struct map_value *v;
198 	int key = 0, ret;
199 
200 	p = bpf_kfunc_call_test_acquire(&arg);
201 	if (!p)
202 		return 1;
203 
204 	p_st = p->next;
205 	if (p_st->cnt.refs.counter != 2) {
206 		ret = 2;
207 		goto end;
208 	}
209 
210 	v = bpf_map_lookup_elem(&array_map, &key);
211 	if (!v) {
212 		ret = 3;
213 		goto end;
214 	}
215 
216 	p = bpf_kptr_xchg(&v->ref_ptr, p);
217 	if (p) {
218 		ret = 4;
219 		goto end;
220 	}
221 	if (p_st->cnt.refs.counter != 2)
222 		return 5;
223 
224 	p = bpf_kfunc_call_test_kptr_get(&v->ref_ptr, 0, 0);
225 	if (!p)
226 		return 6;
227 	if (p_st->cnt.refs.counter != 3) {
228 		ret = 7;
229 		goto end;
230 	}
231 	bpf_kfunc_call_test_release(p);
232 	if (p_st->cnt.refs.counter != 2)
233 		return 8;
234 
235 	p = bpf_kptr_xchg(&v->ref_ptr, NULL);
236 	if (!p)
237 		return 9;
238 	bpf_kfunc_call_test_release(p);
239 	if (p_st->cnt.refs.counter != 1)
240 		return 10;
241 
242 	p = bpf_kfunc_call_test_acquire(&arg);
243 	if (!p)
244 		return 11;
245 	p = bpf_kptr_xchg(&v->ref_ptr, p);
246 	if (p) {
247 		ret = 12;
248 		goto end;
249 	}
250 	if (p_st->cnt.refs.counter != 2)
251 		return 13;
252 	/* Leave in map */
253 
254 	return 0;
255 end:
256 	bpf_kfunc_call_test_release(p);
257 	return ret;
258 }
259 
260 SEC("tc")
261 int test_map_kptr_ref2(struct __sk_buff *ctx)
262 {
263 	struct prog_test_ref_kfunc *p, *p_st;
264 	struct map_value *v;
265 	int key = 0;
266 
267 	v = bpf_map_lookup_elem(&array_map, &key);
268 	if (!v)
269 		return 1;
270 
271 	p_st = v->ref_ptr;
272 	if (!p_st || p_st->cnt.refs.counter != 2)
273 		return 2;
274 
275 	p = bpf_kptr_xchg(&v->ref_ptr, NULL);
276 	if (!p)
277 		return 3;
278 	if (p_st->cnt.refs.counter != 2) {
279 		bpf_kfunc_call_test_release(p);
280 		return 4;
281 	}
282 
283 	p = bpf_kptr_xchg(&v->ref_ptr, p);
284 	if (p) {
285 		bpf_kfunc_call_test_release(p);
286 		return 5;
287 	}
288 	if (p_st->cnt.refs.counter != 2)
289 		return 6;
290 
291 	return 0;
292 }
293 
294 char _license[] SEC("license") = "GPL";
295