1 // SPDX-License-Identifier: GPL-2.0
2 /* Copyright (C) 2021. Huawei Technologies Co., Ltd */
3 #include <test_progs.h>
4 #include "strncmp_test.skel.h"
5 
6 static int trigger_strncmp(const struct strncmp_test *skel)
7 {
8 	int cmp;
9 
10 	usleep(1);
11 
12 	cmp = skel->bss->cmp_ret;
13 	if (cmp > 0)
14 		return 1;
15 	if (cmp < 0)
16 		return -1;
17 	return 0;
18 }
19 
20 /*
21  * Compare str and target after making str[i] != target[i].
22  * When exp is -1, make str[i] < target[i] and delta = -1.
23  */
24 static void strncmp_full_str_cmp(struct strncmp_test *skel, const char *name,
25 				 int exp)
26 {
27 	size_t nr = sizeof(skel->bss->str);
28 	char *str = skel->bss->str;
29 	int delta = exp;
30 	int got;
31 	size_t i;
32 
33 	memcpy(str, skel->rodata->target, nr);
34 	for (i = 0; i < nr - 1; i++) {
35 		str[i] += delta;
36 
37 		got = trigger_strncmp(skel);
38 		ASSERT_EQ(got, exp, name);
39 
40 		str[i] -= delta;
41 	}
42 }
43 
44 static void test_strncmp_ret(void)
45 {
46 	struct strncmp_test *skel;
47 	struct bpf_program *prog;
48 	int err, got;
49 
50 	skel = strncmp_test__open();
51 	if (!ASSERT_OK_PTR(skel, "strncmp_test open"))
52 		return;
53 
54 	bpf_object__for_each_program(prog, skel->obj)
55 		bpf_program__set_autoload(prog, false);
56 
57 	bpf_program__set_autoload(skel->progs.do_strncmp, true);
58 
59 	err = strncmp_test__load(skel);
60 	if (!ASSERT_EQ(err, 0, "strncmp_test load"))
61 		goto out;
62 
63 	err = strncmp_test__attach(skel);
64 	if (!ASSERT_EQ(err, 0, "strncmp_test attach"))
65 		goto out;
66 
67 	skel->bss->target_pid = getpid();
68 
69 	/* Empty str */
70 	skel->bss->str[0] = '\0';
71 	got = trigger_strncmp(skel);
72 	ASSERT_EQ(got, -1, "strncmp: empty str");
73 
74 	/* Same string */
75 	memcpy(skel->bss->str, skel->rodata->target, sizeof(skel->bss->str));
76 	got = trigger_strncmp(skel);
77 	ASSERT_EQ(got, 0, "strncmp: same str");
78 
79 	/* Not-null-termainted string  */
80 	memcpy(skel->bss->str, skel->rodata->target, sizeof(skel->bss->str));
81 	skel->bss->str[sizeof(skel->bss->str) - 1] = 'A';
82 	got = trigger_strncmp(skel);
83 	ASSERT_EQ(got, 1, "strncmp: not-null-term str");
84 
85 	strncmp_full_str_cmp(skel, "strncmp: less than", -1);
86 	strncmp_full_str_cmp(skel, "strncmp: greater than", 1);
87 out:
88 	strncmp_test__destroy(skel);
89 }
90 
91 static void test_strncmp_bad_not_const_str_size(void)
92 {
93 	struct strncmp_test *skel;
94 	struct bpf_program *prog;
95 	int err;
96 
97 	skel = strncmp_test__open();
98 	if (!ASSERT_OK_PTR(skel, "strncmp_test open"))
99 		return;
100 
101 	bpf_object__for_each_program(prog, skel->obj)
102 		bpf_program__set_autoload(prog, false);
103 
104 	bpf_program__set_autoload(skel->progs.strncmp_bad_not_const_str_size,
105 				  true);
106 
107 	err = strncmp_test__load(skel);
108 	ASSERT_ERR(err, "strncmp_test load bad_not_const_str_size");
109 
110 	strncmp_test__destroy(skel);
111 }
112 
113 static void test_strncmp_bad_writable_target(void)
114 {
115 	struct strncmp_test *skel;
116 	struct bpf_program *prog;
117 	int err;
118 
119 	skel = strncmp_test__open();
120 	if (!ASSERT_OK_PTR(skel, "strncmp_test open"))
121 		return;
122 
123 	bpf_object__for_each_program(prog, skel->obj)
124 		bpf_program__set_autoload(prog, false);
125 
126 	bpf_program__set_autoload(skel->progs.strncmp_bad_writable_target,
127 				  true);
128 
129 	err = strncmp_test__load(skel);
130 	ASSERT_ERR(err, "strncmp_test load bad_writable_target");
131 
132 	strncmp_test__destroy(skel);
133 }
134 
135 static void test_strncmp_bad_not_null_term_target(void)
136 {
137 	struct strncmp_test *skel;
138 	struct bpf_program *prog;
139 	int err;
140 
141 	skel = strncmp_test__open();
142 	if (!ASSERT_OK_PTR(skel, "strncmp_test open"))
143 		return;
144 
145 	bpf_object__for_each_program(prog, skel->obj)
146 		bpf_program__set_autoload(prog, false);
147 
148 	bpf_program__set_autoload(skel->progs.strncmp_bad_not_null_term_target,
149 				  true);
150 
151 	err = strncmp_test__load(skel);
152 	ASSERT_ERR(err, "strncmp_test load bad_not_null_term_target");
153 
154 	strncmp_test__destroy(skel);
155 }
156 
157 void test_test_strncmp(void)
158 {
159 	if (test__start_subtest("strncmp_ret"))
160 		test_strncmp_ret();
161 	if (test__start_subtest("strncmp_bad_not_const_str_size"))
162 		test_strncmp_bad_not_const_str_size();
163 	if (test__start_subtest("strncmp_bad_writable_target"))
164 		test_strncmp_bad_writable_target();
165 	if (test__start_subtest("strncmp_bad_not_null_term_target"))
166 		test_strncmp_bad_not_null_term_target();
167 }
168