1 // SPDX-License-Identifier: GPL-2.0 2 // Copyright (c) 2018 Facebook 3 // Copyright (c) 2019 Cloudflare 4 5 #include <limits.h> 6 #include <string.h> 7 #include <stdlib.h> 8 #include <unistd.h> 9 10 #include <arpa/inet.h> 11 #include <netinet/in.h> 12 #include <sys/types.h> 13 #include <sys/socket.h> 14 15 #include <bpf/bpf.h> 16 #include <bpf/libbpf.h> 17 18 #include "cgroup_helpers.h" 19 20 static int start_server(const struct sockaddr *addr, socklen_t len, bool dual) 21 { 22 int mode = !dual; 23 int fd; 24 25 fd = socket(addr->sa_family, SOCK_STREAM, 0); 26 if (fd == -1) { 27 log_err("Failed to create server socket"); 28 goto out; 29 } 30 31 if (addr->sa_family == AF_INET6) { 32 if (setsockopt(fd, IPPROTO_IPV6, IPV6_V6ONLY, (char *)&mode, 33 sizeof(mode)) == -1) { 34 log_err("Failed to set the dual-stack mode"); 35 goto close_out; 36 } 37 } 38 39 if (bind(fd, addr, len) == -1) { 40 log_err("Failed to bind server socket"); 41 goto close_out; 42 } 43 44 if (listen(fd, 128) == -1) { 45 log_err("Failed to listen on server socket"); 46 goto close_out; 47 } 48 49 goto out; 50 51 close_out: 52 close(fd); 53 fd = -1; 54 out: 55 return fd; 56 } 57 58 static int connect_to_server(const struct sockaddr *addr, socklen_t len) 59 { 60 int fd = -1; 61 62 fd = socket(addr->sa_family, SOCK_STREAM, 0); 63 if (fd == -1) { 64 log_err("Failed to create client socket"); 65 goto out; 66 } 67 68 if (connect(fd, (const struct sockaddr *)addr, len) == -1) { 69 log_err("Fail to connect to server"); 70 goto close_out; 71 } 72 73 goto out; 74 75 close_out: 76 close(fd); 77 fd = -1; 78 out: 79 return fd; 80 } 81 82 static int get_map_fd_by_prog_id(int prog_id, bool *xdp) 83 { 84 struct bpf_prog_info info = {}; 85 __u32 info_len = sizeof(info); 86 __u32 map_ids[1]; 87 int prog_fd = -1; 88 int map_fd = -1; 89 90 prog_fd = bpf_prog_get_fd_by_id(prog_id); 91 if (prog_fd < 0) { 92 log_err("Failed to get fd by prog id %d", prog_id); 93 goto err; 94 } 95 96 info.nr_map_ids = 1; 97 info.map_ids = (__u64)(unsigned long)map_ids; 98 99 if (bpf_obj_get_info_by_fd(prog_fd, &info, &info_len)) { 100 log_err("Failed to get info by prog fd %d", prog_fd); 101 goto err; 102 } 103 104 if (!info.nr_map_ids) { 105 log_err("No maps found for prog fd %d", prog_fd); 106 goto err; 107 } 108 109 *xdp = info.type == BPF_PROG_TYPE_XDP; 110 111 map_fd = bpf_map_get_fd_by_id(map_ids[0]); 112 if (map_fd < 0) 113 log_err("Failed to get fd by map id %d", map_ids[0]); 114 err: 115 if (prog_fd >= 0) 116 close(prog_fd); 117 return map_fd; 118 } 119 120 static int run_test(int server_fd, int results_fd, bool xdp, 121 const struct sockaddr *addr, socklen_t len) 122 { 123 int client = -1, srv_client = -1; 124 int ret = 0; 125 __u32 key = 0; 126 __u32 key_gen = 1; 127 __u32 key_mss = 2; 128 __u32 value = 0; 129 __u32 value_gen = 0; 130 __u32 value_mss = 0; 131 132 if (bpf_map_update_elem(results_fd, &key, &value, 0) < 0) { 133 log_err("Can't clear results"); 134 goto err; 135 } 136 137 if (bpf_map_update_elem(results_fd, &key_gen, &value_gen, 0) < 0) { 138 log_err("Can't clear results"); 139 goto err; 140 } 141 142 if (bpf_map_update_elem(results_fd, &key_mss, &value_mss, 0) < 0) { 143 log_err("Can't clear results"); 144 goto err; 145 } 146 147 client = connect_to_server(addr, len); 148 if (client == -1) 149 goto err; 150 151 srv_client = accept(server_fd, NULL, 0); 152 if (srv_client == -1) { 153 log_err("Can't accept connection"); 154 goto err; 155 } 156 157 if (bpf_map_lookup_elem(results_fd, &key, &value) < 0) { 158 log_err("Can't lookup result"); 159 goto err; 160 } 161 162 if (value == 0) { 163 log_err("Didn't match syncookie: %u", value); 164 goto err; 165 } 166 167 if (bpf_map_lookup_elem(results_fd, &key_gen, &value_gen) < 0) { 168 log_err("Can't lookup result"); 169 goto err; 170 } 171 172 if (xdp && value_gen == 0) { 173 // SYN packets do not get passed through generic XDP, skip the 174 // rest of the test. 175 printf("Skipping XDP cookie check\n"); 176 goto out; 177 } 178 179 if (bpf_map_lookup_elem(results_fd, &key_mss, &value_mss) < 0) { 180 log_err("Can't lookup result"); 181 goto err; 182 } 183 184 if (value != value_gen) { 185 log_err("BPF generated cookie does not match kernel one"); 186 goto err; 187 } 188 189 if (value_mss < 536 || value_mss > USHRT_MAX) { 190 log_err("Unexpected MSS retrieved"); 191 goto err; 192 } 193 194 goto out; 195 196 err: 197 ret = 1; 198 out: 199 close(client); 200 close(srv_client); 201 return ret; 202 } 203 204 static bool get_port(int server_fd, in_port_t *port) 205 { 206 struct sockaddr_in addr; 207 socklen_t len = sizeof(addr); 208 209 if (getsockname(server_fd, (struct sockaddr *)&addr, &len)) { 210 log_err("Failed to get server addr"); 211 return false; 212 } 213 214 /* sin_port and sin6_port are located at the same offset. */ 215 *port = addr.sin_port; 216 return true; 217 } 218 219 int main(int argc, char **argv) 220 { 221 struct sockaddr_in addr4; 222 struct sockaddr_in6 addr6; 223 struct sockaddr_in addr4dual; 224 struct sockaddr_in6 addr6dual; 225 int server = -1; 226 int server_v6 = -1; 227 int server_dual = -1; 228 int results = -1; 229 int err = 0; 230 bool xdp; 231 232 if (argc < 2) { 233 fprintf(stderr, "Usage: %s prog_id\n", argv[0]); 234 exit(1); 235 } 236 237 /* Use libbpf 1.0 API mode */ 238 libbpf_set_strict_mode(LIBBPF_STRICT_ALL); 239 240 results = get_map_fd_by_prog_id(atoi(argv[1]), &xdp); 241 if (results < 0) { 242 log_err("Can't get map"); 243 goto err; 244 } 245 246 memset(&addr4, 0, sizeof(addr4)); 247 addr4.sin_family = AF_INET; 248 addr4.sin_addr.s_addr = htonl(INADDR_LOOPBACK); 249 addr4.sin_port = 0; 250 memcpy(&addr4dual, &addr4, sizeof(addr4dual)); 251 252 memset(&addr6, 0, sizeof(addr6)); 253 addr6.sin6_family = AF_INET6; 254 addr6.sin6_addr = in6addr_loopback; 255 addr6.sin6_port = 0; 256 257 memset(&addr6dual, 0, sizeof(addr6dual)); 258 addr6dual.sin6_family = AF_INET6; 259 addr6dual.sin6_addr = in6addr_any; 260 addr6dual.sin6_port = 0; 261 262 server = start_server((const struct sockaddr *)&addr4, sizeof(addr4), 263 false); 264 if (server == -1 || !get_port(server, &addr4.sin_port)) 265 goto err; 266 267 server_v6 = start_server((const struct sockaddr *)&addr6, 268 sizeof(addr6), false); 269 if (server_v6 == -1 || !get_port(server_v6, &addr6.sin6_port)) 270 goto err; 271 272 server_dual = start_server((const struct sockaddr *)&addr6dual, 273 sizeof(addr6dual), true); 274 if (server_dual == -1 || !get_port(server_dual, &addr4dual.sin_port)) 275 goto err; 276 277 if (run_test(server, results, xdp, 278 (const struct sockaddr *)&addr4, sizeof(addr4))) 279 goto err; 280 281 if (run_test(server_v6, results, xdp, 282 (const struct sockaddr *)&addr6, sizeof(addr6))) 283 goto err; 284 285 if (run_test(server_dual, results, xdp, 286 (const struct sockaddr *)&addr4dual, sizeof(addr4dual))) 287 goto err; 288 289 printf("ok\n"); 290 goto out; 291 err: 292 err = 1; 293 out: 294 close(server); 295 close(server_v6); 296 close(server_dual); 297 close(results); 298 return err; 299 } 300