1 // SPDX-License-Identifier: GPL-2.0 2 // Copyright (C) 2020 ARM Limited 3 4 #define _GNU_SOURCE 5 6 #include <errno.h> 7 #include <pthread.h> 8 #include <stdint.h> 9 #include <stdio.h> 10 #include <stdlib.h> 11 #include <time.h> 12 #include <unistd.h> 13 #include <sys/auxv.h> 14 #include <sys/mman.h> 15 #include <sys/prctl.h> 16 #include <sys/types.h> 17 #include <sys/wait.h> 18 19 #include "kselftest.h" 20 #include "mte_common_util.h" 21 22 #define PR_SET_TAGGED_ADDR_CTRL 55 23 #define PR_GET_TAGGED_ADDR_CTRL 56 24 # define PR_TAGGED_ADDR_ENABLE (1UL << 0) 25 # define PR_MTE_TCF_SHIFT 1 26 # define PR_MTE_TCF_NONE (0UL << PR_MTE_TCF_SHIFT) 27 # define PR_MTE_TCF_SYNC (1UL << PR_MTE_TCF_SHIFT) 28 # define PR_MTE_TCF_ASYNC (2UL << PR_MTE_TCF_SHIFT) 29 # define PR_MTE_TCF_MASK (3UL << PR_MTE_TCF_SHIFT) 30 # define PR_MTE_TAG_SHIFT 3 31 # define PR_MTE_TAG_MASK (0xffffUL << PR_MTE_TAG_SHIFT) 32 33 #include "mte_def.h" 34 35 #define NUM_ITERATIONS 1024 36 #define MAX_THREADS 5 37 #define THREAD_ITERATIONS 1000 38 39 void *execute_thread(void *x) 40 { 41 pid_t pid = *((pid_t *)x); 42 pid_t tid = gettid(); 43 uint64_t prctl_tag_mask; 44 uint64_t prctl_set; 45 uint64_t prctl_get; 46 uint64_t prctl_tcf; 47 48 srand(time(NULL) ^ (pid << 16) ^ (tid << 16)); 49 50 prctl_tag_mask = rand() & 0xffff; 51 52 if (prctl_tag_mask % 2) 53 prctl_tcf = PR_MTE_TCF_SYNC; 54 else 55 prctl_tcf = PR_MTE_TCF_ASYNC; 56 57 prctl_set = PR_TAGGED_ADDR_ENABLE | prctl_tcf | (prctl_tag_mask << PR_MTE_TAG_SHIFT); 58 59 for (int j = 0; j < THREAD_ITERATIONS; j++) { 60 if (prctl(PR_SET_TAGGED_ADDR_CTRL, prctl_set, 0, 0, 0)) { 61 perror("prctl() failed"); 62 goto fail; 63 } 64 65 prctl_get = prctl(PR_GET_TAGGED_ADDR_CTRL, 0, 0, 0, 0); 66 67 if (prctl_set != prctl_get) { 68 ksft_print_msg("Error: prctl_set: 0x%lx != prctl_get: 0x%lx\n", 69 prctl_set, prctl_get); 70 goto fail; 71 } 72 } 73 74 return (void *)KSFT_PASS; 75 76 fail: 77 return (void *)KSFT_FAIL; 78 } 79 80 int execute_test(pid_t pid) 81 { 82 pthread_t thread_id[MAX_THREADS]; 83 int thread_data[MAX_THREADS]; 84 85 for (int i = 0; i < MAX_THREADS; i++) 86 pthread_create(&thread_id[i], NULL, 87 execute_thread, (void *)&pid); 88 89 for (int i = 0; i < MAX_THREADS; i++) 90 pthread_join(thread_id[i], (void *)&thread_data[i]); 91 92 for (int i = 0; i < MAX_THREADS; i++) 93 if (thread_data[i] == KSFT_FAIL) 94 return KSFT_FAIL; 95 96 return KSFT_PASS; 97 } 98 99 int mte_gcr_fork_test(void) 100 { 101 pid_t pid; 102 int results[NUM_ITERATIONS]; 103 pid_t cpid; 104 int res; 105 106 for (int i = 0; i < NUM_ITERATIONS; i++) { 107 pid = fork(); 108 109 if (pid < 0) 110 return KSFT_FAIL; 111 112 if (pid == 0) { 113 cpid = getpid(); 114 115 res = execute_test(cpid); 116 117 exit(res); 118 } 119 } 120 121 for (int i = 0; i < NUM_ITERATIONS; i++) { 122 wait(&res); 123 124 if (WIFEXITED(res)) 125 results[i] = WEXITSTATUS(res); 126 else 127 --i; 128 } 129 130 for (int i = 0; i < NUM_ITERATIONS; i++) 131 if (results[i] == KSFT_FAIL) 132 return KSFT_FAIL; 133 134 return KSFT_PASS; 135 } 136 137 int main(int argc, char *argv[]) 138 { 139 int err; 140 141 err = mte_default_setup(); 142 if (err) 143 return err; 144 145 ksft_set_plan(1); 146 147 evaluate_test(mte_gcr_fork_test(), 148 "Verify that GCR_EL1 is set correctly on context switch\n"); 149 150 mte_restore_setup(); 151 ksft_print_cnts(); 152 153 return ksft_get_fail_cnt() == 0 ? KSFT_PASS : KSFT_FAIL; 154 } 155