1 // SPDX-License-Identifier: GPL-2.0-only
2 /*
3  * Copyright (C) 2021 ARM Limited.
4  */
5 
6 #include <errno.h>
7 #include <stdbool.h>
8 #include <stddef.h>
9 #include <stdio.h>
10 #include <stdlib.h>
11 #include <string.h>
12 #include <unistd.h>
13 #include <sys/auxv.h>
14 #include <sys/prctl.h>
15 #include <asm/hwcap.h>
16 #include <asm/sigcontext.h>
17 #include <asm/unistd.h>
18 
19 #include "../../kselftest.h"
20 
21 #include "syscall-abi.h"
22 
23 #define NUM_VL ((SVE_VQ_MAX - SVE_VQ_MIN) + 1)
24 
25 static int default_sme_vl;
26 
27 extern void do_syscall(int sve_vl, int sme_vl);
28 
29 static void fill_random(void *buf, size_t size)
30 {
31 	int i;
32 	uint32_t *lbuf = buf;
33 
34 	/* random() returns a 32 bit number regardless of the size of long */
35 	for (i = 0; i < size / sizeof(uint32_t); i++)
36 		lbuf[i] = random();
37 }
38 
39 /*
40  * We also repeat the test for several syscalls to try to expose different
41  * behaviour.
42  */
43 static struct syscall_cfg {
44 	int syscall_nr;
45 	const char *name;
46 } syscalls[] = {
47 	{ __NR_getpid,		"getpid()" },
48 	{ __NR_sched_yield,	"sched_yield()" },
49 };
50 
51 #define NUM_GPR 31
52 uint64_t gpr_in[NUM_GPR];
53 uint64_t gpr_out[NUM_GPR];
54 
55 static void setup_gpr(struct syscall_cfg *cfg, int sve_vl, int sme_vl,
56 		      uint64_t svcr)
57 {
58 	fill_random(gpr_in, sizeof(gpr_in));
59 	gpr_in[8] = cfg->syscall_nr;
60 	memset(gpr_out, 0, sizeof(gpr_out));
61 }
62 
63 static int check_gpr(struct syscall_cfg *cfg, int sve_vl, int sme_vl, uint64_t svcr)
64 {
65 	int errors = 0;
66 	int i;
67 
68 	/*
69 	 * GPR x0-x7 may be clobbered, and all others should be preserved.
70 	 */
71 	for (i = 9; i < ARRAY_SIZE(gpr_in); i++) {
72 		if (gpr_in[i] != gpr_out[i]) {
73 			ksft_print_msg("%s SVE VL %d mismatch in GPR %d: %llx != %llx\n",
74 				       cfg->name, sve_vl, i,
75 				       gpr_in[i], gpr_out[i]);
76 			errors++;
77 		}
78 	}
79 
80 	return errors;
81 }
82 
83 #define NUM_FPR 32
84 uint64_t fpr_in[NUM_FPR * 2];
85 uint64_t fpr_out[NUM_FPR * 2];
86 
87 static void setup_fpr(struct syscall_cfg *cfg, int sve_vl, int sme_vl,
88 		      uint64_t svcr)
89 {
90 	fill_random(fpr_in, sizeof(fpr_in));
91 	memset(fpr_out, 0, sizeof(fpr_out));
92 }
93 
94 static int check_fpr(struct syscall_cfg *cfg, int sve_vl, int sme_vl,
95 		     uint64_t svcr)
96 {
97 	int errors = 0;
98 	int i;
99 
100 	if (!sve_vl) {
101 		for (i = 0; i < ARRAY_SIZE(fpr_in); i++) {
102 			if (fpr_in[i] != fpr_out[i]) {
103 				ksft_print_msg("%s Q%d/%d mismatch %llx != %llx\n",
104 					       cfg->name,
105 					       i / 2, i % 2,
106 					       fpr_in[i], fpr_out[i]);
107 				errors++;
108 			}
109 		}
110 	}
111 
112 	return errors;
113 }
114 
115 static uint8_t z_zero[__SVE_ZREG_SIZE(SVE_VQ_MAX)];
116 uint8_t z_in[SVE_NUM_PREGS * __SVE_ZREG_SIZE(SVE_VQ_MAX)];
117 uint8_t z_out[SVE_NUM_PREGS * __SVE_ZREG_SIZE(SVE_VQ_MAX)];
118 
119 static void setup_z(struct syscall_cfg *cfg, int sve_vl, int sme_vl,
120 		    uint64_t svcr)
121 {
122 	fill_random(z_in, sizeof(z_in));
123 	fill_random(z_out, sizeof(z_out));
124 }
125 
126 static int check_z(struct syscall_cfg *cfg, int sve_vl, int sme_vl,
127 		   uint64_t svcr)
128 {
129 	size_t reg_size = sve_vl;
130 	int errors = 0;
131 	int i;
132 
133 	if (!sve_vl)
134 		return 0;
135 
136 	/*
137 	 * After a syscall the low 128 bits of the Z registers should
138 	 * be preserved and the rest be zeroed or preserved, except if
139 	 * we were in streaming mode in which case the low 128 bits may
140 	 * also be cleared by the transition out of streaming mode.
141 	 */
142 	for (i = 0; i < SVE_NUM_ZREGS; i++) {
143 		void *in = &z_in[reg_size * i];
144 		void *out = &z_out[reg_size * i];
145 
146 		if ((memcmp(in, out, SVE_VQ_BYTES) != 0) &&
147 		    !((svcr & SVCR_SM_MASK) &&
148 		      memcmp(z_zero, out, SVE_VQ_BYTES) == 0)) {
149 			ksft_print_msg("%s SVE VL %d Z%d low 128 bits changed\n",
150 				       cfg->name, sve_vl, i);
151 			errors++;
152 		}
153 	}
154 
155 	return errors;
156 }
157 
158 uint8_t p_in[SVE_NUM_PREGS * __SVE_PREG_SIZE(SVE_VQ_MAX)];
159 uint8_t p_out[SVE_NUM_PREGS * __SVE_PREG_SIZE(SVE_VQ_MAX)];
160 
161 static void setup_p(struct syscall_cfg *cfg, int sve_vl, int sme_vl,
162 		    uint64_t svcr)
163 {
164 	fill_random(p_in, sizeof(p_in));
165 	fill_random(p_out, sizeof(p_out));
166 }
167 
168 static int check_p(struct syscall_cfg *cfg, int sve_vl, int sme_vl,
169 		   uint64_t svcr)
170 {
171 	size_t reg_size = sve_vq_from_vl(sve_vl) * 2; /* 1 bit per VL byte */
172 
173 	int errors = 0;
174 	int i;
175 
176 	if (!sve_vl)
177 		return 0;
178 
179 	/* After a syscall the P registers should be preserved or zeroed */
180 	for (i = 0; i < SVE_NUM_PREGS * reg_size; i++)
181 		if (p_out[i] && (p_in[i] != p_out[i]))
182 			errors++;
183 	if (errors)
184 		ksft_print_msg("%s SVE VL %d predicate registers non-zero\n",
185 			       cfg->name, sve_vl);
186 
187 	return errors;
188 }
189 
190 uint8_t ffr_in[__SVE_PREG_SIZE(SVE_VQ_MAX)];
191 uint8_t ffr_out[__SVE_PREG_SIZE(SVE_VQ_MAX)];
192 
193 static void setup_ffr(struct syscall_cfg *cfg, int sve_vl, int sme_vl,
194 		      uint64_t svcr)
195 {
196 	/*
197 	 * If we are in streaming mode and do not have FA64 then FFR
198 	 * is unavailable.
199 	 */
200 	if ((svcr & SVCR_SM_MASK) &&
201 	    !(getauxval(AT_HWCAP2) & HWCAP2_SME_FA64)) {
202 		memset(&ffr_in, 0, sizeof(ffr_in));
203 		return;
204 	}
205 
206 	/*
207 	 * It is only valid to set a contiguous set of bits starting
208 	 * at 0.  For now since we're expecting this to be cleared by
209 	 * a syscall just set all bits.
210 	 */
211 	memset(ffr_in, 0xff, sizeof(ffr_in));
212 	fill_random(ffr_out, sizeof(ffr_out));
213 }
214 
215 static int check_ffr(struct syscall_cfg *cfg, int sve_vl, int sme_vl,
216 		     uint64_t svcr)
217 {
218 	size_t reg_size = sve_vq_from_vl(sve_vl) * 2;  /* 1 bit per VL byte */
219 	int errors = 0;
220 	int i;
221 
222 	if (!sve_vl)
223 		return 0;
224 
225 	if ((svcr & SVCR_SM_MASK) &&
226 	    !(getauxval(AT_HWCAP2) & HWCAP2_SME_FA64))
227 		return 0;
228 
229 	/* After a syscall the P registers should be preserved or zeroed */
230 	for (i = 0; i < reg_size; i++)
231 		if (ffr_out[i] && (ffr_in[i] != ffr_out[i]))
232 			errors++;
233 	if (errors)
234 		ksft_print_msg("%s SVE VL %d FFR non-zero\n",
235 			       cfg->name, sve_vl);
236 
237 	return errors;
238 }
239 
240 uint64_t svcr_in, svcr_out;
241 
242 static void setup_svcr(struct syscall_cfg *cfg, int sve_vl, int sme_vl,
243 		    uint64_t svcr)
244 {
245 	svcr_in = svcr;
246 }
247 
248 static int check_svcr(struct syscall_cfg *cfg, int sve_vl, int sme_vl,
249 		      uint64_t svcr)
250 {
251 	int errors = 0;
252 
253 	if (svcr_out & SVCR_SM_MASK) {
254 		ksft_print_msg("%s Still in SM, SVCR %llx\n",
255 			       cfg->name, svcr_out);
256 		errors++;
257 	}
258 
259 	if ((svcr_in & SVCR_ZA_MASK) != (svcr_out & SVCR_ZA_MASK)) {
260 		ksft_print_msg("%s PSTATE.ZA changed, SVCR %llx != %llx\n",
261 			       cfg->name, svcr_in, svcr_out);
262 		errors++;
263 	}
264 
265 	return errors;
266 }
267 
268 uint8_t za_in[SVE_NUM_PREGS * __SVE_ZREG_SIZE(SVE_VQ_MAX)];
269 uint8_t za_out[SVE_NUM_PREGS * __SVE_ZREG_SIZE(SVE_VQ_MAX)];
270 
271 static void setup_za(struct syscall_cfg *cfg, int sve_vl, int sme_vl,
272 		     uint64_t svcr)
273 {
274 	fill_random(za_in, sizeof(za_in));
275 	memset(za_out, 0, sizeof(za_out));
276 }
277 
278 static int check_za(struct syscall_cfg *cfg, int sve_vl, int sme_vl,
279 		    uint64_t svcr)
280 {
281 	size_t reg_size = sme_vl * sme_vl;
282 	int errors = 0;
283 
284 	if (!(svcr & SVCR_ZA_MASK))
285 		return 0;
286 
287 	if (memcmp(za_in, za_out, reg_size) != 0) {
288 		ksft_print_msg("SME VL %d ZA does not match\n", sme_vl);
289 		errors++;
290 	}
291 
292 	return errors;
293 }
294 
295 typedef void (*setup_fn)(struct syscall_cfg *cfg, int sve_vl, int sme_vl,
296 			 uint64_t svcr);
297 typedef int (*check_fn)(struct syscall_cfg *cfg, int sve_vl, int sme_vl,
298 			uint64_t svcr);
299 
300 /*
301  * Each set of registers has a setup function which is called before
302  * the syscall to fill values in a global variable for loading by the
303  * test code and a check function which validates that the results are
304  * as expected.  Vector lengths are passed everywhere, a vector length
305  * of 0 should be treated as do not test.
306  */
307 static struct {
308 	setup_fn setup;
309 	check_fn check;
310 } regset[] = {
311 	{ setup_gpr, check_gpr },
312 	{ setup_fpr, check_fpr },
313 	{ setup_z, check_z },
314 	{ setup_p, check_p },
315 	{ setup_ffr, check_ffr },
316 	{ setup_svcr, check_svcr },
317 	{ setup_za, check_za },
318 };
319 
320 static bool do_test(struct syscall_cfg *cfg, int sve_vl, int sme_vl,
321 		    uint64_t svcr)
322 {
323 	int errors = 0;
324 	int i;
325 
326 	for (i = 0; i < ARRAY_SIZE(regset); i++)
327 		regset[i].setup(cfg, sve_vl, sme_vl, svcr);
328 
329 	do_syscall(sve_vl, sme_vl);
330 
331 	for (i = 0; i < ARRAY_SIZE(regset); i++)
332 		errors += regset[i].check(cfg, sve_vl, sme_vl, svcr);
333 
334 	return errors == 0;
335 }
336 
337 static void test_one_syscall(struct syscall_cfg *cfg)
338 {
339 	int sve_vq, sve_vl;
340 	int sme_vq, sme_vl;
341 
342 	/* FPSIMD only case */
343 	ksft_test_result(do_test(cfg, 0, default_sme_vl, 0),
344 			 "%s FPSIMD\n", cfg->name);
345 
346 	if (!(getauxval(AT_HWCAP) & HWCAP_SVE))
347 		return;
348 
349 	for (sve_vq = SVE_VQ_MAX; sve_vq > 0; --sve_vq) {
350 		sve_vl = prctl(PR_SVE_SET_VL, sve_vq * 16);
351 		if (sve_vl == -1)
352 			ksft_exit_fail_msg("PR_SVE_SET_VL failed: %s (%d)\n",
353 					   strerror(errno), errno);
354 
355 		sve_vl &= PR_SVE_VL_LEN_MASK;
356 
357 		if (sve_vq != sve_vq_from_vl(sve_vl))
358 			sve_vq = sve_vq_from_vl(sve_vl);
359 
360 		ksft_test_result(do_test(cfg, sve_vl, default_sme_vl, 0),
361 				 "%s SVE VL %d\n", cfg->name, sve_vl);
362 
363 		if (!(getauxval(AT_HWCAP2) & HWCAP2_SME))
364 			continue;
365 
366 		for (sme_vq = SVE_VQ_MAX; sme_vq > 0; --sme_vq) {
367 			sme_vl = prctl(PR_SME_SET_VL, sme_vq * 16);
368 			if (sme_vl == -1)
369 				ksft_exit_fail_msg("PR_SME_SET_VL failed: %s (%d)\n",
370 						   strerror(errno), errno);
371 
372 			sme_vl &= PR_SME_VL_LEN_MASK;
373 
374 			if (sme_vq != sve_vq_from_vl(sme_vl))
375 				sme_vq = sve_vq_from_vl(sme_vl);
376 
377 			ksft_test_result(do_test(cfg, sve_vl, sme_vl,
378 						 SVCR_ZA_MASK | SVCR_SM_MASK),
379 					 "%s SVE VL %d/SME VL %d SM+ZA\n",
380 					 cfg->name, sve_vl, sme_vl);
381 			ksft_test_result(do_test(cfg, sve_vl, sme_vl,
382 						 SVCR_SM_MASK),
383 					 "%s SVE VL %d/SME VL %d SM\n",
384 					 cfg->name, sve_vl, sme_vl);
385 			ksft_test_result(do_test(cfg, sve_vl, sme_vl,
386 						 SVCR_ZA_MASK),
387 					 "%s SVE VL %d/SME VL %d ZA\n",
388 					 cfg->name, sve_vl, sme_vl);
389 		}
390 	}
391 }
392 
393 int sve_count_vls(void)
394 {
395 	unsigned int vq;
396 	int vl_count = 0;
397 	int vl;
398 
399 	if (!(getauxval(AT_HWCAP) & HWCAP_SVE))
400 		return 0;
401 
402 	/*
403 	 * Enumerate up to SVE_VQ_MAX vector lengths
404 	 */
405 	for (vq = SVE_VQ_MAX; vq > 0; --vq) {
406 		vl = prctl(PR_SVE_SET_VL, vq * 16);
407 		if (vl == -1)
408 			ksft_exit_fail_msg("PR_SVE_SET_VL failed: %s (%d)\n",
409 					   strerror(errno), errno);
410 
411 		vl &= PR_SVE_VL_LEN_MASK;
412 
413 		if (vq != sve_vq_from_vl(vl))
414 			vq = sve_vq_from_vl(vl);
415 
416 		vl_count++;
417 	}
418 
419 	return vl_count;
420 }
421 
422 int sme_count_vls(void)
423 {
424 	unsigned int vq;
425 	int vl_count = 0;
426 	int vl;
427 
428 	if (!(getauxval(AT_HWCAP2) & HWCAP2_SME))
429 		return 0;
430 
431 	/* Ensure we configure a SME VL, used to flag if SVCR is set */
432 	default_sme_vl = 16;
433 
434 	/*
435 	 * Enumerate up to SVE_VQ_MAX vector lengths
436 	 */
437 	for (vq = SVE_VQ_MAX; vq > 0; --vq) {
438 		vl = prctl(PR_SME_SET_VL, vq * 16);
439 		if (vl == -1)
440 			ksft_exit_fail_msg("PR_SME_SET_VL failed: %s (%d)\n",
441 					   strerror(errno), errno);
442 
443 		vl &= PR_SME_VL_LEN_MASK;
444 
445 		if (vq != sve_vq_from_vl(vl))
446 			vq = sve_vq_from_vl(vl);
447 
448 		vl_count++;
449 	}
450 
451 	return vl_count;
452 }
453 
454 int main(void)
455 {
456 	int i;
457 	int tests = 1;  /* FPSIMD */
458 
459 	srandom(getpid());
460 
461 	ksft_print_header();
462 	tests += sve_count_vls();
463 	tests += (sve_count_vls() * sme_count_vls()) * 3;
464 	ksft_set_plan(ARRAY_SIZE(syscalls) * tests);
465 
466 	if (getauxval(AT_HWCAP2) & HWCAP2_SME_FA64)
467 		ksft_print_msg("SME with FA64\n");
468 	else if (getauxval(AT_HWCAP2) & HWCAP2_SME)
469 		ksft_print_msg("SME without FA64\n");
470 
471 	for (i = 0; i < ARRAY_SIZE(syscalls); i++)
472 		test_one_syscall(&syscalls[i]);
473 
474 	ksft_print_cnts();
475 
476 	return 0;
477 }
478