1 /* SPDX-License-Identifier: GPL-2.0 */
2 #define _GNU_SOURCE
3 #include <linux/membarrier.h>
4 #include <syscall.h>
5 #include <stdio.h>
6 #include <errno.h>
7 #include <string.h>
8 #include <pthread.h>
9 
10 #include "../kselftest.h"
11 
12 static int registrations;
13 
14 static int sys_membarrier(int cmd, int flags)
15 {
16 	return syscall(__NR_membarrier, cmd, flags);
17 }
18 
19 static int test_membarrier_get_registrations(int cmd)
20 {
21 	int ret, flags = 0;
22 	const char *test_name =
23 		"sys membarrier MEMBARRIER_CMD_GET_REGISTRATIONS";
24 
25 	registrations |= cmd;
26 
27 	ret = sys_membarrier(MEMBARRIER_CMD_GET_REGISTRATIONS, 0);
28 	if (ret < 0) {
29 		ksft_exit_fail_msg(
30 			"%s test: flags = %d, errno = %d\n",
31 			test_name, flags, errno);
32 	} else if (ret != registrations) {
33 		ksft_exit_fail_msg(
34 			"%s test: flags = %d, ret = %d, registrations = %d\n",
35 			test_name, flags, ret, registrations);
36 	}
37 	ksft_test_result_pass(
38 		"%s test: flags = %d, ret = %d, registrations = %d\n",
39 		test_name, flags, ret, registrations);
40 
41 	return 0;
42 }
43 
44 static int test_membarrier_cmd_fail(void)
45 {
46 	int cmd = -1, flags = 0;
47 	const char *test_name = "sys membarrier invalid command";
48 
49 	if (sys_membarrier(cmd, flags) != -1) {
50 		ksft_exit_fail_msg(
51 			"%s test: command = %d, flags = %d. Should fail, but passed\n",
52 			test_name, cmd, flags);
53 	}
54 	if (errno != EINVAL) {
55 		ksft_exit_fail_msg(
56 			"%s test: flags = %d. Should return (%d: \"%s\"), but returned (%d: \"%s\").\n",
57 			test_name, flags, EINVAL, strerror(EINVAL),
58 			errno, strerror(errno));
59 	}
60 
61 	ksft_test_result_pass(
62 		"%s test: command = %d, flags = %d, errno = %d. Failed as expected\n",
63 		test_name, cmd, flags, errno);
64 	return 0;
65 }
66 
67 static int test_membarrier_flags_fail(void)
68 {
69 	int cmd = MEMBARRIER_CMD_QUERY, flags = 1;
70 	const char *test_name = "sys membarrier MEMBARRIER_CMD_QUERY invalid flags";
71 
72 	if (sys_membarrier(cmd, flags) != -1) {
73 		ksft_exit_fail_msg(
74 			"%s test: flags = %d. Should fail, but passed\n",
75 			test_name, flags);
76 	}
77 	if (errno != EINVAL) {
78 		ksft_exit_fail_msg(
79 			"%s test: flags = %d. Should return (%d: \"%s\"), but returned (%d: \"%s\").\n",
80 			test_name, flags, EINVAL, strerror(EINVAL),
81 			errno, strerror(errno));
82 	}
83 
84 	ksft_test_result_pass(
85 		"%s test: flags = %d, errno = %d. Failed as expected\n",
86 		test_name, flags, errno);
87 	return 0;
88 }
89 
90 static int test_membarrier_global_success(void)
91 {
92 	int cmd = MEMBARRIER_CMD_GLOBAL, flags = 0;
93 	const char *test_name = "sys membarrier MEMBARRIER_CMD_GLOBAL";
94 
95 	if (sys_membarrier(cmd, flags) != 0) {
96 		ksft_exit_fail_msg(
97 			"%s test: flags = %d, errno = %d\n",
98 			test_name, flags, errno);
99 	}
100 
101 	ksft_test_result_pass(
102 		"%s test: flags = %d\n", test_name, flags);
103 	return 0;
104 }
105 
106 static int test_membarrier_private_expedited_fail(void)
107 {
108 	int cmd = MEMBARRIER_CMD_PRIVATE_EXPEDITED, flags = 0;
109 	const char *test_name = "sys membarrier MEMBARRIER_CMD_PRIVATE_EXPEDITED not registered failure";
110 
111 	if (sys_membarrier(cmd, flags) != -1) {
112 		ksft_exit_fail_msg(
113 			"%s test: flags = %d. Should fail, but passed\n",
114 			test_name, flags);
115 	}
116 	if (errno != EPERM) {
117 		ksft_exit_fail_msg(
118 			"%s test: flags = %d. Should return (%d: \"%s\"), but returned (%d: \"%s\").\n",
119 			test_name, flags, EPERM, strerror(EPERM),
120 			errno, strerror(errno));
121 	}
122 
123 	ksft_test_result_pass(
124 		"%s test: flags = %d, errno = %d\n",
125 		test_name, flags, errno);
126 	return 0;
127 }
128 
129 static int test_membarrier_register_private_expedited_success(void)
130 {
131 	int cmd = MEMBARRIER_CMD_REGISTER_PRIVATE_EXPEDITED, flags = 0;
132 	const char *test_name = "sys membarrier MEMBARRIER_CMD_REGISTER_PRIVATE_EXPEDITED";
133 
134 	if (sys_membarrier(cmd, flags) != 0) {
135 		ksft_exit_fail_msg(
136 			"%s test: flags = %d, errno = %d\n",
137 			test_name, flags, errno);
138 	}
139 
140 	ksft_test_result_pass(
141 		"%s test: flags = %d\n",
142 		test_name, flags);
143 
144 	test_membarrier_get_registrations(cmd);
145 	return 0;
146 }
147 
148 static int test_membarrier_private_expedited_success(void)
149 {
150 	int cmd = MEMBARRIER_CMD_PRIVATE_EXPEDITED, flags = 0;
151 	const char *test_name = "sys membarrier MEMBARRIER_CMD_PRIVATE_EXPEDITED";
152 
153 	if (sys_membarrier(cmd, flags) != 0) {
154 		ksft_exit_fail_msg(
155 			"%s test: flags = %d, errno = %d\n",
156 			test_name, flags, errno);
157 	}
158 
159 	ksft_test_result_pass(
160 		"%s test: flags = %d\n",
161 		test_name, flags);
162 	return 0;
163 }
164 
165 static int test_membarrier_private_expedited_sync_core_fail(void)
166 {
167 	int cmd = MEMBARRIER_CMD_PRIVATE_EXPEDITED_SYNC_CORE, flags = 0;
168 	const char *test_name = "sys membarrier MEMBARRIER_CMD_PRIVATE_EXPEDITED_SYNC_CORE not registered failure";
169 
170 	if (sys_membarrier(cmd, flags) != -1) {
171 		ksft_exit_fail_msg(
172 			"%s test: flags = %d. Should fail, but passed\n",
173 			test_name, flags);
174 	}
175 	if (errno != EPERM) {
176 		ksft_exit_fail_msg(
177 			"%s test: flags = %d. Should return (%d: \"%s\"), but returned (%d: \"%s\").\n",
178 			test_name, flags, EPERM, strerror(EPERM),
179 			errno, strerror(errno));
180 	}
181 
182 	ksft_test_result_pass(
183 		"%s test: flags = %d, errno = %d\n",
184 		test_name, flags, errno);
185 	return 0;
186 }
187 
188 static int test_membarrier_register_private_expedited_sync_core_success(void)
189 {
190 	int cmd = MEMBARRIER_CMD_REGISTER_PRIVATE_EXPEDITED_SYNC_CORE, flags = 0;
191 	const char *test_name = "sys membarrier MEMBARRIER_CMD_REGISTER_PRIVATE_EXPEDITED_SYNC_CORE";
192 
193 	if (sys_membarrier(cmd, flags) != 0) {
194 		ksft_exit_fail_msg(
195 			"%s test: flags = %d, errno = %d\n",
196 			test_name, flags, errno);
197 	}
198 
199 	ksft_test_result_pass(
200 		"%s test: flags = %d\n",
201 		test_name, flags);
202 
203 	test_membarrier_get_registrations(cmd);
204 	return 0;
205 }
206 
207 static int test_membarrier_private_expedited_sync_core_success(void)
208 {
209 	int cmd = MEMBARRIER_CMD_PRIVATE_EXPEDITED, flags = 0;
210 	const char *test_name = "sys membarrier MEMBARRIER_CMD_PRIVATE_EXPEDITED_SYNC_CORE";
211 
212 	if (sys_membarrier(cmd, flags) != 0) {
213 		ksft_exit_fail_msg(
214 			"%s test: flags = %d, errno = %d\n",
215 			test_name, flags, errno);
216 	}
217 
218 	ksft_test_result_pass(
219 		"%s test: flags = %d\n",
220 		test_name, flags);
221 	return 0;
222 }
223 
224 static int test_membarrier_register_global_expedited_success(void)
225 {
226 	int cmd = MEMBARRIER_CMD_REGISTER_GLOBAL_EXPEDITED, flags = 0;
227 	const char *test_name = "sys membarrier MEMBARRIER_CMD_REGISTER_GLOBAL_EXPEDITED";
228 
229 	if (sys_membarrier(cmd, flags) != 0) {
230 		ksft_exit_fail_msg(
231 			"%s test: flags = %d, errno = %d\n",
232 			test_name, flags, errno);
233 	}
234 
235 	ksft_test_result_pass(
236 		"%s test: flags = %d\n",
237 		test_name, flags);
238 
239 	test_membarrier_get_registrations(cmd);
240 	return 0;
241 }
242 
243 static int test_membarrier_global_expedited_success(void)
244 {
245 	int cmd = MEMBARRIER_CMD_GLOBAL_EXPEDITED, flags = 0;
246 	const char *test_name = "sys membarrier MEMBARRIER_CMD_GLOBAL_EXPEDITED";
247 
248 	if (sys_membarrier(cmd, flags) != 0) {
249 		ksft_exit_fail_msg(
250 			"%s test: flags = %d, errno = %d\n",
251 			test_name, flags, errno);
252 	}
253 
254 	ksft_test_result_pass(
255 		"%s test: flags = %d\n",
256 		test_name, flags);
257 	return 0;
258 }
259 
260 static int test_membarrier_fail(void)
261 {
262 	int status;
263 
264 	status = test_membarrier_cmd_fail();
265 	if (status)
266 		return status;
267 	status = test_membarrier_flags_fail();
268 	if (status)
269 		return status;
270 	status = test_membarrier_private_expedited_fail();
271 	if (status)
272 		return status;
273 	status = sys_membarrier(MEMBARRIER_CMD_QUERY, 0);
274 	if (status < 0) {
275 		ksft_test_result_fail("sys_membarrier() failed\n");
276 		return status;
277 	}
278 	if (status & MEMBARRIER_CMD_PRIVATE_EXPEDITED_SYNC_CORE) {
279 		status = test_membarrier_private_expedited_sync_core_fail();
280 		if (status)
281 			return status;
282 	}
283 	return 0;
284 }
285 
286 static int test_membarrier_success(void)
287 {
288 	int status;
289 
290 	status = test_membarrier_global_success();
291 	if (status)
292 		return status;
293 	status = test_membarrier_register_private_expedited_success();
294 	if (status)
295 		return status;
296 	status = test_membarrier_private_expedited_success();
297 	if (status)
298 		return status;
299 	status = sys_membarrier(MEMBARRIER_CMD_QUERY, 0);
300 	if (status < 0) {
301 		ksft_test_result_fail("sys_membarrier() failed\n");
302 		return status;
303 	}
304 	if (status & MEMBARRIER_CMD_PRIVATE_EXPEDITED_SYNC_CORE) {
305 		status = test_membarrier_register_private_expedited_sync_core_success();
306 		if (status)
307 			return status;
308 		status = test_membarrier_private_expedited_sync_core_success();
309 		if (status)
310 			return status;
311 	}
312 	/*
313 	 * It is valid to send a global membarrier from a non-registered
314 	 * process.
315 	 */
316 	status = test_membarrier_global_expedited_success();
317 	if (status)
318 		return status;
319 	status = test_membarrier_register_global_expedited_success();
320 	if (status)
321 		return status;
322 	status = test_membarrier_global_expedited_success();
323 	if (status)
324 		return status;
325 	return 0;
326 }
327 
328 static int test_membarrier_query(void)
329 {
330 	int flags = 0, ret;
331 
332 	ret = sys_membarrier(MEMBARRIER_CMD_QUERY, flags);
333 	if (ret < 0) {
334 		if (errno == ENOSYS) {
335 			/*
336 			 * It is valid to build a kernel with
337 			 * CONFIG_MEMBARRIER=n. However, this skips the tests.
338 			 */
339 			ksft_exit_skip(
340 				"sys membarrier (CONFIG_MEMBARRIER) is disabled.\n");
341 		}
342 		ksft_exit_fail_msg("sys_membarrier() failed\n");
343 	}
344 	if (!(ret & MEMBARRIER_CMD_GLOBAL))
345 		ksft_exit_skip(
346 			"sys_membarrier unsupported: CMD_GLOBAL not found.\n");
347 
348 	ksft_test_result_pass("sys_membarrier available\n");
349 	return 0;
350 }
351