1 // SPDX-License-Identifier: GPL-2.0 2 /* Copyright (c) 2019 Facebook */ 3 4 #include <linux/err.h> 5 #include <netinet/tcp.h> 6 #include <test_progs.h> 7 #include "network_helpers.h" 8 #include "bpf_dctcp.skel.h" 9 #include "bpf_cubic.skel.h" 10 #include "bpf_tcp_nogpl.skel.h" 11 #include "bpf_dctcp_release.skel.h" 12 #include "tcp_ca_write_sk_pacing.skel.h" 13 #include "tcp_ca_incompl_cong_ops.skel.h" 14 #include "tcp_ca_unsupp_cong_op.skel.h" 15 16 #ifndef ENOTSUPP 17 #define ENOTSUPP 524 18 #endif 19 20 static const unsigned int total_bytes = 10 * 1024 * 1024; 21 static int expected_stg = 0xeB9F; 22 static int stop, duration; 23 24 static int settcpca(int fd, const char *tcp_ca) 25 { 26 int err; 27 28 err = setsockopt(fd, IPPROTO_TCP, TCP_CONGESTION, tcp_ca, strlen(tcp_ca)); 29 if (CHECK(err == -1, "setsockopt(fd, TCP_CONGESTION)", "errno:%d\n", 30 errno)) 31 return -1; 32 33 return 0; 34 } 35 36 static void *server(void *arg) 37 { 38 int lfd = (int)(long)arg, err = 0, fd; 39 ssize_t nr_sent = 0, bytes = 0; 40 char batch[1500]; 41 42 fd = accept(lfd, NULL, NULL); 43 while (fd == -1) { 44 if (errno == EINTR) 45 continue; 46 err = -errno; 47 goto done; 48 } 49 50 if (settimeo(fd, 0)) { 51 err = -errno; 52 goto done; 53 } 54 55 while (bytes < total_bytes && !READ_ONCE(stop)) { 56 nr_sent = send(fd, &batch, 57 MIN(total_bytes - bytes, sizeof(batch)), 0); 58 if (nr_sent == -1 && errno == EINTR) 59 continue; 60 if (nr_sent == -1) { 61 err = -errno; 62 break; 63 } 64 bytes += nr_sent; 65 } 66 67 CHECK(bytes != total_bytes, "send", "%zd != %u nr_sent:%zd errno:%d\n", 68 bytes, total_bytes, nr_sent, errno); 69 70 done: 71 if (fd >= 0) 72 close(fd); 73 if (err) { 74 WRITE_ONCE(stop, 1); 75 return ERR_PTR(err); 76 } 77 return NULL; 78 } 79 80 static void do_test(const char *tcp_ca, const struct bpf_map *sk_stg_map) 81 { 82 struct sockaddr_in6 sa6 = {}; 83 ssize_t nr_recv = 0, bytes = 0; 84 int lfd = -1, fd = -1; 85 pthread_t srv_thread; 86 socklen_t addrlen = sizeof(sa6); 87 void *thread_ret; 88 char batch[1500]; 89 int err; 90 91 WRITE_ONCE(stop, 0); 92 93 lfd = socket(AF_INET6, SOCK_STREAM, 0); 94 if (CHECK(lfd == -1, "socket", "errno:%d\n", errno)) 95 return; 96 fd = socket(AF_INET6, SOCK_STREAM, 0); 97 if (CHECK(fd == -1, "socket", "errno:%d\n", errno)) { 98 close(lfd); 99 return; 100 } 101 102 if (settcpca(lfd, tcp_ca) || settcpca(fd, tcp_ca) || 103 settimeo(lfd, 0) || settimeo(fd, 0)) 104 goto done; 105 106 /* bind, listen and start server thread to accept */ 107 sa6.sin6_family = AF_INET6; 108 sa6.sin6_addr = in6addr_loopback; 109 err = bind(lfd, (struct sockaddr *)&sa6, addrlen); 110 if (CHECK(err == -1, "bind", "errno:%d\n", errno)) 111 goto done; 112 err = getsockname(lfd, (struct sockaddr *)&sa6, &addrlen); 113 if (CHECK(err == -1, "getsockname", "errno:%d\n", errno)) 114 goto done; 115 err = listen(lfd, 1); 116 if (CHECK(err == -1, "listen", "errno:%d\n", errno)) 117 goto done; 118 119 if (sk_stg_map) { 120 err = bpf_map_update_elem(bpf_map__fd(sk_stg_map), &fd, 121 &expected_stg, BPF_NOEXIST); 122 if (CHECK(err, "bpf_map_update_elem(sk_stg_map)", 123 "err:%d errno:%d\n", err, errno)) 124 goto done; 125 } 126 127 /* connect to server */ 128 err = connect(fd, (struct sockaddr *)&sa6, addrlen); 129 if (CHECK(err == -1, "connect", "errno:%d\n", errno)) 130 goto done; 131 132 if (sk_stg_map) { 133 int tmp_stg; 134 135 err = bpf_map_lookup_elem(bpf_map__fd(sk_stg_map), &fd, 136 &tmp_stg); 137 if (CHECK(!err || errno != ENOENT, 138 "bpf_map_lookup_elem(sk_stg_map)", 139 "err:%d errno:%d\n", err, errno)) 140 goto done; 141 } 142 143 err = pthread_create(&srv_thread, NULL, server, (void *)(long)lfd); 144 if (CHECK(err != 0, "pthread_create", "err:%d errno:%d\n", err, errno)) 145 goto done; 146 147 /* recv total_bytes */ 148 while (bytes < total_bytes && !READ_ONCE(stop)) { 149 nr_recv = recv(fd, &batch, 150 MIN(total_bytes - bytes, sizeof(batch)), 0); 151 if (nr_recv == -1 && errno == EINTR) 152 continue; 153 if (nr_recv == -1) 154 break; 155 bytes += nr_recv; 156 } 157 158 CHECK(bytes != total_bytes, "recv", "%zd != %u nr_recv:%zd errno:%d\n", 159 bytes, total_bytes, nr_recv, errno); 160 161 WRITE_ONCE(stop, 1); 162 pthread_join(srv_thread, &thread_ret); 163 CHECK(IS_ERR(thread_ret), "pthread_join", "thread_ret:%ld", 164 PTR_ERR(thread_ret)); 165 done: 166 close(lfd); 167 close(fd); 168 } 169 170 static void test_cubic(void) 171 { 172 struct bpf_cubic *cubic_skel; 173 struct bpf_link *link; 174 175 cubic_skel = bpf_cubic__open_and_load(); 176 if (CHECK(!cubic_skel, "bpf_cubic__open_and_load", "failed\n")) 177 return; 178 179 link = bpf_map__attach_struct_ops(cubic_skel->maps.cubic); 180 if (!ASSERT_OK_PTR(link, "bpf_map__attach_struct_ops")) { 181 bpf_cubic__destroy(cubic_skel); 182 return; 183 } 184 185 do_test("bpf_cubic", NULL); 186 187 bpf_link__destroy(link); 188 bpf_cubic__destroy(cubic_skel); 189 } 190 191 static void test_dctcp(void) 192 { 193 struct bpf_dctcp *dctcp_skel; 194 struct bpf_link *link; 195 196 dctcp_skel = bpf_dctcp__open_and_load(); 197 if (CHECK(!dctcp_skel, "bpf_dctcp__open_and_load", "failed\n")) 198 return; 199 200 link = bpf_map__attach_struct_ops(dctcp_skel->maps.dctcp); 201 if (!ASSERT_OK_PTR(link, "bpf_map__attach_struct_ops")) { 202 bpf_dctcp__destroy(dctcp_skel); 203 return; 204 } 205 206 do_test("bpf_dctcp", dctcp_skel->maps.sk_stg_map); 207 CHECK(dctcp_skel->bss->stg_result != expected_stg, 208 "Unexpected stg_result", "stg_result (%x) != expected_stg (%x)\n", 209 dctcp_skel->bss->stg_result, expected_stg); 210 211 bpf_link__destroy(link); 212 bpf_dctcp__destroy(dctcp_skel); 213 } 214 215 static char *err_str; 216 static bool found; 217 218 static int libbpf_debug_print(enum libbpf_print_level level, 219 const char *format, va_list args) 220 { 221 const char *prog_name, *log_buf; 222 223 if (level != LIBBPF_WARN || 224 !strstr(format, "-- BEGIN PROG LOAD LOG --")) { 225 vprintf(format, args); 226 return 0; 227 } 228 229 prog_name = va_arg(args, char *); 230 log_buf = va_arg(args, char *); 231 if (!log_buf) 232 goto out; 233 if (err_str && strstr(log_buf, err_str) != NULL) 234 found = true; 235 out: 236 printf(format, prog_name, log_buf); 237 return 0; 238 } 239 240 static void test_invalid_license(void) 241 { 242 libbpf_print_fn_t old_print_fn; 243 struct bpf_tcp_nogpl *skel; 244 245 err_str = "struct ops programs must have a GPL compatible license"; 246 found = false; 247 old_print_fn = libbpf_set_print(libbpf_debug_print); 248 249 skel = bpf_tcp_nogpl__open_and_load(); 250 ASSERT_NULL(skel, "bpf_tcp_nogpl"); 251 ASSERT_EQ(found, true, "expected_err_msg"); 252 253 bpf_tcp_nogpl__destroy(skel); 254 libbpf_set_print(old_print_fn); 255 } 256 257 static void test_dctcp_fallback(void) 258 { 259 int err, lfd = -1, cli_fd = -1, srv_fd = -1; 260 struct network_helper_opts opts = { 261 .cc = "cubic", 262 }; 263 struct bpf_dctcp *dctcp_skel; 264 struct bpf_link *link = NULL; 265 char srv_cc[16]; 266 socklen_t cc_len = sizeof(srv_cc); 267 268 dctcp_skel = bpf_dctcp__open(); 269 if (!ASSERT_OK_PTR(dctcp_skel, "dctcp_skel")) 270 return; 271 strcpy(dctcp_skel->rodata->fallback, "cubic"); 272 if (!ASSERT_OK(bpf_dctcp__load(dctcp_skel), "bpf_dctcp__load")) 273 goto done; 274 275 link = bpf_map__attach_struct_ops(dctcp_skel->maps.dctcp); 276 if (!ASSERT_OK_PTR(link, "dctcp link")) 277 goto done; 278 279 lfd = start_server(AF_INET6, SOCK_STREAM, "::1", 0, 0); 280 if (!ASSERT_GE(lfd, 0, "lfd") || 281 !ASSERT_OK(settcpca(lfd, "bpf_dctcp"), "lfd=>bpf_dctcp")) 282 goto done; 283 284 cli_fd = connect_to_fd_opts(lfd, &opts); 285 if (!ASSERT_GE(cli_fd, 0, "cli_fd")) 286 goto done; 287 288 srv_fd = accept(lfd, NULL, 0); 289 if (!ASSERT_GE(srv_fd, 0, "srv_fd")) 290 goto done; 291 ASSERT_STREQ(dctcp_skel->bss->cc_res, "cubic", "cc_res"); 292 ASSERT_EQ(dctcp_skel->bss->tcp_cdg_res, -ENOTSUPP, "tcp_cdg_res"); 293 294 err = getsockopt(srv_fd, SOL_TCP, TCP_CONGESTION, srv_cc, &cc_len); 295 if (!ASSERT_OK(err, "getsockopt(srv_fd, TCP_CONGESTION)")) 296 goto done; 297 ASSERT_STREQ(srv_cc, "cubic", "srv_fd cc"); 298 299 done: 300 bpf_link__destroy(link); 301 bpf_dctcp__destroy(dctcp_skel); 302 if (lfd != -1) 303 close(lfd); 304 if (srv_fd != -1) 305 close(srv_fd); 306 if (cli_fd != -1) 307 close(cli_fd); 308 } 309 310 static void test_rel_setsockopt(void) 311 { 312 struct bpf_dctcp_release *rel_skel; 313 libbpf_print_fn_t old_print_fn; 314 315 err_str = "unknown func bpf_setsockopt"; 316 found = false; 317 318 old_print_fn = libbpf_set_print(libbpf_debug_print); 319 rel_skel = bpf_dctcp_release__open_and_load(); 320 libbpf_set_print(old_print_fn); 321 322 ASSERT_ERR_PTR(rel_skel, "rel_skel"); 323 ASSERT_TRUE(found, "expected_err_msg"); 324 325 bpf_dctcp_release__destroy(rel_skel); 326 } 327 328 static void test_write_sk_pacing(void) 329 { 330 struct tcp_ca_write_sk_pacing *skel; 331 struct bpf_link *link; 332 333 skel = tcp_ca_write_sk_pacing__open_and_load(); 334 if (!ASSERT_OK_PTR(skel, "open_and_load")) 335 return; 336 337 link = bpf_map__attach_struct_ops(skel->maps.write_sk_pacing); 338 ASSERT_OK_PTR(link, "attach_struct_ops"); 339 340 bpf_link__destroy(link); 341 tcp_ca_write_sk_pacing__destroy(skel); 342 } 343 344 static void test_incompl_cong_ops(void) 345 { 346 struct tcp_ca_incompl_cong_ops *skel; 347 struct bpf_link *link; 348 349 skel = tcp_ca_incompl_cong_ops__open_and_load(); 350 if (!ASSERT_OK_PTR(skel, "open_and_load")) 351 return; 352 353 /* That cong_avoid() and cong_control() are missing is only reported at 354 * this point: 355 */ 356 link = bpf_map__attach_struct_ops(skel->maps.incompl_cong_ops); 357 ASSERT_ERR_PTR(link, "attach_struct_ops"); 358 359 bpf_link__destroy(link); 360 tcp_ca_incompl_cong_ops__destroy(skel); 361 } 362 363 static void test_unsupp_cong_op(void) 364 { 365 libbpf_print_fn_t old_print_fn; 366 struct tcp_ca_unsupp_cong_op *skel; 367 368 err_str = "attach to unsupported member get_info"; 369 found = false; 370 old_print_fn = libbpf_set_print(libbpf_debug_print); 371 372 skel = tcp_ca_unsupp_cong_op__open_and_load(); 373 ASSERT_NULL(skel, "open_and_load"); 374 ASSERT_EQ(found, true, "expected_err_msg"); 375 376 tcp_ca_unsupp_cong_op__destroy(skel); 377 libbpf_set_print(old_print_fn); 378 } 379 380 void test_bpf_tcp_ca(void) 381 { 382 if (test__start_subtest("dctcp")) 383 test_dctcp(); 384 if (test__start_subtest("cubic")) 385 test_cubic(); 386 if (test__start_subtest("invalid_license")) 387 test_invalid_license(); 388 if (test__start_subtest("dctcp_fallback")) 389 test_dctcp_fallback(); 390 if (test__start_subtest("rel_setsockopt")) 391 test_rel_setsockopt(); 392 if (test__start_subtest("write_sk_pacing")) 393 test_write_sk_pacing(); 394 if (test__start_subtest("incompl_cong_ops")) 395 test_incompl_cong_ops(); 396 if (test__start_subtest("unsupp_cong_op")) 397 test_unsupp_cong_op(); 398 } 399