1 // SPDX-License-Identifier: GPL-2.0 2 3 /* 4 * Test key rotation for TFO. 5 * New keys are 'rotated' in two steps: 6 * 1) Add new key as the 'backup' key 'behind' the primary key 7 * 2) Make new key the primary by swapping the backup and primary keys 8 * 9 * The rotation is done in stages using multiple sockets bound 10 * to the same port via SO_REUSEPORT. This simulates key rotation 11 * behind say a load balancer. We verify that across the rotation 12 * there are no cases in which a cookie is not accepted by verifying 13 * that TcpExtTCPFastOpenPassiveFail remains 0. 14 */ 15 #define _GNU_SOURCE 16 #include <arpa/inet.h> 17 #include <errno.h> 18 #include <error.h> 19 #include <stdbool.h> 20 #include <stdio.h> 21 #include <stdlib.h> 22 #include <string.h> 23 #include <sys/epoll.h> 24 #include <unistd.h> 25 #include <netinet/tcp.h> 26 #include <fcntl.h> 27 #include <time.h> 28 29 #include "../kselftest.h" 30 31 #ifndef TCP_FASTOPEN_KEY 32 #define TCP_FASTOPEN_KEY 33 33 #endif 34 35 #define N_LISTEN 10 36 #define PROC_FASTOPEN_KEY "/proc/sys/net/ipv4/tcp_fastopen_key" 37 #define KEY_LENGTH 16 38 39 static bool do_ipv6; 40 static bool do_sockopt; 41 static bool do_rotate; 42 static int key_len = KEY_LENGTH; 43 static int rcv_fds[N_LISTEN]; 44 static int proc_fd; 45 static const char *IP4_ADDR = "127.0.0.1"; 46 static const char *IP6_ADDR = "::1"; 47 static const int PORT = 8891; 48 49 static void get_keys(int fd, uint32_t *keys) 50 { 51 char buf[128]; 52 socklen_t len = KEY_LENGTH * 2; 53 54 if (do_sockopt) { 55 if (getsockopt(fd, SOL_TCP, TCP_FASTOPEN_KEY, keys, &len)) 56 error(1, errno, "Unable to get key"); 57 return; 58 } 59 lseek(proc_fd, 0, SEEK_SET); 60 if (read(proc_fd, buf, sizeof(buf)) <= 0) 61 error(1, errno, "Unable to read %s", PROC_FASTOPEN_KEY); 62 if (sscanf(buf, "%x-%x-%x-%x,%x-%x-%x-%x", keys, keys + 1, keys + 2, 63 keys + 3, keys + 4, keys + 5, keys + 6, keys + 7) != 8) 64 error(1, 0, "Unable to parse %s", PROC_FASTOPEN_KEY); 65 } 66 67 static void set_keys(int fd, uint32_t *keys) 68 { 69 char buf[128]; 70 71 if (do_sockopt) { 72 if (setsockopt(fd, SOL_TCP, TCP_FASTOPEN_KEY, keys, 73 key_len)) 74 error(1, errno, "Unable to set key"); 75 return; 76 } 77 if (do_rotate) 78 snprintf(buf, 128, "%08x-%08x-%08x-%08x,%08x-%08x-%08x-%08x", 79 keys[0], keys[1], keys[2], keys[3], keys[4], keys[5], 80 keys[6], keys[7]); 81 else 82 snprintf(buf, 128, "%08x-%08x-%08x-%08x", 83 keys[0], keys[1], keys[2], keys[3]); 84 lseek(proc_fd, 0, SEEK_SET); 85 if (write(proc_fd, buf, sizeof(buf)) <= 0) 86 error(1, errno, "Unable to write %s", PROC_FASTOPEN_KEY); 87 } 88 89 static void build_rcv_fd(int family, int proto, int *rcv_fds) 90 { 91 struct sockaddr_in addr4 = {0}; 92 struct sockaddr_in6 addr6 = {0}; 93 struct sockaddr *addr; 94 int opt = 1, i, sz; 95 int qlen = 100; 96 uint32_t keys[8]; 97 98 switch (family) { 99 case AF_INET: 100 addr4.sin_family = family; 101 addr4.sin_addr.s_addr = htonl(INADDR_ANY); 102 addr4.sin_port = htons(PORT); 103 sz = sizeof(addr4); 104 addr = (struct sockaddr *)&addr4; 105 break; 106 case AF_INET6: 107 addr6.sin6_family = AF_INET6; 108 addr6.sin6_addr = in6addr_any; 109 addr6.sin6_port = htons(PORT); 110 sz = sizeof(addr6); 111 addr = (struct sockaddr *)&addr6; 112 break; 113 default: 114 error(1, 0, "Unsupported family %d", family); 115 /* clang does not recognize error() above as terminating 116 * the program, so it complains that saddr, sz are 117 * not initialized when this code path is taken. Silence it. 118 */ 119 return; 120 } 121 for (i = 0; i < ARRAY_SIZE(keys); i++) 122 keys[i] = rand(); 123 for (i = 0; i < N_LISTEN; i++) { 124 rcv_fds[i] = socket(family, proto, 0); 125 if (rcv_fds[i] < 0) 126 error(1, errno, "failed to create receive socket"); 127 if (setsockopt(rcv_fds[i], SOL_SOCKET, SO_REUSEPORT, &opt, 128 sizeof(opt))) 129 error(1, errno, "failed to set SO_REUSEPORT"); 130 if (bind(rcv_fds[i], addr, sz)) 131 error(1, errno, "failed to bind receive socket"); 132 if (setsockopt(rcv_fds[i], SOL_TCP, TCP_FASTOPEN, &qlen, 133 sizeof(qlen))) 134 error(1, errno, "failed to set TCP_FASTOPEN"); 135 set_keys(rcv_fds[i], keys); 136 if (proto == SOCK_STREAM && listen(rcv_fds[i], 10)) 137 error(1, errno, "failed to listen on receive port"); 138 } 139 } 140 141 static int connect_and_send(int family, int proto) 142 { 143 struct sockaddr_in saddr4 = {0}; 144 struct sockaddr_in daddr4 = {0}; 145 struct sockaddr_in6 saddr6 = {0}; 146 struct sockaddr_in6 daddr6 = {0}; 147 struct sockaddr *saddr, *daddr; 148 int fd, sz, ret; 149 char data[1]; 150 151 switch (family) { 152 case AF_INET: 153 saddr4.sin_family = AF_INET; 154 saddr4.sin_addr.s_addr = htonl(INADDR_ANY); 155 saddr4.sin_port = 0; 156 157 daddr4.sin_family = AF_INET; 158 if (!inet_pton(family, IP4_ADDR, &daddr4.sin_addr.s_addr)) 159 error(1, errno, "inet_pton failed: %s", IP4_ADDR); 160 daddr4.sin_port = htons(PORT); 161 162 sz = sizeof(saddr4); 163 saddr = (struct sockaddr *)&saddr4; 164 daddr = (struct sockaddr *)&daddr4; 165 break; 166 case AF_INET6: 167 saddr6.sin6_family = AF_INET6; 168 saddr6.sin6_addr = in6addr_any; 169 170 daddr6.sin6_family = AF_INET6; 171 if (!inet_pton(family, IP6_ADDR, &daddr6.sin6_addr)) 172 error(1, errno, "inet_pton failed: %s", IP6_ADDR); 173 daddr6.sin6_port = htons(PORT); 174 175 sz = sizeof(saddr6); 176 saddr = (struct sockaddr *)&saddr6; 177 daddr = (struct sockaddr *)&daddr6; 178 break; 179 default: 180 error(1, 0, "Unsupported family %d", family); 181 /* clang does not recognize error() above as terminating 182 * the program, so it complains that saddr, daddr, sz are 183 * not initialized when this code path is taken. Silence it. 184 */ 185 return -1; 186 } 187 fd = socket(family, proto, 0); 188 if (fd < 0) 189 error(1, errno, "failed to create send socket"); 190 if (bind(fd, saddr, sz)) 191 error(1, errno, "failed to bind send socket"); 192 data[0] = 'a'; 193 ret = sendto(fd, data, 1, MSG_FASTOPEN, daddr, sz); 194 if (ret != 1) 195 error(1, errno, "failed to sendto"); 196 197 return fd; 198 } 199 200 static bool is_listen_fd(int fd) 201 { 202 int i; 203 204 for (i = 0; i < N_LISTEN; i++) { 205 if (rcv_fds[i] == fd) 206 return true; 207 } 208 return false; 209 } 210 211 static void rotate_key(int fd) 212 { 213 static int iter; 214 static uint32_t new_key[4]; 215 uint32_t keys[8]; 216 uint32_t tmp_key[4]; 217 int i; 218 219 if (iter < N_LISTEN) { 220 /* first set new key as backups */ 221 if (iter == 0) { 222 for (i = 0; i < ARRAY_SIZE(new_key); i++) 223 new_key[i] = rand(); 224 } 225 get_keys(fd, keys); 226 memcpy(keys + 4, new_key, KEY_LENGTH); 227 set_keys(fd, keys); 228 } else { 229 /* swap the keys */ 230 get_keys(fd, keys); 231 memcpy(tmp_key, keys + 4, KEY_LENGTH); 232 memcpy(keys + 4, keys, KEY_LENGTH); 233 memcpy(keys, tmp_key, KEY_LENGTH); 234 set_keys(fd, keys); 235 } 236 if (++iter >= (N_LISTEN * 2)) 237 iter = 0; 238 } 239 240 static void run_one_test(int family) 241 { 242 struct epoll_event ev; 243 int i, send_fd; 244 int n_loops = 10000; 245 int rotate_key_fd = 0; 246 int key_rotate_interval = 50; 247 int fd, epfd; 248 char buf[1]; 249 250 build_rcv_fd(family, SOCK_STREAM, rcv_fds); 251 epfd = epoll_create(1); 252 if (epfd < 0) 253 error(1, errno, "failed to create epoll"); 254 ev.events = EPOLLIN; 255 for (i = 0; i < N_LISTEN; i++) { 256 ev.data.fd = rcv_fds[i]; 257 if (epoll_ctl(epfd, EPOLL_CTL_ADD, rcv_fds[i], &ev)) 258 error(1, errno, "failed to register sock epoll"); 259 } 260 while (n_loops--) { 261 send_fd = connect_and_send(family, SOCK_STREAM); 262 if (do_rotate && ((n_loops % key_rotate_interval) == 0)) { 263 rotate_key(rcv_fds[rotate_key_fd]); 264 if (++rotate_key_fd >= N_LISTEN) 265 rotate_key_fd = 0; 266 } 267 while (1) { 268 i = epoll_wait(epfd, &ev, 1, -1); 269 if (i < 0) 270 error(1, errno, "epoll_wait failed"); 271 if (is_listen_fd(ev.data.fd)) { 272 fd = accept(ev.data.fd, NULL, NULL); 273 if (fd < 0) 274 error(1, errno, "failed to accept"); 275 ev.data.fd = fd; 276 if (epoll_ctl(epfd, EPOLL_CTL_ADD, fd, &ev)) 277 error(1, errno, "failed epoll add"); 278 continue; 279 } 280 i = recv(ev.data.fd, buf, sizeof(buf), 0); 281 if (i != 1) 282 error(1, errno, "failed recv data"); 283 if (epoll_ctl(epfd, EPOLL_CTL_DEL, ev.data.fd, NULL)) 284 error(1, errno, "failed epoll del"); 285 close(ev.data.fd); 286 break; 287 } 288 close(send_fd); 289 } 290 for (i = 0; i < N_LISTEN; i++) 291 close(rcv_fds[i]); 292 } 293 294 static void parse_opts(int argc, char **argv) 295 { 296 int c; 297 298 while ((c = getopt(argc, argv, "46sr")) != -1) { 299 switch (c) { 300 case '4': 301 do_ipv6 = false; 302 break; 303 case '6': 304 do_ipv6 = true; 305 break; 306 case 's': 307 do_sockopt = true; 308 break; 309 case 'r': 310 do_rotate = true; 311 key_len = KEY_LENGTH * 2; 312 break; 313 default: 314 error(1, 0, "%s: parse error", argv[0]); 315 } 316 } 317 } 318 319 int main(int argc, char **argv) 320 { 321 parse_opts(argc, argv); 322 proc_fd = open(PROC_FASTOPEN_KEY, O_RDWR); 323 if (proc_fd < 0) 324 error(1, errno, "Unable to open %s", PROC_FASTOPEN_KEY); 325 srand(time(NULL)); 326 if (do_ipv6) 327 run_one_test(AF_INET6); 328 else 329 run_one_test(AF_INET); 330 close(proc_fd); 331 fprintf(stderr, "PASS\n"); 332 return 0; 333 } 334