1 // SPDX-License-Identifier: GPL-2.0 2 #include <test_progs.h> 3 #include "cgroup_helpers.h" 4 5 struct tcp_rtt_storage { 6 __u32 invoked; 7 __u32 dsack_dups; 8 __u32 delivered; 9 __u32 delivered_ce; 10 __u32 icsk_retransmits; 11 }; 12 13 static void send_byte(int fd) 14 { 15 char b = 0x55; 16 17 if (CHECK_FAIL(write(fd, &b, sizeof(b)) != 1)) 18 perror("Failed to send single byte"); 19 } 20 21 static int wait_for_ack(int fd, int retries) 22 { 23 struct tcp_info info; 24 socklen_t optlen; 25 int i, err; 26 27 for (i = 0; i < retries; i++) { 28 optlen = sizeof(info); 29 err = getsockopt(fd, SOL_TCP, TCP_INFO, &info, &optlen); 30 if (err < 0) { 31 log_err("Failed to lookup TCP stats"); 32 return err; 33 } 34 35 if (info.tcpi_unacked == 0) 36 return 0; 37 38 usleep(10); 39 } 40 41 log_err("Did not receive ACK"); 42 return -1; 43 } 44 45 static int verify_sk(int map_fd, int client_fd, const char *msg, __u32 invoked, 46 __u32 dsack_dups, __u32 delivered, __u32 delivered_ce, 47 __u32 icsk_retransmits) 48 { 49 int err = 0; 50 struct tcp_rtt_storage val; 51 52 if (CHECK_FAIL(bpf_map_lookup_elem(map_fd, &client_fd, &val) < 0)) { 53 perror("Failed to read socket storage"); 54 return -1; 55 } 56 57 if (val.invoked != invoked) { 58 log_err("%s: unexpected bpf_tcp_sock.invoked %d != %d", 59 msg, val.invoked, invoked); 60 err++; 61 } 62 63 if (val.dsack_dups != dsack_dups) { 64 log_err("%s: unexpected bpf_tcp_sock.dsack_dups %d != %d", 65 msg, val.dsack_dups, dsack_dups); 66 err++; 67 } 68 69 if (val.delivered != delivered) { 70 log_err("%s: unexpected bpf_tcp_sock.delivered %d != %d", 71 msg, val.delivered, delivered); 72 err++; 73 } 74 75 if (val.delivered_ce != delivered_ce) { 76 log_err("%s: unexpected bpf_tcp_sock.delivered_ce %d != %d", 77 msg, val.delivered_ce, delivered_ce); 78 err++; 79 } 80 81 if (val.icsk_retransmits != icsk_retransmits) { 82 log_err("%s: unexpected bpf_tcp_sock.icsk_retransmits %d != %d", 83 msg, val.icsk_retransmits, icsk_retransmits); 84 err++; 85 } 86 87 return err; 88 } 89 90 static int connect_to_server(int server_fd) 91 { 92 struct sockaddr_storage addr; 93 socklen_t len = sizeof(addr); 94 int fd; 95 96 fd = socket(AF_INET, SOCK_STREAM, 0); 97 if (fd < 0) { 98 log_err("Failed to create client socket"); 99 return -1; 100 } 101 102 if (getsockname(server_fd, (struct sockaddr *)&addr, &len)) { 103 log_err("Failed to get server addr"); 104 goto out; 105 } 106 107 if (connect(fd, (const struct sockaddr *)&addr, len) < 0) { 108 log_err("Fail to connect to server"); 109 goto out; 110 } 111 112 return fd; 113 114 out: 115 close(fd); 116 return -1; 117 } 118 119 static int run_test(int cgroup_fd, int server_fd) 120 { 121 struct bpf_prog_load_attr attr = { 122 .prog_type = BPF_PROG_TYPE_SOCK_OPS, 123 .file = "./tcp_rtt.o", 124 .expected_attach_type = BPF_CGROUP_SOCK_OPS, 125 }; 126 struct bpf_object *obj; 127 struct bpf_map *map; 128 int client_fd; 129 int prog_fd; 130 int map_fd; 131 int err; 132 133 err = bpf_prog_load_xattr(&attr, &obj, &prog_fd); 134 if (err) { 135 log_err("Failed to load BPF object"); 136 return -1; 137 } 138 139 map = bpf_map__next(NULL, obj); 140 map_fd = bpf_map__fd(map); 141 142 err = bpf_prog_attach(prog_fd, cgroup_fd, BPF_CGROUP_SOCK_OPS, 0); 143 if (err) { 144 log_err("Failed to attach BPF program"); 145 goto close_bpf_object; 146 } 147 148 client_fd = connect_to_server(server_fd); 149 if (client_fd < 0) { 150 err = -1; 151 goto close_bpf_object; 152 } 153 154 err += verify_sk(map_fd, client_fd, "syn-ack", 155 /*invoked=*/1, 156 /*dsack_dups=*/0, 157 /*delivered=*/1, 158 /*delivered_ce=*/0, 159 /*icsk_retransmits=*/0); 160 161 send_byte(client_fd); 162 if (wait_for_ack(client_fd, 100) < 0) { 163 err = -1; 164 goto close_client_fd; 165 } 166 167 168 err += verify_sk(map_fd, client_fd, "first payload byte", 169 /*invoked=*/2, 170 /*dsack_dups=*/0, 171 /*delivered=*/2, 172 /*delivered_ce=*/0, 173 /*icsk_retransmits=*/0); 174 175 close_client_fd: 176 close(client_fd); 177 178 close_bpf_object: 179 bpf_object__close(obj); 180 return err; 181 } 182 183 static int start_server(void) 184 { 185 struct sockaddr_in addr = { 186 .sin_family = AF_INET, 187 .sin_addr.s_addr = htonl(INADDR_LOOPBACK), 188 }; 189 int fd; 190 191 fd = socket(AF_INET, SOCK_STREAM, 0); 192 if (fd < 0) { 193 log_err("Failed to create server socket"); 194 return -1; 195 } 196 197 if (bind(fd, (const struct sockaddr *)&addr, sizeof(addr)) < 0) { 198 log_err("Failed to bind socket"); 199 close(fd); 200 return -1; 201 } 202 203 return fd; 204 } 205 206 static pthread_mutex_t server_started_mtx = PTHREAD_MUTEX_INITIALIZER; 207 static pthread_cond_t server_started = PTHREAD_COND_INITIALIZER; 208 209 static void *server_thread(void *arg) 210 { 211 struct sockaddr_storage addr; 212 socklen_t len = sizeof(addr); 213 int fd = *(int *)arg; 214 int client_fd; 215 int err; 216 217 err = listen(fd, 1); 218 219 pthread_mutex_lock(&server_started_mtx); 220 pthread_cond_signal(&server_started); 221 pthread_mutex_unlock(&server_started_mtx); 222 223 if (CHECK_FAIL(err < 0)) { 224 perror("Failed to listed on socket"); 225 return NULL; 226 } 227 228 client_fd = accept(fd, (struct sockaddr *)&addr, &len); 229 if (CHECK_FAIL(client_fd < 0)) { 230 perror("Failed to accept client"); 231 return NULL; 232 } 233 234 /* Wait for the next connection (that never arrives) 235 * to keep this thread alive to prevent calling 236 * close() on client_fd. 237 */ 238 if (CHECK_FAIL(accept(fd, (struct sockaddr *)&addr, &len) >= 0)) { 239 perror("Unexpected success in second accept"); 240 return NULL; 241 } 242 243 close(client_fd); 244 245 return NULL; 246 } 247 248 void test_tcp_rtt(void) 249 { 250 int server_fd, cgroup_fd; 251 pthread_t tid; 252 253 cgroup_fd = test__join_cgroup("/tcp_rtt"); 254 if (CHECK_FAIL(cgroup_fd < 0)) 255 return; 256 257 server_fd = start_server(); 258 if (CHECK_FAIL(server_fd < 0)) 259 goto close_cgroup_fd; 260 261 if (CHECK_FAIL(pthread_create(&tid, NULL, server_thread, 262 (void *)&server_fd))) 263 goto close_cgroup_fd; 264 265 pthread_mutex_lock(&server_started_mtx); 266 pthread_cond_wait(&server_started, &server_started_mtx); 267 pthread_mutex_unlock(&server_started_mtx); 268 269 CHECK_FAIL(run_test(cgroup_fd, server_fd)); 270 close(server_fd); 271 close_cgroup_fd: 272 close(cgroup_fd); 273 } 274