1 // SPDX-License-Identifier: GPL-2.0 2 // Copyright (C) 2020 ARM Limited 3 4 #define _GNU_SOURCE 5 6 #include <errno.h> 7 #include <pthread.h> 8 #include <stdint.h> 9 #include <stdio.h> 10 #include <stdlib.h> 11 #include <time.h> 12 #include <unistd.h> 13 #include <sys/auxv.h> 14 #include <sys/mman.h> 15 #include <sys/prctl.h> 16 #include <sys/types.h> 17 #include <sys/wait.h> 18 19 #include "kselftest.h" 20 #include "mte_common_util.h" 21 22 #include "mte_def.h" 23 24 #define NUM_ITERATIONS 1024 25 #define MAX_THREADS 5 26 #define THREAD_ITERATIONS 1000 27 28 void *execute_thread(void *x) 29 { 30 pid_t pid = *((pid_t *)x); 31 pid_t tid = gettid(); 32 uint64_t prctl_tag_mask; 33 uint64_t prctl_set; 34 uint64_t prctl_get; 35 uint64_t prctl_tcf; 36 37 srand(time(NULL) ^ (pid << 16) ^ (tid << 16)); 38 39 prctl_tag_mask = rand() & 0xffff; 40 41 if (prctl_tag_mask % 2) 42 prctl_tcf = PR_MTE_TCF_SYNC; 43 else 44 prctl_tcf = PR_MTE_TCF_ASYNC; 45 46 prctl_set = PR_TAGGED_ADDR_ENABLE | prctl_tcf | (prctl_tag_mask << PR_MTE_TAG_SHIFT); 47 48 for (int j = 0; j < THREAD_ITERATIONS; j++) { 49 if (prctl(PR_SET_TAGGED_ADDR_CTRL, prctl_set, 0, 0, 0)) { 50 perror("prctl() failed"); 51 goto fail; 52 } 53 54 prctl_get = prctl(PR_GET_TAGGED_ADDR_CTRL, 0, 0, 0, 0); 55 56 if (prctl_set != prctl_get) { 57 ksft_print_msg("Error: prctl_set: 0x%lx != prctl_get: 0x%lx\n", 58 prctl_set, prctl_get); 59 goto fail; 60 } 61 } 62 63 return (void *)KSFT_PASS; 64 65 fail: 66 return (void *)KSFT_FAIL; 67 } 68 69 int execute_test(pid_t pid) 70 { 71 pthread_t thread_id[MAX_THREADS]; 72 int thread_data[MAX_THREADS]; 73 74 for (int i = 0; i < MAX_THREADS; i++) 75 pthread_create(&thread_id[i], NULL, 76 execute_thread, (void *)&pid); 77 78 for (int i = 0; i < MAX_THREADS; i++) 79 pthread_join(thread_id[i], (void *)&thread_data[i]); 80 81 for (int i = 0; i < MAX_THREADS; i++) 82 if (thread_data[i] == KSFT_FAIL) 83 return KSFT_FAIL; 84 85 return KSFT_PASS; 86 } 87 88 int mte_gcr_fork_test(void) 89 { 90 pid_t pid; 91 int results[NUM_ITERATIONS]; 92 pid_t cpid; 93 int res; 94 95 for (int i = 0; i < NUM_ITERATIONS; i++) { 96 pid = fork(); 97 98 if (pid < 0) 99 return KSFT_FAIL; 100 101 if (pid == 0) { 102 cpid = getpid(); 103 104 res = execute_test(cpid); 105 106 exit(res); 107 } 108 } 109 110 for (int i = 0; i < NUM_ITERATIONS; i++) { 111 wait(&res); 112 113 if (WIFEXITED(res)) 114 results[i] = WEXITSTATUS(res); 115 else 116 --i; 117 } 118 119 for (int i = 0; i < NUM_ITERATIONS; i++) 120 if (results[i] == KSFT_FAIL) 121 return KSFT_FAIL; 122 123 return KSFT_PASS; 124 } 125 126 int main(int argc, char *argv[]) 127 { 128 int err; 129 130 err = mte_default_setup(); 131 if (err) 132 return err; 133 134 ksft_set_plan(1); 135 136 evaluate_test(mte_gcr_fork_test(), 137 "Verify that GCR_EL1 is set correctly on context switch\n"); 138 139 mte_restore_setup(); 140 ksft_print_cnts(); 141 142 return ksft_get_fail_cnt() == 0 ? KSFT_PASS : KSFT_FAIL; 143 } 144