1 // SPDX-License-Identifier: GPL-2.0 2 // Copyright (c) 2018 Facebook 3 // Copyright (c) 2019 Cloudflare 4 5 #include <string.h> 6 #include <stdlib.h> 7 #include <unistd.h> 8 9 #include <arpa/inet.h> 10 #include <netinet/in.h> 11 #include <sys/types.h> 12 #include <sys/socket.h> 13 14 #include <bpf/bpf.h> 15 #include <bpf/libbpf.h> 16 17 #include "bpf_rlimit.h" 18 #include "cgroup_helpers.h" 19 20 static int start_server(const struct sockaddr *addr, socklen_t len) 21 { 22 int fd; 23 24 fd = socket(addr->sa_family, SOCK_STREAM, 0); 25 if (fd == -1) { 26 log_err("Failed to create server socket"); 27 goto out; 28 } 29 30 if (bind(fd, addr, len) == -1) { 31 log_err("Failed to bind server socket"); 32 goto close_out; 33 } 34 35 if (listen(fd, 128) == -1) { 36 log_err("Failed to listen on server socket"); 37 goto close_out; 38 } 39 40 goto out; 41 42 close_out: 43 close(fd); 44 fd = -1; 45 out: 46 return fd; 47 } 48 49 static int connect_to_server(int server_fd) 50 { 51 struct sockaddr_storage addr; 52 socklen_t len = sizeof(addr); 53 int fd = -1; 54 55 if (getsockname(server_fd, (struct sockaddr *)&addr, &len)) { 56 log_err("Failed to get server addr"); 57 goto out; 58 } 59 60 fd = socket(addr.ss_family, SOCK_STREAM, 0); 61 if (fd == -1) { 62 log_err("Failed to create client socket"); 63 goto out; 64 } 65 66 if (connect(fd, (const struct sockaddr *)&addr, len) == -1) { 67 log_err("Fail to connect to server"); 68 goto close_out; 69 } 70 71 goto out; 72 73 close_out: 74 close(fd); 75 fd = -1; 76 out: 77 return fd; 78 } 79 80 static int get_map_fd_by_prog_id(int prog_id) 81 { 82 struct bpf_prog_info info = {}; 83 __u32 info_len = sizeof(info); 84 __u32 map_ids[1]; 85 int prog_fd = -1; 86 int map_fd = -1; 87 88 prog_fd = bpf_prog_get_fd_by_id(prog_id); 89 if (prog_fd < 0) { 90 log_err("Failed to get fd by prog id %d", prog_id); 91 goto err; 92 } 93 94 info.nr_map_ids = 1; 95 info.map_ids = (__u64)(unsigned long)map_ids; 96 97 if (bpf_obj_get_info_by_fd(prog_fd, &info, &info_len)) { 98 log_err("Failed to get info by prog fd %d", prog_fd); 99 goto err; 100 } 101 102 if (!info.nr_map_ids) { 103 log_err("No maps found for prog fd %d", prog_fd); 104 goto err; 105 } 106 107 map_fd = bpf_map_get_fd_by_id(map_ids[0]); 108 if (map_fd < 0) 109 log_err("Failed to get fd by map id %d", map_ids[0]); 110 err: 111 if (prog_fd >= 0) 112 close(prog_fd); 113 return map_fd; 114 } 115 116 static int run_test(int server_fd, int results_fd) 117 { 118 int client = -1, srv_client = -1; 119 int ret = 0; 120 __u32 key = 0; 121 __u64 value = 0; 122 123 if (bpf_map_update_elem(results_fd, &key, &value, 0) < 0) { 124 log_err("Can't clear results"); 125 goto err; 126 } 127 128 client = connect_to_server(server_fd); 129 if (client == -1) 130 goto err; 131 132 srv_client = accept(server_fd, NULL, 0); 133 if (srv_client == -1) { 134 log_err("Can't accept connection"); 135 goto err; 136 } 137 138 if (bpf_map_lookup_elem(results_fd, &key, &value) < 0) { 139 log_err("Can't lookup result"); 140 goto err; 141 } 142 143 if (value != 1) { 144 log_err("Didn't match syncookie: %llu", value); 145 goto err; 146 } 147 148 goto out; 149 150 err: 151 ret = 1; 152 out: 153 close(client); 154 close(srv_client); 155 return ret; 156 } 157 158 int main(int argc, char **argv) 159 { 160 struct sockaddr_in addr4; 161 struct sockaddr_in6 addr6; 162 int server = -1; 163 int server_v6 = -1; 164 int results = -1; 165 int err = 0; 166 167 if (argc < 2) { 168 fprintf(stderr, "Usage: %s prog_id\n", argv[0]); 169 exit(1); 170 } 171 172 results = get_map_fd_by_prog_id(atoi(argv[1])); 173 if (results < 0) { 174 log_err("Can't get map"); 175 goto err; 176 } 177 178 memset(&addr4, 0, sizeof(addr4)); 179 addr4.sin_family = AF_INET; 180 addr4.sin_addr.s_addr = htonl(INADDR_LOOPBACK); 181 addr4.sin_port = 0; 182 183 memset(&addr6, 0, sizeof(addr6)); 184 addr6.sin6_family = AF_INET6; 185 addr6.sin6_addr = in6addr_loopback; 186 addr6.sin6_port = 0; 187 188 server = start_server((const struct sockaddr *)&addr4, sizeof(addr4)); 189 if (server == -1) 190 goto err; 191 192 server_v6 = start_server((const struct sockaddr *)&addr6, 193 sizeof(addr6)); 194 if (server_v6 == -1) 195 goto err; 196 197 if (run_test(server, results)) 198 goto err; 199 200 if (run_test(server_v6, results)) 201 goto err; 202 203 printf("ok\n"); 204 goto out; 205 err: 206 err = 1; 207 out: 208 close(server); 209 close(server_v6); 210 close(results); 211 return err; 212 } 213