1 // SPDX-License-Identifier: GPL-2.0
2 // Copyright (c) 2018 Facebook
3 
4 #include <stdio.h>
5 #include <unistd.h>
6 
7 #include <arpa/inet.h>
8 #include <sys/types.h>
9 #include <sys/socket.h>
10 
11 #include <linux/filter.h>
12 
13 #include <bpf/bpf.h>
14 
15 #include "cgroup_helpers.h"
16 #include "bpf_rlimit.h"
17 #include "bpf_util.h"
18 
19 #define CG_PATH		"/foo"
20 #define MAX_INSNS	512
21 
22 char bpf_log_buf[BPF_LOG_BUF_SIZE];
23 
24 struct sock_test {
25 	const char *descr;
26 	/* BPF prog properties */
27 	struct bpf_insn	insns[MAX_INSNS];
28 	enum bpf_attach_type expected_attach_type;
29 	enum bpf_attach_type attach_type;
30 	/* Socket properties */
31 	int domain;
32 	int type;
33 	/* Endpoint to bind() to */
34 	const char *ip;
35 	unsigned short port;
36 	/* Expected test result */
37 	enum {
38 		LOAD_REJECT,
39 		ATTACH_REJECT,
40 		BIND_REJECT,
41 		SUCCESS,
42 	} result;
43 };
44 
45 static struct sock_test tests[] = {
46 	{
47 		"bind4 load with invalid access: src_ip6",
48 		.insns = {
49 			BPF_MOV64_REG(BPF_REG_6, BPF_REG_1),
50 			BPF_LDX_MEM(BPF_W, BPF_REG_7, BPF_REG_6,
51 				    offsetof(struct bpf_sock, src_ip6[0])),
52 			BPF_MOV64_IMM(BPF_REG_0, 1),
53 			BPF_EXIT_INSN(),
54 		},
55 		BPF_CGROUP_INET4_POST_BIND,
56 		BPF_CGROUP_INET4_POST_BIND,
57 		0,
58 		0,
59 		NULL,
60 		0,
61 		LOAD_REJECT,
62 	},
63 	{
64 		"bind4 load with invalid access: mark",
65 		.insns = {
66 			BPF_MOV64_REG(BPF_REG_6, BPF_REG_1),
67 			BPF_LDX_MEM(BPF_W, BPF_REG_7, BPF_REG_6,
68 				    offsetof(struct bpf_sock, mark)),
69 			BPF_MOV64_IMM(BPF_REG_0, 1),
70 			BPF_EXIT_INSN(),
71 		},
72 		BPF_CGROUP_INET4_POST_BIND,
73 		BPF_CGROUP_INET4_POST_BIND,
74 		0,
75 		0,
76 		NULL,
77 		0,
78 		LOAD_REJECT,
79 	},
80 	{
81 		"bind6 load with invalid access: src_ip4",
82 		.insns = {
83 			BPF_MOV64_REG(BPF_REG_6, BPF_REG_1),
84 			BPF_LDX_MEM(BPF_W, BPF_REG_7, BPF_REG_6,
85 				    offsetof(struct bpf_sock, src_ip4)),
86 			BPF_MOV64_IMM(BPF_REG_0, 1),
87 			BPF_EXIT_INSN(),
88 		},
89 		BPF_CGROUP_INET6_POST_BIND,
90 		BPF_CGROUP_INET6_POST_BIND,
91 		0,
92 		0,
93 		NULL,
94 		0,
95 		LOAD_REJECT,
96 	},
97 	{
98 		"sock_create load with invalid access: src_port",
99 		.insns = {
100 			BPF_MOV64_REG(BPF_REG_6, BPF_REG_1),
101 			BPF_LDX_MEM(BPF_W, BPF_REG_7, BPF_REG_6,
102 				    offsetof(struct bpf_sock, src_port)),
103 			BPF_MOV64_IMM(BPF_REG_0, 1),
104 			BPF_EXIT_INSN(),
105 		},
106 		BPF_CGROUP_INET_SOCK_CREATE,
107 		BPF_CGROUP_INET_SOCK_CREATE,
108 		0,
109 		0,
110 		NULL,
111 		0,
112 		LOAD_REJECT,
113 	},
114 	{
115 		"sock_create load w/o expected_attach_type (compat mode)",
116 		.insns = {
117 			BPF_MOV64_IMM(BPF_REG_0, 1),
118 			BPF_EXIT_INSN(),
119 		},
120 		0,
121 		BPF_CGROUP_INET_SOCK_CREATE,
122 		AF_INET,
123 		SOCK_STREAM,
124 		"127.0.0.1",
125 		8097,
126 		SUCCESS,
127 	},
128 	{
129 		"sock_create load w/ expected_attach_type",
130 		.insns = {
131 			BPF_MOV64_IMM(BPF_REG_0, 1),
132 			BPF_EXIT_INSN(),
133 		},
134 		BPF_CGROUP_INET_SOCK_CREATE,
135 		BPF_CGROUP_INET_SOCK_CREATE,
136 		AF_INET,
137 		SOCK_STREAM,
138 		"127.0.0.1",
139 		8097,
140 		SUCCESS,
141 	},
142 	{
143 		"attach type mismatch bind4 vs bind6",
144 		.insns = {
145 			BPF_MOV64_IMM(BPF_REG_0, 1),
146 			BPF_EXIT_INSN(),
147 		},
148 		BPF_CGROUP_INET4_POST_BIND,
149 		BPF_CGROUP_INET6_POST_BIND,
150 		0,
151 		0,
152 		NULL,
153 		0,
154 		ATTACH_REJECT,
155 	},
156 	{
157 		"attach type mismatch bind6 vs bind4",
158 		.insns = {
159 			BPF_MOV64_IMM(BPF_REG_0, 1),
160 			BPF_EXIT_INSN(),
161 		},
162 		BPF_CGROUP_INET6_POST_BIND,
163 		BPF_CGROUP_INET4_POST_BIND,
164 		0,
165 		0,
166 		NULL,
167 		0,
168 		ATTACH_REJECT,
169 	},
170 	{
171 		"attach type mismatch default vs bind4",
172 		.insns = {
173 			BPF_MOV64_IMM(BPF_REG_0, 1),
174 			BPF_EXIT_INSN(),
175 		},
176 		0,
177 		BPF_CGROUP_INET4_POST_BIND,
178 		0,
179 		0,
180 		NULL,
181 		0,
182 		ATTACH_REJECT,
183 	},
184 	{
185 		"attach type mismatch bind6 vs sock_create",
186 		.insns = {
187 			BPF_MOV64_IMM(BPF_REG_0, 1),
188 			BPF_EXIT_INSN(),
189 		},
190 		BPF_CGROUP_INET6_POST_BIND,
191 		BPF_CGROUP_INET_SOCK_CREATE,
192 		0,
193 		0,
194 		NULL,
195 		0,
196 		ATTACH_REJECT,
197 	},
198 	{
199 		"bind4 reject all",
200 		.insns = {
201 			BPF_MOV64_IMM(BPF_REG_0, 0),
202 			BPF_EXIT_INSN(),
203 		},
204 		BPF_CGROUP_INET4_POST_BIND,
205 		BPF_CGROUP_INET4_POST_BIND,
206 		AF_INET,
207 		SOCK_STREAM,
208 		"0.0.0.0",
209 		0,
210 		BIND_REJECT,
211 	},
212 	{
213 		"bind6 reject all",
214 		.insns = {
215 			BPF_MOV64_IMM(BPF_REG_0, 0),
216 			BPF_EXIT_INSN(),
217 		},
218 		BPF_CGROUP_INET6_POST_BIND,
219 		BPF_CGROUP_INET6_POST_BIND,
220 		AF_INET6,
221 		SOCK_STREAM,
222 		"::",
223 		0,
224 		BIND_REJECT,
225 	},
226 	{
227 		"bind6 deny specific IP & port",
228 		.insns = {
229 			BPF_MOV64_REG(BPF_REG_6, BPF_REG_1),
230 
231 			/* if (ip == expected && port == expected) */
232 			BPF_LDX_MEM(BPF_W, BPF_REG_7, BPF_REG_6,
233 				    offsetof(struct bpf_sock, src_ip6[3])),
234 			BPF_JMP_IMM(BPF_JNE, BPF_REG_7, 0x01000000, 4),
235 			BPF_LDX_MEM(BPF_W, BPF_REG_7, BPF_REG_6,
236 				    offsetof(struct bpf_sock, src_port)),
237 			BPF_JMP_IMM(BPF_JNE, BPF_REG_7, 0x2001, 2),
238 
239 			/* return DENY; */
240 			BPF_MOV64_IMM(BPF_REG_0, 0),
241 			BPF_JMP_A(1),
242 
243 			/* else return ALLOW; */
244 			BPF_MOV64_IMM(BPF_REG_0, 1),
245 			BPF_EXIT_INSN(),
246 		},
247 		BPF_CGROUP_INET6_POST_BIND,
248 		BPF_CGROUP_INET6_POST_BIND,
249 		AF_INET6,
250 		SOCK_STREAM,
251 		"::1",
252 		8193,
253 		BIND_REJECT,
254 	},
255 	{
256 		"bind4 allow specific IP & port",
257 		.insns = {
258 			BPF_MOV64_REG(BPF_REG_6, BPF_REG_1),
259 
260 			/* if (ip == expected && port == expected) */
261 			BPF_LDX_MEM(BPF_W, BPF_REG_7, BPF_REG_6,
262 				    offsetof(struct bpf_sock, src_ip4)),
263 			BPF_JMP_IMM(BPF_JNE, BPF_REG_7, 0x0100007F, 4),
264 			BPF_LDX_MEM(BPF_W, BPF_REG_7, BPF_REG_6,
265 				    offsetof(struct bpf_sock, src_port)),
266 			BPF_JMP_IMM(BPF_JNE, BPF_REG_7, 0x1002, 2),
267 
268 			/* return ALLOW; */
269 			BPF_MOV64_IMM(BPF_REG_0, 1),
270 			BPF_JMP_A(1),
271 
272 			/* else return DENY; */
273 			BPF_MOV64_IMM(BPF_REG_0, 0),
274 			BPF_EXIT_INSN(),
275 		},
276 		BPF_CGROUP_INET4_POST_BIND,
277 		BPF_CGROUP_INET4_POST_BIND,
278 		AF_INET,
279 		SOCK_STREAM,
280 		"127.0.0.1",
281 		4098,
282 		SUCCESS,
283 	},
284 	{
285 		"bind4 allow all",
286 		.insns = {
287 			BPF_MOV64_IMM(BPF_REG_0, 1),
288 			BPF_EXIT_INSN(),
289 		},
290 		BPF_CGROUP_INET4_POST_BIND,
291 		BPF_CGROUP_INET4_POST_BIND,
292 		AF_INET,
293 		SOCK_STREAM,
294 		"0.0.0.0",
295 		0,
296 		SUCCESS,
297 	},
298 	{
299 		"bind6 allow all",
300 		.insns = {
301 			BPF_MOV64_IMM(BPF_REG_0, 1),
302 			BPF_EXIT_INSN(),
303 		},
304 		BPF_CGROUP_INET6_POST_BIND,
305 		BPF_CGROUP_INET6_POST_BIND,
306 		AF_INET6,
307 		SOCK_STREAM,
308 		"::",
309 		0,
310 		SUCCESS,
311 	},
312 };
313 
314 static size_t probe_prog_length(const struct bpf_insn *fp)
315 {
316 	size_t len;
317 
318 	for (len = MAX_INSNS - 1; len > 0; --len)
319 		if (fp[len].code != 0 || fp[len].imm != 0)
320 			break;
321 	return len + 1;
322 }
323 
324 static int load_sock_prog(const struct bpf_insn *prog,
325 			  enum bpf_attach_type attach_type)
326 {
327 	struct bpf_load_program_attr attr;
328 
329 	memset(&attr, 0, sizeof(struct bpf_load_program_attr));
330 	attr.prog_type = BPF_PROG_TYPE_CGROUP_SOCK;
331 	attr.expected_attach_type = attach_type;
332 	attr.insns = prog;
333 	attr.insns_cnt = probe_prog_length(attr.insns);
334 	attr.license = "GPL";
335 
336 	return bpf_load_program_xattr(&attr, bpf_log_buf, BPF_LOG_BUF_SIZE);
337 }
338 
339 static int attach_sock_prog(int cgfd, int progfd,
340 			    enum bpf_attach_type attach_type)
341 {
342 	return bpf_prog_attach(progfd, cgfd, attach_type, BPF_F_ALLOW_OVERRIDE);
343 }
344 
345 static int bind_sock(int domain, int type, const char *ip, unsigned short port)
346 {
347 	struct sockaddr_storage addr;
348 	struct sockaddr_in6 *addr6;
349 	struct sockaddr_in *addr4;
350 	int sockfd = -1;
351 	socklen_t len;
352 	int err = 0;
353 
354 	sockfd = socket(domain, type, 0);
355 	if (sockfd < 0)
356 		goto err;
357 
358 	memset(&addr, 0, sizeof(addr));
359 
360 	if (domain == AF_INET) {
361 		len = sizeof(struct sockaddr_in);
362 		addr4 = (struct sockaddr_in *)&addr;
363 		addr4->sin_family = domain;
364 		addr4->sin_port = htons(port);
365 		if (inet_pton(domain, ip, (void *)&addr4->sin_addr) != 1)
366 			goto err;
367 	} else if (domain == AF_INET6) {
368 		len = sizeof(struct sockaddr_in6);
369 		addr6 = (struct sockaddr_in6 *)&addr;
370 		addr6->sin6_family = domain;
371 		addr6->sin6_port = htons(port);
372 		if (inet_pton(domain, ip, (void *)&addr6->sin6_addr) != 1)
373 			goto err;
374 	} else {
375 		goto err;
376 	}
377 
378 	if (bind(sockfd, (const struct sockaddr *)&addr, len) == -1)
379 		goto err;
380 
381 	goto out;
382 err:
383 	err = -1;
384 out:
385 	close(sockfd);
386 	return err;
387 }
388 
389 static int run_test_case(int cgfd, const struct sock_test *test)
390 {
391 	int progfd = -1;
392 	int err = 0;
393 
394 	printf("Test case: %s .. ", test->descr);
395 	progfd = load_sock_prog(test->insns, test->expected_attach_type);
396 	if (progfd < 0) {
397 		if (test->result == LOAD_REJECT)
398 			goto out;
399 		else
400 			goto err;
401 	}
402 
403 	if (attach_sock_prog(cgfd, progfd, test->attach_type) == -1) {
404 		if (test->result == ATTACH_REJECT)
405 			goto out;
406 		else
407 			goto err;
408 	}
409 
410 	if (bind_sock(test->domain, test->type, test->ip, test->port) == -1) {
411 		/* sys_bind() may fail for different reasons, errno has to be
412 		 * checked to confirm that BPF program rejected it.
413 		 */
414 		if (test->result == BIND_REJECT && errno == EPERM)
415 			goto out;
416 		else
417 			goto err;
418 	}
419 
420 
421 	if (test->result != SUCCESS)
422 		goto err;
423 
424 	goto out;
425 err:
426 	err = -1;
427 out:
428 	/* Detaching w/o checking return code: best effort attempt. */
429 	if (progfd != -1)
430 		bpf_prog_detach(cgfd, test->attach_type);
431 	close(progfd);
432 	printf("[%s]\n", err ? "FAIL" : "PASS");
433 	return err;
434 }
435 
436 static int run_tests(int cgfd)
437 {
438 	int passes = 0;
439 	int fails = 0;
440 	int i;
441 
442 	for (i = 0; i < ARRAY_SIZE(tests); ++i) {
443 		if (run_test_case(cgfd, &tests[i]))
444 			++fails;
445 		else
446 			++passes;
447 	}
448 	printf("Summary: %d PASSED, %d FAILED\n", passes, fails);
449 	return fails ? -1 : 0;
450 }
451 
452 int main(int argc, char **argv)
453 {
454 	int cgfd = -1;
455 	int err = 0;
456 
457 	if (setup_cgroup_environment())
458 		goto err;
459 
460 	cgfd = create_and_get_cgroup(CG_PATH);
461 	if (!cgfd)
462 		goto err;
463 
464 	if (join_cgroup(CG_PATH))
465 		goto err;
466 
467 	if (run_tests(cgfd))
468 		goto err;
469 
470 	goto out;
471 err:
472 	err = -1;
473 out:
474 	close(cgfd);
475 	cleanup_cgroup_environment();
476 	return err;
477 }
478