1 // SPDX-License-Identifier: GPL-2.0 2 /* Copyright (c) 2019 Facebook */ 3 4 #include <linux/err.h> 5 #include <test_progs.h> 6 #include "bpf_dctcp.skel.h" 7 #include "bpf_cubic.skel.h" 8 9 #define min(a, b) ((a) < (b) ? (a) : (b)) 10 11 static const unsigned int total_bytes = 10 * 1024 * 1024; 12 static const struct timeval timeo_sec = { .tv_sec = 10 }; 13 static const size_t timeo_optlen = sizeof(timeo_sec); 14 static int stop, duration; 15 16 static int settimeo(int fd) 17 { 18 int err; 19 20 err = setsockopt(fd, SOL_SOCKET, SO_RCVTIMEO, &timeo_sec, 21 timeo_optlen); 22 if (CHECK(err == -1, "setsockopt(fd, SO_RCVTIMEO)", "errno:%d\n", 23 errno)) 24 return -1; 25 26 err = setsockopt(fd, SOL_SOCKET, SO_SNDTIMEO, &timeo_sec, 27 timeo_optlen); 28 if (CHECK(err == -1, "setsockopt(fd, SO_SNDTIMEO)", "errno:%d\n", 29 errno)) 30 return -1; 31 32 return 0; 33 } 34 35 static int settcpca(int fd, const char *tcp_ca) 36 { 37 int err; 38 39 err = setsockopt(fd, IPPROTO_TCP, TCP_CONGESTION, tcp_ca, strlen(tcp_ca)); 40 if (CHECK(err == -1, "setsockopt(fd, TCP_CONGESTION)", "errno:%d\n", 41 errno)) 42 return -1; 43 44 return 0; 45 } 46 47 static void *server(void *arg) 48 { 49 int lfd = (int)(long)arg, err = 0, fd; 50 ssize_t nr_sent = 0, bytes = 0; 51 char batch[1500]; 52 53 fd = accept(lfd, NULL, NULL); 54 while (fd == -1) { 55 if (errno == EINTR) 56 continue; 57 err = -errno; 58 goto done; 59 } 60 61 if (settimeo(fd)) { 62 err = -errno; 63 goto done; 64 } 65 66 while (bytes < total_bytes && !READ_ONCE(stop)) { 67 nr_sent = send(fd, &batch, 68 min(total_bytes - bytes, sizeof(batch)), 0); 69 if (nr_sent == -1 && errno == EINTR) 70 continue; 71 if (nr_sent == -1) { 72 err = -errno; 73 break; 74 } 75 bytes += nr_sent; 76 } 77 78 CHECK(bytes != total_bytes, "send", "%zd != %u nr_sent:%zd errno:%d\n", 79 bytes, total_bytes, nr_sent, errno); 80 81 done: 82 if (fd != -1) 83 close(fd); 84 if (err) { 85 WRITE_ONCE(stop, 1); 86 return ERR_PTR(err); 87 } 88 return NULL; 89 } 90 91 static void do_test(const char *tcp_ca) 92 { 93 struct sockaddr_in6 sa6 = {}; 94 ssize_t nr_recv = 0, bytes = 0; 95 int lfd = -1, fd = -1; 96 pthread_t srv_thread; 97 socklen_t addrlen = sizeof(sa6); 98 void *thread_ret; 99 char batch[1500]; 100 int err; 101 102 WRITE_ONCE(stop, 0); 103 104 lfd = socket(AF_INET6, SOCK_STREAM, 0); 105 if (CHECK(lfd == -1, "socket", "errno:%d\n", errno)) 106 return; 107 fd = socket(AF_INET6, SOCK_STREAM, 0); 108 if (CHECK(fd == -1, "socket", "errno:%d\n", errno)) { 109 close(lfd); 110 return; 111 } 112 113 if (settcpca(lfd, tcp_ca) || settcpca(fd, tcp_ca) || 114 settimeo(lfd) || settimeo(fd)) 115 goto done; 116 117 /* bind, listen and start server thread to accept */ 118 sa6.sin6_family = AF_INET6; 119 sa6.sin6_addr = in6addr_loopback; 120 err = bind(lfd, (struct sockaddr *)&sa6, addrlen); 121 if (CHECK(err == -1, "bind", "errno:%d\n", errno)) 122 goto done; 123 err = getsockname(lfd, (struct sockaddr *)&sa6, &addrlen); 124 if (CHECK(err == -1, "getsockname", "errno:%d\n", errno)) 125 goto done; 126 err = listen(lfd, 1); 127 if (CHECK(err == -1, "listen", "errno:%d\n", errno)) 128 goto done; 129 err = pthread_create(&srv_thread, NULL, server, (void *)(long)lfd); 130 if (CHECK(err != 0, "pthread_create", "err:%d\n", err)) 131 goto done; 132 133 /* connect to server */ 134 err = connect(fd, (struct sockaddr *)&sa6, addrlen); 135 if (CHECK(err == -1, "connect", "errno:%d\n", errno)) 136 goto wait_thread; 137 138 /* recv total_bytes */ 139 while (bytes < total_bytes && !READ_ONCE(stop)) { 140 nr_recv = recv(fd, &batch, 141 min(total_bytes - bytes, sizeof(batch)), 0); 142 if (nr_recv == -1 && errno == EINTR) 143 continue; 144 if (nr_recv == -1) 145 break; 146 bytes += nr_recv; 147 } 148 149 CHECK(bytes != total_bytes, "recv", "%zd != %u nr_recv:%zd errno:%d\n", 150 bytes, total_bytes, nr_recv, errno); 151 152 wait_thread: 153 WRITE_ONCE(stop, 1); 154 pthread_join(srv_thread, &thread_ret); 155 CHECK(IS_ERR(thread_ret), "pthread_join", "thread_ret:%ld", 156 PTR_ERR(thread_ret)); 157 done: 158 close(lfd); 159 close(fd); 160 } 161 162 static void test_cubic(void) 163 { 164 struct bpf_cubic *cubic_skel; 165 struct bpf_link *link; 166 167 cubic_skel = bpf_cubic__open_and_load(); 168 if (CHECK(!cubic_skel, "bpf_cubic__open_and_load", "failed\n")) 169 return; 170 171 link = bpf_map__attach_struct_ops(cubic_skel->maps.cubic); 172 if (CHECK(IS_ERR(link), "bpf_map__attach_struct_ops", "err:%ld\n", 173 PTR_ERR(link))) { 174 bpf_cubic__destroy(cubic_skel); 175 return; 176 } 177 178 do_test("bpf_cubic"); 179 180 bpf_link__destroy(link); 181 bpf_cubic__destroy(cubic_skel); 182 } 183 184 static void test_dctcp(void) 185 { 186 struct bpf_dctcp *dctcp_skel; 187 struct bpf_link *link; 188 189 dctcp_skel = bpf_dctcp__open_and_load(); 190 if (CHECK(!dctcp_skel, "bpf_dctcp__open_and_load", "failed\n")) 191 return; 192 193 link = bpf_map__attach_struct_ops(dctcp_skel->maps.dctcp); 194 if (CHECK(IS_ERR(link), "bpf_map__attach_struct_ops", "err:%ld\n", 195 PTR_ERR(link))) { 196 bpf_dctcp__destroy(dctcp_skel); 197 return; 198 } 199 200 do_test("bpf_dctcp"); 201 202 bpf_link__destroy(link); 203 bpf_dctcp__destroy(dctcp_skel); 204 } 205 206 void test_bpf_tcp_ca(void) 207 { 208 if (test__start_subtest("dctcp")) 209 test_dctcp(); 210 if (test__start_subtest("cubic")) 211 test_cubic(); 212 } 213