1 // SPDX-License-Identifier: GPL-2.0 2 // Copyright (c) 2018 Facebook 3 // Copyright (c) 2019 Cloudflare 4 5 #include <limits.h> 6 #include <string.h> 7 #include <stdlib.h> 8 #include <unistd.h> 9 10 #include <arpa/inet.h> 11 #include <netinet/in.h> 12 #include <sys/types.h> 13 #include <sys/socket.h> 14 15 #include <bpf/bpf.h> 16 #include <bpf/libbpf.h> 17 18 #include "bpf_rlimit.h" 19 #include "cgroup_helpers.h" 20 21 static int start_server(const struct sockaddr *addr, socklen_t len) 22 { 23 int fd; 24 25 fd = socket(addr->sa_family, SOCK_STREAM, 0); 26 if (fd == -1) { 27 log_err("Failed to create server socket"); 28 goto out; 29 } 30 31 if (bind(fd, addr, len) == -1) { 32 log_err("Failed to bind server socket"); 33 goto close_out; 34 } 35 36 if (listen(fd, 128) == -1) { 37 log_err("Failed to listen on server socket"); 38 goto close_out; 39 } 40 41 goto out; 42 43 close_out: 44 close(fd); 45 fd = -1; 46 out: 47 return fd; 48 } 49 50 static int connect_to_server(int server_fd) 51 { 52 struct sockaddr_storage addr; 53 socklen_t len = sizeof(addr); 54 int fd = -1; 55 56 if (getsockname(server_fd, (struct sockaddr *)&addr, &len)) { 57 log_err("Failed to get server addr"); 58 goto out; 59 } 60 61 fd = socket(addr.ss_family, SOCK_STREAM, 0); 62 if (fd == -1) { 63 log_err("Failed to create client socket"); 64 goto out; 65 } 66 67 if (connect(fd, (const struct sockaddr *)&addr, len) == -1) { 68 log_err("Fail to connect to server"); 69 goto close_out; 70 } 71 72 goto out; 73 74 close_out: 75 close(fd); 76 fd = -1; 77 out: 78 return fd; 79 } 80 81 static int get_map_fd_by_prog_id(int prog_id, bool *xdp) 82 { 83 struct bpf_prog_info info = {}; 84 __u32 info_len = sizeof(info); 85 __u32 map_ids[1]; 86 int prog_fd = -1; 87 int map_fd = -1; 88 89 prog_fd = bpf_prog_get_fd_by_id(prog_id); 90 if (prog_fd < 0) { 91 log_err("Failed to get fd by prog id %d", prog_id); 92 goto err; 93 } 94 95 info.nr_map_ids = 1; 96 info.map_ids = (__u64)(unsigned long)map_ids; 97 98 if (bpf_obj_get_info_by_fd(prog_fd, &info, &info_len)) { 99 log_err("Failed to get info by prog fd %d", prog_fd); 100 goto err; 101 } 102 103 if (!info.nr_map_ids) { 104 log_err("No maps found for prog fd %d", prog_fd); 105 goto err; 106 } 107 108 *xdp = info.type == BPF_PROG_TYPE_XDP; 109 110 map_fd = bpf_map_get_fd_by_id(map_ids[0]); 111 if (map_fd < 0) 112 log_err("Failed to get fd by map id %d", map_ids[0]); 113 err: 114 if (prog_fd >= 0) 115 close(prog_fd); 116 return map_fd; 117 } 118 119 static int run_test(int server_fd, int results_fd, bool xdp) 120 { 121 int client = -1, srv_client = -1; 122 int ret = 0; 123 __u32 key = 0; 124 __u32 key_gen = 1; 125 __u32 key_mss = 2; 126 __u32 value = 0; 127 __u32 value_gen = 0; 128 __u32 value_mss = 0; 129 130 if (bpf_map_update_elem(results_fd, &key, &value, 0) < 0) { 131 log_err("Can't clear results"); 132 goto err; 133 } 134 135 if (bpf_map_update_elem(results_fd, &key_gen, &value_gen, 0) < 0) { 136 log_err("Can't clear results"); 137 goto err; 138 } 139 140 if (bpf_map_update_elem(results_fd, &key_mss, &value_mss, 0) < 0) { 141 log_err("Can't clear results"); 142 goto err; 143 } 144 145 client = connect_to_server(server_fd); 146 if (client == -1) 147 goto err; 148 149 srv_client = accept(server_fd, NULL, 0); 150 if (srv_client == -1) { 151 log_err("Can't accept connection"); 152 goto err; 153 } 154 155 if (bpf_map_lookup_elem(results_fd, &key, &value) < 0) { 156 log_err("Can't lookup result"); 157 goto err; 158 } 159 160 if (value == 0) { 161 log_err("Didn't match syncookie: %u", value); 162 goto err; 163 } 164 165 if (bpf_map_lookup_elem(results_fd, &key_gen, &value_gen) < 0) { 166 log_err("Can't lookup result"); 167 goto err; 168 } 169 170 if (xdp && value_gen == 0) { 171 // SYN packets do not get passed through generic XDP, skip the 172 // rest of the test. 173 printf("Skipping XDP cookie check\n"); 174 goto out; 175 } 176 177 if (bpf_map_lookup_elem(results_fd, &key_mss, &value_mss) < 0) { 178 log_err("Can't lookup result"); 179 goto err; 180 } 181 182 if (value != value_gen) { 183 log_err("BPF generated cookie does not match kernel one"); 184 goto err; 185 } 186 187 if (value_mss < 536 || value_mss > USHRT_MAX) { 188 log_err("Unexpected MSS retrieved"); 189 goto err; 190 } 191 192 goto out; 193 194 err: 195 ret = 1; 196 out: 197 close(client); 198 close(srv_client); 199 return ret; 200 } 201 202 int main(int argc, char **argv) 203 { 204 struct sockaddr_in addr4; 205 struct sockaddr_in6 addr6; 206 int server = -1; 207 int server_v6 = -1; 208 int results = -1; 209 int err = 0; 210 bool xdp; 211 212 if (argc < 2) { 213 fprintf(stderr, "Usage: %s prog_id\n", argv[0]); 214 exit(1); 215 } 216 217 results = get_map_fd_by_prog_id(atoi(argv[1]), &xdp); 218 if (results < 0) { 219 log_err("Can't get map"); 220 goto err; 221 } 222 223 memset(&addr4, 0, sizeof(addr4)); 224 addr4.sin_family = AF_INET; 225 addr4.sin_addr.s_addr = htonl(INADDR_LOOPBACK); 226 addr4.sin_port = 0; 227 228 memset(&addr6, 0, sizeof(addr6)); 229 addr6.sin6_family = AF_INET6; 230 addr6.sin6_addr = in6addr_loopback; 231 addr6.sin6_port = 0; 232 233 server = start_server((const struct sockaddr *)&addr4, sizeof(addr4)); 234 if (server == -1) 235 goto err; 236 237 server_v6 = start_server((const struct sockaddr *)&addr6, 238 sizeof(addr6)); 239 if (server_v6 == -1) 240 goto err; 241 242 if (run_test(server, results, xdp)) 243 goto err; 244 245 if (run_test(server_v6, results, xdp)) 246 goto err; 247 248 printf("ok\n"); 249 goto out; 250 err: 251 err = 1; 252 out: 253 close(server); 254 close(server_v6); 255 close(results); 256 return err; 257 } 258