1 // SPDX-License-Identifier: GPL-2.0 2 // Copyright (C) 2020 ARM Limited 3 4 #define _GNU_SOURCE 5 6 #include <errno.h> 7 #include <signal.h> 8 #include <stdio.h> 9 #include <stdlib.h> 10 #include <string.h> 11 #include <ucontext.h> 12 #include <sys/wait.h> 13 14 #include "kselftest.h" 15 #include "mte_common_util.h" 16 #include "mte_def.h" 17 18 #define BUFFER_SIZE (5 * MT_GRANULE_SIZE) 19 #define RUNS (MT_TAG_COUNT * 2) 20 #define MTE_LAST_TAG_MASK (0x7FFF) 21 22 static int verify_mte_pointer_validity(char *ptr, int mode) 23 { 24 mte_initialize_current_context(mode, (uintptr_t)ptr, BUFFER_SIZE); 25 /* Check the validity of the tagged pointer */ 26 memset((void *)ptr, '1', BUFFER_SIZE); 27 mte_wait_after_trig(); 28 if (cur_mte_cxt.fault_valid) 29 return KSFT_FAIL; 30 /* Proceed further for nonzero tags */ 31 if (!MT_FETCH_TAG((uintptr_t)ptr)) 32 return KSFT_PASS; 33 mte_initialize_current_context(mode, (uintptr_t)ptr, BUFFER_SIZE + 1); 34 /* Check the validity outside the range */ 35 ptr[BUFFER_SIZE] = '2'; 36 mte_wait_after_trig(); 37 if (!cur_mte_cxt.fault_valid) 38 return KSFT_FAIL; 39 else 40 return KSFT_PASS; 41 } 42 43 static int check_single_included_tags(int mem_type, int mode) 44 { 45 char *ptr; 46 int tag, run, result = KSFT_PASS; 47 48 ptr = (char *)mte_allocate_memory(BUFFER_SIZE + MT_GRANULE_SIZE, mem_type, 0, false); 49 if (check_allocated_memory(ptr, BUFFER_SIZE + MT_GRANULE_SIZE, 50 mem_type, false) != KSFT_PASS) 51 return KSFT_FAIL; 52 53 for (tag = 0; (tag < MT_TAG_COUNT) && (result == KSFT_PASS); tag++) { 54 mte_switch_mode(mode, MT_INCLUDE_VALID_TAG(tag)); 55 /* Try to catch a excluded tag by a number of tries. */ 56 for (run = 0; (run < RUNS) && (result == KSFT_PASS); run++) { 57 ptr = (char *)mte_insert_tags(ptr, BUFFER_SIZE); 58 /* Check tag value */ 59 if (MT_FETCH_TAG((uintptr_t)ptr) == tag) { 60 ksft_print_msg("FAIL: wrong tag = 0x%x with include mask=0x%x\n", 61 MT_FETCH_TAG((uintptr_t)ptr), 62 MT_INCLUDE_VALID_TAG(tag)); 63 result = KSFT_FAIL; 64 break; 65 } 66 result = verify_mte_pointer_validity(ptr, mode); 67 } 68 } 69 mte_free_memory_tag_range((void *)ptr, BUFFER_SIZE, mem_type, 0, MT_GRANULE_SIZE); 70 return result; 71 } 72 73 static int check_multiple_included_tags(int mem_type, int mode) 74 { 75 char *ptr; 76 int tag, run, result = KSFT_PASS; 77 unsigned long excl_mask = 0; 78 79 ptr = (char *)mte_allocate_memory(BUFFER_SIZE + MT_GRANULE_SIZE, mem_type, 0, false); 80 if (check_allocated_memory(ptr, BUFFER_SIZE + MT_GRANULE_SIZE, 81 mem_type, false) != KSFT_PASS) 82 return KSFT_FAIL; 83 84 for (tag = 0; (tag < MT_TAG_COUNT - 1) && (result == KSFT_PASS); tag++) { 85 excl_mask |= 1 << tag; 86 mte_switch_mode(mode, MT_INCLUDE_VALID_TAGS(excl_mask)); 87 /* Try to catch a excluded tag by a number of tries. */ 88 for (run = 0; (run < RUNS) && (result == KSFT_PASS); run++) { 89 ptr = (char *)mte_insert_tags(ptr, BUFFER_SIZE); 90 /* Check tag value */ 91 if (MT_FETCH_TAG((uintptr_t)ptr) < tag) { 92 ksft_print_msg("FAIL: wrong tag = 0x%x with include mask=0x%x\n", 93 MT_FETCH_TAG((uintptr_t)ptr), 94 MT_INCLUDE_VALID_TAGS(excl_mask)); 95 result = KSFT_FAIL; 96 break; 97 } 98 result = verify_mte_pointer_validity(ptr, mode); 99 } 100 } 101 mte_free_memory_tag_range((void *)ptr, BUFFER_SIZE, mem_type, 0, MT_GRANULE_SIZE); 102 return result; 103 } 104 105 static int check_all_included_tags(int mem_type, int mode) 106 { 107 char *ptr; 108 int run, result = KSFT_PASS; 109 110 ptr = (char *)mte_allocate_memory(BUFFER_SIZE + MT_GRANULE_SIZE, mem_type, 0, false); 111 if (check_allocated_memory(ptr, BUFFER_SIZE + MT_GRANULE_SIZE, 112 mem_type, false) != KSFT_PASS) 113 return KSFT_FAIL; 114 115 mte_switch_mode(mode, MT_INCLUDE_TAG_MASK); 116 /* Try to catch a excluded tag by a number of tries. */ 117 for (run = 0; (run < RUNS) && (result == KSFT_PASS); run++) { 118 ptr = (char *)mte_insert_tags(ptr, BUFFER_SIZE); 119 /* 120 * Here tag byte can be between 0x0 to 0xF (full allowed range) 121 * so no need to match so just verify if it is writable. 122 */ 123 result = verify_mte_pointer_validity(ptr, mode); 124 } 125 mte_free_memory_tag_range((void *)ptr, BUFFER_SIZE, mem_type, 0, MT_GRANULE_SIZE); 126 return result; 127 } 128 129 static int check_none_included_tags(int mem_type, int mode) 130 { 131 char *ptr; 132 int run; 133 134 ptr = (char *)mte_allocate_memory(BUFFER_SIZE, mem_type, 0, false); 135 if (check_allocated_memory(ptr, BUFFER_SIZE, mem_type, false) != KSFT_PASS) 136 return KSFT_FAIL; 137 138 mte_switch_mode(mode, MT_EXCLUDE_TAG_MASK); 139 /* Try to catch a excluded tag by a number of tries. */ 140 for (run = 0; run < RUNS; run++) { 141 ptr = (char *)mte_insert_tags(ptr, BUFFER_SIZE); 142 /* Here all tags exluded so tag value generated should be 0 */ 143 if (MT_FETCH_TAG((uintptr_t)ptr)) { 144 ksft_print_msg("FAIL: included tag value found\n"); 145 mte_free_memory((void *)ptr, BUFFER_SIZE, mem_type, true); 146 return KSFT_FAIL; 147 } 148 mte_initialize_current_context(mode, (uintptr_t)ptr, BUFFER_SIZE); 149 /* Check the write validity of the untagged pointer */ 150 memset((void *)ptr, '1', BUFFER_SIZE); 151 mte_wait_after_trig(); 152 if (cur_mte_cxt.fault_valid) 153 break; 154 } 155 mte_free_memory((void *)ptr, BUFFER_SIZE, mem_type, false); 156 if (cur_mte_cxt.fault_valid) 157 return KSFT_FAIL; 158 else 159 return KSFT_PASS; 160 } 161 162 int main(int argc, char *argv[]) 163 { 164 int err; 165 166 err = mte_default_setup(); 167 if (err) 168 return err; 169 170 /* Register SIGSEGV handler */ 171 mte_register_signal(SIGSEGV, mte_default_handler); 172 173 /* Set test plan */ 174 ksft_set_plan(4); 175 176 evaluate_test(check_single_included_tags(USE_MMAP, MTE_SYNC_ERR), 177 "Check an included tag value with sync mode\n"); 178 evaluate_test(check_multiple_included_tags(USE_MMAP, MTE_SYNC_ERR), 179 "Check different included tags value with sync mode\n"); 180 evaluate_test(check_none_included_tags(USE_MMAP, MTE_SYNC_ERR), 181 "Check none included tags value with sync mode\n"); 182 evaluate_test(check_all_included_tags(USE_MMAP, MTE_SYNC_ERR), 183 "Check all included tags value with sync mode\n"); 184 185 mte_restore_setup(); 186 ksft_print_cnts(); 187 return ksft_get_fail_cnt() == 0 ? KSFT_PASS : KSFT_FAIL; 188 } 189