1 // SPDX-License-Identifier: GPL-2.0-only
2 /*
3  * Copyright (C) 2022 ARM Limited.
4  */
5 #include <errno.h>
6 #include <stdbool.h>
7 #include <stddef.h>
8 #include <stdio.h>
9 #include <stdlib.h>
10 #include <string.h>
11 #include <unistd.h>
12 #include <sys/auxv.h>
13 #include <sys/prctl.h>
14 #include <sys/ptrace.h>
15 #include <sys/types.h>
16 #include <sys/uio.h>
17 #include <sys/wait.h>
18 #include <asm/sigcontext.h>
19 #include <asm/ptrace.h>
20 
21 #include "../../kselftest.h"
22 
23 #define EXPECTED_TESTS 7
24 
25 #define MAX_TPIDRS 2
26 
27 static bool have_sme(void)
28 {
29 	return getauxval(AT_HWCAP2) & HWCAP2_SME;
30 }
31 
32 static void test_tpidr(pid_t child)
33 {
34 	uint64_t read_val[MAX_TPIDRS];
35 	uint64_t write_val[MAX_TPIDRS];
36 	struct iovec read_iov, write_iov;
37 	bool test_tpidr2 = false;
38 	int ret, i;
39 
40 	read_iov.iov_base = read_val;
41 	write_iov.iov_base = write_val;
42 
43 	/* Should be able to read a single TPIDR... */
44 	read_iov.iov_len = sizeof(uint64_t);
45 	ret = ptrace(PTRACE_GETREGSET, child, NT_ARM_TLS, &read_iov);
46 	ksft_test_result(ret == 0, "read_tpidr_one\n");
47 
48 	/* ...write a new value.. */
49 	write_iov.iov_len = sizeof(uint64_t);
50 	write_val[0] = read_val[0]++;
51 	ret = ptrace(PTRACE_SETREGSET, child, NT_ARM_TLS, &write_iov);
52 	ksft_test_result(ret == 0, "write_tpidr_one\n");
53 
54 	/* ...then read it back */
55 	ret = ptrace(PTRACE_GETREGSET, child, NT_ARM_TLS, &read_iov);
56 	ksft_test_result(ret == 0 && write_val[0] == read_val[0],
57 			 "verify_tpidr_one\n");
58 
59 	/* If we have TPIDR2 we should be able to read it */
60 	read_iov.iov_len = sizeof(read_val);
61 	ret = ptrace(PTRACE_GETREGSET, child, NT_ARM_TLS, &read_iov);
62 	if (ret == 0) {
63 		/* If we have SME there should be two TPIDRs */
64 		if (read_iov.iov_len >= sizeof(read_val))
65 			test_tpidr2 = true;
66 
67 		if (have_sme() && test_tpidr2) {
68 			ksft_test_result(test_tpidr2, "count_tpidrs\n");
69 		} else {
70 			ksft_test_result(read_iov.iov_len % sizeof(uint64_t) == 0,
71 					 "count_tpidrs\n");
72 		}
73 	} else {
74 		ksft_test_result_fail("count_tpidrs\n");
75 	}
76 
77 	if (test_tpidr2) {
78 		/* Try to write new values to all known TPIDRs... */
79 		write_iov.iov_len = sizeof(write_val);
80 		for (i = 0; i < MAX_TPIDRS; i++)
81 			write_val[i] = read_val[i] + 1;
82 		ret = ptrace(PTRACE_SETREGSET, child, NT_ARM_TLS, &write_iov);
83 
84 		ksft_test_result(ret == 0 &&
85 				 write_iov.iov_len == sizeof(write_val),
86 				 "tpidr2_write\n");
87 
88 		/* ...then read them back */
89 		read_iov.iov_len = sizeof(read_val);
90 		ret = ptrace(PTRACE_GETREGSET, child, NT_ARM_TLS, &read_iov);
91 
92 		if (have_sme()) {
93 			/* Should read back the written value */
94 			ksft_test_result(ret == 0 &&
95 					 read_iov.iov_len >= sizeof(read_val) &&
96 					 memcmp(read_val, write_val,
97 						sizeof(read_val)) == 0,
98 					 "tpidr2_read\n");
99 		} else {
100 			/* TPIDR2 should read as zero */
101 			ksft_test_result(ret == 0 &&
102 					 read_iov.iov_len >= sizeof(read_val) &&
103 					 read_val[0] == write_val[0] &&
104 					 read_val[1] == 0,
105 					 "tpidr2_read\n");
106 		}
107 
108 		/* Writing only TPIDR... */
109 		write_iov.iov_len = sizeof(uint64_t);
110 		memcpy(write_val, read_val, sizeof(read_val));
111 		write_val[0] += 1;
112 		ret = ptrace(PTRACE_SETREGSET, child, NT_ARM_TLS, &write_iov);
113 
114 		if (ret == 0) {
115 			/* ...should leave TPIDR2 untouched */
116 			read_iov.iov_len = sizeof(read_val);
117 			ret = ptrace(PTRACE_GETREGSET, child, NT_ARM_TLS,
118 				     &read_iov);
119 
120 			ksft_test_result(ret == 0 &&
121 					 read_iov.iov_len >= sizeof(read_val) &&
122 					 memcmp(read_val, write_val,
123 						sizeof(read_val)) == 0,
124 					 "write_tpidr_only\n");
125 		} else {
126 			ksft_test_result_fail("write_tpidr_only\n");
127 		}
128 	} else {
129 		ksft_test_result_skip("tpidr2_write\n");
130 		ksft_test_result_skip("tpidr2_read\n");
131 		ksft_test_result_skip("write_tpidr_only\n");
132 	}
133 }
134 
135 static int do_child(void)
136 {
137 	if (ptrace(PTRACE_TRACEME, -1, NULL, NULL))
138 		ksft_exit_fail_msg("PTRACE_TRACEME", strerror(errno));
139 
140 	if (raise(SIGSTOP))
141 		ksft_exit_fail_msg("raise(SIGSTOP)", strerror(errno));
142 
143 	return EXIT_SUCCESS;
144 }
145 
146 static int do_parent(pid_t child)
147 {
148 	int ret = EXIT_FAILURE;
149 	pid_t pid;
150 	int status;
151 	siginfo_t si;
152 
153 	/* Attach to the child */
154 	while (1) {
155 		int sig;
156 
157 		pid = wait(&status);
158 		if (pid == -1) {
159 			perror("wait");
160 			goto error;
161 		}
162 
163 		/*
164 		 * This should never happen but it's hard to flag in
165 		 * the framework.
166 		 */
167 		if (pid != child)
168 			continue;
169 
170 		if (WIFEXITED(status) || WIFSIGNALED(status))
171 			ksft_exit_fail_msg("Child died unexpectedly\n");
172 
173 		if (!WIFSTOPPED(status))
174 			goto error;
175 
176 		sig = WSTOPSIG(status);
177 
178 		if (ptrace(PTRACE_GETSIGINFO, pid, NULL, &si)) {
179 			if (errno == ESRCH)
180 				goto disappeared;
181 
182 			if (errno == EINVAL) {
183 				sig = 0; /* bust group-stop */
184 				goto cont;
185 			}
186 
187 			ksft_test_result_fail("PTRACE_GETSIGINFO: %s\n",
188 					      strerror(errno));
189 			goto error;
190 		}
191 
192 		if (sig == SIGSTOP && si.si_code == SI_TKILL &&
193 		    si.si_pid == pid)
194 			break;
195 
196 	cont:
197 		if (ptrace(PTRACE_CONT, pid, NULL, sig)) {
198 			if (errno == ESRCH)
199 				goto disappeared;
200 
201 			ksft_test_result_fail("PTRACE_CONT: %s\n",
202 					      strerror(errno));
203 			goto error;
204 		}
205 	}
206 
207 	ksft_print_msg("Parent is %d, child is %d\n", getpid(), child);
208 
209 	test_tpidr(child);
210 
211 	ret = EXIT_SUCCESS;
212 
213 error:
214 	kill(child, SIGKILL);
215 
216 disappeared:
217 	return ret;
218 }
219 
220 int main(void)
221 {
222 	int ret = EXIT_SUCCESS;
223 	pid_t child;
224 
225 	srandom(getpid());
226 
227 	ksft_print_header();
228 
229 	ksft_set_plan(EXPECTED_TESTS);
230 
231 	child = fork();
232 	if (!child)
233 		return do_child();
234 
235 	if (do_parent(child))
236 		ret = EXIT_FAILURE;
237 
238 	ksft_print_cnts();
239 
240 	return ret;
241 }
242