1 // SPDX-License-Identifier: GPL-2.0 2 3 /* 4 * Test key rotation for TFO. 5 * New keys are 'rotated' in two steps: 6 * 1) Add new key as the 'backup' key 'behind' the primary key 7 * 2) Make new key the primary by swapping the backup and primary keys 8 * 9 * The rotation is done in stages using multiple sockets bound 10 * to the same port via SO_REUSEPORT. This simulates key rotation 11 * behind say a load balancer. We verify that across the rotation 12 * there are no cases in which a cookie is not accepted by verifying 13 * that TcpExtTCPFastOpenPassiveFail remains 0. 14 */ 15 #define _GNU_SOURCE 16 #include <arpa/inet.h> 17 #include <errno.h> 18 #include <error.h> 19 #include <stdbool.h> 20 #include <stdio.h> 21 #include <stdlib.h> 22 #include <string.h> 23 #include <sys/epoll.h> 24 #include <unistd.h> 25 #include <netinet/tcp.h> 26 #include <fcntl.h> 27 #include <time.h> 28 29 #ifndef TCP_FASTOPEN_KEY 30 #define TCP_FASTOPEN_KEY 33 31 #endif 32 33 #define N_LISTEN 10 34 #define PROC_FASTOPEN_KEY "/proc/sys/net/ipv4/tcp_fastopen_key" 35 #define KEY_LENGTH 16 36 37 #ifndef ARRAY_SIZE 38 #define ARRAY_SIZE(arr) (sizeof(arr) / sizeof((arr)[0])) 39 #endif 40 41 static bool do_ipv6; 42 static bool do_sockopt; 43 static bool do_rotate; 44 static int key_len = KEY_LENGTH; 45 static int rcv_fds[N_LISTEN]; 46 static int proc_fd; 47 static const char *IP4_ADDR = "127.0.0.1"; 48 static const char *IP6_ADDR = "::1"; 49 static const int PORT = 8891; 50 51 static void get_keys(int fd, uint32_t *keys) 52 { 53 char buf[128]; 54 socklen_t len = KEY_LENGTH * 2; 55 56 if (do_sockopt) { 57 if (getsockopt(fd, SOL_TCP, TCP_FASTOPEN_KEY, keys, &len)) 58 error(1, errno, "Unable to get key"); 59 return; 60 } 61 lseek(proc_fd, 0, SEEK_SET); 62 if (read(proc_fd, buf, sizeof(buf)) <= 0) 63 error(1, errno, "Unable to read %s", PROC_FASTOPEN_KEY); 64 if (sscanf(buf, "%x-%x-%x-%x,%x-%x-%x-%x", keys, keys + 1, keys + 2, 65 keys + 3, keys + 4, keys + 5, keys + 6, keys + 7) != 8) 66 error(1, 0, "Unable to parse %s", PROC_FASTOPEN_KEY); 67 } 68 69 static void set_keys(int fd, uint32_t *keys) 70 { 71 char buf[128]; 72 73 if (do_sockopt) { 74 if (setsockopt(fd, SOL_TCP, TCP_FASTOPEN_KEY, keys, 75 key_len)) 76 error(1, errno, "Unable to set key"); 77 return; 78 } 79 if (do_rotate) 80 snprintf(buf, 128, "%08x-%08x-%08x-%08x,%08x-%08x-%08x-%08x", 81 keys[0], keys[1], keys[2], keys[3], keys[4], keys[5], 82 keys[6], keys[7]); 83 else 84 snprintf(buf, 128, "%08x-%08x-%08x-%08x", 85 keys[0], keys[1], keys[2], keys[3]); 86 lseek(proc_fd, 0, SEEK_SET); 87 if (write(proc_fd, buf, sizeof(buf)) <= 0) 88 error(1, errno, "Unable to write %s", PROC_FASTOPEN_KEY); 89 } 90 91 static void build_rcv_fd(int family, int proto, int *rcv_fds) 92 { 93 struct sockaddr_in addr4 = {0}; 94 struct sockaddr_in6 addr6 = {0}; 95 struct sockaddr *addr; 96 int opt = 1, i, sz; 97 int qlen = 100; 98 uint32_t keys[8]; 99 100 switch (family) { 101 case AF_INET: 102 addr4.sin_family = family; 103 addr4.sin_addr.s_addr = htonl(INADDR_ANY); 104 addr4.sin_port = htons(PORT); 105 sz = sizeof(addr4); 106 addr = (struct sockaddr *)&addr4; 107 break; 108 case AF_INET6: 109 addr6.sin6_family = AF_INET6; 110 addr6.sin6_addr = in6addr_any; 111 addr6.sin6_port = htons(PORT); 112 sz = sizeof(addr6); 113 addr = (struct sockaddr *)&addr6; 114 break; 115 default: 116 error(1, 0, "Unsupported family %d", family); 117 /* clang does not recognize error() above as terminating 118 * the program, so it complains that saddr, sz are 119 * not initialized when this code path is taken. Silence it. 120 */ 121 return; 122 } 123 for (i = 0; i < ARRAY_SIZE(keys); i++) 124 keys[i] = rand(); 125 for (i = 0; i < N_LISTEN; i++) { 126 rcv_fds[i] = socket(family, proto, 0); 127 if (rcv_fds[i] < 0) 128 error(1, errno, "failed to create receive socket"); 129 if (setsockopt(rcv_fds[i], SOL_SOCKET, SO_REUSEPORT, &opt, 130 sizeof(opt))) 131 error(1, errno, "failed to set SO_REUSEPORT"); 132 if (bind(rcv_fds[i], addr, sz)) 133 error(1, errno, "failed to bind receive socket"); 134 if (setsockopt(rcv_fds[i], SOL_TCP, TCP_FASTOPEN, &qlen, 135 sizeof(qlen))) 136 error(1, errno, "failed to set TCP_FASTOPEN"); 137 set_keys(rcv_fds[i], keys); 138 if (proto == SOCK_STREAM && listen(rcv_fds[i], 10)) 139 error(1, errno, "failed to listen on receive port"); 140 } 141 } 142 143 static int connect_and_send(int family, int proto) 144 { 145 struct sockaddr_in saddr4 = {0}; 146 struct sockaddr_in daddr4 = {0}; 147 struct sockaddr_in6 saddr6 = {0}; 148 struct sockaddr_in6 daddr6 = {0}; 149 struct sockaddr *saddr, *daddr; 150 int fd, sz, ret; 151 char data[1]; 152 153 switch (family) { 154 case AF_INET: 155 saddr4.sin_family = AF_INET; 156 saddr4.sin_addr.s_addr = htonl(INADDR_ANY); 157 saddr4.sin_port = 0; 158 159 daddr4.sin_family = AF_INET; 160 if (!inet_pton(family, IP4_ADDR, &daddr4.sin_addr.s_addr)) 161 error(1, errno, "inet_pton failed: %s", IP4_ADDR); 162 daddr4.sin_port = htons(PORT); 163 164 sz = sizeof(saddr4); 165 saddr = (struct sockaddr *)&saddr4; 166 daddr = (struct sockaddr *)&daddr4; 167 break; 168 case AF_INET6: 169 saddr6.sin6_family = AF_INET6; 170 saddr6.sin6_addr = in6addr_any; 171 172 daddr6.sin6_family = AF_INET6; 173 if (!inet_pton(family, IP6_ADDR, &daddr6.sin6_addr)) 174 error(1, errno, "inet_pton failed: %s", IP6_ADDR); 175 daddr6.sin6_port = htons(PORT); 176 177 sz = sizeof(saddr6); 178 saddr = (struct sockaddr *)&saddr6; 179 daddr = (struct sockaddr *)&daddr6; 180 break; 181 default: 182 error(1, 0, "Unsupported family %d", family); 183 /* clang does not recognize error() above as terminating 184 * the program, so it complains that saddr, daddr, sz are 185 * not initialized when this code path is taken. Silence it. 186 */ 187 return -1; 188 } 189 fd = socket(family, proto, 0); 190 if (fd < 0) 191 error(1, errno, "failed to create send socket"); 192 if (bind(fd, saddr, sz)) 193 error(1, errno, "failed to bind send socket"); 194 data[0] = 'a'; 195 ret = sendto(fd, data, 1, MSG_FASTOPEN, daddr, sz); 196 if (ret != 1) 197 error(1, errno, "failed to sendto"); 198 199 return fd; 200 } 201 202 static bool is_listen_fd(int fd) 203 { 204 int i; 205 206 for (i = 0; i < N_LISTEN; i++) { 207 if (rcv_fds[i] == fd) 208 return true; 209 } 210 return false; 211 } 212 213 static void rotate_key(int fd) 214 { 215 static int iter; 216 static uint32_t new_key[4]; 217 uint32_t keys[8]; 218 uint32_t tmp_key[4]; 219 int i; 220 221 if (iter < N_LISTEN) { 222 /* first set new key as backups */ 223 if (iter == 0) { 224 for (i = 0; i < ARRAY_SIZE(new_key); i++) 225 new_key[i] = rand(); 226 } 227 get_keys(fd, keys); 228 memcpy(keys + 4, new_key, KEY_LENGTH); 229 set_keys(fd, keys); 230 } else { 231 /* swap the keys */ 232 get_keys(fd, keys); 233 memcpy(tmp_key, keys + 4, KEY_LENGTH); 234 memcpy(keys + 4, keys, KEY_LENGTH); 235 memcpy(keys, tmp_key, KEY_LENGTH); 236 set_keys(fd, keys); 237 } 238 if (++iter >= (N_LISTEN * 2)) 239 iter = 0; 240 } 241 242 static void run_one_test(int family) 243 { 244 struct epoll_event ev; 245 int i, send_fd; 246 int n_loops = 10000; 247 int rotate_key_fd = 0; 248 int key_rotate_interval = 50; 249 int fd, epfd; 250 char buf[1]; 251 252 build_rcv_fd(family, SOCK_STREAM, rcv_fds); 253 epfd = epoll_create(1); 254 if (epfd < 0) 255 error(1, errno, "failed to create epoll"); 256 ev.events = EPOLLIN; 257 for (i = 0; i < N_LISTEN; i++) { 258 ev.data.fd = rcv_fds[i]; 259 if (epoll_ctl(epfd, EPOLL_CTL_ADD, rcv_fds[i], &ev)) 260 error(1, errno, "failed to register sock epoll"); 261 } 262 while (n_loops--) { 263 send_fd = connect_and_send(family, SOCK_STREAM); 264 if (do_rotate && ((n_loops % key_rotate_interval) == 0)) { 265 rotate_key(rcv_fds[rotate_key_fd]); 266 if (++rotate_key_fd >= N_LISTEN) 267 rotate_key_fd = 0; 268 } 269 while (1) { 270 i = epoll_wait(epfd, &ev, 1, -1); 271 if (i < 0) 272 error(1, errno, "epoll_wait failed"); 273 if (is_listen_fd(ev.data.fd)) { 274 fd = accept(ev.data.fd, NULL, NULL); 275 if (fd < 0) 276 error(1, errno, "failed to accept"); 277 ev.data.fd = fd; 278 if (epoll_ctl(epfd, EPOLL_CTL_ADD, fd, &ev)) 279 error(1, errno, "failed epoll add"); 280 continue; 281 } 282 i = recv(ev.data.fd, buf, sizeof(buf), 0); 283 if (i != 1) 284 error(1, errno, "failed recv data"); 285 if (epoll_ctl(epfd, EPOLL_CTL_DEL, ev.data.fd, NULL)) 286 error(1, errno, "failed epoll del"); 287 close(ev.data.fd); 288 break; 289 } 290 close(send_fd); 291 } 292 for (i = 0; i < N_LISTEN; i++) 293 close(rcv_fds[i]); 294 } 295 296 static void parse_opts(int argc, char **argv) 297 { 298 int c; 299 300 while ((c = getopt(argc, argv, "46sr")) != -1) { 301 switch (c) { 302 case '4': 303 do_ipv6 = false; 304 break; 305 case '6': 306 do_ipv6 = true; 307 break; 308 case 's': 309 do_sockopt = true; 310 break; 311 case 'r': 312 do_rotate = true; 313 key_len = KEY_LENGTH * 2; 314 break; 315 default: 316 error(1, 0, "%s: parse error", argv[0]); 317 } 318 } 319 } 320 321 int main(int argc, char **argv) 322 { 323 parse_opts(argc, argv); 324 proc_fd = open(PROC_FASTOPEN_KEY, O_RDWR); 325 if (proc_fd < 0) 326 error(1, errno, "Unable to open %s", PROC_FASTOPEN_KEY); 327 srand(time(NULL)); 328 if (do_ipv6) 329 run_one_test(AF_INET6); 330 else 331 run_one_test(AF_INET); 332 close(proc_fd); 333 fprintf(stderr, "PASS\n"); 334 return 0; 335 } 336