1 // SPDX-License-Identifier: GPL-2.0
2 // Copyright (c) 2018 Facebook
3 
4 #include <stdio.h>
5 #include <unistd.h>
6 
7 #include <arpa/inet.h>
8 #include <sys/types.h>
9 #include <sys/socket.h>
10 
11 #include <linux/filter.h>
12 
13 #include <bpf/bpf.h>
14 
15 #include "cgroup_helpers.h"
16 
17 #ifndef ARRAY_SIZE
18 # define ARRAY_SIZE(x) (sizeof(x) / sizeof((x)[0]))
19 #endif
20 
21 #define CG_PATH		"/foo"
22 #define MAX_INSNS	512
23 
24 char bpf_log_buf[BPF_LOG_BUF_SIZE];
25 
26 struct sock_test {
27 	const char *descr;
28 	/* BPF prog properties */
29 	struct bpf_insn	insns[MAX_INSNS];
30 	enum bpf_attach_type expected_attach_type;
31 	enum bpf_attach_type attach_type;
32 	/* Socket properties */
33 	int domain;
34 	int type;
35 	/* Endpoint to bind() to */
36 	const char *ip;
37 	unsigned short port;
38 	/* Expected test result */
39 	enum {
40 		LOAD_REJECT,
41 		ATTACH_REJECT,
42 		BIND_REJECT,
43 		SUCCESS,
44 	} result;
45 };
46 
47 static struct sock_test tests[] = {
48 	{
49 		"bind4 load with invalid access: src_ip6",
50 		.insns = {
51 			BPF_MOV64_REG(BPF_REG_6, BPF_REG_1),
52 			BPF_LDX_MEM(BPF_W, BPF_REG_7, BPF_REG_6,
53 				    offsetof(struct bpf_sock, src_ip6[0])),
54 			BPF_MOV64_IMM(BPF_REG_0, 1),
55 			BPF_EXIT_INSN(),
56 		},
57 		BPF_CGROUP_INET4_POST_BIND,
58 		BPF_CGROUP_INET4_POST_BIND,
59 		0,
60 		0,
61 		NULL,
62 		0,
63 		LOAD_REJECT,
64 	},
65 	{
66 		"bind4 load with invalid access: mark",
67 		.insns = {
68 			BPF_MOV64_REG(BPF_REG_6, BPF_REG_1),
69 			BPF_LDX_MEM(BPF_W, BPF_REG_7, BPF_REG_6,
70 				    offsetof(struct bpf_sock, mark)),
71 			BPF_MOV64_IMM(BPF_REG_0, 1),
72 			BPF_EXIT_INSN(),
73 		},
74 		BPF_CGROUP_INET4_POST_BIND,
75 		BPF_CGROUP_INET4_POST_BIND,
76 		0,
77 		0,
78 		NULL,
79 		0,
80 		LOAD_REJECT,
81 	},
82 	{
83 		"bind6 load with invalid access: src_ip4",
84 		.insns = {
85 			BPF_MOV64_REG(BPF_REG_6, BPF_REG_1),
86 			BPF_LDX_MEM(BPF_W, BPF_REG_7, BPF_REG_6,
87 				    offsetof(struct bpf_sock, src_ip4)),
88 			BPF_MOV64_IMM(BPF_REG_0, 1),
89 			BPF_EXIT_INSN(),
90 		},
91 		BPF_CGROUP_INET6_POST_BIND,
92 		BPF_CGROUP_INET6_POST_BIND,
93 		0,
94 		0,
95 		NULL,
96 		0,
97 		LOAD_REJECT,
98 	},
99 	{
100 		"sock_create load with invalid access: src_port",
101 		.insns = {
102 			BPF_MOV64_REG(BPF_REG_6, BPF_REG_1),
103 			BPF_LDX_MEM(BPF_W, BPF_REG_7, BPF_REG_6,
104 				    offsetof(struct bpf_sock, src_port)),
105 			BPF_MOV64_IMM(BPF_REG_0, 1),
106 			BPF_EXIT_INSN(),
107 		},
108 		BPF_CGROUP_INET_SOCK_CREATE,
109 		BPF_CGROUP_INET_SOCK_CREATE,
110 		0,
111 		0,
112 		NULL,
113 		0,
114 		LOAD_REJECT,
115 	},
116 	{
117 		"sock_create load w/o expected_attach_type (compat mode)",
118 		.insns = {
119 			BPF_MOV64_IMM(BPF_REG_0, 1),
120 			BPF_EXIT_INSN(),
121 		},
122 		0,
123 		BPF_CGROUP_INET_SOCK_CREATE,
124 		AF_INET,
125 		SOCK_STREAM,
126 		"127.0.0.1",
127 		8097,
128 		SUCCESS,
129 	},
130 	{
131 		"sock_create load w/ expected_attach_type",
132 		.insns = {
133 			BPF_MOV64_IMM(BPF_REG_0, 1),
134 			BPF_EXIT_INSN(),
135 		},
136 		BPF_CGROUP_INET_SOCK_CREATE,
137 		BPF_CGROUP_INET_SOCK_CREATE,
138 		AF_INET,
139 		SOCK_STREAM,
140 		"127.0.0.1",
141 		8097,
142 		SUCCESS,
143 	},
144 	{
145 		"attach type mismatch bind4 vs bind6",
146 		.insns = {
147 			BPF_MOV64_IMM(BPF_REG_0, 1),
148 			BPF_EXIT_INSN(),
149 		},
150 		BPF_CGROUP_INET4_POST_BIND,
151 		BPF_CGROUP_INET6_POST_BIND,
152 		0,
153 		0,
154 		NULL,
155 		0,
156 		ATTACH_REJECT,
157 	},
158 	{
159 		"attach type mismatch bind6 vs bind4",
160 		.insns = {
161 			BPF_MOV64_IMM(BPF_REG_0, 1),
162 			BPF_EXIT_INSN(),
163 		},
164 		BPF_CGROUP_INET6_POST_BIND,
165 		BPF_CGROUP_INET4_POST_BIND,
166 		0,
167 		0,
168 		NULL,
169 		0,
170 		ATTACH_REJECT,
171 	},
172 	{
173 		"attach type mismatch default vs bind4",
174 		.insns = {
175 			BPF_MOV64_IMM(BPF_REG_0, 1),
176 			BPF_EXIT_INSN(),
177 		},
178 		0,
179 		BPF_CGROUP_INET4_POST_BIND,
180 		0,
181 		0,
182 		NULL,
183 		0,
184 		ATTACH_REJECT,
185 	},
186 	{
187 		"attach type mismatch bind6 vs sock_create",
188 		.insns = {
189 			BPF_MOV64_IMM(BPF_REG_0, 1),
190 			BPF_EXIT_INSN(),
191 		},
192 		BPF_CGROUP_INET6_POST_BIND,
193 		BPF_CGROUP_INET_SOCK_CREATE,
194 		0,
195 		0,
196 		NULL,
197 		0,
198 		ATTACH_REJECT,
199 	},
200 	{
201 		"bind4 reject all",
202 		.insns = {
203 			BPF_MOV64_IMM(BPF_REG_0, 0),
204 			BPF_EXIT_INSN(),
205 		},
206 		BPF_CGROUP_INET4_POST_BIND,
207 		BPF_CGROUP_INET4_POST_BIND,
208 		AF_INET,
209 		SOCK_STREAM,
210 		"0.0.0.0",
211 		0,
212 		BIND_REJECT,
213 	},
214 	{
215 		"bind6 reject all",
216 		.insns = {
217 			BPF_MOV64_IMM(BPF_REG_0, 0),
218 			BPF_EXIT_INSN(),
219 		},
220 		BPF_CGROUP_INET6_POST_BIND,
221 		BPF_CGROUP_INET6_POST_BIND,
222 		AF_INET6,
223 		SOCK_STREAM,
224 		"::",
225 		0,
226 		BIND_REJECT,
227 	},
228 	{
229 		"bind6 deny specific IP & port",
230 		.insns = {
231 			BPF_MOV64_REG(BPF_REG_6, BPF_REG_1),
232 
233 			/* if (ip == expected && port == expected) */
234 			BPF_LDX_MEM(BPF_W, BPF_REG_7, BPF_REG_6,
235 				    offsetof(struct bpf_sock, src_ip6[3])),
236 			BPF_JMP_IMM(BPF_JNE, BPF_REG_7, 0x01000000, 4),
237 			BPF_LDX_MEM(BPF_W, BPF_REG_7, BPF_REG_6,
238 				    offsetof(struct bpf_sock, src_port)),
239 			BPF_JMP_IMM(BPF_JNE, BPF_REG_7, 0x2001, 2),
240 
241 			/* return DENY; */
242 			BPF_MOV64_IMM(BPF_REG_0, 0),
243 			BPF_JMP_A(1),
244 
245 			/* else return ALLOW; */
246 			BPF_MOV64_IMM(BPF_REG_0, 1),
247 			BPF_EXIT_INSN(),
248 		},
249 		BPF_CGROUP_INET6_POST_BIND,
250 		BPF_CGROUP_INET6_POST_BIND,
251 		AF_INET6,
252 		SOCK_STREAM,
253 		"::1",
254 		8193,
255 		BIND_REJECT,
256 	},
257 	{
258 		"bind4 allow specific IP & port",
259 		.insns = {
260 			BPF_MOV64_REG(BPF_REG_6, BPF_REG_1),
261 
262 			/* if (ip == expected && port == expected) */
263 			BPF_LDX_MEM(BPF_W, BPF_REG_7, BPF_REG_6,
264 				    offsetof(struct bpf_sock, src_ip4)),
265 			BPF_JMP_IMM(BPF_JNE, BPF_REG_7, 0x0100007F, 4),
266 			BPF_LDX_MEM(BPF_W, BPF_REG_7, BPF_REG_6,
267 				    offsetof(struct bpf_sock, src_port)),
268 			BPF_JMP_IMM(BPF_JNE, BPF_REG_7, 0x1002, 2),
269 
270 			/* return ALLOW; */
271 			BPF_MOV64_IMM(BPF_REG_0, 1),
272 			BPF_JMP_A(1),
273 
274 			/* else return DENY; */
275 			BPF_MOV64_IMM(BPF_REG_0, 0),
276 			BPF_EXIT_INSN(),
277 		},
278 		BPF_CGROUP_INET4_POST_BIND,
279 		BPF_CGROUP_INET4_POST_BIND,
280 		AF_INET,
281 		SOCK_STREAM,
282 		"127.0.0.1",
283 		4098,
284 		SUCCESS,
285 	},
286 	{
287 		"bind4 allow all",
288 		.insns = {
289 			BPF_MOV64_IMM(BPF_REG_0, 1),
290 			BPF_EXIT_INSN(),
291 		},
292 		BPF_CGROUP_INET4_POST_BIND,
293 		BPF_CGROUP_INET4_POST_BIND,
294 		AF_INET,
295 		SOCK_STREAM,
296 		"0.0.0.0",
297 		0,
298 		SUCCESS,
299 	},
300 	{
301 		"bind6 allow all",
302 		.insns = {
303 			BPF_MOV64_IMM(BPF_REG_0, 1),
304 			BPF_EXIT_INSN(),
305 		},
306 		BPF_CGROUP_INET6_POST_BIND,
307 		BPF_CGROUP_INET6_POST_BIND,
308 		AF_INET6,
309 		SOCK_STREAM,
310 		"::",
311 		0,
312 		SUCCESS,
313 	},
314 };
315 
316 static size_t probe_prog_length(const struct bpf_insn *fp)
317 {
318 	size_t len;
319 
320 	for (len = MAX_INSNS - 1; len > 0; --len)
321 		if (fp[len].code != 0 || fp[len].imm != 0)
322 			break;
323 	return len + 1;
324 }
325 
326 static int load_sock_prog(const struct bpf_insn *prog,
327 			  enum bpf_attach_type attach_type)
328 {
329 	struct bpf_load_program_attr attr;
330 
331 	memset(&attr, 0, sizeof(struct bpf_load_program_attr));
332 	attr.prog_type = BPF_PROG_TYPE_CGROUP_SOCK;
333 	attr.expected_attach_type = attach_type;
334 	attr.insns = prog;
335 	attr.insns_cnt = probe_prog_length(attr.insns);
336 	attr.license = "GPL";
337 
338 	return bpf_load_program_xattr(&attr, bpf_log_buf, BPF_LOG_BUF_SIZE);
339 }
340 
341 static int attach_sock_prog(int cgfd, int progfd,
342 			    enum bpf_attach_type attach_type)
343 {
344 	return bpf_prog_attach(progfd, cgfd, attach_type, BPF_F_ALLOW_OVERRIDE);
345 }
346 
347 static int bind_sock(int domain, int type, const char *ip, unsigned short port)
348 {
349 	struct sockaddr_storage addr;
350 	struct sockaddr_in6 *addr6;
351 	struct sockaddr_in *addr4;
352 	int sockfd = -1;
353 	socklen_t len;
354 	int err = 0;
355 
356 	sockfd = socket(domain, type, 0);
357 	if (sockfd < 0)
358 		goto err;
359 
360 	memset(&addr, 0, sizeof(addr));
361 
362 	if (domain == AF_INET) {
363 		len = sizeof(struct sockaddr_in);
364 		addr4 = (struct sockaddr_in *)&addr;
365 		addr4->sin_family = domain;
366 		addr4->sin_port = htons(port);
367 		if (inet_pton(domain, ip, (void *)&addr4->sin_addr) != 1)
368 			goto err;
369 	} else if (domain == AF_INET6) {
370 		len = sizeof(struct sockaddr_in6);
371 		addr6 = (struct sockaddr_in6 *)&addr;
372 		addr6->sin6_family = domain;
373 		addr6->sin6_port = htons(port);
374 		if (inet_pton(domain, ip, (void *)&addr6->sin6_addr) != 1)
375 			goto err;
376 	} else {
377 		goto err;
378 	}
379 
380 	if (bind(sockfd, (const struct sockaddr *)&addr, len) == -1)
381 		goto err;
382 
383 	goto out;
384 err:
385 	err = -1;
386 out:
387 	close(sockfd);
388 	return err;
389 }
390 
391 static int run_test_case(int cgfd, const struct sock_test *test)
392 {
393 	int progfd = -1;
394 	int err = 0;
395 
396 	printf("Test case: %s .. ", test->descr);
397 	progfd = load_sock_prog(test->insns, test->expected_attach_type);
398 	if (progfd < 0) {
399 		if (test->result == LOAD_REJECT)
400 			goto out;
401 		else
402 			goto err;
403 	}
404 
405 	if (attach_sock_prog(cgfd, progfd, test->attach_type) == -1) {
406 		if (test->result == ATTACH_REJECT)
407 			goto out;
408 		else
409 			goto err;
410 	}
411 
412 	if (bind_sock(test->domain, test->type, test->ip, test->port) == -1) {
413 		/* sys_bind() may fail for different reasons, errno has to be
414 		 * checked to confirm that BPF program rejected it.
415 		 */
416 		if (test->result == BIND_REJECT && errno == EPERM)
417 			goto out;
418 		else
419 			goto err;
420 	}
421 
422 
423 	if (test->result != SUCCESS)
424 		goto err;
425 
426 	goto out;
427 err:
428 	err = -1;
429 out:
430 	/* Detaching w/o checking return code: best effort attempt. */
431 	if (progfd != -1)
432 		bpf_prog_detach(cgfd, test->attach_type);
433 	close(progfd);
434 	printf("[%s]\n", err ? "FAIL" : "PASS");
435 	return err;
436 }
437 
438 static int run_tests(int cgfd)
439 {
440 	int passes = 0;
441 	int fails = 0;
442 	int i;
443 
444 	for (i = 0; i < ARRAY_SIZE(tests); ++i) {
445 		if (run_test_case(cgfd, &tests[i]))
446 			++fails;
447 		else
448 			++passes;
449 	}
450 	printf("Summary: %d PASSED, %d FAILED\n", passes, fails);
451 	return fails ? -1 : 0;
452 }
453 
454 int main(int argc, char **argv)
455 {
456 	int cgfd = -1;
457 	int err = 0;
458 
459 	if (setup_cgroup_environment())
460 		goto err;
461 
462 	cgfd = create_and_get_cgroup(CG_PATH);
463 	if (!cgfd)
464 		goto err;
465 
466 	if (join_cgroup(CG_PATH))
467 		goto err;
468 
469 	if (run_tests(cgfd))
470 		goto err;
471 
472 	goto out;
473 err:
474 	err = -1;
475 out:
476 	close(cgfd);
477 	cleanup_cgroup_environment();
478 	return err;
479 }
480