1 // SPDX-License-Identifier: GPL-2.0-only
2 /*
3 * Copyright (C) 2021 ARM Limited.
4 */
5 #include <errno.h>
6 #include <stdbool.h>
7 #include <stddef.h>
8 #include <stdio.h>
9 #include <stdlib.h>
10 #include <string.h>
11 #include <unistd.h>
12 #include <sys/auxv.h>
13 #include <sys/prctl.h>
14 #include <sys/ptrace.h>
15 #include <sys/types.h>
16 #include <sys/uio.h>
17 #include <sys/wait.h>
18 #include <asm/sigcontext.h>
19 #include <asm/ptrace.h>
20
21 #include "../../kselftest.h"
22
23 /* <linux/elf.h> and <sys/auxv.h> don't like each other, so: */
24 #ifndef NT_ARM_ZA
25 #define NT_ARM_ZA 0x40c
26 #endif
27 #ifndef NT_ARM_ZT
28 #define NT_ARM_ZT 0x40d
29 #endif
30
31 #define EXPECTED_TESTS 3
32
33 static int sme_vl;
34
fill_buf(char * buf,size_t size)35 static void fill_buf(char *buf, size_t size)
36 {
37 int i;
38
39 for (i = 0; i < size; i++)
40 buf[i] = random();
41 }
42
do_child(void)43 static int do_child(void)
44 {
45 if (ptrace(PTRACE_TRACEME, -1, NULL, NULL))
46 ksft_exit_fail_msg("PTRACE_TRACEME", strerror(errno));
47
48 if (raise(SIGSTOP))
49 ksft_exit_fail_msg("raise(SIGSTOP)", strerror(errno));
50
51 return EXIT_SUCCESS;
52 }
53
get_za(pid_t pid,void ** buf,size_t * size)54 static struct user_za_header *get_za(pid_t pid, void **buf, size_t *size)
55 {
56 struct user_za_header *za;
57 void *p;
58 size_t sz = sizeof(*za);
59 struct iovec iov;
60
61 while (1) {
62 if (*size < sz) {
63 p = realloc(*buf, sz);
64 if (!p) {
65 errno = ENOMEM;
66 goto error;
67 }
68
69 *buf = p;
70 *size = sz;
71 }
72
73 iov.iov_base = *buf;
74 iov.iov_len = sz;
75 if (ptrace(PTRACE_GETREGSET, pid, NT_ARM_ZA, &iov))
76 goto error;
77
78 za = *buf;
79 if (za->size <= sz)
80 break;
81
82 sz = za->size;
83 }
84
85 return za;
86
87 error:
88 return NULL;
89 }
90
set_za(pid_t pid,const struct user_za_header * za)91 static int set_za(pid_t pid, const struct user_za_header *za)
92 {
93 struct iovec iov;
94
95 iov.iov_base = (void *)za;
96 iov.iov_len = za->size;
97 return ptrace(PTRACE_SETREGSET, pid, NT_ARM_ZA, &iov);
98 }
99
get_zt(pid_t pid,char zt[ZT_SIG_REG_BYTES])100 static int get_zt(pid_t pid, char zt[ZT_SIG_REG_BYTES])
101 {
102 struct iovec iov;
103
104 iov.iov_base = zt;
105 iov.iov_len = ZT_SIG_REG_BYTES;
106 return ptrace(PTRACE_GETREGSET, pid, NT_ARM_ZT, &iov);
107 }
108
109
set_zt(pid_t pid,const char zt[ZT_SIG_REG_BYTES])110 static int set_zt(pid_t pid, const char zt[ZT_SIG_REG_BYTES])
111 {
112 struct iovec iov;
113
114 iov.iov_base = (void *)zt;
115 iov.iov_len = ZT_SIG_REG_BYTES;
116 return ptrace(PTRACE_SETREGSET, pid, NT_ARM_ZT, &iov);
117 }
118
119 /* Reading with ZA disabled returns all zeros */
ptrace_za_disabled_read_zt(pid_t child)120 static void ptrace_za_disabled_read_zt(pid_t child)
121 {
122 struct user_za_header za;
123 char zt[ZT_SIG_REG_BYTES];
124 int ret, i;
125 bool fail = false;
126
127 /* Disable PSTATE.ZA using the ZA interface */
128 memset(&za, 0, sizeof(za));
129 za.vl = sme_vl;
130 za.size = sizeof(za);
131
132 ret = set_za(child, &za);
133 if (ret != 0) {
134 ksft_print_msg("Failed to disable ZA\n");
135 fail = true;
136 }
137
138 /* Read back ZT */
139 ret = get_zt(child, zt);
140 if (ret != 0) {
141 ksft_print_msg("Failed to read ZT\n");
142 fail = true;
143 }
144
145 for (i = 0; i < ARRAY_SIZE(zt); i++) {
146 if (zt[i]) {
147 ksft_print_msg("zt[%d]: 0x%x != 0\n", i, zt[i]);
148 fail = true;
149 }
150 }
151
152 ksft_test_result(!fail, "ptrace_za_disabled_read_zt\n");
153 }
154
155 /* Writing then reading ZT should return the data written */
ptrace_set_get_zt(pid_t child)156 static void ptrace_set_get_zt(pid_t child)
157 {
158 char zt_in[ZT_SIG_REG_BYTES];
159 char zt_out[ZT_SIG_REG_BYTES];
160 int ret, i;
161 bool fail = false;
162
163 fill_buf(zt_in, sizeof(zt_in));
164
165 ret = set_zt(child, zt_in);
166 if (ret != 0) {
167 ksft_print_msg("Failed to set ZT\n");
168 fail = true;
169 }
170
171 ret = get_zt(child, zt_out);
172 if (ret != 0) {
173 ksft_print_msg("Failed to read ZT\n");
174 fail = true;
175 }
176
177 for (i = 0; i < ARRAY_SIZE(zt_in); i++) {
178 if (zt_in[i] != zt_out[i]) {
179 ksft_print_msg("zt[%d]: 0x%x != 0x%x\n", i,
180 zt_in[i], zt_out[i]);
181 fail = true;
182 }
183 }
184
185 ksft_test_result(!fail, "ptrace_set_get_zt\n");
186 }
187
188 /* Writing ZT should set PSTATE.ZA */
ptrace_enable_za_via_zt(pid_t child)189 static void ptrace_enable_za_via_zt(pid_t child)
190 {
191 struct user_za_header za_in;
192 struct user_za_header *za_out;
193 char zt[ZT_SIG_REG_BYTES];
194 char *za_data;
195 size_t za_out_size;
196 int ret, i, vq;
197 bool fail = false;
198
199 /* Disable PSTATE.ZA using the ZA interface */
200 memset(&za_in, 0, sizeof(za_in));
201 za_in.vl = sme_vl;
202 za_in.size = sizeof(za_in);
203
204 ret = set_za(child, &za_in);
205 if (ret != 0) {
206 ksft_print_msg("Failed to disable ZA\n");
207 fail = true;
208 }
209
210 /* Write ZT */
211 fill_buf(zt, sizeof(zt));
212 ret = set_zt(child, zt);
213 if (ret != 0) {
214 ksft_print_msg("Failed to set ZT\n");
215 fail = true;
216 }
217
218 /* Read back ZA and check for register data */
219 za_out = NULL;
220 za_out_size = 0;
221 if (get_za(child, (void **)&za_out, &za_out_size)) {
222 /* Should have an unchanged VL */
223 if (za_out->vl != sme_vl) {
224 ksft_print_msg("VL changed from %d to %d\n",
225 sme_vl, za_out->vl);
226 fail = true;
227 }
228 vq = __sve_vq_from_vl(za_out->vl);
229 za_data = (char *)za_out + ZA_PT_ZA_OFFSET;
230
231 /* Should have register data */
232 if (za_out->size < ZA_PT_SIZE(vq)) {
233 ksft_print_msg("ZA data less than expected: %u < %u\n",
234 za_out->size, ZA_PT_SIZE(vq));
235 fail = true;
236 vq = 0;
237 }
238
239 /* That register data should be non-zero */
240 for (i = 0; i < ZA_PT_ZA_SIZE(vq); i++) {
241 if (za_data[i]) {
242 ksft_print_msg("ZA byte %d is %x\n",
243 i, za_data[i]);
244 fail = true;
245 }
246 }
247 } else {
248 ksft_print_msg("Failed to read ZA\n");
249 fail = true;
250 }
251
252 ksft_test_result(!fail, "ptrace_enable_za_via_zt\n");
253 }
254
do_parent(pid_t child)255 static int do_parent(pid_t child)
256 {
257 int ret = EXIT_FAILURE;
258 pid_t pid;
259 int status;
260 siginfo_t si;
261
262 /* Attach to the child */
263 while (1) {
264 int sig;
265
266 pid = wait(&status);
267 if (pid == -1) {
268 perror("wait");
269 goto error;
270 }
271
272 /*
273 * This should never happen but it's hard to flag in
274 * the framework.
275 */
276 if (pid != child)
277 continue;
278
279 if (WIFEXITED(status) || WIFSIGNALED(status))
280 ksft_exit_fail_msg("Child died unexpectedly\n");
281
282 if (!WIFSTOPPED(status))
283 goto error;
284
285 sig = WSTOPSIG(status);
286
287 if (ptrace(PTRACE_GETSIGINFO, pid, NULL, &si)) {
288 if (errno == ESRCH)
289 goto disappeared;
290
291 if (errno == EINVAL) {
292 sig = 0; /* bust group-stop */
293 goto cont;
294 }
295
296 ksft_test_result_fail("PTRACE_GETSIGINFO: %s\n",
297 strerror(errno));
298 goto error;
299 }
300
301 if (sig == SIGSTOP && si.si_code == SI_TKILL &&
302 si.si_pid == pid)
303 break;
304
305 cont:
306 if (ptrace(PTRACE_CONT, pid, NULL, sig)) {
307 if (errno == ESRCH)
308 goto disappeared;
309
310 ksft_test_result_fail("PTRACE_CONT: %s\n",
311 strerror(errno));
312 goto error;
313 }
314 }
315
316 ksft_print_msg("Parent is %d, child is %d\n", getpid(), child);
317
318 ptrace_za_disabled_read_zt(child);
319 ptrace_set_get_zt(child);
320 ptrace_enable_za_via_zt(child);
321
322 ret = EXIT_SUCCESS;
323
324 error:
325 kill(child, SIGKILL);
326
327 disappeared:
328 return ret;
329 }
330
main(void)331 int main(void)
332 {
333 int ret = EXIT_SUCCESS;
334 pid_t child;
335
336 srandom(getpid());
337
338 ksft_print_header();
339
340 if (!(getauxval(AT_HWCAP2) & HWCAP2_SME2)) {
341 ksft_set_plan(1);
342 ksft_exit_skip("SME2 not available\n");
343 }
344
345 /* We need a valid SME VL to enable/disable ZA */
346 sme_vl = prctl(PR_SME_GET_VL);
347 if (sme_vl == -1) {
348 ksft_set_plan(1);
349 ksft_exit_skip("Failed to read SME VL: %d (%s)\n",
350 errno, strerror(errno));
351 }
352
353 ksft_set_plan(EXPECTED_TESTS);
354
355 child = fork();
356 if (!child)
357 return do_child();
358
359 if (do_parent(child))
360 ret = EXIT_FAILURE;
361
362 ksft_print_cnts();
363
364 return ret;
365 }
366