1 // SPDX-License-Identifier: GPL-2.0
2 // Copyright (c) 2018 Facebook
3 // Copyright (c) 2019 Cloudflare
4 
5 #include <string.h>
6 #include <stdlib.h>
7 #include <unistd.h>
8 
9 #include <arpa/inet.h>
10 #include <netinet/in.h>
11 #include <sys/types.h>
12 #include <sys/socket.h>
13 
14 #include <bpf/bpf.h>
15 #include <bpf/libbpf.h>
16 
17 #include "bpf_rlimit.h"
18 #include "cgroup_helpers.h"
19 
20 static int start_server(const struct sockaddr *addr, socklen_t len)
21 {
22 	int fd;
23 
24 	fd = socket(addr->sa_family, SOCK_STREAM, 0);
25 	if (fd == -1) {
26 		log_err("Failed to create server socket");
27 		goto out;
28 	}
29 
30 	if (bind(fd, addr, len) == -1) {
31 		log_err("Failed to bind server socket");
32 		goto close_out;
33 	}
34 
35 	if (listen(fd, 128) == -1) {
36 		log_err("Failed to listen on server socket");
37 		goto close_out;
38 	}
39 
40 	goto out;
41 
42 close_out:
43 	close(fd);
44 	fd = -1;
45 out:
46 	return fd;
47 }
48 
49 static int connect_to_server(int server_fd)
50 {
51 	struct sockaddr_storage addr;
52 	socklen_t len = sizeof(addr);
53 	int fd = -1;
54 
55 	if (getsockname(server_fd, (struct sockaddr *)&addr, &len)) {
56 		log_err("Failed to get server addr");
57 		goto out;
58 	}
59 
60 	fd = socket(addr.ss_family, SOCK_STREAM, 0);
61 	if (fd == -1) {
62 		log_err("Failed to create client socket");
63 		goto out;
64 	}
65 
66 	if (connect(fd, (const struct sockaddr *)&addr, len) == -1) {
67 		log_err("Fail to connect to server");
68 		goto close_out;
69 	}
70 
71 	goto out;
72 
73 close_out:
74 	close(fd);
75 	fd = -1;
76 out:
77 	return fd;
78 }
79 
80 static int get_map_fd_by_prog_id(int prog_id)
81 {
82 	struct bpf_prog_info info = {};
83 	__u32 info_len = sizeof(info);
84 	__u32 map_ids[1];
85 	int prog_fd = -1;
86 	int map_fd = -1;
87 
88 	prog_fd = bpf_prog_get_fd_by_id(prog_id);
89 	if (prog_fd < 0) {
90 		log_err("Failed to get fd by prog id %d", prog_id);
91 		goto err;
92 	}
93 
94 	info.nr_map_ids = 1;
95 	info.map_ids = (__u64)(unsigned long)map_ids;
96 
97 	if (bpf_obj_get_info_by_fd(prog_fd, &info, &info_len)) {
98 		log_err("Failed to get info by prog fd %d", prog_fd);
99 		goto err;
100 	}
101 
102 	if (!info.nr_map_ids) {
103 		log_err("No maps found for prog fd %d", prog_fd);
104 		goto err;
105 	}
106 
107 	map_fd = bpf_map_get_fd_by_id(map_ids[0]);
108 	if (map_fd < 0)
109 		log_err("Failed to get fd by map id %d", map_ids[0]);
110 err:
111 	if (prog_fd >= 0)
112 		close(prog_fd);
113 	return map_fd;
114 }
115 
116 static int run_test(int server_fd, int results_fd)
117 {
118 	int client = -1, srv_client = -1;
119 	int ret = 0;
120 	__u32 key = 0;
121 	__u64 value = 0;
122 
123 	if (bpf_map_update_elem(results_fd, &key, &value, 0) < 0) {
124 		log_err("Can't clear results");
125 		goto err;
126 	}
127 
128 	client = connect_to_server(server_fd);
129 	if (client == -1)
130 		goto err;
131 
132 	srv_client = accept(server_fd, NULL, 0);
133 	if (srv_client == -1) {
134 		log_err("Can't accept connection");
135 		goto err;
136 	}
137 
138 	if (bpf_map_lookup_elem(results_fd, &key, &value) < 0) {
139 		log_err("Can't lookup result");
140 		goto err;
141 	}
142 
143 	if (value != 1) {
144 		log_err("Didn't match syncookie: %llu", value);
145 		goto err;
146 	}
147 
148 	goto out;
149 
150 err:
151 	ret = 1;
152 out:
153 	close(client);
154 	close(srv_client);
155 	return ret;
156 }
157 
158 int main(int argc, char **argv)
159 {
160 	struct sockaddr_in addr4;
161 	struct sockaddr_in6 addr6;
162 	int server = -1;
163 	int server_v6 = -1;
164 	int results = -1;
165 	int err = 0;
166 
167 	if (argc < 2) {
168 		fprintf(stderr, "Usage: %s prog_id\n", argv[0]);
169 		exit(1);
170 	}
171 
172 	results = get_map_fd_by_prog_id(atoi(argv[1]));
173 	if (results < 0) {
174 		log_err("Can't get map");
175 		goto err;
176 	}
177 
178 	memset(&addr4, 0, sizeof(addr4));
179 	addr4.sin_family = AF_INET;
180 	addr4.sin_addr.s_addr = htonl(INADDR_LOOPBACK);
181 	addr4.sin_port = 0;
182 
183 	memset(&addr6, 0, sizeof(addr6));
184 	addr6.sin6_family = AF_INET6;
185 	addr6.sin6_addr = in6addr_loopback;
186 	addr6.sin6_port = 0;
187 
188 	server = start_server((const struct sockaddr *)&addr4, sizeof(addr4));
189 	if (server == -1)
190 		goto err;
191 
192 	server_v6 = start_server((const struct sockaddr *)&addr6,
193 				 sizeof(addr6));
194 	if (server_v6 == -1)
195 		goto err;
196 
197 	if (run_test(server, results))
198 		goto err;
199 
200 	if (run_test(server_v6, results))
201 		goto err;
202 
203 	printf("ok\n");
204 	goto out;
205 err:
206 	err = 1;
207 out:
208 	close(server);
209 	close(server_v6);
210 	close(results);
211 	return err;
212 }
213