1 // SPDX-License-Identifier: GPL-2.0-only
2 /*
3  * Copyright (C) 2021 ARM Limited.
4  */
5 #include <errno.h>
6 #include <stdbool.h>
7 #include <stddef.h>
8 #include <stdio.h>
9 #include <stdlib.h>
10 #include <string.h>
11 #include <unistd.h>
12 #include <sys/auxv.h>
13 #include <sys/prctl.h>
14 #include <sys/ptrace.h>
15 #include <sys/types.h>
16 #include <sys/uio.h>
17 #include <sys/wait.h>
18 #include <asm/sigcontext.h>
19 #include <asm/ptrace.h>
20 
21 #include "../../kselftest.h"
22 
23 /* <linux/elf.h> and <sys/auxv.h> don't like each other, so: */
24 #ifndef NT_ARM_ZA
25 #define NT_ARM_ZA 0x40c
26 #endif
27 #ifndef NT_ARM_ZT
28 #define NT_ARM_ZT 0x40d
29 #endif
30 
31 #define EXPECTED_TESTS 3
32 
33 static int sme_vl;
34 
35 static void fill_buf(char *buf, size_t size)
36 {
37 	int i;
38 
39 	for (i = 0; i < size; i++)
40 		buf[i] = random();
41 }
42 
43 static int do_child(void)
44 {
45 	if (ptrace(PTRACE_TRACEME, -1, NULL, NULL))
46 		ksft_exit_fail_msg("PTRACE_TRACEME", strerror(errno));
47 
48 	if (raise(SIGSTOP))
49 		ksft_exit_fail_msg("raise(SIGSTOP)", strerror(errno));
50 
51 	return EXIT_SUCCESS;
52 }
53 
54 static struct user_za_header *get_za(pid_t pid, void **buf, size_t *size)
55 {
56 	struct user_za_header *za;
57 	void *p;
58 	size_t sz = sizeof(*za);
59 	struct iovec iov;
60 
61 	while (1) {
62 		if (*size < sz) {
63 			p = realloc(*buf, sz);
64 			if (!p) {
65 				errno = ENOMEM;
66 				goto error;
67 			}
68 
69 			*buf = p;
70 			*size = sz;
71 		}
72 
73 		iov.iov_base = *buf;
74 		iov.iov_len = sz;
75 		if (ptrace(PTRACE_GETREGSET, pid, NT_ARM_ZA, &iov))
76 			goto error;
77 
78 		za = *buf;
79 		if (za->size <= sz)
80 			break;
81 
82 		sz = za->size;
83 	}
84 
85 	return za;
86 
87 error:
88 	return NULL;
89 }
90 
91 static int set_za(pid_t pid, const struct user_za_header *za)
92 {
93 	struct iovec iov;
94 
95 	iov.iov_base = (void *)za;
96 	iov.iov_len = za->size;
97 	return ptrace(PTRACE_SETREGSET, pid, NT_ARM_ZA, &iov);
98 }
99 
100 static int get_zt(pid_t pid, char zt[ZT_SIG_REG_BYTES])
101 {
102 	struct iovec iov;
103 
104 	iov.iov_base = zt;
105 	iov.iov_len = ZT_SIG_REG_BYTES;
106 	return ptrace(PTRACE_GETREGSET, pid, NT_ARM_ZT, &iov);
107 }
108 
109 
110 static int set_zt(pid_t pid, const char zt[ZT_SIG_REG_BYTES])
111 {
112 	struct iovec iov;
113 
114 	iov.iov_base = (void *)zt;
115 	iov.iov_len = ZT_SIG_REG_BYTES;
116 	return ptrace(PTRACE_SETREGSET, pid, NT_ARM_ZT, &iov);
117 }
118 
119 /* Reading with ZA disabled returns all zeros */
120 static void ptrace_za_disabled_read_zt(pid_t child)
121 {
122 	struct user_za_header za;
123 	char zt[ZT_SIG_REG_BYTES];
124 	int ret, i;
125 	bool fail = false;
126 
127 	/* Disable PSTATE.ZA using the ZA interface */
128 	memset(&za, 0, sizeof(za));
129 	za.vl = sme_vl;
130 	za.size = sizeof(za);
131 
132 	ret = set_za(child, &za);
133 	if (ret != 0) {
134 		ksft_print_msg("Failed to disable ZA\n");
135 		fail = true;
136 	}
137 
138 	/* Read back ZT */
139 	ret = get_zt(child, zt);
140 	if (ret != 0) {
141 		ksft_print_msg("Failed to read ZT\n");
142 		fail = true;
143 	}
144 
145 	for (i = 0; i < ARRAY_SIZE(zt); i++) {
146 		if (zt[i]) {
147 			ksft_print_msg("zt[%d]: 0x%x != 0\n", i, zt[i]);
148 			fail = true;
149 		}
150 	}
151 
152 	ksft_test_result(!fail, "ptrace_za_disabled_read_zt\n");
153 }
154 
155 /* Writing then reading ZT should return the data written */
156 static void ptrace_set_get_zt(pid_t child)
157 {
158 	char zt_in[ZT_SIG_REG_BYTES];
159 	char zt_out[ZT_SIG_REG_BYTES];
160 	int ret, i;
161 	bool fail = false;
162 
163 	fill_buf(zt_in, sizeof(zt_in));
164 
165 	ret = set_zt(child, zt_in);
166 	if (ret != 0) {
167 		ksft_print_msg("Failed to set ZT\n");
168 		fail = true;
169 	}
170 
171 	ret = get_zt(child, zt_out);
172 	if (ret != 0) {
173 		ksft_print_msg("Failed to read ZT\n");
174 		fail = true;
175 	}
176 
177 	for (i = 0; i < ARRAY_SIZE(zt_in); i++) {
178 		if (zt_in[i] != zt_out[i]) {
179 			ksft_print_msg("zt[%d]: 0x%x != 0x%x\n", i,
180 				       zt_in[i], zt_out[i]);
181 			fail = true;
182 		}
183 	}
184 
185 	ksft_test_result(!fail, "ptrace_set_get_zt\n");
186 }
187 
188 /* Writing ZT should set PSTATE.ZA */
189 static void ptrace_enable_za_via_zt(pid_t child)
190 {
191 	struct user_za_header za_in;
192 	struct user_za_header *za_out;
193 	char zt[ZT_SIG_REG_BYTES];
194 	char *za_data;
195 	size_t za_out_size;
196 	int ret, i, vq;
197 	bool fail = false;
198 
199 	/* Disable PSTATE.ZA using the ZA interface */
200 	memset(&za_in, 0, sizeof(za_in));
201 	za_in.vl = sme_vl;
202 	za_in.size = sizeof(za_in);
203 
204 	ret = set_za(child, &za_in);
205 	if (ret != 0) {
206 		ksft_print_msg("Failed to disable ZA\n");
207 		fail = true;
208 	}
209 
210 	/* Write ZT */
211 	fill_buf(zt, sizeof(zt));
212 	ret = set_zt(child, zt);
213 	if (ret != 0) {
214 		ksft_print_msg("Failed to set ZT\n");
215 		fail = true;
216 	}
217 
218 	/* Read back ZA and check for register data */
219 	za_out = NULL;
220 	za_out_size = 0;
221 	if (get_za(child, (void **)&za_out, &za_out_size)) {
222 		/* Should have an unchanged VL */
223 		if (za_out->vl != sme_vl) {
224 			ksft_print_msg("VL changed from %d to %d\n",
225 				       sme_vl, za_out->vl);
226 			fail = true;
227 		}
228 		vq = __sve_vq_from_vl(za_out->vl);
229 		za_data = (char *)za_out + ZA_PT_ZA_OFFSET;
230 
231 		/* Should have register data */
232 		if (za_out->size < ZA_PT_SIZE(vq)) {
233 			ksft_print_msg("ZA data less than expected: %u < %u\n",
234 				       za_out->size, ZA_PT_SIZE(vq));
235 			fail = true;
236 			vq = 0;
237 		}
238 
239 		/* That register data should be non-zero */
240 		for (i = 0; i < ZA_PT_ZA_SIZE(vq); i++) {
241 			if (za_data[i]) {
242 				ksft_print_msg("ZA byte %d is %x\n",
243 					       i, za_data[i]);
244 				fail = true;
245 			}
246 		}
247 	} else {
248 		ksft_print_msg("Failed to read ZA\n");
249 		fail = true;
250 	}
251 
252 	ksft_test_result(!fail, "ptrace_enable_za_via_zt\n");
253 }
254 
255 static int do_parent(pid_t child)
256 {
257 	int ret = EXIT_FAILURE;
258 	pid_t pid;
259 	int status;
260 	siginfo_t si;
261 
262 	/* Attach to the child */
263 	while (1) {
264 		int sig;
265 
266 		pid = wait(&status);
267 		if (pid == -1) {
268 			perror("wait");
269 			goto error;
270 		}
271 
272 		/*
273 		 * This should never happen but it's hard to flag in
274 		 * the framework.
275 		 */
276 		if (pid != child)
277 			continue;
278 
279 		if (WIFEXITED(status) || WIFSIGNALED(status))
280 			ksft_exit_fail_msg("Child died unexpectedly\n");
281 
282 		if (!WIFSTOPPED(status))
283 			goto error;
284 
285 		sig = WSTOPSIG(status);
286 
287 		if (ptrace(PTRACE_GETSIGINFO, pid, NULL, &si)) {
288 			if (errno == ESRCH)
289 				goto disappeared;
290 
291 			if (errno == EINVAL) {
292 				sig = 0; /* bust group-stop */
293 				goto cont;
294 			}
295 
296 			ksft_test_result_fail("PTRACE_GETSIGINFO: %s\n",
297 					      strerror(errno));
298 			goto error;
299 		}
300 
301 		if (sig == SIGSTOP && si.si_code == SI_TKILL &&
302 		    si.si_pid == pid)
303 			break;
304 
305 	cont:
306 		if (ptrace(PTRACE_CONT, pid, NULL, sig)) {
307 			if (errno == ESRCH)
308 				goto disappeared;
309 
310 			ksft_test_result_fail("PTRACE_CONT: %s\n",
311 					      strerror(errno));
312 			goto error;
313 		}
314 	}
315 
316 	ksft_print_msg("Parent is %d, child is %d\n", getpid(), child);
317 
318 	ptrace_za_disabled_read_zt(child);
319 	ptrace_set_get_zt(child);
320 	ptrace_enable_za_via_zt(child);
321 
322 	ret = EXIT_SUCCESS;
323 
324 error:
325 	kill(child, SIGKILL);
326 
327 disappeared:
328 	return ret;
329 }
330 
331 int main(void)
332 {
333 	int ret = EXIT_SUCCESS;
334 	pid_t child;
335 
336 	srandom(getpid());
337 
338 	ksft_print_header();
339 
340 	if (!(getauxval(AT_HWCAP2) & HWCAP2_SME2)) {
341 		ksft_set_plan(1);
342 		ksft_exit_skip("SME2 not available\n");
343 	}
344 
345 	/* We need a valid SME VL to enable/disable ZA */
346 	sme_vl = prctl(PR_SME_GET_VL);
347 	if (sme_vl == -1) {
348 		ksft_set_plan(1);
349 		ksft_exit_skip("Failed to read SME VL: %d (%s)\n",
350 			       errno, strerror(errno));
351 	}
352 
353 	ksft_set_plan(EXPECTED_TESTS);
354 
355 	child = fork();
356 	if (!child)
357 		return do_child();
358 
359 	if (do_parent(child))
360 		ret = EXIT_FAILURE;
361 
362 	ksft_print_cnts();
363 
364 	return ret;
365 }
366