1 // SPDX-License-Identifier: GPL-2.0-only
2 /*
3  * Copyright (C) 2021 ARM Limited.
4  */
5 
6 #include <errno.h>
7 #include <stdbool.h>
8 #include <stddef.h>
9 #include <stdio.h>
10 #include <stdlib.h>
11 #include <string.h>
12 #include <unistd.h>
13 #include <sys/auxv.h>
14 #include <sys/prctl.h>
15 #include <asm/hwcap.h>
16 #include <asm/sigcontext.h>
17 #include <asm/unistd.h>
18 
19 #include "../../kselftest.h"
20 
21 #include "syscall-abi.h"
22 
23 /*
24  * The kernel defines a much larger SVE_VQ_MAX than is expressable in
25  * the architecture, this creates a *lot* of overhead filling the
26  * buffers (especially ZA) on emulated platforms so use the actual
27  * architectural maximum instead.
28  */
29 #define ARCH_SVE_VQ_MAX 16
30 
31 static int default_sme_vl;
32 
33 static int sve_vl_count;
34 static unsigned int sve_vls[ARCH_SVE_VQ_MAX];
35 static int sme_vl_count;
36 static unsigned int sme_vls[ARCH_SVE_VQ_MAX];
37 
38 extern void do_syscall(int sve_vl, int sme_vl);
39 
fill_random(void * buf,size_t size)40 static void fill_random(void *buf, size_t size)
41 {
42 	int i;
43 	uint32_t *lbuf = buf;
44 
45 	/* random() returns a 32 bit number regardless of the size of long */
46 	for (i = 0; i < size / sizeof(uint32_t); i++)
47 		lbuf[i] = random();
48 }
49 
50 /*
51  * We also repeat the test for several syscalls to try to expose different
52  * behaviour.
53  */
54 static struct syscall_cfg {
55 	int syscall_nr;
56 	const char *name;
57 } syscalls[] = {
58 	{ __NR_getpid,		"getpid()" },
59 	{ __NR_sched_yield,	"sched_yield()" },
60 };
61 
62 #define NUM_GPR 31
63 uint64_t gpr_in[NUM_GPR];
64 uint64_t gpr_out[NUM_GPR];
65 
setup_gpr(struct syscall_cfg * cfg,int sve_vl,int sme_vl,uint64_t svcr)66 static void setup_gpr(struct syscall_cfg *cfg, int sve_vl, int sme_vl,
67 		      uint64_t svcr)
68 {
69 	fill_random(gpr_in, sizeof(gpr_in));
70 	gpr_in[8] = cfg->syscall_nr;
71 	memset(gpr_out, 0, sizeof(gpr_out));
72 }
73 
check_gpr(struct syscall_cfg * cfg,int sve_vl,int sme_vl,uint64_t svcr)74 static int check_gpr(struct syscall_cfg *cfg, int sve_vl, int sme_vl, uint64_t svcr)
75 {
76 	int errors = 0;
77 	int i;
78 
79 	/*
80 	 * GPR x0-x7 may be clobbered, and all others should be preserved.
81 	 */
82 	for (i = 9; i < ARRAY_SIZE(gpr_in); i++) {
83 		if (gpr_in[i] != gpr_out[i]) {
84 			ksft_print_msg("%s SVE VL %d mismatch in GPR %d: %llx != %llx\n",
85 				       cfg->name, sve_vl, i,
86 				       gpr_in[i], gpr_out[i]);
87 			errors++;
88 		}
89 	}
90 
91 	return errors;
92 }
93 
94 #define NUM_FPR 32
95 uint64_t fpr_in[NUM_FPR * 2];
96 uint64_t fpr_out[NUM_FPR * 2];
97 uint64_t fpr_zero[NUM_FPR * 2];
98 
setup_fpr(struct syscall_cfg * cfg,int sve_vl,int sme_vl,uint64_t svcr)99 static void setup_fpr(struct syscall_cfg *cfg, int sve_vl, int sme_vl,
100 		      uint64_t svcr)
101 {
102 	fill_random(fpr_in, sizeof(fpr_in));
103 	memset(fpr_out, 0, sizeof(fpr_out));
104 }
105 
check_fpr(struct syscall_cfg * cfg,int sve_vl,int sme_vl,uint64_t svcr)106 static int check_fpr(struct syscall_cfg *cfg, int sve_vl, int sme_vl,
107 		     uint64_t svcr)
108 {
109 	int errors = 0;
110 	int i;
111 
112 	if (!sve_vl && !(svcr & SVCR_SM_MASK)) {
113 		for (i = 0; i < ARRAY_SIZE(fpr_in); i++) {
114 			if (fpr_in[i] != fpr_out[i]) {
115 				ksft_print_msg("%s Q%d/%d mismatch %llx != %llx\n",
116 					       cfg->name,
117 					       i / 2, i % 2,
118 					       fpr_in[i], fpr_out[i]);
119 				errors++;
120 			}
121 		}
122 	}
123 
124 	/*
125 	 * In streaming mode the whole register set should be cleared
126 	 * by the transition out of streaming mode.
127 	 */
128 	if (svcr & SVCR_SM_MASK) {
129 		if (memcmp(fpr_zero, fpr_out, sizeof(fpr_out)) != 0) {
130 			ksft_print_msg("%s FPSIMD registers non-zero exiting SM\n",
131 				       cfg->name);
132 			errors++;
133 		}
134 	}
135 
136 	return errors;
137 }
138 
139 #define SVE_Z_SHARED_BYTES (128 / 8)
140 
141 static uint8_t z_zero[__SVE_ZREG_SIZE(ARCH_SVE_VQ_MAX)];
142 uint8_t z_in[SVE_NUM_ZREGS * __SVE_ZREG_SIZE(ARCH_SVE_VQ_MAX)];
143 uint8_t z_out[SVE_NUM_ZREGS * __SVE_ZREG_SIZE(ARCH_SVE_VQ_MAX)];
144 
setup_z(struct syscall_cfg * cfg,int sve_vl,int sme_vl,uint64_t svcr)145 static void setup_z(struct syscall_cfg *cfg, int sve_vl, int sme_vl,
146 		    uint64_t svcr)
147 {
148 	fill_random(z_in, sizeof(z_in));
149 	fill_random(z_out, sizeof(z_out));
150 }
151 
check_z(struct syscall_cfg * cfg,int sve_vl,int sme_vl,uint64_t svcr)152 static int check_z(struct syscall_cfg *cfg, int sve_vl, int sme_vl,
153 		   uint64_t svcr)
154 {
155 	size_t reg_size = sve_vl;
156 	int errors = 0;
157 	int i;
158 
159 	if (!sve_vl)
160 		return 0;
161 
162 	for (i = 0; i < SVE_NUM_ZREGS; i++) {
163 		uint8_t *in = &z_in[reg_size * i];
164 		uint8_t *out = &z_out[reg_size * i];
165 
166 		if (svcr & SVCR_SM_MASK) {
167 			/*
168 			 * In streaming mode the whole register should
169 			 * be cleared by the transition out of
170 			 * streaming mode.
171 			 */
172 			if (memcmp(z_zero, out, reg_size) != 0) {
173 				ksft_print_msg("%s SVE VL %d Z%d non-zero\n",
174 					       cfg->name, sve_vl, i);
175 				errors++;
176 			}
177 		} else {
178 			/*
179 			 * For standard SVE the low 128 bits should be
180 			 * preserved and any additional bits cleared.
181 			 */
182 			if (memcmp(in, out, SVE_Z_SHARED_BYTES) != 0) {
183 				ksft_print_msg("%s SVE VL %d Z%d low 128 bits changed\n",
184 					       cfg->name, sve_vl, i);
185 				errors++;
186 			}
187 
188 			if (reg_size > SVE_Z_SHARED_BYTES &&
189 			    (memcmp(z_zero, out + SVE_Z_SHARED_BYTES,
190 				    reg_size - SVE_Z_SHARED_BYTES) != 0)) {
191 				ksft_print_msg("%s SVE VL %d Z%d high bits non-zero\n",
192 					       cfg->name, sve_vl, i);
193 				errors++;
194 			}
195 		}
196 	}
197 
198 	return errors;
199 }
200 
201 uint8_t p_in[SVE_NUM_PREGS * __SVE_PREG_SIZE(ARCH_SVE_VQ_MAX)];
202 uint8_t p_out[SVE_NUM_PREGS * __SVE_PREG_SIZE(ARCH_SVE_VQ_MAX)];
203 
setup_p(struct syscall_cfg * cfg,int sve_vl,int sme_vl,uint64_t svcr)204 static void setup_p(struct syscall_cfg *cfg, int sve_vl, int sme_vl,
205 		    uint64_t svcr)
206 {
207 	fill_random(p_in, sizeof(p_in));
208 	fill_random(p_out, sizeof(p_out));
209 }
210 
check_p(struct syscall_cfg * cfg,int sve_vl,int sme_vl,uint64_t svcr)211 static int check_p(struct syscall_cfg *cfg, int sve_vl, int sme_vl,
212 		   uint64_t svcr)
213 {
214 	size_t reg_size = sve_vq_from_vl(sve_vl) * 2; /* 1 bit per VL byte */
215 
216 	int errors = 0;
217 	int i;
218 
219 	if (!sve_vl)
220 		return 0;
221 
222 	/* After a syscall the P registers should be zeroed */
223 	for (i = 0; i < SVE_NUM_PREGS * reg_size; i++)
224 		if (p_out[i])
225 			errors++;
226 	if (errors)
227 		ksft_print_msg("%s SVE VL %d predicate registers non-zero\n",
228 			       cfg->name, sve_vl);
229 
230 	return errors;
231 }
232 
233 uint8_t ffr_in[__SVE_PREG_SIZE(ARCH_SVE_VQ_MAX)];
234 uint8_t ffr_out[__SVE_PREG_SIZE(ARCH_SVE_VQ_MAX)];
235 
setup_ffr(struct syscall_cfg * cfg,int sve_vl,int sme_vl,uint64_t svcr)236 static void setup_ffr(struct syscall_cfg *cfg, int sve_vl, int sme_vl,
237 		      uint64_t svcr)
238 {
239 	/*
240 	 * If we are in streaming mode and do not have FA64 then FFR
241 	 * is unavailable.
242 	 */
243 	if ((svcr & SVCR_SM_MASK) &&
244 	    !(getauxval(AT_HWCAP2) & HWCAP2_SME_FA64)) {
245 		memset(&ffr_in, 0, sizeof(ffr_in));
246 		return;
247 	}
248 
249 	/*
250 	 * It is only valid to set a contiguous set of bits starting
251 	 * at 0.  For now since we're expecting this to be cleared by
252 	 * a syscall just set all bits.
253 	 */
254 	memset(ffr_in, 0xff, sizeof(ffr_in));
255 	fill_random(ffr_out, sizeof(ffr_out));
256 }
257 
check_ffr(struct syscall_cfg * cfg,int sve_vl,int sme_vl,uint64_t svcr)258 static int check_ffr(struct syscall_cfg *cfg, int sve_vl, int sme_vl,
259 		     uint64_t svcr)
260 {
261 	size_t reg_size = sve_vq_from_vl(sve_vl) * 2;  /* 1 bit per VL byte */
262 	int errors = 0;
263 	int i;
264 
265 	if (!sve_vl)
266 		return 0;
267 
268 	if ((svcr & SVCR_SM_MASK) &&
269 	    !(getauxval(AT_HWCAP2) & HWCAP2_SME_FA64))
270 		return 0;
271 
272 	/* After a syscall FFR should be zeroed */
273 	for (i = 0; i < reg_size; i++)
274 		if (ffr_out[i])
275 			errors++;
276 	if (errors)
277 		ksft_print_msg("%s SVE VL %d FFR non-zero\n",
278 			       cfg->name, sve_vl);
279 
280 	return errors;
281 }
282 
283 uint64_t svcr_in, svcr_out;
284 
setup_svcr(struct syscall_cfg * cfg,int sve_vl,int sme_vl,uint64_t svcr)285 static void setup_svcr(struct syscall_cfg *cfg, int sve_vl, int sme_vl,
286 		    uint64_t svcr)
287 {
288 	svcr_in = svcr;
289 }
290 
check_svcr(struct syscall_cfg * cfg,int sve_vl,int sme_vl,uint64_t svcr)291 static int check_svcr(struct syscall_cfg *cfg, int sve_vl, int sme_vl,
292 		      uint64_t svcr)
293 {
294 	int errors = 0;
295 
296 	if (svcr_out & SVCR_SM_MASK) {
297 		ksft_print_msg("%s Still in SM, SVCR %llx\n",
298 			       cfg->name, svcr_out);
299 		errors++;
300 	}
301 
302 	if ((svcr_in & SVCR_ZA_MASK) != (svcr_out & SVCR_ZA_MASK)) {
303 		ksft_print_msg("%s PSTATE.ZA changed, SVCR %llx != %llx\n",
304 			       cfg->name, svcr_in, svcr_out);
305 		errors++;
306 	}
307 
308 	return errors;
309 }
310 
311 uint8_t za_in[ZA_SIG_REGS_SIZE(ARCH_SVE_VQ_MAX)];
312 uint8_t za_out[ZA_SIG_REGS_SIZE(ARCH_SVE_VQ_MAX)];
313 
setup_za(struct syscall_cfg * cfg,int sve_vl,int sme_vl,uint64_t svcr)314 static void setup_za(struct syscall_cfg *cfg, int sve_vl, int sme_vl,
315 		     uint64_t svcr)
316 {
317 	fill_random(za_in, sizeof(za_in));
318 	memset(za_out, 0, sizeof(za_out));
319 }
320 
check_za(struct syscall_cfg * cfg,int sve_vl,int sme_vl,uint64_t svcr)321 static int check_za(struct syscall_cfg *cfg, int sve_vl, int sme_vl,
322 		    uint64_t svcr)
323 {
324 	size_t reg_size = sme_vl * sme_vl;
325 	int errors = 0;
326 
327 	if (!(svcr & SVCR_ZA_MASK))
328 		return 0;
329 
330 	if (memcmp(za_in, za_out, reg_size) != 0) {
331 		ksft_print_msg("SME VL %d ZA does not match\n", sme_vl);
332 		errors++;
333 	}
334 
335 	return errors;
336 }
337 
338 uint8_t zt_in[ZT_SIG_REG_BYTES] __attribute__((aligned(16)));
339 uint8_t zt_out[ZT_SIG_REG_BYTES] __attribute__((aligned(16)));
340 
setup_zt(struct syscall_cfg * cfg,int sve_vl,int sme_vl,uint64_t svcr)341 static void setup_zt(struct syscall_cfg *cfg, int sve_vl, int sme_vl,
342 		     uint64_t svcr)
343 {
344 	fill_random(zt_in, sizeof(zt_in));
345 	memset(zt_out, 0, sizeof(zt_out));
346 }
347 
check_zt(struct syscall_cfg * cfg,int sve_vl,int sme_vl,uint64_t svcr)348 static int check_zt(struct syscall_cfg *cfg, int sve_vl, int sme_vl,
349 		    uint64_t svcr)
350 {
351 	int errors = 0;
352 
353 	if (!(getauxval(AT_HWCAP2) & HWCAP2_SME2))
354 		return 0;
355 
356 	if (!(svcr & SVCR_ZA_MASK))
357 		return 0;
358 
359 	if (memcmp(zt_in, zt_out, sizeof(zt_in)) != 0) {
360 		ksft_print_msg("SME VL %d ZT does not match\n", sme_vl);
361 		errors++;
362 	}
363 
364 	return errors;
365 }
366 
367 typedef void (*setup_fn)(struct syscall_cfg *cfg, int sve_vl, int sme_vl,
368 			 uint64_t svcr);
369 typedef int (*check_fn)(struct syscall_cfg *cfg, int sve_vl, int sme_vl,
370 			uint64_t svcr);
371 
372 /*
373  * Each set of registers has a setup function which is called before
374  * the syscall to fill values in a global variable for loading by the
375  * test code and a check function which validates that the results are
376  * as expected.  Vector lengths are passed everywhere, a vector length
377  * of 0 should be treated as do not test.
378  */
379 static struct {
380 	setup_fn setup;
381 	check_fn check;
382 } regset[] = {
383 	{ setup_gpr, check_gpr },
384 	{ setup_fpr, check_fpr },
385 	{ setup_z, check_z },
386 	{ setup_p, check_p },
387 	{ setup_ffr, check_ffr },
388 	{ setup_svcr, check_svcr },
389 	{ setup_za, check_za },
390 	{ setup_zt, check_zt },
391 };
392 
do_test(struct syscall_cfg * cfg,int sve_vl,int sme_vl,uint64_t svcr)393 static bool do_test(struct syscall_cfg *cfg, int sve_vl, int sme_vl,
394 		    uint64_t svcr)
395 {
396 	int errors = 0;
397 	int i;
398 
399 	for (i = 0; i < ARRAY_SIZE(regset); i++)
400 		regset[i].setup(cfg, sve_vl, sme_vl, svcr);
401 
402 	do_syscall(sve_vl, sme_vl);
403 
404 	for (i = 0; i < ARRAY_SIZE(regset); i++)
405 		errors += regset[i].check(cfg, sve_vl, sme_vl, svcr);
406 
407 	return errors == 0;
408 }
409 
test_one_syscall(struct syscall_cfg * cfg)410 static void test_one_syscall(struct syscall_cfg *cfg)
411 {
412 	int sve, sme;
413 	int ret;
414 
415 	/* FPSIMD only case */
416 	ksft_test_result(do_test(cfg, 0, default_sme_vl, 0),
417 			 "%s FPSIMD\n", cfg->name);
418 
419 	for (sve = 0; sve < sve_vl_count; sve++) {
420 		ret = prctl(PR_SVE_SET_VL, sve_vls[sve]);
421 		if (ret == -1)
422 			ksft_exit_fail_msg("PR_SVE_SET_VL failed: %s (%d)\n",
423 					   strerror(errno), errno);
424 
425 		ksft_test_result(do_test(cfg, sve_vls[sve], default_sme_vl, 0),
426 				 "%s SVE VL %d\n", cfg->name, sve_vls[sve]);
427 
428 		for (sme = 0; sme < sme_vl_count; sme++) {
429 			ret = prctl(PR_SME_SET_VL, sme_vls[sme]);
430 			if (ret == -1)
431 				ksft_exit_fail_msg("PR_SME_SET_VL failed: %s (%d)\n",
432 						   strerror(errno), errno);
433 
434 			ksft_test_result(do_test(cfg, sve_vls[sve],
435 						 sme_vls[sme],
436 						 SVCR_ZA_MASK | SVCR_SM_MASK),
437 					 "%s SVE VL %d/SME VL %d SM+ZA\n",
438 					 cfg->name, sve_vls[sve],
439 					 sme_vls[sme]);
440 			ksft_test_result(do_test(cfg, sve_vls[sve],
441 						 sme_vls[sme], SVCR_SM_MASK),
442 					 "%s SVE VL %d/SME VL %d SM\n",
443 					 cfg->name, sve_vls[sve],
444 					 sme_vls[sme]);
445 			ksft_test_result(do_test(cfg, sve_vls[sve],
446 						 sme_vls[sme], SVCR_ZA_MASK),
447 					 "%s SVE VL %d/SME VL %d ZA\n",
448 					 cfg->name, sve_vls[sve],
449 					 sme_vls[sme]);
450 		}
451 	}
452 
453 	for (sme = 0; sme < sme_vl_count; sme++) {
454 		ret = prctl(PR_SME_SET_VL, sme_vls[sme]);
455 		if (ret == -1)
456 			ksft_exit_fail_msg("PR_SME_SET_VL failed: %s (%d)\n",
457 						   strerror(errno), errno);
458 
459 		ksft_test_result(do_test(cfg, 0, sme_vls[sme],
460 					 SVCR_ZA_MASK | SVCR_SM_MASK),
461 				 "%s SME VL %d SM+ZA\n",
462 				 cfg->name, sme_vls[sme]);
463 		ksft_test_result(do_test(cfg, 0, sme_vls[sme], SVCR_SM_MASK),
464 				 "%s SME VL %d SM\n",
465 				 cfg->name, sme_vls[sme]);
466 		ksft_test_result(do_test(cfg, 0, sme_vls[sme], SVCR_ZA_MASK),
467 				 "%s SME VL %d ZA\n",
468 				 cfg->name, sme_vls[sme]);
469 	}
470 }
471 
sve_count_vls(void)472 void sve_count_vls(void)
473 {
474 	unsigned int vq;
475 	int vl;
476 
477 	if (!(getauxval(AT_HWCAP) & HWCAP_SVE))
478 		return;
479 
480 	/*
481 	 * Enumerate up to ARCH_SVE_VQ_MAX vector lengths
482 	 */
483 	for (vq = ARCH_SVE_VQ_MAX; vq > 0; vq /= 2) {
484 		vl = prctl(PR_SVE_SET_VL, vq * 16);
485 		if (vl == -1)
486 			ksft_exit_fail_msg("PR_SVE_SET_VL failed: %s (%d)\n",
487 					   strerror(errno), errno);
488 
489 		vl &= PR_SVE_VL_LEN_MASK;
490 
491 		if (vq != sve_vq_from_vl(vl))
492 			vq = sve_vq_from_vl(vl);
493 
494 		sve_vls[sve_vl_count++] = vl;
495 	}
496 }
497 
sme_count_vls(void)498 void sme_count_vls(void)
499 {
500 	unsigned int vq;
501 	int vl;
502 
503 	if (!(getauxval(AT_HWCAP2) & HWCAP2_SME))
504 		return;
505 
506 	/*
507 	 * Enumerate up to ARCH_SVE_VQ_MAX vector lengths
508 	 */
509 	for (vq = ARCH_SVE_VQ_MAX; vq > 0; vq /= 2) {
510 		vl = prctl(PR_SME_SET_VL, vq * 16);
511 		if (vl == -1)
512 			ksft_exit_fail_msg("PR_SME_SET_VL failed: %s (%d)\n",
513 					   strerror(errno), errno);
514 
515 		vl &= PR_SME_VL_LEN_MASK;
516 
517 		/* Found lowest VL */
518 		if (sve_vq_from_vl(vl) > vq)
519 			break;
520 
521 		if (vq != sve_vq_from_vl(vl))
522 			vq = sve_vq_from_vl(vl);
523 
524 		sme_vls[sme_vl_count++] = vl;
525 	}
526 
527 	/* Ensure we configure a SME VL, used to flag if SVCR is set */
528 	default_sme_vl = sme_vls[0];
529 }
530 
main(void)531 int main(void)
532 {
533 	int i;
534 	int tests = 1;  /* FPSIMD */
535 	int sme_ver;
536 
537 	srandom(getpid());
538 
539 	ksft_print_header();
540 
541 	sve_count_vls();
542 	sme_count_vls();
543 
544 	tests += sve_vl_count;
545 	tests += sme_vl_count * 3;
546 	tests += (sve_vl_count * sme_vl_count) * 3;
547 	ksft_set_plan(ARRAY_SIZE(syscalls) * tests);
548 
549 	if (getauxval(AT_HWCAP2) & HWCAP2_SME2)
550 		sme_ver = 2;
551 	else
552 		sme_ver = 1;
553 
554 	if (getauxval(AT_HWCAP2) & HWCAP2_SME_FA64)
555 		ksft_print_msg("SME%d with FA64\n", sme_ver);
556 	else if (getauxval(AT_HWCAP2) & HWCAP2_SME)
557 		ksft_print_msg("SME%d without FA64\n", sme_ver);
558 
559 	for (i = 0; i < ARRAY_SIZE(syscalls); i++)
560 		test_one_syscall(&syscalls[i]);
561 
562 	ksft_print_cnts();
563 
564 	return 0;
565 }
566