1 // SPDX-License-Identifier: GPL-2.0
2 // Copyright (C) 2020 ARM Limited
3 
4 #include <fcntl.h>
5 #include <sched.h>
6 #include <signal.h>
7 #include <stdio.h>
8 #include <stdlib.h>
9 #include <unistd.h>
10 
11 #include <linux/auxvec.h>
12 #include <sys/auxv.h>
13 #include <sys/mman.h>
14 #include <sys/prctl.h>
15 
16 #include <asm/hwcap.h>
17 
18 #include "kselftest.h"
19 #include "mte_common_util.h"
20 #include "mte_def.h"
21 
22 #define INIT_BUFFER_SIZE       256
23 
24 struct mte_fault_cxt cur_mte_cxt;
25 static unsigned int mte_cur_mode;
26 static unsigned int mte_cur_pstate_tco;
27 
28 void mte_default_handler(int signum, siginfo_t *si, void *uc)
29 {
30 	unsigned long addr = (unsigned long)si->si_addr;
31 
32 	if (signum == SIGSEGV) {
33 #ifdef DEBUG
34 		ksft_print_msg("INFO: SIGSEGV signal at pc=%lx, fault addr=%lx, si_code=%lx\n",
35 				((ucontext_t *)uc)->uc_mcontext.pc, addr, si->si_code);
36 #endif
37 		if (si->si_code == SEGV_MTEAERR) {
38 			if (cur_mte_cxt.trig_si_code == si->si_code)
39 				cur_mte_cxt.fault_valid = true;
40 			else
41 				ksft_print_msg("Got unexpected SEGV_MTEAERR at pc=$lx, fault addr=%lx\n",
42 					       ((ucontext_t *)uc)->uc_mcontext.pc,
43 					       addr);
44 			return;
45 		}
46 		/* Compare the context for precise error */
47 		else if (si->si_code == SEGV_MTESERR) {
48 			if (cur_mte_cxt.trig_si_code == si->si_code &&
49 			    ((cur_mte_cxt.trig_range >= 0 &&
50 			      addr >= MT_CLEAR_TAG(cur_mte_cxt.trig_addr) &&
51 			      addr <= (MT_CLEAR_TAG(cur_mte_cxt.trig_addr) + cur_mte_cxt.trig_range)) ||
52 			     (cur_mte_cxt.trig_range < 0 &&
53 			      addr <= MT_CLEAR_TAG(cur_mte_cxt.trig_addr) &&
54 			      addr >= (MT_CLEAR_TAG(cur_mte_cxt.trig_addr) + cur_mte_cxt.trig_range)))) {
55 				cur_mte_cxt.fault_valid = true;
56 				/* Adjust the pc by 4 */
57 				((ucontext_t *)uc)->uc_mcontext.pc += 4;
58 			} else {
59 				ksft_print_msg("Invalid MTE synchronous exception caught!\n");
60 				exit(1);
61 			}
62 		} else {
63 			ksft_print_msg("Unknown SIGSEGV exception caught!\n");
64 			exit(1);
65 		}
66 	} else if (signum == SIGBUS) {
67 		ksft_print_msg("INFO: SIGBUS signal at pc=%lx, fault addr=%lx, si_code=%lx\n",
68 				((ucontext_t *)uc)->uc_mcontext.pc, addr, si->si_code);
69 		if ((cur_mte_cxt.trig_range >= 0 &&
70 		     addr >= MT_CLEAR_TAG(cur_mte_cxt.trig_addr) &&
71 		     addr <= (MT_CLEAR_TAG(cur_mte_cxt.trig_addr) + cur_mte_cxt.trig_range)) ||
72 		    (cur_mte_cxt.trig_range < 0 &&
73 		     addr <= MT_CLEAR_TAG(cur_mte_cxt.trig_addr) &&
74 		     addr >= (MT_CLEAR_TAG(cur_mte_cxt.trig_addr) + cur_mte_cxt.trig_range))) {
75 			cur_mte_cxt.fault_valid = true;
76 			/* Adjust the pc by 4 */
77 			((ucontext_t *)uc)->uc_mcontext.pc += 4;
78 		}
79 	}
80 }
81 
82 void mte_register_signal(int signal, void (*handler)(int, siginfo_t *, void *))
83 {
84 	struct sigaction sa;
85 
86 	sa.sa_sigaction = handler;
87 	sa.sa_flags = SA_SIGINFO;
88 	sigemptyset(&sa.sa_mask);
89 	sigaction(signal, &sa, NULL);
90 }
91 
92 void mte_wait_after_trig(void)
93 {
94 	sched_yield();
95 }
96 
97 void *mte_insert_tags(void *ptr, size_t size)
98 {
99 	void *tag_ptr;
100 	int align_size;
101 
102 	if (!ptr || (unsigned long)(ptr) & MT_ALIGN_GRANULE) {
103 		ksft_print_msg("FAIL: Addr=%lx: invalid\n", ptr);
104 		return NULL;
105 	}
106 	align_size = MT_ALIGN_UP(size);
107 	tag_ptr = mte_insert_random_tag(ptr);
108 	mte_set_tag_address_range(tag_ptr, align_size);
109 	return tag_ptr;
110 }
111 
112 void mte_clear_tags(void *ptr, size_t size)
113 {
114 	if (!ptr || (unsigned long)(ptr) & MT_ALIGN_GRANULE) {
115 		ksft_print_msg("FAIL: Addr=%lx: invalid\n", ptr);
116 		return;
117 	}
118 	size = MT_ALIGN_UP(size);
119 	ptr = (void *)MT_CLEAR_TAG((unsigned long)ptr);
120 	mte_clear_tag_address_range(ptr, size);
121 }
122 
123 static void *__mte_allocate_memory_range(size_t size, int mem_type, int mapping,
124 					 size_t range_before, size_t range_after,
125 					 bool tags, int fd)
126 {
127 	void *ptr;
128 	int prot_flag, map_flag;
129 	size_t entire_size = size + range_before + range_after;
130 
131 	if (mem_type != USE_MALLOC && mem_type != USE_MMAP &&
132 	    mem_type != USE_MPROTECT) {
133 		ksft_print_msg("FAIL: Invalid allocate request\n");
134 		return NULL;
135 	}
136 	if (mem_type == USE_MALLOC)
137 		return malloc(entire_size) + range_before;
138 
139 	prot_flag = PROT_READ | PROT_WRITE;
140 	if (mem_type == USE_MMAP)
141 		prot_flag |= PROT_MTE;
142 
143 	map_flag = mapping;
144 	if (fd == -1)
145 		map_flag = MAP_ANONYMOUS | map_flag;
146 	if (!(mapping & MAP_SHARED))
147 		map_flag |= MAP_PRIVATE;
148 	ptr = mmap(NULL, entire_size, prot_flag, map_flag, fd, 0);
149 	if (ptr == MAP_FAILED) {
150 		ksft_print_msg("FAIL: mmap allocation\n");
151 		return NULL;
152 	}
153 	if (mem_type == USE_MPROTECT) {
154 		if (mprotect(ptr, entire_size, prot_flag | PROT_MTE)) {
155 			munmap(ptr, size);
156 			ksft_print_msg("FAIL: mprotect PROT_MTE property\n");
157 			return NULL;
158 		}
159 	}
160 	if (tags)
161 		ptr = mte_insert_tags(ptr + range_before, size);
162 	return ptr;
163 }
164 
165 void *mte_allocate_memory_tag_range(size_t size, int mem_type, int mapping,
166 				    size_t range_before, size_t range_after)
167 {
168 	return __mte_allocate_memory_range(size, mem_type, mapping, range_before,
169 					   range_after, true, -1);
170 }
171 
172 void *mte_allocate_memory(size_t size, int mem_type, int mapping, bool tags)
173 {
174 	return __mte_allocate_memory_range(size, mem_type, mapping, 0, 0, tags, -1);
175 }
176 
177 void *mte_allocate_file_memory(size_t size, int mem_type, int mapping, bool tags, int fd)
178 {
179 	int index;
180 	char buffer[INIT_BUFFER_SIZE];
181 
182 	if (mem_type != USE_MPROTECT && mem_type != USE_MMAP) {
183 		ksft_print_msg("FAIL: Invalid mmap file request\n");
184 		return NULL;
185 	}
186 	/* Initialize the file for mappable size */
187 	lseek(fd, 0, SEEK_SET);
188 	for (index = INIT_BUFFER_SIZE; index < size; index += INIT_BUFFER_SIZE) {
189 		if (write(fd, buffer, INIT_BUFFER_SIZE) != INIT_BUFFER_SIZE) {
190 			perror("initialising buffer");
191 			return NULL;
192 		}
193 	}
194 	index -= INIT_BUFFER_SIZE;
195 	if (write(fd, buffer, size - index) != size - index) {
196 		perror("initialising buffer");
197 		return NULL;
198 	}
199 	return __mte_allocate_memory_range(size, mem_type, mapping, 0, 0, tags, fd);
200 }
201 
202 void *mte_allocate_file_memory_tag_range(size_t size, int mem_type, int mapping,
203 					 size_t range_before, size_t range_after, int fd)
204 {
205 	int index;
206 	char buffer[INIT_BUFFER_SIZE];
207 	int map_size = size + range_before + range_after;
208 
209 	if (mem_type != USE_MPROTECT && mem_type != USE_MMAP) {
210 		ksft_print_msg("FAIL: Invalid mmap file request\n");
211 		return NULL;
212 	}
213 	/* Initialize the file for mappable size */
214 	lseek(fd, 0, SEEK_SET);
215 	for (index = INIT_BUFFER_SIZE; index < map_size; index += INIT_BUFFER_SIZE)
216 		if (write(fd, buffer, INIT_BUFFER_SIZE) != INIT_BUFFER_SIZE) {
217 			perror("initialising buffer");
218 			return NULL;
219 		}
220 	index -= INIT_BUFFER_SIZE;
221 	if (write(fd, buffer, map_size - index) != map_size - index) {
222 		perror("initialising buffer");
223 		return NULL;
224 	}
225 	return __mte_allocate_memory_range(size, mem_type, mapping, range_before,
226 					   range_after, true, fd);
227 }
228 
229 static void __mte_free_memory_range(void *ptr, size_t size, int mem_type,
230 				    size_t range_before, size_t range_after, bool tags)
231 {
232 	switch (mem_type) {
233 	case USE_MALLOC:
234 		free(ptr - range_before);
235 		break;
236 	case USE_MMAP:
237 	case USE_MPROTECT:
238 		if (tags)
239 			mte_clear_tags(ptr, size);
240 		munmap(ptr - range_before, size + range_before + range_after);
241 		break;
242 	default:
243 		ksft_print_msg("FAIL: Invalid free request\n");
244 		break;
245 	}
246 }
247 
248 void mte_free_memory_tag_range(void *ptr, size_t size, int mem_type,
249 			       size_t range_before, size_t range_after)
250 {
251 	__mte_free_memory_range(ptr, size, mem_type, range_before, range_after, true);
252 }
253 
254 void mte_free_memory(void *ptr, size_t size, int mem_type, bool tags)
255 {
256 	__mte_free_memory_range(ptr, size, mem_type, 0, 0, tags);
257 }
258 
259 void mte_initialize_current_context(int mode, uintptr_t ptr, ssize_t range)
260 {
261 	cur_mte_cxt.fault_valid = false;
262 	cur_mte_cxt.trig_addr = ptr;
263 	cur_mte_cxt.trig_range = range;
264 	if (mode == MTE_SYNC_ERR)
265 		cur_mte_cxt.trig_si_code = SEGV_MTESERR;
266 	else if (mode == MTE_ASYNC_ERR)
267 		cur_mte_cxt.trig_si_code = SEGV_MTEAERR;
268 	else
269 		cur_mte_cxt.trig_si_code = 0;
270 }
271 
272 int mte_switch_mode(int mte_option, unsigned long incl_mask)
273 {
274 	unsigned long en = 0;
275 
276 	switch (mte_option) {
277 	case MTE_NONE_ERR:
278 	case MTE_SYNC_ERR:
279 	case MTE_ASYNC_ERR:
280 		break;
281 	default:
282 		ksft_print_msg("FAIL: Invalid MTE option %x\n", mte_option);
283 		return -EINVAL;
284 	}
285 
286 	if (!(incl_mask <= MTE_ALLOW_NON_ZERO_TAG)) {
287 		ksft_print_msg("FAIL: Invalid incl_mask %lx\n", incl_mask);
288 		return -EINVAL;
289 	}
290 	en = PR_TAGGED_ADDR_ENABLE;
291 	if (mte_option == MTE_SYNC_ERR)
292 		en |= PR_MTE_TCF_SYNC;
293 	else if (mte_option == MTE_ASYNC_ERR)
294 		en |= PR_MTE_TCF_ASYNC;
295 	else if (mte_option == MTE_NONE_ERR)
296 		en |= PR_MTE_TCF_NONE;
297 
298 	en |= (incl_mask << PR_MTE_TAG_SHIFT);
299 	/* Enable address tagging ABI, mte error reporting mode and tag inclusion mask. */
300 	if (prctl(PR_SET_TAGGED_ADDR_CTRL, en, 0, 0, 0) != 0) {
301 		ksft_print_msg("FAIL:prctl PR_SET_TAGGED_ADDR_CTRL for mte mode\n");
302 		return -EINVAL;
303 	}
304 	return 0;
305 }
306 
307 int mte_default_setup(void)
308 {
309 	unsigned long hwcaps2 = getauxval(AT_HWCAP2);
310 	unsigned long en = 0;
311 	int ret;
312 
313 	if (!(hwcaps2 & HWCAP2_MTE)) {
314 		ksft_print_msg("SKIP: MTE features unavailable\n");
315 		return KSFT_SKIP;
316 	}
317 	/* Get current mte mode */
318 	ret = prctl(PR_GET_TAGGED_ADDR_CTRL, en, 0, 0, 0);
319 	if (ret < 0) {
320 		ksft_print_msg("FAIL:prctl PR_GET_TAGGED_ADDR_CTRL with error =%d\n", ret);
321 		return KSFT_FAIL;
322 	}
323 	if (ret & PR_MTE_TCF_SYNC)
324 		mte_cur_mode = MTE_SYNC_ERR;
325 	else if (ret & PR_MTE_TCF_ASYNC)
326 		mte_cur_mode = MTE_ASYNC_ERR;
327 	else if (ret & PR_MTE_TCF_NONE)
328 		mte_cur_mode = MTE_NONE_ERR;
329 
330 	mte_cur_pstate_tco = mte_get_pstate_tco();
331 	/* Disable PSTATE.TCO */
332 	mte_disable_pstate_tco();
333 	return 0;
334 }
335 
336 void mte_restore_setup(void)
337 {
338 	mte_switch_mode(mte_cur_mode, MTE_ALLOW_NON_ZERO_TAG);
339 	if (mte_cur_pstate_tco == MT_PSTATE_TCO_EN)
340 		mte_enable_pstate_tco();
341 	else if (mte_cur_pstate_tco == MT_PSTATE_TCO_DIS)
342 		mte_disable_pstate_tco();
343 }
344 
345 int create_temp_file(void)
346 {
347 	int fd;
348 	char filename[] = "/dev/shm/tmp_XXXXXX";
349 
350 	/* Create a file in the tmpfs filesystem */
351 	fd = mkstemp(&filename[0]);
352 	if (fd == -1) {
353 		perror(filename);
354 		ksft_print_msg("FAIL: Unable to open temporary file\n");
355 		return 0;
356 	}
357 	unlink(&filename[0]);
358 	return fd;
359 }
360