1 // SPDX-License-Identifier: GPL-2.0
2 // Copyright (c) 2018 Facebook
3 
4 #include <stdio.h>
5 #include <unistd.h>
6 
7 #include <arpa/inet.h>
8 #include <sys/types.h>
9 #include <sys/socket.h>
10 
11 #include <linux/filter.h>
12 
13 #include <bpf/bpf.h>
14 
15 #include "cgroup_helpers.h"
16 #include <bpf/bpf_endian.h>
17 #include "bpf_rlimit.h"
18 #include "bpf_util.h"
19 
20 #define CG_PATH		"/foo"
21 #define MAX_INSNS	512
22 
23 char bpf_log_buf[BPF_LOG_BUF_SIZE];
24 static bool verbose = false;
25 
26 struct sock_test {
27 	const char *descr;
28 	/* BPF prog properties */
29 	struct bpf_insn	insns[MAX_INSNS];
30 	enum bpf_attach_type expected_attach_type;
31 	enum bpf_attach_type attach_type;
32 	/* Socket properties */
33 	int domain;
34 	int type;
35 	/* Endpoint to bind() to */
36 	const char *ip;
37 	unsigned short port;
38 	/* Expected test result */
39 	enum {
40 		LOAD_REJECT,
41 		ATTACH_REJECT,
42 		BIND_REJECT,
43 		SUCCESS,
44 	} result;
45 };
46 
47 static struct sock_test tests[] = {
48 	{
49 		"bind4 load with invalid access: src_ip6",
50 		.insns = {
51 			BPF_MOV64_REG(BPF_REG_6, BPF_REG_1),
52 			BPF_LDX_MEM(BPF_W, BPF_REG_7, BPF_REG_6,
53 				    offsetof(struct bpf_sock, src_ip6[0])),
54 			BPF_MOV64_IMM(BPF_REG_0, 1),
55 			BPF_EXIT_INSN(),
56 		},
57 		BPF_CGROUP_INET4_POST_BIND,
58 		BPF_CGROUP_INET4_POST_BIND,
59 		0,
60 		0,
61 		NULL,
62 		0,
63 		LOAD_REJECT,
64 	},
65 	{
66 		"bind4 load with invalid access: mark",
67 		.insns = {
68 			BPF_MOV64_REG(BPF_REG_6, BPF_REG_1),
69 			BPF_LDX_MEM(BPF_W, BPF_REG_7, BPF_REG_6,
70 				    offsetof(struct bpf_sock, mark)),
71 			BPF_MOV64_IMM(BPF_REG_0, 1),
72 			BPF_EXIT_INSN(),
73 		},
74 		BPF_CGROUP_INET4_POST_BIND,
75 		BPF_CGROUP_INET4_POST_BIND,
76 		0,
77 		0,
78 		NULL,
79 		0,
80 		LOAD_REJECT,
81 	},
82 	{
83 		"bind6 load with invalid access: src_ip4",
84 		.insns = {
85 			BPF_MOV64_REG(BPF_REG_6, BPF_REG_1),
86 			BPF_LDX_MEM(BPF_W, BPF_REG_7, BPF_REG_6,
87 				    offsetof(struct bpf_sock, src_ip4)),
88 			BPF_MOV64_IMM(BPF_REG_0, 1),
89 			BPF_EXIT_INSN(),
90 		},
91 		BPF_CGROUP_INET6_POST_BIND,
92 		BPF_CGROUP_INET6_POST_BIND,
93 		0,
94 		0,
95 		NULL,
96 		0,
97 		LOAD_REJECT,
98 	},
99 	{
100 		"sock_create load with invalid access: src_port",
101 		.insns = {
102 			BPF_MOV64_REG(BPF_REG_6, BPF_REG_1),
103 			BPF_LDX_MEM(BPF_W, BPF_REG_7, BPF_REG_6,
104 				    offsetof(struct bpf_sock, src_port)),
105 			BPF_MOV64_IMM(BPF_REG_0, 1),
106 			BPF_EXIT_INSN(),
107 		},
108 		BPF_CGROUP_INET_SOCK_CREATE,
109 		BPF_CGROUP_INET_SOCK_CREATE,
110 		0,
111 		0,
112 		NULL,
113 		0,
114 		LOAD_REJECT,
115 	},
116 	{
117 		"sock_create load w/o expected_attach_type (compat mode)",
118 		.insns = {
119 			BPF_MOV64_IMM(BPF_REG_0, 1),
120 			BPF_EXIT_INSN(),
121 		},
122 		0,
123 		BPF_CGROUP_INET_SOCK_CREATE,
124 		AF_INET,
125 		SOCK_STREAM,
126 		"127.0.0.1",
127 		8097,
128 		SUCCESS,
129 	},
130 	{
131 		"sock_create load w/ expected_attach_type",
132 		.insns = {
133 			BPF_MOV64_IMM(BPF_REG_0, 1),
134 			BPF_EXIT_INSN(),
135 		},
136 		BPF_CGROUP_INET_SOCK_CREATE,
137 		BPF_CGROUP_INET_SOCK_CREATE,
138 		AF_INET,
139 		SOCK_STREAM,
140 		"127.0.0.1",
141 		8097,
142 		SUCCESS,
143 	},
144 	{
145 		"attach type mismatch bind4 vs bind6",
146 		.insns = {
147 			BPF_MOV64_IMM(BPF_REG_0, 1),
148 			BPF_EXIT_INSN(),
149 		},
150 		BPF_CGROUP_INET4_POST_BIND,
151 		BPF_CGROUP_INET6_POST_BIND,
152 		0,
153 		0,
154 		NULL,
155 		0,
156 		ATTACH_REJECT,
157 	},
158 	{
159 		"attach type mismatch bind6 vs bind4",
160 		.insns = {
161 			BPF_MOV64_IMM(BPF_REG_0, 1),
162 			BPF_EXIT_INSN(),
163 		},
164 		BPF_CGROUP_INET6_POST_BIND,
165 		BPF_CGROUP_INET4_POST_BIND,
166 		0,
167 		0,
168 		NULL,
169 		0,
170 		ATTACH_REJECT,
171 	},
172 	{
173 		"attach type mismatch default vs bind4",
174 		.insns = {
175 			BPF_MOV64_IMM(BPF_REG_0, 1),
176 			BPF_EXIT_INSN(),
177 		},
178 		0,
179 		BPF_CGROUP_INET4_POST_BIND,
180 		0,
181 		0,
182 		NULL,
183 		0,
184 		ATTACH_REJECT,
185 	},
186 	{
187 		"attach type mismatch bind6 vs sock_create",
188 		.insns = {
189 			BPF_MOV64_IMM(BPF_REG_0, 1),
190 			BPF_EXIT_INSN(),
191 		},
192 		BPF_CGROUP_INET6_POST_BIND,
193 		BPF_CGROUP_INET_SOCK_CREATE,
194 		0,
195 		0,
196 		NULL,
197 		0,
198 		ATTACH_REJECT,
199 	},
200 	{
201 		"bind4 reject all",
202 		.insns = {
203 			BPF_MOV64_IMM(BPF_REG_0, 0),
204 			BPF_EXIT_INSN(),
205 		},
206 		BPF_CGROUP_INET4_POST_BIND,
207 		BPF_CGROUP_INET4_POST_BIND,
208 		AF_INET,
209 		SOCK_STREAM,
210 		"0.0.0.0",
211 		0,
212 		BIND_REJECT,
213 	},
214 	{
215 		"bind6 reject all",
216 		.insns = {
217 			BPF_MOV64_IMM(BPF_REG_0, 0),
218 			BPF_EXIT_INSN(),
219 		},
220 		BPF_CGROUP_INET6_POST_BIND,
221 		BPF_CGROUP_INET6_POST_BIND,
222 		AF_INET6,
223 		SOCK_STREAM,
224 		"::",
225 		0,
226 		BIND_REJECT,
227 	},
228 	{
229 		"bind6 deny specific IP & port",
230 		.insns = {
231 			BPF_MOV64_REG(BPF_REG_6, BPF_REG_1),
232 
233 			/* if (ip == expected && port == expected) */
234 			BPF_LDX_MEM(BPF_W, BPF_REG_7, BPF_REG_6,
235 				    offsetof(struct bpf_sock, src_ip6[3])),
236 			BPF_JMP_IMM(BPF_JNE, BPF_REG_7,
237 				    __bpf_constant_ntohl(0x00000001), 4),
238 			BPF_LDX_MEM(BPF_W, BPF_REG_7, BPF_REG_6,
239 				    offsetof(struct bpf_sock, src_port)),
240 			BPF_JMP_IMM(BPF_JNE, BPF_REG_7, 0x2001, 2),
241 
242 			/* return DENY; */
243 			BPF_MOV64_IMM(BPF_REG_0, 0),
244 			BPF_JMP_A(1),
245 
246 			/* else return ALLOW; */
247 			BPF_MOV64_IMM(BPF_REG_0, 1),
248 			BPF_EXIT_INSN(),
249 		},
250 		BPF_CGROUP_INET6_POST_BIND,
251 		BPF_CGROUP_INET6_POST_BIND,
252 		AF_INET6,
253 		SOCK_STREAM,
254 		"::1",
255 		8193,
256 		BIND_REJECT,
257 	},
258 	{
259 		"bind4 allow specific IP & port",
260 		.insns = {
261 			BPF_MOV64_REG(BPF_REG_6, BPF_REG_1),
262 
263 			/* if (ip == expected && port == expected) */
264 			BPF_LDX_MEM(BPF_W, BPF_REG_7, BPF_REG_6,
265 				    offsetof(struct bpf_sock, src_ip4)),
266 			BPF_JMP_IMM(BPF_JNE, BPF_REG_7,
267 				    __bpf_constant_ntohl(0x7F000001), 4),
268 			BPF_LDX_MEM(BPF_W, BPF_REG_7, BPF_REG_6,
269 				    offsetof(struct bpf_sock, src_port)),
270 			BPF_JMP_IMM(BPF_JNE, BPF_REG_7, 0x1002, 2),
271 
272 			/* return ALLOW; */
273 			BPF_MOV64_IMM(BPF_REG_0, 1),
274 			BPF_JMP_A(1),
275 
276 			/* else return DENY; */
277 			BPF_MOV64_IMM(BPF_REG_0, 0),
278 			BPF_EXIT_INSN(),
279 		},
280 		BPF_CGROUP_INET4_POST_BIND,
281 		BPF_CGROUP_INET4_POST_BIND,
282 		AF_INET,
283 		SOCK_STREAM,
284 		"127.0.0.1",
285 		4098,
286 		SUCCESS,
287 	},
288 	{
289 		"bind4 allow all",
290 		.insns = {
291 			BPF_MOV64_IMM(BPF_REG_0, 1),
292 			BPF_EXIT_INSN(),
293 		},
294 		BPF_CGROUP_INET4_POST_BIND,
295 		BPF_CGROUP_INET4_POST_BIND,
296 		AF_INET,
297 		SOCK_STREAM,
298 		"0.0.0.0",
299 		0,
300 		SUCCESS,
301 	},
302 	{
303 		"bind6 allow all",
304 		.insns = {
305 			BPF_MOV64_IMM(BPF_REG_0, 1),
306 			BPF_EXIT_INSN(),
307 		},
308 		BPF_CGROUP_INET6_POST_BIND,
309 		BPF_CGROUP_INET6_POST_BIND,
310 		AF_INET6,
311 		SOCK_STREAM,
312 		"::",
313 		0,
314 		SUCCESS,
315 	},
316 };
317 
318 static size_t probe_prog_length(const struct bpf_insn *fp)
319 {
320 	size_t len;
321 
322 	for (len = MAX_INSNS - 1; len > 0; --len)
323 		if (fp[len].code != 0 || fp[len].imm != 0)
324 			break;
325 	return len + 1;
326 }
327 
328 static int load_sock_prog(const struct bpf_insn *prog,
329 			  enum bpf_attach_type attach_type)
330 {
331 	struct bpf_load_program_attr attr;
332 	int ret;
333 
334 	memset(&attr, 0, sizeof(struct bpf_load_program_attr));
335 	attr.prog_type = BPF_PROG_TYPE_CGROUP_SOCK;
336 	attr.expected_attach_type = attach_type;
337 	attr.insns = prog;
338 	attr.insns_cnt = probe_prog_length(attr.insns);
339 	attr.license = "GPL";
340 	attr.log_level = 2;
341 
342 	ret = bpf_load_program_xattr(&attr, bpf_log_buf, BPF_LOG_BUF_SIZE);
343 	if (verbose && ret < 0)
344 		fprintf(stderr, "%s\n", bpf_log_buf);
345 
346 	return ret;
347 }
348 
349 static int attach_sock_prog(int cgfd, int progfd,
350 			    enum bpf_attach_type attach_type)
351 {
352 	return bpf_prog_attach(progfd, cgfd, attach_type, BPF_F_ALLOW_OVERRIDE);
353 }
354 
355 static int bind_sock(int domain, int type, const char *ip, unsigned short port)
356 {
357 	struct sockaddr_storage addr;
358 	struct sockaddr_in6 *addr6;
359 	struct sockaddr_in *addr4;
360 	int sockfd = -1;
361 	socklen_t len;
362 	int err = 0;
363 
364 	sockfd = socket(domain, type, 0);
365 	if (sockfd < 0)
366 		goto err;
367 
368 	memset(&addr, 0, sizeof(addr));
369 
370 	if (domain == AF_INET) {
371 		len = sizeof(struct sockaddr_in);
372 		addr4 = (struct sockaddr_in *)&addr;
373 		addr4->sin_family = domain;
374 		addr4->sin_port = htons(port);
375 		if (inet_pton(domain, ip, (void *)&addr4->sin_addr) != 1)
376 			goto err;
377 	} else if (domain == AF_INET6) {
378 		len = sizeof(struct sockaddr_in6);
379 		addr6 = (struct sockaddr_in6 *)&addr;
380 		addr6->sin6_family = domain;
381 		addr6->sin6_port = htons(port);
382 		if (inet_pton(domain, ip, (void *)&addr6->sin6_addr) != 1)
383 			goto err;
384 	} else {
385 		goto err;
386 	}
387 
388 	if (bind(sockfd, (const struct sockaddr *)&addr, len) == -1)
389 		goto err;
390 
391 	goto out;
392 err:
393 	err = -1;
394 out:
395 	close(sockfd);
396 	return err;
397 }
398 
399 static int run_test_case(int cgfd, const struct sock_test *test)
400 {
401 	int progfd = -1;
402 	int err = 0;
403 
404 	printf("Test case: %s .. ", test->descr);
405 	progfd = load_sock_prog(test->insns, test->expected_attach_type);
406 	if (progfd < 0) {
407 		if (test->result == LOAD_REJECT)
408 			goto out;
409 		else
410 			goto err;
411 	}
412 
413 	if (attach_sock_prog(cgfd, progfd, test->attach_type) == -1) {
414 		if (test->result == ATTACH_REJECT)
415 			goto out;
416 		else
417 			goto err;
418 	}
419 
420 	if (bind_sock(test->domain, test->type, test->ip, test->port) == -1) {
421 		/* sys_bind() may fail for different reasons, errno has to be
422 		 * checked to confirm that BPF program rejected it.
423 		 */
424 		if (test->result == BIND_REJECT && errno == EPERM)
425 			goto out;
426 		else
427 			goto err;
428 	}
429 
430 
431 	if (test->result != SUCCESS)
432 		goto err;
433 
434 	goto out;
435 err:
436 	err = -1;
437 out:
438 	/* Detaching w/o checking return code: best effort attempt. */
439 	if (progfd != -1)
440 		bpf_prog_detach(cgfd, test->attach_type);
441 	close(progfd);
442 	printf("[%s]\n", err ? "FAIL" : "PASS");
443 	return err;
444 }
445 
446 static int run_tests(int cgfd)
447 {
448 	int passes = 0;
449 	int fails = 0;
450 	int i;
451 
452 	for (i = 0; i < ARRAY_SIZE(tests); ++i) {
453 		if (run_test_case(cgfd, &tests[i]))
454 			++fails;
455 		else
456 			++passes;
457 	}
458 	printf("Summary: %d PASSED, %d FAILED\n", passes, fails);
459 	return fails ? -1 : 0;
460 }
461 
462 int main(int argc, char **argv)
463 {
464 	int cgfd = -1;
465 	int err = 0;
466 
467 	if (setup_cgroup_environment())
468 		goto err;
469 
470 	cgfd = create_and_get_cgroup(CG_PATH);
471 	if (cgfd < 0)
472 		goto err;
473 
474 	if (join_cgroup(CG_PATH))
475 		goto err;
476 
477 	if (run_tests(cgfd))
478 		goto err;
479 
480 	goto out;
481 err:
482 	err = -1;
483 out:
484 	close(cgfd);
485 	cleanup_cgroup_environment();
486 	return err;
487 }
488