1 // SPDX-License-Identifier: GPL-2.0
2 /* Copyright (c) 2023 Meta Platforms, Inc. and affiliates. */
3 
4 #include <stdbool.h>
5 #include <linux/bpf.h>
6 #include <bpf/bpf_helpers.h>
7 #include "bpf_misc.h"
8 
9 #define ARRAY_SIZE(x) (sizeof(x) / sizeof((x)[0]))
10 
11 static volatile int zero = 0;
12 
13 int my_pid;
14 int arr[256];
15 int small_arr[16] SEC(".data.small_arr");
16 
17 #ifdef REAL_TEST
18 #define MY_PID_GUARD() if (my_pid != (bpf_get_current_pid_tgid() >> 32)) return 0
19 #else
20 #define MY_PID_GUARD() ({ })
21 #endif
22 
23 SEC("?raw_tp")
24 __failure __msg("math between map_value pointer and register with unbounded min value is not allowed")
25 int iter_err_unsafe_c_loop(const void *ctx)
26 {
27 	struct bpf_iter_num it;
28 	int *v, i = zero; /* obscure initial value of i */
29 
30 	MY_PID_GUARD();
31 
32 	bpf_iter_num_new(&it, 0, 1000);
33 	while ((v = bpf_iter_num_next(&it))) {
34 		i++;
35 	}
36 	bpf_iter_num_destroy(&it);
37 
38 	small_arr[i] = 123; /* invalid */
39 
40 	return 0;
41 }
42 
43 SEC("?raw_tp")
44 __failure __msg("unbounded memory access")
45 int iter_err_unsafe_asm_loop(const void *ctx)
46 {
47 	struct bpf_iter_num it;
48 
49 	MY_PID_GUARD();
50 
51 	asm volatile (
52 		"r6 = %[zero];" /* iteration counter */
53 		"r1 = %[it];" /* iterator state */
54 		"r2 = 0;"
55 		"r3 = 1000;"
56 		"r4 = 1;"
57 		"call %[bpf_iter_num_new];"
58 	"loop:"
59 		"r1 = %[it];"
60 		"call %[bpf_iter_num_next];"
61 		"if r0 == 0 goto out;"
62 		"r6 += 1;"
63 		"goto loop;"
64 	"out:"
65 		"r1 = %[it];"
66 		"call %[bpf_iter_num_destroy];"
67 		"r1 = %[small_arr];"
68 		"r2 = r6;"
69 		"r2 <<= 2;"
70 		"r1 += r2;"
71 		"*(u32 *)(r1 + 0) = r6;" /* invalid */
72 		:
73 		: [it]"r"(&it),
74 		  [small_arr]"p"(small_arr),
75 		  [zero]"p"(zero),
76 		  __imm(bpf_iter_num_new),
77 		  __imm(bpf_iter_num_next),
78 		  __imm(bpf_iter_num_destroy)
79 		: __clobber_common, "r6"
80 	);
81 
82 	return 0;
83 }
84 
85 SEC("raw_tp")
86 __success
87 int iter_while_loop(const void *ctx)
88 {
89 	struct bpf_iter_num it;
90 	int *v;
91 
92 	MY_PID_GUARD();
93 
94 	bpf_iter_num_new(&it, 0, 3);
95 	while ((v = bpf_iter_num_next(&it))) {
96 		bpf_printk("ITER_BASIC: E1 VAL: v=%d", *v);
97 	}
98 	bpf_iter_num_destroy(&it);
99 
100 	return 0;
101 }
102 
103 SEC("raw_tp")
104 __success
105 int iter_while_loop_auto_cleanup(const void *ctx)
106 {
107 	__attribute__((cleanup(bpf_iter_num_destroy))) struct bpf_iter_num it;
108 	int *v;
109 
110 	MY_PID_GUARD();
111 
112 	bpf_iter_num_new(&it, 0, 3);
113 	while ((v = bpf_iter_num_next(&it))) {
114 		bpf_printk("ITER_BASIC: E1 VAL: v=%d", *v);
115 	}
116 	/* (!) no explicit bpf_iter_num_destroy() */
117 
118 	return 0;
119 }
120 
121 SEC("raw_tp")
122 __success
123 int iter_for_loop(const void *ctx)
124 {
125 	struct bpf_iter_num it;
126 	int *v;
127 
128 	MY_PID_GUARD();
129 
130 	bpf_iter_num_new(&it, 5, 10);
131 	for (v = bpf_iter_num_next(&it); v; v = bpf_iter_num_next(&it)) {
132 		bpf_printk("ITER_BASIC: E2 VAL: v=%d", *v);
133 	}
134 	bpf_iter_num_destroy(&it);
135 
136 	return 0;
137 }
138 
139 SEC("raw_tp")
140 __success
141 int iter_bpf_for_each_macro(const void *ctx)
142 {
143 	int *v;
144 
145 	MY_PID_GUARD();
146 
147 	bpf_for_each(num, v, 5, 10) {
148 		bpf_printk("ITER_BASIC: E2 VAL: v=%d", *v);
149 	}
150 
151 	return 0;
152 }
153 
154 SEC("raw_tp")
155 __success
156 int iter_bpf_for_macro(const void *ctx)
157 {
158 	int i;
159 
160 	MY_PID_GUARD();
161 
162 	bpf_for(i, 5, 10) {
163 		bpf_printk("ITER_BASIC: E2 VAL: v=%d", i);
164 	}
165 
166 	return 0;
167 }
168 
169 SEC("raw_tp")
170 __success
171 int iter_pragma_unroll_loop(const void *ctx)
172 {
173 	struct bpf_iter_num it;
174 	int *v, i;
175 
176 	MY_PID_GUARD();
177 
178 	bpf_iter_num_new(&it, 0, 2);
179 #pragma nounroll
180 	for (i = 0; i < 3; i++) {
181 		v = bpf_iter_num_next(&it);
182 		bpf_printk("ITER_BASIC: E3 VAL: i=%d v=%d", i, v ? *v : -1);
183 	}
184 	bpf_iter_num_destroy(&it);
185 
186 	return 0;
187 }
188 
189 SEC("raw_tp")
190 __success
191 int iter_manual_unroll_loop(const void *ctx)
192 {
193 	struct bpf_iter_num it;
194 	int *v;
195 
196 	MY_PID_GUARD();
197 
198 	bpf_iter_num_new(&it, 100, 200);
199 	v = bpf_iter_num_next(&it);
200 	bpf_printk("ITER_BASIC: E4 VAL: v=%d", v ? *v : -1);
201 	v = bpf_iter_num_next(&it);
202 	bpf_printk("ITER_BASIC: E4 VAL: v=%d", v ? *v : -1);
203 	v = bpf_iter_num_next(&it);
204 	bpf_printk("ITER_BASIC: E4 VAL: v=%d", v ? *v : -1);
205 	v = bpf_iter_num_next(&it);
206 	bpf_printk("ITER_BASIC: E4 VAL: v=%d\n", v ? *v : -1);
207 	bpf_iter_num_destroy(&it);
208 
209 	return 0;
210 }
211 
212 SEC("raw_tp")
213 __success
214 int iter_multiple_sequential_loops(const void *ctx)
215 {
216 	struct bpf_iter_num it;
217 	int *v, i;
218 
219 	MY_PID_GUARD();
220 
221 	bpf_iter_num_new(&it, 0, 3);
222 	while ((v = bpf_iter_num_next(&it))) {
223 		bpf_printk("ITER_BASIC: E1 VAL: v=%d", *v);
224 	}
225 	bpf_iter_num_destroy(&it);
226 
227 	bpf_iter_num_new(&it, 5, 10);
228 	for (v = bpf_iter_num_next(&it); v; v = bpf_iter_num_next(&it)) {
229 		bpf_printk("ITER_BASIC: E2 VAL: v=%d", *v);
230 	}
231 	bpf_iter_num_destroy(&it);
232 
233 	bpf_iter_num_new(&it, 0, 2);
234 #pragma nounroll
235 	for (i = 0; i < 3; i++) {
236 		v = bpf_iter_num_next(&it);
237 		bpf_printk("ITER_BASIC: E3 VAL: i=%d v=%d", i, v ? *v : -1);
238 	}
239 	bpf_iter_num_destroy(&it);
240 
241 	bpf_iter_num_new(&it, 100, 200);
242 	v = bpf_iter_num_next(&it);
243 	bpf_printk("ITER_BASIC: E4 VAL: v=%d", v ? *v : -1);
244 	v = bpf_iter_num_next(&it);
245 	bpf_printk("ITER_BASIC: E4 VAL: v=%d", v ? *v : -1);
246 	v = bpf_iter_num_next(&it);
247 	bpf_printk("ITER_BASIC: E4 VAL: v=%d", v ? *v : -1);
248 	v = bpf_iter_num_next(&it);
249 	bpf_printk("ITER_BASIC: E4 VAL: v=%d\n", v ? *v : -1);
250 	bpf_iter_num_destroy(&it);
251 
252 	return 0;
253 }
254 
255 SEC("raw_tp")
256 __success
257 int iter_limit_cond_break_loop(const void *ctx)
258 {
259 	struct bpf_iter_num it;
260 	int *v, i = 0, sum = 0;
261 
262 	MY_PID_GUARD();
263 
264 	bpf_iter_num_new(&it, 0, 10);
265 	while ((v = bpf_iter_num_next(&it))) {
266 		bpf_printk("ITER_SIMPLE: i=%d v=%d", i, *v);
267 		sum += *v;
268 
269 		i++;
270 		if (i > 3)
271 			break;
272 	}
273 	bpf_iter_num_destroy(&it);
274 
275 	bpf_printk("ITER_SIMPLE: sum=%d\n", sum);
276 
277 	return 0;
278 }
279 
280 SEC("raw_tp")
281 __success
282 int iter_obfuscate_counter(const void *ctx)
283 {
284 	struct bpf_iter_num it;
285 	int *v, sum = 0;
286 	/* Make i's initial value unknowable for verifier to prevent it from
287 	 * pruning if/else branch inside the loop body and marking i as precise.
288 	 */
289 	int i = zero;
290 
291 	MY_PID_GUARD();
292 
293 	bpf_iter_num_new(&it, 0, 10);
294 	while ((v = bpf_iter_num_next(&it))) {
295 		int x;
296 
297 		i += 1;
298 
299 		/* If we initialized i as `int i = 0;` above, verifier would
300 		 * track that i becomes 1 on first iteration after increment
301 		 * above, and here verifier would eagerly prune else branch
302 		 * and mark i as precise, ruining open-coded iterator logic
303 		 * completely, as each next iteration would have a different
304 		 * *precise* value of i, and thus there would be no
305 		 * convergence of state. This would result in reaching maximum
306 		 * instruction limit, no matter what the limit is.
307 		 */
308 		if (i == 1)
309 			x = 123;
310 		else
311 			x = i * 3 + 1;
312 
313 		bpf_printk("ITER_OBFUSCATE_COUNTER: i=%d v=%d x=%d", i, *v, x);
314 
315 		sum += x;
316 	}
317 	bpf_iter_num_destroy(&it);
318 
319 	bpf_printk("ITER_OBFUSCATE_COUNTER: sum=%d\n", sum);
320 
321 	return 0;
322 }
323 
324 SEC("raw_tp")
325 __success
326 int iter_search_loop(const void *ctx)
327 {
328 	struct bpf_iter_num it;
329 	int *v, *elem = NULL;
330 	bool found = false;
331 
332 	MY_PID_GUARD();
333 
334 	bpf_iter_num_new(&it, 0, 10);
335 
336 	while ((v = bpf_iter_num_next(&it))) {
337 		bpf_printk("ITER_SEARCH_LOOP: v=%d", *v);
338 
339 		if (*v == 2) {
340 			found = true;
341 			elem = v;
342 			barrier_var(elem);
343 		}
344 	}
345 
346 	/* should fail to verify if bpf_iter_num_destroy() is here */
347 
348 	if (found)
349 		/* here found element will be wrong, we should have copied
350 		 * value to a variable, but here we want to make sure we can
351 		 * access memory after the loop anyways
352 		 */
353 		bpf_printk("ITER_SEARCH_LOOP: FOUND IT = %d!\n", *elem);
354 	else
355 		bpf_printk("ITER_SEARCH_LOOP: NOT FOUND IT!\n");
356 
357 	bpf_iter_num_destroy(&it);
358 
359 	return 0;
360 }
361 
362 SEC("raw_tp")
363 __success
364 int iter_array_fill(const void *ctx)
365 {
366 	int sum, i;
367 
368 	MY_PID_GUARD();
369 
370 	bpf_for(i, 0, ARRAY_SIZE(arr)) {
371 		arr[i] = i * 2;
372 	}
373 
374 	sum = 0;
375 	bpf_for(i, 0, ARRAY_SIZE(arr)) {
376 		sum += arr[i];
377 	}
378 
379 	bpf_printk("ITER_ARRAY_FILL: sum=%d (should be %d)\n", sum, 255 * 256);
380 
381 	return 0;
382 }
383 
384 static int arr2d[4][5];
385 static int arr2d_row_sums[4];
386 static int arr2d_col_sums[5];
387 
388 SEC("raw_tp")
389 __success
390 int iter_nested_iters(const void *ctx)
391 {
392 	int sum, row, col;
393 
394 	MY_PID_GUARD();
395 
396 	bpf_for(row, 0, ARRAY_SIZE(arr2d)) {
397 		bpf_for( col, 0, ARRAY_SIZE(arr2d[0])) {
398 			arr2d[row][col] = row * col;
399 		}
400 	}
401 
402 	/* zero-initialize sums */
403 	sum = 0;
404 	bpf_for(row, 0, ARRAY_SIZE(arr2d)) {
405 		arr2d_row_sums[row] = 0;
406 	}
407 	bpf_for(col, 0, ARRAY_SIZE(arr2d[0])) {
408 		arr2d_col_sums[col] = 0;
409 	}
410 
411 	/* calculate sums */
412 	bpf_for(row, 0, ARRAY_SIZE(arr2d)) {
413 		bpf_for(col, 0, ARRAY_SIZE(arr2d[0])) {
414 			sum += arr2d[row][col];
415 			arr2d_row_sums[row] += arr2d[row][col];
416 			arr2d_col_sums[col] += arr2d[row][col];
417 		}
418 	}
419 
420 	bpf_printk("ITER_NESTED_ITERS: total sum=%d", sum);
421 	bpf_for(row, 0, ARRAY_SIZE(arr2d)) {
422 		bpf_printk("ITER_NESTED_ITERS: row #%d sum=%d", row, arr2d_row_sums[row]);
423 	}
424 	bpf_for(col, 0, ARRAY_SIZE(arr2d[0])) {
425 		bpf_printk("ITER_NESTED_ITERS: col #%d sum=%d%s",
426 			   col, arr2d_col_sums[col],
427 			   col == ARRAY_SIZE(arr2d[0]) - 1 ? "\n" : "");
428 	}
429 
430 	return 0;
431 }
432 
433 SEC("raw_tp")
434 __success
435 int iter_nested_deeply_iters(const void *ctx)
436 {
437 	int sum = 0;
438 
439 	MY_PID_GUARD();
440 
441 	bpf_repeat(10) {
442 		bpf_repeat(10) {
443 			bpf_repeat(10) {
444 				bpf_repeat(10) {
445 					bpf_repeat(10) {
446 						sum += 1;
447 					}
448 				}
449 			}
450 		}
451 		/* validate that we can break from inside bpf_repeat() */
452 		break;
453 	}
454 
455 	return sum;
456 }
457 
458 static __noinline void fill_inner_dimension(int row)
459 {
460 	int col;
461 
462 	bpf_for(col, 0, ARRAY_SIZE(arr2d[0])) {
463 		arr2d[row][col] = row * col;
464 	}
465 }
466 
467 static __noinline int sum_inner_dimension(int row)
468 {
469 	int sum = 0, col;
470 
471 	bpf_for(col, 0, ARRAY_SIZE(arr2d[0])) {
472 		sum += arr2d[row][col];
473 		arr2d_row_sums[row] += arr2d[row][col];
474 		arr2d_col_sums[col] += arr2d[row][col];
475 	}
476 
477 	return sum;
478 }
479 
480 SEC("raw_tp")
481 __success
482 int iter_subprog_iters(const void *ctx)
483 {
484 	int sum, row, col;
485 
486 	MY_PID_GUARD();
487 
488 	bpf_for(row, 0, ARRAY_SIZE(arr2d)) {
489 		fill_inner_dimension(row);
490 	}
491 
492 	/* zero-initialize sums */
493 	sum = 0;
494 	bpf_for(row, 0, ARRAY_SIZE(arr2d)) {
495 		arr2d_row_sums[row] = 0;
496 	}
497 	bpf_for(col, 0, ARRAY_SIZE(arr2d[0])) {
498 		arr2d_col_sums[col] = 0;
499 	}
500 
501 	/* calculate sums */
502 	bpf_for(row, 0, ARRAY_SIZE(arr2d)) {
503 		sum += sum_inner_dimension(row);
504 	}
505 
506 	bpf_printk("ITER_SUBPROG_ITERS: total sum=%d", sum);
507 	bpf_for(row, 0, ARRAY_SIZE(arr2d)) {
508 		bpf_printk("ITER_SUBPROG_ITERS: row #%d sum=%d",
509 			   row, arr2d_row_sums[row]);
510 	}
511 	bpf_for(col, 0, ARRAY_SIZE(arr2d[0])) {
512 		bpf_printk("ITER_SUBPROG_ITERS: col #%d sum=%d%s",
513 			   col, arr2d_col_sums[col],
514 			   col == ARRAY_SIZE(arr2d[0]) - 1 ? "\n" : "");
515 	}
516 
517 	return 0;
518 }
519 
520 struct {
521 	__uint(type, BPF_MAP_TYPE_ARRAY);
522 	__type(key, int);
523 	__type(value, int);
524 	__uint(max_entries, 1000);
525 } arr_map SEC(".maps");
526 
527 SEC("?raw_tp")
528 __failure __msg("invalid mem access 'scalar'")
529 int iter_err_too_permissive1(const void *ctx)
530 {
531 	int *map_val = NULL;
532 	int key = 0;
533 
534 	MY_PID_GUARD();
535 
536 	map_val = bpf_map_lookup_elem(&arr_map, &key);
537 	if (!map_val)
538 		return 0;
539 
540 	bpf_repeat(1000000) {
541 		map_val = NULL;
542 	}
543 
544 	*map_val = 123;
545 
546 	return 0;
547 }
548 
549 SEC("?raw_tp")
550 __failure __msg("invalid mem access 'map_value_or_null'")
551 int iter_err_too_permissive2(const void *ctx)
552 {
553 	int *map_val = NULL;
554 	int key = 0;
555 
556 	MY_PID_GUARD();
557 
558 	map_val = bpf_map_lookup_elem(&arr_map, &key);
559 	if (!map_val)
560 		return 0;
561 
562 	bpf_repeat(1000000) {
563 		map_val = bpf_map_lookup_elem(&arr_map, &key);
564 	}
565 
566 	*map_val = 123;
567 
568 	return 0;
569 }
570 
571 SEC("?raw_tp")
572 __failure __msg("invalid mem access 'map_value_or_null'")
573 int iter_err_too_permissive3(const void *ctx)
574 {
575 	int *map_val = NULL;
576 	int key = 0;
577 	bool found = false;
578 
579 	MY_PID_GUARD();
580 
581 	bpf_repeat(1000000) {
582 		map_val = bpf_map_lookup_elem(&arr_map, &key);
583 		found = true;
584 	}
585 
586 	if (found)
587 		*map_val = 123;
588 
589 	return 0;
590 }
591 
592 SEC("raw_tp")
593 __success
594 int iter_tricky_but_fine(const void *ctx)
595 {
596 	int *map_val = NULL;
597 	int key = 0;
598 	bool found = false;
599 
600 	MY_PID_GUARD();
601 
602 	bpf_repeat(1000000) {
603 		map_val = bpf_map_lookup_elem(&arr_map, &key);
604 		if (map_val) {
605 			found = true;
606 			break;
607 		}
608 	}
609 
610 	if (found)
611 		*map_val = 123;
612 
613 	return 0;
614 }
615 
616 #define __bpf_memzero(p, sz) bpf_probe_read_kernel((p), (sz), 0)
617 
618 SEC("raw_tp")
619 __success
620 int iter_stack_array_loop(const void *ctx)
621 {
622 	long arr1[16], arr2[16], sum = 0;
623 	int i;
624 
625 	MY_PID_GUARD();
626 
627 	/* zero-init arr1 and arr2 in such a way that verifier doesn't know
628 	 * it's all zeros; if we don't do that, we'll make BPF verifier track
629 	 * all combination of zero/non-zero stack slots for arr1/arr2, which
630 	 * will lead to O(2^(ARRAY_SIZE(arr1)+ARRAY_SIZE(arr2))) different
631 	 * states
632 	 */
633 	__bpf_memzero(arr1, sizeof(arr1));
634 	__bpf_memzero(arr2, sizeof(arr1));
635 
636 	/* validate that we can break and continue when using bpf_for() */
637 	bpf_for(i, 0, ARRAY_SIZE(arr1)) {
638 		if (i & 1) {
639 			arr1[i] = i;
640 			continue;
641 		} else {
642 			arr2[i] = i;
643 			break;
644 		}
645 	}
646 
647 	bpf_for(i, 0, ARRAY_SIZE(arr1)) {
648 		sum += arr1[i] + arr2[i];
649 	}
650 
651 	return sum;
652 }
653 
654 #define ARR_SZ 16
655 
656 static __noinline void fill(struct bpf_iter_num *it, int *arr, int mul)
657 {
658 	int *t;
659 	__u64 i;
660 
661 	while ((t = bpf_iter_num_next(it))) {
662 		i = *t;
663 		if (i >= ARR_SZ)
664 			break;
665 		arr[i] =  i * mul;
666 	}
667 }
668 
669 static __noinline int sum(struct bpf_iter_num *it, int *arr)
670 {
671 	int *t, sum = 0;;
672 	__u64 i;
673 
674 	while ((t = bpf_iter_num_next(it))) {
675 		i = *t;
676 		if (i >= ARR_SZ)
677 			break;
678 		sum += arr[i];
679 	}
680 
681 	return sum;
682 }
683 
684 SEC("raw_tp")
685 __success
686 int iter_pass_iter_ptr_to_subprog(const void *ctx)
687 {
688 	int arr1[ARR_SZ], arr2[ARR_SZ];
689 	struct bpf_iter_num it;
690 	int n, sum1, sum2;
691 
692 	MY_PID_GUARD();
693 
694 	/* fill arr1 */
695 	n = ARRAY_SIZE(arr1);
696 	bpf_iter_num_new(&it, 0, n);
697 	fill(&it, arr1, 2);
698 	bpf_iter_num_destroy(&it);
699 
700 	/* fill arr2 */
701 	n = ARRAY_SIZE(arr2);
702 	bpf_iter_num_new(&it, 0, n);
703 	fill(&it, arr2, 10);
704 	bpf_iter_num_destroy(&it);
705 
706 	/* sum arr1 */
707 	n = ARRAY_SIZE(arr1);
708 	bpf_iter_num_new(&it, 0, n);
709 	sum1 = sum(&it, arr1);
710 	bpf_iter_num_destroy(&it);
711 
712 	/* sum arr2 */
713 	n = ARRAY_SIZE(arr2);
714 	bpf_iter_num_new(&it, 0, n);
715 	sum2 = sum(&it, arr2);
716 	bpf_iter_num_destroy(&it);
717 
718 	bpf_printk("sum1=%d, sum2=%d", sum1, sum2);
719 
720 	return 0;
721 }
722 
723 char _license[] SEC("license") = "GPL";
724