xref: /openbmc/linux/tools/testing/selftests/bpf/prog_tests/ctx_rewrite.c (revision 1ac731c529cd4d6adbce134754b51ff7d822b145)
1  // SPDX-License-Identifier: GPL-2.0
2  
3  #include <limits.h>
4  #include <stdio.h>
5  #include <string.h>
6  #include <ctype.h>
7  #include <regex.h>
8  #include <test_progs.h>
9  
10  #include "bpf/btf.h"
11  #include "bpf_util.h"
12  #include "linux/filter.h"
13  #include "disasm.h"
14  
15  #define MAX_PROG_TEXT_SZ (32 * 1024)
16  
17  /* The code in this file serves the sole purpose of executing test cases
18   * specified in the test_cases array. Each test case specifies a program
19   * type, context field offset, and disassembly patterns that correspond
20   * to read and write instructions generated by
21   * verifier.c:convert_ctx_access() for accessing that field.
22   *
23   * For each test case, up to three programs are created:
24   * - One that uses BPF_LDX_MEM to read the context field.
25   * - One that uses BPF_STX_MEM to write to the context field.
26   * - One that uses BPF_ST_MEM to write to the context field.
27   *
28   * The disassembly of each program is then compared with the pattern
29   * specified in the test case.
30   */
31  struct test_case {
32  	char *name;
33  	enum bpf_prog_type prog_type;
34  	enum bpf_attach_type expected_attach_type;
35  	int field_offset;
36  	int field_sz;
37  	/* Program generated for BPF_ST_MEM uses value 42 by default,
38  	 * this field allows to specify custom value.
39  	 */
40  	struct {
41  		bool use;
42  		int value;
43  	} st_value;
44  	/* Pattern for BPF_LDX_MEM(field_sz, dst, ctx, field_offset) */
45  	char *read;
46  	/* Pattern for BPF_STX_MEM(field_sz, ctx, src, field_offset) and
47  	 *             BPF_ST_MEM (field_sz, ctx, src, field_offset)
48  	 */
49  	char *write;
50  	/* Pattern for BPF_ST_MEM(field_sz, ctx, src, field_offset),
51  	 * takes priority over `write`.
52  	 */
53  	char *write_st;
54  	/* Pattern for BPF_STX_MEM (field_sz, ctx, src, field_offset),
55  	 * takes priority over `write`.
56  	 */
57  	char *write_stx;
58  };
59  
60  #define N(_prog_type, type, field, name_extra...)	\
61  	.name = #_prog_type "." #field name_extra,	\
62  	.prog_type = BPF_PROG_TYPE_##_prog_type,	\
63  	.field_offset = offsetof(type, field),		\
64  	.field_sz = sizeof(typeof(((type *)NULL)->field))
65  
66  static struct test_case test_cases[] = {
67  /* Sign extension on s390 changes the pattern */
68  #if defined(__x86_64__) || defined(__aarch64__)
69  	{
70  		N(SCHED_CLS, struct __sk_buff, tstamp),
71  		.read  = "r11 = *(u8 *)($ctx + sk_buff::__mono_tc_offset);"
72  			 "w11 &= 3;"
73  			 "if w11 != 0x3 goto pc+2;"
74  			 "$dst = 0;"
75  			 "goto pc+1;"
76  			 "$dst = *(u64 *)($ctx + sk_buff::tstamp);",
77  		.write = "r11 = *(u8 *)($ctx + sk_buff::__mono_tc_offset);"
78  			 "if w11 & 0x2 goto pc+1;"
79  			 "goto pc+2;"
80  			 "w11 &= -2;"
81  			 "*(u8 *)($ctx + sk_buff::__mono_tc_offset) = r11;"
82  			 "*(u64 *)($ctx + sk_buff::tstamp) = $src;",
83  	},
84  #endif
85  	{
86  		N(SCHED_CLS, struct __sk_buff, priority),
87  		.read  = "$dst = *(u32 *)($ctx + sk_buff::priority);",
88  		.write = "*(u32 *)($ctx + sk_buff::priority) = $src;",
89  	},
90  	{
91  		N(SCHED_CLS, struct __sk_buff, mark),
92  		.read  = "$dst = *(u32 *)($ctx + sk_buff::mark);",
93  		.write = "*(u32 *)($ctx + sk_buff::mark) = $src;",
94  	},
95  	{
96  		N(SCHED_CLS, struct __sk_buff, cb[0]),
97  		.read  = "$dst = *(u32 *)($ctx + $(sk_buff::cb + qdisc_skb_cb::data));",
98  		.write = "*(u32 *)($ctx + $(sk_buff::cb + qdisc_skb_cb::data)) = $src;",
99  	},
100  	{
101  		N(SCHED_CLS, struct __sk_buff, tc_classid),
102  		.read  = "$dst = *(u16 *)($ctx + $(sk_buff::cb + qdisc_skb_cb::tc_classid));",
103  		.write = "*(u16 *)($ctx + $(sk_buff::cb + qdisc_skb_cb::tc_classid)) = $src;",
104  	},
105  	{
106  		N(SCHED_CLS, struct __sk_buff, tc_index),
107  		.read  = "$dst = *(u16 *)($ctx + sk_buff::tc_index);",
108  		.write = "*(u16 *)($ctx + sk_buff::tc_index) = $src;",
109  	},
110  	{
111  		N(SCHED_CLS, struct __sk_buff, queue_mapping),
112  		.read      = "$dst = *(u16 *)($ctx + sk_buff::queue_mapping);",
113  		.write_stx = "if $src >= 0xffff goto pc+1;"
114  			     "*(u16 *)($ctx + sk_buff::queue_mapping) = $src;",
115  		.write_st  = "*(u16 *)($ctx + sk_buff::queue_mapping) = $src;",
116  	},
117  	{
118  		/* This is a corner case in filter.c:bpf_convert_ctx_access() */
119  		N(SCHED_CLS, struct __sk_buff, queue_mapping, ".ushrt_max"),
120  		.st_value = { true, USHRT_MAX },
121  		.write_st = "goto pc+0;",
122  	},
123  	{
124  		N(CGROUP_SOCK, struct bpf_sock, bound_dev_if),
125  		.read  = "$dst = *(u32 *)($ctx + sock_common::skc_bound_dev_if);",
126  		.write = "*(u32 *)($ctx + sock_common::skc_bound_dev_if) = $src;",
127  	},
128  	{
129  		N(CGROUP_SOCK, struct bpf_sock, mark),
130  		.read  = "$dst = *(u32 *)($ctx + sock::sk_mark);",
131  		.write = "*(u32 *)($ctx + sock::sk_mark) = $src;",
132  	},
133  	{
134  		N(CGROUP_SOCK, struct bpf_sock, priority),
135  		.read  = "$dst = *(u32 *)($ctx + sock::sk_priority);",
136  		.write = "*(u32 *)($ctx + sock::sk_priority) = $src;",
137  	},
138  	{
139  		N(SOCK_OPS, struct bpf_sock_ops, replylong[0]),
140  		.read  = "$dst = *(u32 *)($ctx + bpf_sock_ops_kern::replylong);",
141  		.write = "*(u32 *)($ctx + bpf_sock_ops_kern::replylong) = $src;",
142  	},
143  	{
144  		N(CGROUP_SYSCTL, struct bpf_sysctl, file_pos),
145  #if __BYTE_ORDER == __LITTLE_ENDIAN
146  		.read  = "$dst = *(u64 *)($ctx + bpf_sysctl_kern::ppos);"
147  			 "$dst = *(u32 *)($dst +0);",
148  		.write = "*(u64 *)($ctx + bpf_sysctl_kern::tmp_reg) = r9;"
149  			 "r9 = *(u64 *)($ctx + bpf_sysctl_kern::ppos);"
150  			 "*(u32 *)(r9 +0) = $src;"
151  			 "r9 = *(u64 *)($ctx + bpf_sysctl_kern::tmp_reg);",
152  #else
153  		.read  = "$dst = *(u64 *)($ctx + bpf_sysctl_kern::ppos);"
154  			 "$dst = *(u32 *)($dst +4);",
155  		.write = "*(u64 *)($ctx + bpf_sysctl_kern::tmp_reg) = r9;"
156  			 "r9 = *(u64 *)($ctx + bpf_sysctl_kern::ppos);"
157  			 "*(u32 *)(r9 +4) = $src;"
158  			 "r9 = *(u64 *)($ctx + bpf_sysctl_kern::tmp_reg);",
159  #endif
160  	},
161  	{
162  		N(CGROUP_SOCKOPT, struct bpf_sockopt, sk),
163  		.read  = "$dst = *(u64 *)($ctx + bpf_sockopt_kern::sk);",
164  		.expected_attach_type = BPF_CGROUP_GETSOCKOPT,
165  	},
166  	{
167  		N(CGROUP_SOCKOPT, struct bpf_sockopt, level),
168  		.read  = "$dst = *(u32 *)($ctx + bpf_sockopt_kern::level);",
169  		.write = "*(u32 *)($ctx + bpf_sockopt_kern::level) = $src;",
170  		.expected_attach_type = BPF_CGROUP_SETSOCKOPT,
171  	},
172  	{
173  		N(CGROUP_SOCKOPT, struct bpf_sockopt, optname),
174  		.read  = "$dst = *(u32 *)($ctx + bpf_sockopt_kern::optname);",
175  		.write = "*(u32 *)($ctx + bpf_sockopt_kern::optname) = $src;",
176  		.expected_attach_type = BPF_CGROUP_SETSOCKOPT,
177  	},
178  	{
179  		N(CGROUP_SOCKOPT, struct bpf_sockopt, optlen),
180  		.read  = "$dst = *(u32 *)($ctx + bpf_sockopt_kern::optlen);",
181  		.write = "*(u32 *)($ctx + bpf_sockopt_kern::optlen) = $src;",
182  		.expected_attach_type = BPF_CGROUP_SETSOCKOPT,
183  	},
184  	{
185  		N(CGROUP_SOCKOPT, struct bpf_sockopt, retval),
186  		.read  = "$dst = *(u64 *)($ctx + bpf_sockopt_kern::current_task);"
187  			 "$dst = *(u64 *)($dst + task_struct::bpf_ctx);"
188  			 "$dst = *(u32 *)($dst + bpf_cg_run_ctx::retval);",
189  		.write = "*(u64 *)($ctx + bpf_sockopt_kern::tmp_reg) = r9;"
190  			 "r9 = *(u64 *)($ctx + bpf_sockopt_kern::current_task);"
191  			 "r9 = *(u64 *)(r9 + task_struct::bpf_ctx);"
192  			 "*(u32 *)(r9 + bpf_cg_run_ctx::retval) = $src;"
193  			 "r9 = *(u64 *)($ctx + bpf_sockopt_kern::tmp_reg);",
194  		.expected_attach_type = BPF_CGROUP_GETSOCKOPT,
195  	},
196  	{
197  		N(CGROUP_SOCKOPT, struct bpf_sockopt, optval),
198  		.read  = "$dst = *(u64 *)($ctx + bpf_sockopt_kern::optval);",
199  		.expected_attach_type = BPF_CGROUP_GETSOCKOPT,
200  	},
201  	{
202  		N(CGROUP_SOCKOPT, struct bpf_sockopt, optval_end),
203  		.read  = "$dst = *(u64 *)($ctx + bpf_sockopt_kern::optval_end);",
204  		.expected_attach_type = BPF_CGROUP_GETSOCKOPT,
205  	},
206  };
207  
208  #undef N
209  
210  static regex_t *ident_regex;
211  static regex_t *field_regex;
212  
skip_space(char * str)213  static char *skip_space(char *str)
214  {
215  	while (*str && isspace(*str))
216  		++str;
217  	return str;
218  }
219  
skip_space_and_semi(char * str)220  static char *skip_space_and_semi(char *str)
221  {
222  	while (*str && (isspace(*str) || *str == ';'))
223  		++str;
224  	return str;
225  }
226  
match_str(char * str,char * prefix)227  static char *match_str(char *str, char *prefix)
228  {
229  	while (*str && *prefix && *str == *prefix) {
230  		++str;
231  		++prefix;
232  	}
233  	if (*prefix)
234  		return NULL;
235  	return str;
236  }
237  
match_number(char * str,int num)238  static char *match_number(char *str, int num)
239  {
240  	char *next;
241  	int snum = strtol(str, &next, 10);
242  
243  	if (next - str == 0 || num != snum)
244  		return NULL;
245  
246  	return next;
247  }
248  
find_field_offset_aux(struct btf * btf,int btf_id,char * field_name,int off)249  static int find_field_offset_aux(struct btf *btf, int btf_id, char *field_name, int off)
250  {
251  	const struct btf_type *type = btf__type_by_id(btf, btf_id);
252  	const struct btf_member *m;
253  	__u16 mnum;
254  	int i;
255  
256  	if (!type) {
257  		PRINT_FAIL("Can't find btf_type for id %d\n", btf_id);
258  		return -1;
259  	}
260  
261  	if (!btf_is_struct(type) && !btf_is_union(type)) {
262  		PRINT_FAIL("BTF id %d is not struct or union\n", btf_id);
263  		return -1;
264  	}
265  
266  	m = btf_members(type);
267  	mnum = btf_vlen(type);
268  
269  	for (i = 0; i < mnum; ++i, ++m) {
270  		const char *mname = btf__name_by_offset(btf, m->name_off);
271  
272  		if (strcmp(mname, "") == 0) {
273  			int msize = find_field_offset_aux(btf, m->type, field_name,
274  							  off + m->offset);
275  			if (msize >= 0)
276  				return msize;
277  		}
278  
279  		if (strcmp(mname, field_name))
280  			continue;
281  
282  		return (off + m->offset) / 8;
283  	}
284  
285  	return -1;
286  }
287  
find_field_offset(struct btf * btf,char * pattern,regmatch_t * matches)288  static int find_field_offset(struct btf *btf, char *pattern, regmatch_t *matches)
289  {
290  	int type_sz  = matches[1].rm_eo - matches[1].rm_so;
291  	int field_sz = matches[2].rm_eo - matches[2].rm_so;
292  	char *type   = pattern + matches[1].rm_so;
293  	char *field  = pattern + matches[2].rm_so;
294  	char field_str[128] = {};
295  	char type_str[128] = {};
296  	int btf_id, field_offset;
297  
298  	if (type_sz >= sizeof(type_str)) {
299  		PRINT_FAIL("Malformed pattern: type ident is too long: %d\n", type_sz);
300  		return -1;
301  	}
302  
303  	if (field_sz >= sizeof(field_str)) {
304  		PRINT_FAIL("Malformed pattern: field ident is too long: %d\n", field_sz);
305  		return -1;
306  	}
307  
308  	strncpy(type_str, type, type_sz);
309  	strncpy(field_str, field, field_sz);
310  	btf_id = btf__find_by_name(btf, type_str);
311  	if (btf_id < 0) {
312  		PRINT_FAIL("No BTF info for type %s\n", type_str);
313  		return -1;
314  	}
315  
316  	field_offset = find_field_offset_aux(btf, btf_id, field_str, 0);
317  	if (field_offset < 0) {
318  		PRINT_FAIL("No BTF info for field %s::%s\n", type_str, field_str);
319  		return -1;
320  	}
321  
322  	return field_offset;
323  }
324  
compile_regex(char * pat)325  static regex_t *compile_regex(char *pat)
326  {
327  	regex_t *re;
328  	int err;
329  
330  	re = malloc(sizeof(regex_t));
331  	if (!re) {
332  		PRINT_FAIL("Can't alloc regex\n");
333  		return NULL;
334  	}
335  
336  	err = regcomp(re, pat, REG_EXTENDED);
337  	if (err) {
338  		char errbuf[512];
339  
340  		regerror(err, re, errbuf, sizeof(errbuf));
341  		PRINT_FAIL("Can't compile regex: %s\n", errbuf);
342  		free(re);
343  		return NULL;
344  	}
345  
346  	return re;
347  }
348  
free_regex(regex_t * re)349  static void free_regex(regex_t *re)
350  {
351  	if (!re)
352  		return;
353  
354  	regfree(re);
355  	free(re);
356  }
357  
max_line_len(char * str)358  static u32 max_line_len(char *str)
359  {
360  	u32 max_line = 0;
361  	char *next = str;
362  
363  	while (next) {
364  		next = strchr(str, '\n');
365  		if (next) {
366  			max_line = max_t(u32, max_line, (next - str));
367  			str = next + 1;
368  		} else {
369  			max_line = max_t(u32, max_line, strlen(str));
370  		}
371  	}
372  
373  	return min(max_line, 60u);
374  }
375  
376  /* Print strings `pattern_origin` and `text_origin` side by side,
377   * assume `pattern_pos` and `text_pos` designate location within
378   * corresponding origin string where match diverges.
379   * The output should look like:
380   *
381   *   Can't match disassembly(left) with pattern(right):
382   *   r2 = *(u64 *)(r1 +0)  ;  $dst = *(u64 *)($ctx + bpf_sockopt_kern::sk1)
383   *                     ^                             ^
384   *   r0 = 0                ;
385   *   exit                  ;
386   */
print_match_error(FILE * out,char * pattern_origin,char * text_origin,char * pattern_pos,char * text_pos)387  static void print_match_error(FILE *out,
388  			      char *pattern_origin, char *text_origin,
389  			      char *pattern_pos, char *text_pos)
390  {
391  	char *pattern = pattern_origin;
392  	char *text = text_origin;
393  	int middle = max_line_len(text) + 2;
394  
395  	fprintf(out, "Can't match disassembly(left) with pattern(right):\n");
396  	while (*pattern || *text) {
397  		int column = 0;
398  		int mark1 = -1;
399  		int mark2 = -1;
400  
401  		/* Print one line from text */
402  		while (*text && *text != '\n') {
403  			if (text == text_pos)
404  				mark1 = column;
405  			fputc(*text, out);
406  			++text;
407  			++column;
408  		}
409  		if (text == text_pos)
410  			mark1 = column;
411  
412  		/* Pad to the middle */
413  		while (column < middle) {
414  			fputc(' ', out);
415  			++column;
416  		}
417  		fputs(";  ", out);
418  		column += 3;
419  
420  		/* Print one line from pattern, pattern lines are terminated by ';' */
421  		while (*pattern && *pattern != ';') {
422  			if (pattern == pattern_pos)
423  				mark2 = column;
424  			fputc(*pattern, out);
425  			++pattern;
426  			++column;
427  		}
428  		if (pattern == pattern_pos)
429  			mark2 = column;
430  
431  		fputc('\n', out);
432  		if (*pattern)
433  			++pattern;
434  		if (*text)
435  			++text;
436  
437  		/* If pattern and text diverge at this line, print an
438  		 * additional line with '^' marks, highlighting
439  		 * positions where match fails.
440  		 */
441  		if (mark1 > 0 || mark2 > 0) {
442  			for (column = 0; column <= max(mark1, mark2); ++column) {
443  				if (column == mark1 || column == mark2)
444  					fputc('^', out);
445  				else
446  					fputc(' ', out);
447  			}
448  			fputc('\n', out);
449  		}
450  	}
451  }
452  
453  /* Test if `text` matches `pattern`. Pattern consists of the following elements:
454   *
455   * - Field offset references:
456   *
457   *     <type>::<field>
458   *
459   *   When such reference is encountered BTF is used to compute numerical
460   *   value for the offset of <field> in <type>. The `text` is expected to
461   *   contain matching numerical value.
462   *
463   * - Field groups:
464   *
465   *     $(<type>::<field> [+ <type>::<field>]*)
466   *
467   *   Allows to specify an offset that is a sum of multiple field offsets.
468   *   The `text` is expected to contain matching numerical value.
469   *
470   * - Variable references, e.g. `$src`, `$dst`, `$ctx`.
471   *   These are substitutions specified in `reg_map` array.
472   *   If a substring of pattern is equal to `reg_map[i][0]` the `text` is
473   *   expected to contain `reg_map[i][1]` in the matching position.
474   *
475   * - Whitespace is ignored, ';' counts as whitespace for `pattern`.
476   *
477   * - Any other characters, `pattern` and `text` should match one-to-one.
478   *
479   * Example of a pattern:
480   *
481   *                    __________ fields group ________________
482   *                   '                                        '
483   *   *(u16 *)($ctx + $(sk_buff::cb + qdisc_skb_cb::tc_classid)) = $src;
484   *            ^^^^                   '______________________'
485   *     variable reference             field offset reference
486   */
match_pattern(struct btf * btf,char * pattern,char * text,char * reg_map[][2])487  static bool match_pattern(struct btf *btf, char *pattern, char *text, char *reg_map[][2])
488  {
489  	char *pattern_origin = pattern;
490  	char *text_origin = text;
491  	regmatch_t matches[3];
492  
493  _continue:
494  	while (*pattern) {
495  		if (!*text)
496  			goto err;
497  
498  		/* Skip whitespace */
499  		if (isspace(*pattern) || *pattern == ';') {
500  			if (!isspace(*text) && text != text_origin && isalnum(text[-1]))
501  				goto err;
502  			pattern = skip_space_and_semi(pattern);
503  			text = skip_space(text);
504  			continue;
505  		}
506  
507  		/* Check for variable references */
508  		for (int i = 0; reg_map[i][0]; ++i) {
509  			char *pattern_next, *text_next;
510  
511  			pattern_next = match_str(pattern, reg_map[i][0]);
512  			if (!pattern_next)
513  				continue;
514  
515  			text_next = match_str(text, reg_map[i][1]);
516  			if (!text_next)
517  				goto err;
518  
519  			pattern = pattern_next;
520  			text = text_next;
521  			goto _continue;
522  		}
523  
524  		/* Match field group:
525  		 *   $(sk_buff::cb + qdisc_skb_cb::tc_classid)
526  		 */
527  		if (strncmp(pattern, "$(", 2) == 0) {
528  			char *group_start = pattern, *text_next;
529  			int acc_offset = 0;
530  
531  			pattern += 2;
532  
533  			for (;;) {
534  				int field_offset;
535  
536  				pattern = skip_space(pattern);
537  				if (!*pattern) {
538  					PRINT_FAIL("Unexpected end of pattern\n");
539  					goto err;
540  				}
541  
542  				if (*pattern == ')') {
543  					++pattern;
544  					break;
545  				}
546  
547  				if (*pattern == '+') {
548  					++pattern;
549  					continue;
550  				}
551  
552  				printf("pattern: %s\n", pattern);
553  				if (regexec(field_regex, pattern, 3, matches, 0) != 0) {
554  					PRINT_FAIL("Field reference expected\n");
555  					goto err;
556  				}
557  
558  				field_offset = find_field_offset(btf, pattern, matches);
559  				if (field_offset < 0)
560  					goto err;
561  
562  				pattern += matches[0].rm_eo;
563  				acc_offset += field_offset;
564  			}
565  
566  			text_next = match_number(text, acc_offset);
567  			if (!text_next) {
568  				PRINT_FAIL("No match for group offset %.*s (%d)\n",
569  					   (int)(pattern - group_start),
570  					   group_start,
571  					   acc_offset);
572  				goto err;
573  			}
574  			text = text_next;
575  		}
576  
577  		/* Match field reference:
578  		 *   sk_buff::cb
579  		 */
580  		if (regexec(field_regex, pattern, 3, matches, 0) == 0) {
581  			int field_offset;
582  			char *text_next;
583  
584  			field_offset = find_field_offset(btf, pattern, matches);
585  			if (field_offset < 0)
586  				goto err;
587  
588  			text_next = match_number(text, field_offset);
589  			if (!text_next) {
590  				PRINT_FAIL("No match for field offset %.*s (%d)\n",
591  					   (int)matches[0].rm_eo, pattern, field_offset);
592  				goto err;
593  			}
594  
595  			pattern += matches[0].rm_eo;
596  			text = text_next;
597  			continue;
598  		}
599  
600  		/* If pattern points to identifier not followed by '::'
601  		 * skip the identifier to avoid n^2 application of the
602  		 * field reference rule.
603  		 */
604  		if (regexec(ident_regex, pattern, 1, matches, 0) == 0) {
605  			if (strncmp(pattern, text, matches[0].rm_eo) != 0)
606  				goto err;
607  
608  			pattern += matches[0].rm_eo;
609  			text += matches[0].rm_eo;
610  			continue;
611  		}
612  
613  		/* Match literally */
614  		if (*pattern != *text)
615  			goto err;
616  
617  		++pattern;
618  		++text;
619  	}
620  
621  	return true;
622  
623  err:
624  	test__fail();
625  	print_match_error(stdout, pattern_origin, text_origin, pattern, text);
626  	return false;
627  }
628  
629  /* Request BPF program instructions after all rewrites are applied,
630   * e.g. verifier.c:convert_ctx_access() is done.
631   */
get_xlated_program(int fd_prog,struct bpf_insn ** buf,__u32 * cnt)632  static int get_xlated_program(int fd_prog, struct bpf_insn **buf, __u32 *cnt)
633  {
634  	struct bpf_prog_info info = {};
635  	__u32 info_len = sizeof(info);
636  	__u32 xlated_prog_len;
637  	__u32 buf_element_size = sizeof(struct bpf_insn);
638  
639  	if (bpf_prog_get_info_by_fd(fd_prog, &info, &info_len)) {
640  		perror("bpf_prog_get_info_by_fd failed");
641  		return -1;
642  	}
643  
644  	xlated_prog_len = info.xlated_prog_len;
645  	if (xlated_prog_len % buf_element_size) {
646  		printf("Program length %d is not multiple of %d\n",
647  		       xlated_prog_len, buf_element_size);
648  		return -1;
649  	}
650  
651  	*cnt = xlated_prog_len / buf_element_size;
652  	*buf = calloc(*cnt, buf_element_size);
653  	if (!buf) {
654  		perror("can't allocate xlated program buffer");
655  		return -ENOMEM;
656  	}
657  
658  	bzero(&info, sizeof(info));
659  	info.xlated_prog_len = xlated_prog_len;
660  	info.xlated_prog_insns = (__u64)(unsigned long)*buf;
661  	if (bpf_prog_get_info_by_fd(fd_prog, &info, &info_len)) {
662  		perror("second bpf_prog_get_info_by_fd failed");
663  		goto out_free_buf;
664  	}
665  
666  	return 0;
667  
668  out_free_buf:
669  	free(*buf);
670  	return -1;
671  }
672  
print_insn(void * private_data,const char * fmt,...)673  static void print_insn(void *private_data, const char *fmt, ...)
674  {
675  	va_list args;
676  
677  	va_start(args, fmt);
678  	vfprintf((FILE *)private_data, fmt, args);
679  	va_end(args);
680  }
681  
682  /* Disassemble instructions to a stream */
print_xlated(FILE * out,struct bpf_insn * insn,__u32 len)683  static void print_xlated(FILE *out, struct bpf_insn *insn, __u32 len)
684  {
685  	const struct bpf_insn_cbs cbs = {
686  		.cb_print	= print_insn,
687  		.cb_call	= NULL,
688  		.cb_imm		= NULL,
689  		.private_data	= out,
690  	};
691  	bool double_insn = false;
692  	int i;
693  
694  	for (i = 0; i < len; i++) {
695  		if (double_insn) {
696  			double_insn = false;
697  			continue;
698  		}
699  
700  		double_insn = insn[i].code == (BPF_LD | BPF_IMM | BPF_DW);
701  		print_bpf_insn(&cbs, insn + i, true);
702  	}
703  }
704  
705  /* We share code with kernel BPF disassembler, it adds '(FF) ' prefix
706   * for each instruction (FF stands for instruction `code` byte).
707   * This function removes the prefix inplace for each line in `str`.
708   */
remove_insn_prefix(char * str,int size)709  static void remove_insn_prefix(char *str, int size)
710  {
711  	const int prefix_size = 5;
712  
713  	int write_pos = 0, read_pos = prefix_size;
714  	int len = strlen(str);
715  	char c;
716  
717  	size = min(size, len);
718  
719  	while (read_pos < size) {
720  		c = str[read_pos++];
721  		if (c == 0)
722  			break;
723  		str[write_pos++] = c;
724  		if (c == '\n')
725  			read_pos += prefix_size;
726  	}
727  	str[write_pos] = 0;
728  }
729  
730  struct prog_info {
731  	char *prog_kind;
732  	enum bpf_prog_type prog_type;
733  	enum bpf_attach_type expected_attach_type;
734  	struct bpf_insn *prog;
735  	u32 prog_len;
736  };
737  
match_program(struct btf * btf,struct prog_info * pinfo,char * pattern,char * reg_map[][2],bool skip_first_insn)738  static void match_program(struct btf *btf,
739  			  struct prog_info *pinfo,
740  			  char *pattern,
741  			  char *reg_map[][2],
742  			  bool skip_first_insn)
743  {
744  	struct bpf_insn *buf = NULL;
745  	int err = 0, prog_fd = 0;
746  	FILE *prog_out = NULL;
747  	char *text = NULL;
748  	__u32 cnt = 0;
749  
750  	text = calloc(MAX_PROG_TEXT_SZ, 1);
751  	if (!text) {
752  		PRINT_FAIL("Can't allocate %d bytes\n", MAX_PROG_TEXT_SZ);
753  		goto out;
754  	}
755  
756  	// TODO: log level
757  	LIBBPF_OPTS(bpf_prog_load_opts, opts);
758  	opts.log_buf = text;
759  	opts.log_size = MAX_PROG_TEXT_SZ;
760  	opts.log_level = 1 | 2 | 4;
761  	opts.expected_attach_type = pinfo->expected_attach_type;
762  
763  	prog_fd = bpf_prog_load(pinfo->prog_type, NULL, "GPL",
764  				pinfo->prog, pinfo->prog_len, &opts);
765  	if (prog_fd < 0) {
766  		PRINT_FAIL("Can't load program, errno %d (%s), verifier log:\n%s\n",
767  			   errno, strerror(errno), text);
768  		goto out;
769  	}
770  
771  	memset(text, 0, MAX_PROG_TEXT_SZ);
772  
773  	err = get_xlated_program(prog_fd, &buf, &cnt);
774  	if (err) {
775  		PRINT_FAIL("Can't load back BPF program\n");
776  		goto out;
777  	}
778  
779  	prog_out = fmemopen(text, MAX_PROG_TEXT_SZ - 1, "w");
780  	if (!prog_out) {
781  		PRINT_FAIL("Can't open memory stream\n");
782  		goto out;
783  	}
784  	if (skip_first_insn)
785  		print_xlated(prog_out, buf + 1, cnt - 1);
786  	else
787  		print_xlated(prog_out, buf, cnt);
788  	fclose(prog_out);
789  	remove_insn_prefix(text, MAX_PROG_TEXT_SZ);
790  
791  	ASSERT_TRUE(match_pattern(btf, pattern, text, reg_map),
792  		    pinfo->prog_kind);
793  
794  out:
795  	if (prog_fd)
796  		close(prog_fd);
797  	free(buf);
798  	free(text);
799  }
800  
run_one_testcase(struct btf * btf,struct test_case * test)801  static void run_one_testcase(struct btf *btf, struct test_case *test)
802  {
803  	struct prog_info pinfo = {};
804  	int bpf_sz;
805  
806  	if (!test__start_subtest(test->name))
807  		return;
808  
809  	switch (test->field_sz) {
810  	case 8:
811  		bpf_sz = BPF_DW;
812  		break;
813  	case 4:
814  		bpf_sz = BPF_W;
815  		break;
816  	case 2:
817  		bpf_sz = BPF_H;
818  		break;
819  	case 1:
820  		bpf_sz = BPF_B;
821  		break;
822  	default:
823  		PRINT_FAIL("Unexpected field size: %d, want 8,4,2 or 1\n", test->field_sz);
824  		return;
825  	}
826  
827  	pinfo.prog_type = test->prog_type;
828  	pinfo.expected_attach_type = test->expected_attach_type;
829  
830  	if (test->read) {
831  		struct bpf_insn ldx_prog[] = {
832  			BPF_LDX_MEM(bpf_sz, BPF_REG_2, BPF_REG_1, test->field_offset),
833  			BPF_MOV64_IMM(BPF_REG_0, 0),
834  			BPF_EXIT_INSN(),
835  		};
836  		char *reg_map[][2] = {
837  			{ "$ctx", "r1" },
838  			{ "$dst", "r2" },
839  			{}
840  		};
841  
842  		pinfo.prog_kind = "LDX";
843  		pinfo.prog = ldx_prog;
844  		pinfo.prog_len = ARRAY_SIZE(ldx_prog);
845  		match_program(btf, &pinfo, test->read, reg_map, false);
846  	}
847  
848  	if (test->write || test->write_st || test->write_stx) {
849  		struct bpf_insn stx_prog[] = {
850  			BPF_MOV64_IMM(BPF_REG_2, 0),
851  			BPF_STX_MEM(bpf_sz, BPF_REG_1, BPF_REG_2, test->field_offset),
852  			BPF_MOV64_IMM(BPF_REG_0, 0),
853  			BPF_EXIT_INSN(),
854  		};
855  		char *stx_reg_map[][2] = {
856  			{ "$ctx", "r1" },
857  			{ "$src", "r2" },
858  			{}
859  		};
860  		struct bpf_insn st_prog[] = {
861  			BPF_ST_MEM(bpf_sz, BPF_REG_1, test->field_offset,
862  				   test->st_value.use ? test->st_value.value : 42),
863  			BPF_MOV64_IMM(BPF_REG_0, 0),
864  			BPF_EXIT_INSN(),
865  		};
866  		char *st_reg_map[][2] = {
867  			{ "$ctx", "r1" },
868  			{ "$src", "42" },
869  			{}
870  		};
871  
872  		if (test->write || test->write_stx) {
873  			char *pattern = test->write_stx ? test->write_stx : test->write;
874  
875  			pinfo.prog_kind = "STX";
876  			pinfo.prog = stx_prog;
877  			pinfo.prog_len = ARRAY_SIZE(stx_prog);
878  			match_program(btf, &pinfo, pattern, stx_reg_map, true);
879  		}
880  
881  		if (test->write || test->write_st) {
882  			char *pattern = test->write_st ? test->write_st : test->write;
883  
884  			pinfo.prog_kind = "ST";
885  			pinfo.prog = st_prog;
886  			pinfo.prog_len = ARRAY_SIZE(st_prog);
887  			match_program(btf, &pinfo, pattern, st_reg_map, false);
888  		}
889  	}
890  
891  	test__end_subtest();
892  }
893  
test_ctx_rewrite(void)894  void test_ctx_rewrite(void)
895  {
896  	struct btf *btf;
897  	int i;
898  
899  	field_regex = compile_regex("^([[:alpha:]_][[:alnum:]_]+)::([[:alpha:]_][[:alnum:]_]+)");
900  	ident_regex = compile_regex("^[[:alpha:]_][[:alnum:]_]+");
901  	if (!field_regex || !ident_regex)
902  		return;
903  
904  	btf = btf__load_vmlinux_btf();
905  	if (!btf) {
906  		PRINT_FAIL("Can't load vmlinux BTF, errno %d (%s)\n", errno, strerror(errno));
907  		goto out;
908  	}
909  
910  	for (i = 0; i < ARRAY_SIZE(test_cases); ++i)
911  		run_one_testcase(btf, &test_cases[i]);
912  
913  out:
914  	btf__free(btf);
915  	free_regex(field_regex);
916  	free_regex(ident_regex);
917  }
918