1 // SPDX-License-Identifier: GPL-2.0 2 /* Copyright (c) 2019 Facebook */ 3 4 #include <linux/err.h> 5 #include <netinet/tcp.h> 6 #include <test_progs.h> 7 #include "bpf_dctcp.skel.h" 8 #include "bpf_cubic.skel.h" 9 10 #define min(a, b) ((a) < (b) ? (a) : (b)) 11 12 static const unsigned int total_bytes = 10 * 1024 * 1024; 13 static const struct timeval timeo_sec = { .tv_sec = 10 }; 14 static const size_t timeo_optlen = sizeof(timeo_sec); 15 static int expected_stg = 0xeB9F; 16 static int stop, duration; 17 18 static int settimeo(int fd) 19 { 20 int err; 21 22 err = setsockopt(fd, SOL_SOCKET, SO_RCVTIMEO, &timeo_sec, 23 timeo_optlen); 24 if (CHECK(err == -1, "setsockopt(fd, SO_RCVTIMEO)", "errno:%d\n", 25 errno)) 26 return -1; 27 28 err = setsockopt(fd, SOL_SOCKET, SO_SNDTIMEO, &timeo_sec, 29 timeo_optlen); 30 if (CHECK(err == -1, "setsockopt(fd, SO_SNDTIMEO)", "errno:%d\n", 31 errno)) 32 return -1; 33 34 return 0; 35 } 36 37 static int settcpca(int fd, const char *tcp_ca) 38 { 39 int err; 40 41 err = setsockopt(fd, IPPROTO_TCP, TCP_CONGESTION, tcp_ca, strlen(tcp_ca)); 42 if (CHECK(err == -1, "setsockopt(fd, TCP_CONGESTION)", "errno:%d\n", 43 errno)) 44 return -1; 45 46 return 0; 47 } 48 49 static void *server(void *arg) 50 { 51 int lfd = (int)(long)arg, err = 0, fd; 52 ssize_t nr_sent = 0, bytes = 0; 53 char batch[1500]; 54 55 fd = accept(lfd, NULL, NULL); 56 while (fd == -1) { 57 if (errno == EINTR) 58 continue; 59 err = -errno; 60 goto done; 61 } 62 63 if (settimeo(fd)) { 64 err = -errno; 65 goto done; 66 } 67 68 while (bytes < total_bytes && !READ_ONCE(stop)) { 69 nr_sent = send(fd, &batch, 70 min(total_bytes - bytes, sizeof(batch)), 0); 71 if (nr_sent == -1 && errno == EINTR) 72 continue; 73 if (nr_sent == -1) { 74 err = -errno; 75 break; 76 } 77 bytes += nr_sent; 78 } 79 80 CHECK(bytes != total_bytes, "send", "%zd != %u nr_sent:%zd errno:%d\n", 81 bytes, total_bytes, nr_sent, errno); 82 83 done: 84 if (fd != -1) 85 close(fd); 86 if (err) { 87 WRITE_ONCE(stop, 1); 88 return ERR_PTR(err); 89 } 90 return NULL; 91 } 92 93 static void do_test(const char *tcp_ca, const struct bpf_map *sk_stg_map) 94 { 95 struct sockaddr_in6 sa6 = {}; 96 ssize_t nr_recv = 0, bytes = 0; 97 int lfd = -1, fd = -1; 98 pthread_t srv_thread; 99 socklen_t addrlen = sizeof(sa6); 100 void *thread_ret; 101 char batch[1500]; 102 int err; 103 104 WRITE_ONCE(stop, 0); 105 106 lfd = socket(AF_INET6, SOCK_STREAM, 0); 107 if (CHECK(lfd == -1, "socket", "errno:%d\n", errno)) 108 return; 109 fd = socket(AF_INET6, SOCK_STREAM, 0); 110 if (CHECK(fd == -1, "socket", "errno:%d\n", errno)) { 111 close(lfd); 112 return; 113 } 114 115 if (settcpca(lfd, tcp_ca) || settcpca(fd, tcp_ca) || 116 settimeo(lfd) || settimeo(fd)) 117 goto done; 118 119 /* bind, listen and start server thread to accept */ 120 sa6.sin6_family = AF_INET6; 121 sa6.sin6_addr = in6addr_loopback; 122 err = bind(lfd, (struct sockaddr *)&sa6, addrlen); 123 if (CHECK(err == -1, "bind", "errno:%d\n", errno)) 124 goto done; 125 err = getsockname(lfd, (struct sockaddr *)&sa6, &addrlen); 126 if (CHECK(err == -1, "getsockname", "errno:%d\n", errno)) 127 goto done; 128 err = listen(lfd, 1); 129 if (CHECK(err == -1, "listen", "errno:%d\n", errno)) 130 goto done; 131 132 if (sk_stg_map) { 133 err = bpf_map_update_elem(bpf_map__fd(sk_stg_map), &fd, 134 &expected_stg, BPF_NOEXIST); 135 if (CHECK(err, "bpf_map_update_elem(sk_stg_map)", 136 "err:%d errno:%d\n", err, errno)) 137 goto done; 138 } 139 140 /* connect to server */ 141 err = connect(fd, (struct sockaddr *)&sa6, addrlen); 142 if (CHECK(err == -1, "connect", "errno:%d\n", errno)) 143 goto done; 144 145 if (sk_stg_map) { 146 int tmp_stg; 147 148 err = bpf_map_lookup_elem(bpf_map__fd(sk_stg_map), &fd, 149 &tmp_stg); 150 if (CHECK(!err || errno != ENOENT, 151 "bpf_map_lookup_elem(sk_stg_map)", 152 "err:%d errno:%d\n", err, errno)) 153 goto done; 154 } 155 156 err = pthread_create(&srv_thread, NULL, server, (void *)(long)lfd); 157 if (CHECK(err != 0, "pthread_create", "err:%d errno:%d\n", err, errno)) 158 goto done; 159 160 /* recv total_bytes */ 161 while (bytes < total_bytes && !READ_ONCE(stop)) { 162 nr_recv = recv(fd, &batch, 163 min(total_bytes - bytes, sizeof(batch)), 0); 164 if (nr_recv == -1 && errno == EINTR) 165 continue; 166 if (nr_recv == -1) 167 break; 168 bytes += nr_recv; 169 } 170 171 CHECK(bytes != total_bytes, "recv", "%zd != %u nr_recv:%zd errno:%d\n", 172 bytes, total_bytes, nr_recv, errno); 173 174 WRITE_ONCE(stop, 1); 175 pthread_join(srv_thread, &thread_ret); 176 CHECK(IS_ERR(thread_ret), "pthread_join", "thread_ret:%ld", 177 PTR_ERR(thread_ret)); 178 done: 179 close(lfd); 180 close(fd); 181 } 182 183 static void test_cubic(void) 184 { 185 struct bpf_cubic *cubic_skel; 186 struct bpf_link *link; 187 188 cubic_skel = bpf_cubic__open_and_load(); 189 if (CHECK(!cubic_skel, "bpf_cubic__open_and_load", "failed\n")) 190 return; 191 192 link = bpf_map__attach_struct_ops(cubic_skel->maps.cubic); 193 if (CHECK(IS_ERR(link), "bpf_map__attach_struct_ops", "err:%ld\n", 194 PTR_ERR(link))) { 195 bpf_cubic__destroy(cubic_skel); 196 return; 197 } 198 199 do_test("bpf_cubic", NULL); 200 201 bpf_link__destroy(link); 202 bpf_cubic__destroy(cubic_skel); 203 } 204 205 static void test_dctcp(void) 206 { 207 struct bpf_dctcp *dctcp_skel; 208 struct bpf_link *link; 209 210 dctcp_skel = bpf_dctcp__open_and_load(); 211 if (CHECK(!dctcp_skel, "bpf_dctcp__open_and_load", "failed\n")) 212 return; 213 214 link = bpf_map__attach_struct_ops(dctcp_skel->maps.dctcp); 215 if (CHECK(IS_ERR(link), "bpf_map__attach_struct_ops", "err:%ld\n", 216 PTR_ERR(link))) { 217 bpf_dctcp__destroy(dctcp_skel); 218 return; 219 } 220 221 do_test("bpf_dctcp", dctcp_skel->maps.sk_stg_map); 222 CHECK(dctcp_skel->bss->stg_result != expected_stg, 223 "Unexpected stg_result", "stg_result (%x) != expected_stg (%x)\n", 224 dctcp_skel->bss->stg_result, expected_stg); 225 226 bpf_link__destroy(link); 227 bpf_dctcp__destroy(dctcp_skel); 228 } 229 230 void test_bpf_tcp_ca(void) 231 { 232 if (test__start_subtest("dctcp")) 233 test_dctcp(); 234 if (test__start_subtest("cubic")) 235 test_cubic(); 236 } 237