1 // SPDX-License-Identifier: GPL-2.0 2 #include <test_progs.h> 3 #include "cgroup_helpers.h" 4 5 struct tcp_rtt_storage { 6 __u32 invoked; 7 __u32 dsack_dups; 8 __u32 delivered; 9 __u32 delivered_ce; 10 __u32 icsk_retransmits; 11 }; 12 13 static void send_byte(int fd) 14 { 15 char b = 0x55; 16 17 if (CHECK_FAIL(write(fd, &b, sizeof(b)) != 1)) 18 perror("Failed to send single byte"); 19 } 20 21 static int wait_for_ack(int fd, int retries) 22 { 23 struct tcp_info info; 24 socklen_t optlen; 25 int i, err; 26 27 for (i = 0; i < retries; i++) { 28 optlen = sizeof(info); 29 err = getsockopt(fd, SOL_TCP, TCP_INFO, &info, &optlen); 30 if (err < 0) { 31 log_err("Failed to lookup TCP stats"); 32 return err; 33 } 34 35 if (info.tcpi_unacked == 0) 36 return 0; 37 38 usleep(10); 39 } 40 41 log_err("Did not receive ACK"); 42 return -1; 43 } 44 45 static int verify_sk(int map_fd, int client_fd, const char *msg, __u32 invoked, 46 __u32 dsack_dups, __u32 delivered, __u32 delivered_ce, 47 __u32 icsk_retransmits) 48 { 49 int err = 0; 50 struct tcp_rtt_storage val; 51 52 if (CHECK_FAIL(bpf_map_lookup_elem(map_fd, &client_fd, &val) < 0)) { 53 perror("Failed to read socket storage"); 54 return -1; 55 } 56 57 if (val.invoked != invoked) { 58 log_err("%s: unexpected bpf_tcp_sock.invoked %d != %d", 59 msg, val.invoked, invoked); 60 err++; 61 } 62 63 if (val.dsack_dups != dsack_dups) { 64 log_err("%s: unexpected bpf_tcp_sock.dsack_dups %d != %d", 65 msg, val.dsack_dups, dsack_dups); 66 err++; 67 } 68 69 if (val.delivered != delivered) { 70 log_err("%s: unexpected bpf_tcp_sock.delivered %d != %d", 71 msg, val.delivered, delivered); 72 err++; 73 } 74 75 if (val.delivered_ce != delivered_ce) { 76 log_err("%s: unexpected bpf_tcp_sock.delivered_ce %d != %d", 77 msg, val.delivered_ce, delivered_ce); 78 err++; 79 } 80 81 if (val.icsk_retransmits != icsk_retransmits) { 82 log_err("%s: unexpected bpf_tcp_sock.icsk_retransmits %d != %d", 83 msg, val.icsk_retransmits, icsk_retransmits); 84 err++; 85 } 86 87 return err; 88 } 89 90 static int connect_to_server(int server_fd) 91 { 92 struct sockaddr_storage addr; 93 socklen_t len = sizeof(addr); 94 int fd; 95 96 fd = socket(AF_INET, SOCK_STREAM, 0); 97 if (fd < 0) { 98 log_err("Failed to create client socket"); 99 return -1; 100 } 101 102 if (getsockname(server_fd, (struct sockaddr *)&addr, &len)) { 103 log_err("Failed to get server addr"); 104 goto out; 105 } 106 107 if (connect(fd, (const struct sockaddr *)&addr, len) < 0) { 108 log_err("Fail to connect to server"); 109 goto out; 110 } 111 112 return fd; 113 114 out: 115 close(fd); 116 return -1; 117 } 118 119 static int run_test(int cgroup_fd, int server_fd) 120 { 121 struct bpf_prog_load_attr attr = { 122 .prog_type = BPF_PROG_TYPE_SOCK_OPS, 123 .file = "./tcp_rtt.o", 124 .expected_attach_type = BPF_CGROUP_SOCK_OPS, 125 }; 126 struct bpf_object *obj; 127 struct bpf_map *map; 128 int client_fd; 129 int prog_fd; 130 int map_fd; 131 int err; 132 133 err = bpf_prog_load_xattr(&attr, &obj, &prog_fd); 134 if (err) { 135 log_err("Failed to load BPF object"); 136 return -1; 137 } 138 139 map = bpf_map__next(NULL, obj); 140 map_fd = bpf_map__fd(map); 141 142 err = bpf_prog_attach(prog_fd, cgroup_fd, BPF_CGROUP_SOCK_OPS, 0); 143 if (err) { 144 log_err("Failed to attach BPF program"); 145 goto close_bpf_object; 146 } 147 148 client_fd = connect_to_server(server_fd); 149 if (client_fd < 0) { 150 err = -1; 151 goto close_bpf_object; 152 } 153 154 err += verify_sk(map_fd, client_fd, "syn-ack", 155 /*invoked=*/1, 156 /*dsack_dups=*/0, 157 /*delivered=*/1, 158 /*delivered_ce=*/0, 159 /*icsk_retransmits=*/0); 160 161 send_byte(client_fd); 162 if (wait_for_ack(client_fd, 100) < 0) { 163 err = -1; 164 goto close_client_fd; 165 } 166 167 168 err += verify_sk(map_fd, client_fd, "first payload byte", 169 /*invoked=*/2, 170 /*dsack_dups=*/0, 171 /*delivered=*/2, 172 /*delivered_ce=*/0, 173 /*icsk_retransmits=*/0); 174 175 close_client_fd: 176 close(client_fd); 177 178 close_bpf_object: 179 bpf_object__close(obj); 180 return err; 181 } 182 183 static int start_server(void) 184 { 185 struct sockaddr_in addr = { 186 .sin_family = AF_INET, 187 .sin_addr.s_addr = htonl(INADDR_LOOPBACK), 188 }; 189 int fd; 190 191 fd = socket(AF_INET, SOCK_STREAM | SOCK_NONBLOCK, 0); 192 if (fd < 0) { 193 log_err("Failed to create server socket"); 194 return -1; 195 } 196 197 if (bind(fd, (const struct sockaddr *)&addr, sizeof(addr)) < 0) { 198 log_err("Failed to bind socket"); 199 close(fd); 200 return -1; 201 } 202 203 return fd; 204 } 205 206 static pthread_mutex_t server_started_mtx = PTHREAD_MUTEX_INITIALIZER; 207 static pthread_cond_t server_started = PTHREAD_COND_INITIALIZER; 208 static volatile bool server_done = false; 209 210 static void *server_thread(void *arg) 211 { 212 struct sockaddr_storage addr; 213 socklen_t len = sizeof(addr); 214 int fd = *(int *)arg; 215 int client_fd; 216 int err; 217 218 err = listen(fd, 1); 219 220 pthread_mutex_lock(&server_started_mtx); 221 pthread_cond_signal(&server_started); 222 pthread_mutex_unlock(&server_started_mtx); 223 224 if (CHECK_FAIL(err < 0)) { 225 perror("Failed to listed on socket"); 226 return ERR_PTR(err); 227 } 228 229 while (true) { 230 client_fd = accept(fd, (struct sockaddr *)&addr, &len); 231 if (client_fd == -1 && errno == EAGAIN) { 232 usleep(50); 233 continue; 234 } 235 break; 236 } 237 if (CHECK_FAIL(client_fd < 0)) { 238 perror("Failed to accept client"); 239 return ERR_PTR(err); 240 } 241 242 while (!server_done) 243 usleep(50); 244 245 close(client_fd); 246 247 return NULL; 248 } 249 250 void test_tcp_rtt(void) 251 { 252 int server_fd, cgroup_fd; 253 pthread_t tid; 254 void *server_res; 255 256 cgroup_fd = test__join_cgroup("/tcp_rtt"); 257 if (CHECK_FAIL(cgroup_fd < 0)) 258 return; 259 260 server_fd = start_server(); 261 if (CHECK_FAIL(server_fd < 0)) 262 goto close_cgroup_fd; 263 264 if (CHECK_FAIL(pthread_create(&tid, NULL, server_thread, 265 (void *)&server_fd))) 266 goto close_server_fd; 267 268 pthread_mutex_lock(&server_started_mtx); 269 pthread_cond_wait(&server_started, &server_started_mtx); 270 pthread_mutex_unlock(&server_started_mtx); 271 272 CHECK_FAIL(run_test(cgroup_fd, server_fd)); 273 274 server_done = true; 275 CHECK_FAIL(pthread_join(tid, &server_res)); 276 CHECK_FAIL(IS_ERR(server_res)); 277 278 close_server_fd: 279 close(server_fd); 280 close_cgroup_fd: 281 close(cgroup_fd); 282 } 283