1 // SPDX-License-Identifier: GPL-2.0-only
2 /*
3  * Copyright (C) 2021 ARM Limited.
4  */
5 
6 #include <errno.h>
7 #include <stdbool.h>
8 #include <stddef.h>
9 #include <stdio.h>
10 #include <stdlib.h>
11 #include <string.h>
12 #include <unistd.h>
13 #include <sys/auxv.h>
14 #include <sys/prctl.h>
15 #include <asm/hwcap.h>
16 #include <asm/sigcontext.h>
17 #include <asm/unistd.h>
18 
19 #include "../../kselftest.h"
20 
21 #include "syscall-abi.h"
22 
23 static int default_sme_vl;
24 
25 static int sve_vl_count;
26 static unsigned int sve_vls[SVE_VQ_MAX];
27 static int sme_vl_count;
28 static unsigned int sme_vls[SVE_VQ_MAX];
29 
30 extern void do_syscall(int sve_vl, int sme_vl);
31 
32 static void fill_random(void *buf, size_t size)
33 {
34 	int i;
35 	uint32_t *lbuf = buf;
36 
37 	/* random() returns a 32 bit number regardless of the size of long */
38 	for (i = 0; i < size / sizeof(uint32_t); i++)
39 		lbuf[i] = random();
40 }
41 
42 /*
43  * We also repeat the test for several syscalls to try to expose different
44  * behaviour.
45  */
46 static struct syscall_cfg {
47 	int syscall_nr;
48 	const char *name;
49 } syscalls[] = {
50 	{ __NR_getpid,		"getpid()" },
51 	{ __NR_sched_yield,	"sched_yield()" },
52 };
53 
54 #define NUM_GPR 31
55 uint64_t gpr_in[NUM_GPR];
56 uint64_t gpr_out[NUM_GPR];
57 
58 static void setup_gpr(struct syscall_cfg *cfg, int sve_vl, int sme_vl,
59 		      uint64_t svcr)
60 {
61 	fill_random(gpr_in, sizeof(gpr_in));
62 	gpr_in[8] = cfg->syscall_nr;
63 	memset(gpr_out, 0, sizeof(gpr_out));
64 }
65 
66 static int check_gpr(struct syscall_cfg *cfg, int sve_vl, int sme_vl, uint64_t svcr)
67 {
68 	int errors = 0;
69 	int i;
70 
71 	/*
72 	 * GPR x0-x7 may be clobbered, and all others should be preserved.
73 	 */
74 	for (i = 9; i < ARRAY_SIZE(gpr_in); i++) {
75 		if (gpr_in[i] != gpr_out[i]) {
76 			ksft_print_msg("%s SVE VL %d mismatch in GPR %d: %llx != %llx\n",
77 				       cfg->name, sve_vl, i,
78 				       gpr_in[i], gpr_out[i]);
79 			errors++;
80 		}
81 	}
82 
83 	return errors;
84 }
85 
86 #define NUM_FPR 32
87 uint64_t fpr_in[NUM_FPR * 2];
88 uint64_t fpr_out[NUM_FPR * 2];
89 uint64_t fpr_zero[NUM_FPR * 2];
90 
91 static void setup_fpr(struct syscall_cfg *cfg, int sve_vl, int sme_vl,
92 		      uint64_t svcr)
93 {
94 	fill_random(fpr_in, sizeof(fpr_in));
95 	memset(fpr_out, 0, sizeof(fpr_out));
96 }
97 
98 static int check_fpr(struct syscall_cfg *cfg, int sve_vl, int sme_vl,
99 		     uint64_t svcr)
100 {
101 	int errors = 0;
102 	int i;
103 
104 	if (!sve_vl && !(svcr & SVCR_SM_MASK)) {
105 		for (i = 0; i < ARRAY_SIZE(fpr_in); i++) {
106 			if (fpr_in[i] != fpr_out[i]) {
107 				ksft_print_msg("%s Q%d/%d mismatch %llx != %llx\n",
108 					       cfg->name,
109 					       i / 2, i % 2,
110 					       fpr_in[i], fpr_out[i]);
111 				errors++;
112 			}
113 		}
114 	}
115 
116 	/*
117 	 * In streaming mode the whole register set should be cleared
118 	 * by the transition out of streaming mode.
119 	 */
120 	if (svcr & SVCR_SM_MASK) {
121 		if (memcmp(fpr_zero, fpr_out, sizeof(fpr_out)) != 0) {
122 			ksft_print_msg("%s FPSIMD registers non-zero exiting SM\n",
123 				       cfg->name);
124 			errors++;
125 		}
126 	}
127 
128 	return errors;
129 }
130 
131 #define SVE_Z_SHARED_BYTES (128 / 8)
132 
133 static uint8_t z_zero[__SVE_ZREG_SIZE(SVE_VQ_MAX)];
134 uint8_t z_in[SVE_NUM_ZREGS * __SVE_ZREG_SIZE(SVE_VQ_MAX)];
135 uint8_t z_out[SVE_NUM_ZREGS * __SVE_ZREG_SIZE(SVE_VQ_MAX)];
136 
137 static void setup_z(struct syscall_cfg *cfg, int sve_vl, int sme_vl,
138 		    uint64_t svcr)
139 {
140 	fill_random(z_in, sizeof(z_in));
141 	fill_random(z_out, sizeof(z_out));
142 }
143 
144 static int check_z(struct syscall_cfg *cfg, int sve_vl, int sme_vl,
145 		   uint64_t svcr)
146 {
147 	size_t reg_size = sve_vl;
148 	int errors = 0;
149 	int i;
150 
151 	if (!sve_vl)
152 		return 0;
153 
154 	for (i = 0; i < SVE_NUM_ZREGS; i++) {
155 		uint8_t *in = &z_in[reg_size * i];
156 		uint8_t *out = &z_out[reg_size * i];
157 
158 		if (svcr & SVCR_SM_MASK) {
159 			/*
160 			 * In streaming mode the whole register should
161 			 * be cleared by the transition out of
162 			 * streaming mode.
163 			 */
164 			if (memcmp(z_zero, out, reg_size) != 0) {
165 				ksft_print_msg("%s SVE VL %d Z%d non-zero\n",
166 					       cfg->name, sve_vl, i);
167 				errors++;
168 			}
169 		} else {
170 			/*
171 			 * For standard SVE the low 128 bits should be
172 			 * preserved and any additional bits cleared.
173 			 */
174 			if (memcmp(in, out, SVE_Z_SHARED_BYTES) != 0) {
175 				ksft_print_msg("%s SVE VL %d Z%d low 128 bits changed\n",
176 					       cfg->name, sve_vl, i);
177 				errors++;
178 			}
179 
180 			if (reg_size > SVE_Z_SHARED_BYTES &&
181 			    (memcmp(z_zero, out + SVE_Z_SHARED_BYTES,
182 				    reg_size - SVE_Z_SHARED_BYTES) != 0)) {
183 				ksft_print_msg("%s SVE VL %d Z%d high bits non-zero\n",
184 					       cfg->name, sve_vl, i);
185 				errors++;
186 			}
187 		}
188 	}
189 
190 	return errors;
191 }
192 
193 uint8_t p_in[SVE_NUM_PREGS * __SVE_PREG_SIZE(SVE_VQ_MAX)];
194 uint8_t p_out[SVE_NUM_PREGS * __SVE_PREG_SIZE(SVE_VQ_MAX)];
195 
196 static void setup_p(struct syscall_cfg *cfg, int sve_vl, int sme_vl,
197 		    uint64_t svcr)
198 {
199 	fill_random(p_in, sizeof(p_in));
200 	fill_random(p_out, sizeof(p_out));
201 }
202 
203 static int check_p(struct syscall_cfg *cfg, int sve_vl, int sme_vl,
204 		   uint64_t svcr)
205 {
206 	size_t reg_size = sve_vq_from_vl(sve_vl) * 2; /* 1 bit per VL byte */
207 
208 	int errors = 0;
209 	int i;
210 
211 	if (!sve_vl)
212 		return 0;
213 
214 	/* After a syscall the P registers should be zeroed */
215 	for (i = 0; i < SVE_NUM_PREGS * reg_size; i++)
216 		if (p_out[i])
217 			errors++;
218 	if (errors)
219 		ksft_print_msg("%s SVE VL %d predicate registers non-zero\n",
220 			       cfg->name, sve_vl);
221 
222 	return errors;
223 }
224 
225 uint8_t ffr_in[__SVE_PREG_SIZE(SVE_VQ_MAX)];
226 uint8_t ffr_out[__SVE_PREG_SIZE(SVE_VQ_MAX)];
227 
228 static void setup_ffr(struct syscall_cfg *cfg, int sve_vl, int sme_vl,
229 		      uint64_t svcr)
230 {
231 	/*
232 	 * If we are in streaming mode and do not have FA64 then FFR
233 	 * is unavailable.
234 	 */
235 	if ((svcr & SVCR_SM_MASK) &&
236 	    !(getauxval(AT_HWCAP2) & HWCAP2_SME_FA64)) {
237 		memset(&ffr_in, 0, sizeof(ffr_in));
238 		return;
239 	}
240 
241 	/*
242 	 * It is only valid to set a contiguous set of bits starting
243 	 * at 0.  For now since we're expecting this to be cleared by
244 	 * a syscall just set all bits.
245 	 */
246 	memset(ffr_in, 0xff, sizeof(ffr_in));
247 	fill_random(ffr_out, sizeof(ffr_out));
248 }
249 
250 static int check_ffr(struct syscall_cfg *cfg, int sve_vl, int sme_vl,
251 		     uint64_t svcr)
252 {
253 	size_t reg_size = sve_vq_from_vl(sve_vl) * 2;  /* 1 bit per VL byte */
254 	int errors = 0;
255 	int i;
256 
257 	if (!sve_vl)
258 		return 0;
259 
260 	if ((svcr & SVCR_SM_MASK) &&
261 	    !(getauxval(AT_HWCAP2) & HWCAP2_SME_FA64))
262 		return 0;
263 
264 	/* After a syscall FFR should be zeroed */
265 	for (i = 0; i < reg_size; i++)
266 		if (ffr_out[i])
267 			errors++;
268 	if (errors)
269 		ksft_print_msg("%s SVE VL %d FFR non-zero\n",
270 			       cfg->name, sve_vl);
271 
272 	return errors;
273 }
274 
275 uint64_t svcr_in, svcr_out;
276 
277 static void setup_svcr(struct syscall_cfg *cfg, int sve_vl, int sme_vl,
278 		    uint64_t svcr)
279 {
280 	svcr_in = svcr;
281 }
282 
283 static int check_svcr(struct syscall_cfg *cfg, int sve_vl, int sme_vl,
284 		      uint64_t svcr)
285 {
286 	int errors = 0;
287 
288 	if (svcr_out & SVCR_SM_MASK) {
289 		ksft_print_msg("%s Still in SM, SVCR %llx\n",
290 			       cfg->name, svcr_out);
291 		errors++;
292 	}
293 
294 	if ((svcr_in & SVCR_ZA_MASK) != (svcr_out & SVCR_ZA_MASK)) {
295 		ksft_print_msg("%s PSTATE.ZA changed, SVCR %llx != %llx\n",
296 			       cfg->name, svcr_in, svcr_out);
297 		errors++;
298 	}
299 
300 	return errors;
301 }
302 
303 uint8_t za_in[ZA_SIG_REGS_SIZE(SVE_VQ_MAX)];
304 uint8_t za_out[ZA_SIG_REGS_SIZE(SVE_VQ_MAX)];
305 
306 static void setup_za(struct syscall_cfg *cfg, int sve_vl, int sme_vl,
307 		     uint64_t svcr)
308 {
309 	fill_random(za_in, sizeof(za_in));
310 	memset(za_out, 0, sizeof(za_out));
311 }
312 
313 static int check_za(struct syscall_cfg *cfg, int sve_vl, int sme_vl,
314 		    uint64_t svcr)
315 {
316 	size_t reg_size = sme_vl * sme_vl;
317 	int errors = 0;
318 
319 	if (!(svcr & SVCR_ZA_MASK))
320 		return 0;
321 
322 	if (memcmp(za_in, za_out, reg_size) != 0) {
323 		ksft_print_msg("SME VL %d ZA does not match\n", sme_vl);
324 		errors++;
325 	}
326 
327 	return errors;
328 }
329 
330 uint8_t zt_in[ZT_SIG_REG_BYTES] __attribute__((aligned(16)));
331 uint8_t zt_out[ZT_SIG_REG_BYTES] __attribute__((aligned(16)));
332 
333 static void setup_zt(struct syscall_cfg *cfg, int sve_vl, int sme_vl,
334 		     uint64_t svcr)
335 {
336 	fill_random(zt_in, sizeof(zt_in));
337 	memset(zt_out, 0, sizeof(zt_out));
338 }
339 
340 static int check_zt(struct syscall_cfg *cfg, int sve_vl, int sme_vl,
341 		    uint64_t svcr)
342 {
343 	int errors = 0;
344 
345 	if (!(getauxval(AT_HWCAP2) & HWCAP2_SME2))
346 		return 0;
347 
348 	if (!(svcr & SVCR_ZA_MASK))
349 		return 0;
350 
351 	if (memcmp(zt_in, zt_out, sizeof(zt_in)) != 0) {
352 		ksft_print_msg("SME VL %d ZT does not match\n", sme_vl);
353 		errors++;
354 	}
355 
356 	return errors;
357 }
358 
359 typedef void (*setup_fn)(struct syscall_cfg *cfg, int sve_vl, int sme_vl,
360 			 uint64_t svcr);
361 typedef int (*check_fn)(struct syscall_cfg *cfg, int sve_vl, int sme_vl,
362 			uint64_t svcr);
363 
364 /*
365  * Each set of registers has a setup function which is called before
366  * the syscall to fill values in a global variable for loading by the
367  * test code and a check function which validates that the results are
368  * as expected.  Vector lengths are passed everywhere, a vector length
369  * of 0 should be treated as do not test.
370  */
371 static struct {
372 	setup_fn setup;
373 	check_fn check;
374 } regset[] = {
375 	{ setup_gpr, check_gpr },
376 	{ setup_fpr, check_fpr },
377 	{ setup_z, check_z },
378 	{ setup_p, check_p },
379 	{ setup_ffr, check_ffr },
380 	{ setup_svcr, check_svcr },
381 	{ setup_za, check_za },
382 	{ setup_zt, check_zt },
383 };
384 
385 static bool do_test(struct syscall_cfg *cfg, int sve_vl, int sme_vl,
386 		    uint64_t svcr)
387 {
388 	int errors = 0;
389 	int i;
390 
391 	for (i = 0; i < ARRAY_SIZE(regset); i++)
392 		regset[i].setup(cfg, sve_vl, sme_vl, svcr);
393 
394 	do_syscall(sve_vl, sme_vl);
395 
396 	for (i = 0; i < ARRAY_SIZE(regset); i++)
397 		errors += regset[i].check(cfg, sve_vl, sme_vl, svcr);
398 
399 	return errors == 0;
400 }
401 
402 static void test_one_syscall(struct syscall_cfg *cfg)
403 {
404 	int sve, sme;
405 	int ret;
406 
407 	/* FPSIMD only case */
408 	ksft_test_result(do_test(cfg, 0, default_sme_vl, 0),
409 			 "%s FPSIMD\n", cfg->name);
410 
411 	for (sve = 0; sve < sve_vl_count; sve++) {
412 		ret = prctl(PR_SVE_SET_VL, sve_vls[sve]);
413 		if (ret == -1)
414 			ksft_exit_fail_msg("PR_SVE_SET_VL failed: %s (%d)\n",
415 					   strerror(errno), errno);
416 
417 		ksft_test_result(do_test(cfg, sve_vls[sve], default_sme_vl, 0),
418 				 "%s SVE VL %d\n", cfg->name, sve_vls[sve]);
419 
420 		for (sme = 0; sme < sme_vl_count; sme++) {
421 			ret = prctl(PR_SME_SET_VL, sme_vls[sme]);
422 			if (ret == -1)
423 				ksft_exit_fail_msg("PR_SME_SET_VL failed: %s (%d)\n",
424 						   strerror(errno), errno);
425 
426 			ksft_test_result(do_test(cfg, sve_vls[sve],
427 						 sme_vls[sme],
428 						 SVCR_ZA_MASK | SVCR_SM_MASK),
429 					 "%s SVE VL %d/SME VL %d SM+ZA\n",
430 					 cfg->name, sve_vls[sve],
431 					 sme_vls[sme]);
432 			ksft_test_result(do_test(cfg, sve_vls[sve],
433 						 sme_vls[sme], SVCR_SM_MASK),
434 					 "%s SVE VL %d/SME VL %d SM\n",
435 					 cfg->name, sve_vls[sve],
436 					 sme_vls[sme]);
437 			ksft_test_result(do_test(cfg, sve_vls[sve],
438 						 sme_vls[sme], SVCR_ZA_MASK),
439 					 "%s SVE VL %d/SME VL %d ZA\n",
440 					 cfg->name, sve_vls[sve],
441 					 sme_vls[sme]);
442 		}
443 	}
444 
445 	for (sme = 0; sme < sme_vl_count; sme++) {
446 		ret = prctl(PR_SME_SET_VL, sme_vls[sme]);
447 		if (ret == -1)
448 			ksft_exit_fail_msg("PR_SME_SET_VL failed: %s (%d)\n",
449 						   strerror(errno), errno);
450 
451 		ksft_test_result(do_test(cfg, 0, sme_vls[sme],
452 					 SVCR_ZA_MASK | SVCR_SM_MASK),
453 				 "%s SME VL %d SM+ZA\n",
454 				 cfg->name, sme_vls[sme]);
455 		ksft_test_result(do_test(cfg, 0, sme_vls[sme], SVCR_SM_MASK),
456 				 "%s SME VL %d SM\n",
457 				 cfg->name, sme_vls[sme]);
458 		ksft_test_result(do_test(cfg, 0, sme_vls[sme], SVCR_ZA_MASK),
459 				 "%s SME VL %d ZA\n",
460 				 cfg->name, sme_vls[sme]);
461 	}
462 }
463 
464 void sve_count_vls(void)
465 {
466 	unsigned int vq;
467 	int vl;
468 
469 	if (!(getauxval(AT_HWCAP) & HWCAP_SVE))
470 		return;
471 
472 	/*
473 	 * Enumerate up to SVE_VQ_MAX vector lengths
474 	 */
475 	for (vq = SVE_VQ_MAX; vq > 0; vq /= 2) {
476 		vl = prctl(PR_SVE_SET_VL, vq * 16);
477 		if (vl == -1)
478 			ksft_exit_fail_msg("PR_SVE_SET_VL failed: %s (%d)\n",
479 					   strerror(errno), errno);
480 
481 		vl &= PR_SVE_VL_LEN_MASK;
482 
483 		if (vq != sve_vq_from_vl(vl))
484 			vq = sve_vq_from_vl(vl);
485 
486 		sve_vls[sve_vl_count++] = vl;
487 	}
488 }
489 
490 void sme_count_vls(void)
491 {
492 	unsigned int vq;
493 	int vl;
494 
495 	if (!(getauxval(AT_HWCAP2) & HWCAP2_SME))
496 		return;
497 
498 	/*
499 	 * Enumerate up to SVE_VQ_MAX vector lengths
500 	 */
501 	for (vq = SVE_VQ_MAX; vq > 0; vq /= 2) {
502 		vl = prctl(PR_SME_SET_VL, vq * 16);
503 		if (vl == -1)
504 			ksft_exit_fail_msg("PR_SME_SET_VL failed: %s (%d)\n",
505 					   strerror(errno), errno);
506 
507 		vl &= PR_SME_VL_LEN_MASK;
508 
509 		/* Found lowest VL */
510 		if (sve_vq_from_vl(vl) > vq)
511 			break;
512 
513 		if (vq != sve_vq_from_vl(vl))
514 			vq = sve_vq_from_vl(vl);
515 
516 		sme_vls[sme_vl_count++] = vl;
517 	}
518 
519 	/* Ensure we configure a SME VL, used to flag if SVCR is set */
520 	default_sme_vl = sme_vls[0];
521 }
522 
523 int main(void)
524 {
525 	int i;
526 	int tests = 1;  /* FPSIMD */
527 	int sme_ver;
528 
529 	srandom(getpid());
530 
531 	ksft_print_header();
532 
533 	sve_count_vls();
534 	sme_count_vls();
535 
536 	tests += sve_vl_count;
537 	tests += sme_vl_count * 3;
538 	tests += (sve_vl_count * sme_vl_count) * 3;
539 	ksft_set_plan(ARRAY_SIZE(syscalls) * tests);
540 
541 	if (getauxval(AT_HWCAP2) & HWCAP2_SME2)
542 		sme_ver = 2;
543 	else
544 		sme_ver = 1;
545 
546 	if (getauxval(AT_HWCAP2) & HWCAP2_SME_FA64)
547 		ksft_print_msg("SME%d with FA64\n", sme_ver);
548 	else if (getauxval(AT_HWCAP2) & HWCAP2_SME)
549 		ksft_print_msg("SME%d without FA64\n", sme_ver);
550 
551 	for (i = 0; i < ARRAY_SIZE(syscalls); i++)
552 		test_one_syscall(&syscalls[i]);
553 
554 	ksft_print_cnts();
555 
556 	return 0;
557 }
558