1 // SPDX-License-Identifier: GPL-2.0
2 // Copyright (C) 2020 ARM Limited
3 
4 #define _GNU_SOURCE
5 
6 #include <errno.h>
7 #include <signal.h>
8 #include <stdio.h>
9 #include <stdlib.h>
10 #include <string.h>
11 #include <ucontext.h>
12 #include <sys/wait.h>
13 
14 #include "kselftest.h"
15 #include "mte_common_util.h"
16 #include "mte_def.h"
17 
18 #define BUFFER_SIZE		(5 * MT_GRANULE_SIZE)
19 #define RUNS			(MT_TAG_COUNT * 2)
20 #define MTE_LAST_TAG_MASK	(0x7FFF)
21 
verify_mte_pointer_validity(char * ptr,int mode)22 static int verify_mte_pointer_validity(char *ptr, int mode)
23 {
24 	mte_initialize_current_context(mode, (uintptr_t)ptr, BUFFER_SIZE);
25 	/* Check the validity of the tagged pointer */
26 	memset(ptr, '1', BUFFER_SIZE);
27 	mte_wait_after_trig();
28 	if (cur_mte_cxt.fault_valid) {
29 		ksft_print_msg("Unexpected fault recorded for %p-%p in mode %x\n",
30 			       ptr, ptr + BUFFER_SIZE, mode);
31 		return KSFT_FAIL;
32 	}
33 	/* Proceed further for nonzero tags */
34 	if (!MT_FETCH_TAG((uintptr_t)ptr))
35 		return KSFT_PASS;
36 	mte_initialize_current_context(mode, (uintptr_t)ptr, BUFFER_SIZE + 1);
37 	/* Check the validity outside the range */
38 	ptr[BUFFER_SIZE] = '2';
39 	mte_wait_after_trig();
40 	if (!cur_mte_cxt.fault_valid) {
41 		ksft_print_msg("No valid fault recorded for %p in mode %x\n",
42 			       ptr, mode);
43 		return KSFT_FAIL;
44 	} else {
45 		return KSFT_PASS;
46 	}
47 }
48 
check_single_included_tags(int mem_type,int mode)49 static int check_single_included_tags(int mem_type, int mode)
50 {
51 	char *ptr;
52 	int tag, run, ret, result = KSFT_PASS;
53 
54 	ptr = mte_allocate_memory(BUFFER_SIZE + MT_GRANULE_SIZE, mem_type, 0, false);
55 	if (check_allocated_memory(ptr, BUFFER_SIZE + MT_GRANULE_SIZE,
56 				   mem_type, false) != KSFT_PASS)
57 		return KSFT_FAIL;
58 
59 	for (tag = 0; (tag < MT_TAG_COUNT) && (result == KSFT_PASS); tag++) {
60 		ret = mte_switch_mode(mode, MT_INCLUDE_VALID_TAG(tag));
61 		if (ret != 0)
62 			result = KSFT_FAIL;
63 		/* Try to catch a excluded tag by a number of tries. */
64 		for (run = 0; (run < RUNS) && (result == KSFT_PASS); run++) {
65 			ptr = mte_insert_tags(ptr, BUFFER_SIZE);
66 			/* Check tag value */
67 			if (MT_FETCH_TAG((uintptr_t)ptr) == tag) {
68 				ksft_print_msg("FAIL: wrong tag = 0x%x with include mask=0x%x\n",
69 					       MT_FETCH_TAG((uintptr_t)ptr),
70 					       MT_INCLUDE_VALID_TAG(tag));
71 				result = KSFT_FAIL;
72 				break;
73 			}
74 			result = verify_mte_pointer_validity(ptr, mode);
75 		}
76 	}
77 	mte_free_memory_tag_range(ptr, BUFFER_SIZE, mem_type, 0, MT_GRANULE_SIZE);
78 	return result;
79 }
80 
check_multiple_included_tags(int mem_type,int mode)81 static int check_multiple_included_tags(int mem_type, int mode)
82 {
83 	char *ptr;
84 	int tag, run, result = KSFT_PASS;
85 	unsigned long excl_mask = 0;
86 
87 	ptr = mte_allocate_memory(BUFFER_SIZE + MT_GRANULE_SIZE, mem_type, 0, false);
88 	if (check_allocated_memory(ptr, BUFFER_SIZE + MT_GRANULE_SIZE,
89 				   mem_type, false) != KSFT_PASS)
90 		return KSFT_FAIL;
91 
92 	for (tag = 0; (tag < MT_TAG_COUNT - 1) && (result == KSFT_PASS); tag++) {
93 		excl_mask |= 1 << tag;
94 		mte_switch_mode(mode, MT_INCLUDE_VALID_TAGS(excl_mask));
95 		/* Try to catch a excluded tag by a number of tries. */
96 		for (run = 0; (run < RUNS) && (result == KSFT_PASS); run++) {
97 			ptr = mte_insert_tags(ptr, BUFFER_SIZE);
98 			/* Check tag value */
99 			if (MT_FETCH_TAG((uintptr_t)ptr) < tag) {
100 				ksft_print_msg("FAIL: wrong tag = 0x%x with include mask=0x%x\n",
101 					       MT_FETCH_TAG((uintptr_t)ptr),
102 					       MT_INCLUDE_VALID_TAGS(excl_mask));
103 				result = KSFT_FAIL;
104 				break;
105 			}
106 			result = verify_mte_pointer_validity(ptr, mode);
107 		}
108 	}
109 	mte_free_memory_tag_range(ptr, BUFFER_SIZE, mem_type, 0, MT_GRANULE_SIZE);
110 	return result;
111 }
112 
check_all_included_tags(int mem_type,int mode)113 static int check_all_included_tags(int mem_type, int mode)
114 {
115 	char *ptr;
116 	int run, ret, result = KSFT_PASS;
117 
118 	ptr = mte_allocate_memory(BUFFER_SIZE + MT_GRANULE_SIZE, mem_type, 0, false);
119 	if (check_allocated_memory(ptr, BUFFER_SIZE + MT_GRANULE_SIZE,
120 				   mem_type, false) != KSFT_PASS)
121 		return KSFT_FAIL;
122 
123 	ret = mte_switch_mode(mode, MT_INCLUDE_TAG_MASK);
124 	if (ret != 0)
125 		return KSFT_FAIL;
126 	/* Try to catch a excluded tag by a number of tries. */
127 	for (run = 0; (run < RUNS) && (result == KSFT_PASS); run++) {
128 		ptr = (char *)mte_insert_tags(ptr, BUFFER_SIZE);
129 		/*
130 		 * Here tag byte can be between 0x0 to 0xF (full allowed range)
131 		 * so no need to match so just verify if it is writable.
132 		 */
133 		result = verify_mte_pointer_validity(ptr, mode);
134 	}
135 	mte_free_memory_tag_range(ptr, BUFFER_SIZE, mem_type, 0, MT_GRANULE_SIZE);
136 	return result;
137 }
138 
check_none_included_tags(int mem_type,int mode)139 static int check_none_included_tags(int mem_type, int mode)
140 {
141 	char *ptr;
142 	int run, ret;
143 
144 	ptr = mte_allocate_memory(BUFFER_SIZE, mem_type, 0, false);
145 	if (check_allocated_memory(ptr, BUFFER_SIZE, mem_type, false) != KSFT_PASS)
146 		return KSFT_FAIL;
147 
148 	ret = mte_switch_mode(mode, MT_EXCLUDE_TAG_MASK);
149 	if (ret != 0)
150 		return KSFT_FAIL;
151 	/* Try to catch a excluded tag by a number of tries. */
152 	for (run = 0; run < RUNS; run++) {
153 		ptr = (char *)mte_insert_tags(ptr, BUFFER_SIZE);
154 		/* Here all tags exluded so tag value generated should be 0 */
155 		if (MT_FETCH_TAG((uintptr_t)ptr)) {
156 			ksft_print_msg("FAIL: included tag value found\n");
157 			mte_free_memory((void *)ptr, BUFFER_SIZE, mem_type, true);
158 			return KSFT_FAIL;
159 		}
160 		mte_initialize_current_context(mode, (uintptr_t)ptr, BUFFER_SIZE);
161 		/* Check the write validity of the untagged pointer */
162 		memset(ptr, '1', BUFFER_SIZE);
163 		mte_wait_after_trig();
164 		if (cur_mte_cxt.fault_valid)
165 			break;
166 	}
167 	mte_free_memory(ptr, BUFFER_SIZE, mem_type, false);
168 	if (cur_mte_cxt.fault_valid)
169 		return KSFT_FAIL;
170 	else
171 		return KSFT_PASS;
172 }
173 
main(int argc,char * argv[])174 int main(int argc, char *argv[])
175 {
176 	int err;
177 
178 	err = mte_default_setup();
179 	if (err)
180 		return err;
181 
182 	/* Register SIGSEGV handler */
183 	mte_register_signal(SIGSEGV, mte_default_handler);
184 
185 	/* Set test plan */
186 	ksft_set_plan(4);
187 
188 	evaluate_test(check_single_included_tags(USE_MMAP, MTE_SYNC_ERR),
189 		      "Check an included tag value with sync mode\n");
190 	evaluate_test(check_multiple_included_tags(USE_MMAP, MTE_SYNC_ERR),
191 		      "Check different included tags value with sync mode\n");
192 	evaluate_test(check_none_included_tags(USE_MMAP, MTE_SYNC_ERR),
193 		      "Check none included tags value with sync mode\n");
194 	evaluate_test(check_all_included_tags(USE_MMAP, MTE_SYNC_ERR),
195 		      "Check all included tags value with sync mode\n");
196 
197 	mte_restore_setup();
198 	ksft_print_cnts();
199 	return ksft_get_fail_cnt() == 0 ? KSFT_PASS : KSFT_FAIL;
200 }
201