1 // SPDX-License-Identifier: GPL-2.0 2 // Copyright (c) 2018 Facebook 3 4 #include <stdio.h> 5 #include <unistd.h> 6 7 #include <arpa/inet.h> 8 #include <sys/types.h> 9 #include <sys/socket.h> 10 11 #include <linux/filter.h> 12 13 #include <bpf/bpf.h> 14 15 #include "cgroup_helpers.h" 16 #include <bpf/bpf_endian.h> 17 #include "bpf_rlimit.h" 18 #include "bpf_util.h" 19 20 #define CG_PATH "/foo" 21 #define MAX_INSNS 512 22 23 char bpf_log_buf[BPF_LOG_BUF_SIZE]; 24 static bool verbose = false; 25 26 struct sock_test { 27 const char *descr; 28 /* BPF prog properties */ 29 struct bpf_insn insns[MAX_INSNS]; 30 enum bpf_attach_type expected_attach_type; 31 enum bpf_attach_type attach_type; 32 /* Socket properties */ 33 int domain; 34 int type; 35 /* Endpoint to bind() to */ 36 const char *ip; 37 unsigned short port; 38 /* Expected test result */ 39 enum { 40 LOAD_REJECT, 41 ATTACH_REJECT, 42 BIND_REJECT, 43 SUCCESS, 44 } result; 45 }; 46 47 static struct sock_test tests[] = { 48 { 49 "bind4 load with invalid access: src_ip6", 50 .insns = { 51 BPF_MOV64_REG(BPF_REG_6, BPF_REG_1), 52 BPF_LDX_MEM(BPF_W, BPF_REG_7, BPF_REG_6, 53 offsetof(struct bpf_sock, src_ip6[0])), 54 BPF_MOV64_IMM(BPF_REG_0, 1), 55 BPF_EXIT_INSN(), 56 }, 57 BPF_CGROUP_INET4_POST_BIND, 58 BPF_CGROUP_INET4_POST_BIND, 59 0, 60 0, 61 NULL, 62 0, 63 LOAD_REJECT, 64 }, 65 { 66 "bind4 load with invalid access: mark", 67 .insns = { 68 BPF_MOV64_REG(BPF_REG_6, BPF_REG_1), 69 BPF_LDX_MEM(BPF_W, BPF_REG_7, BPF_REG_6, 70 offsetof(struct bpf_sock, mark)), 71 BPF_MOV64_IMM(BPF_REG_0, 1), 72 BPF_EXIT_INSN(), 73 }, 74 BPF_CGROUP_INET4_POST_BIND, 75 BPF_CGROUP_INET4_POST_BIND, 76 0, 77 0, 78 NULL, 79 0, 80 LOAD_REJECT, 81 }, 82 { 83 "bind6 load with invalid access: src_ip4", 84 .insns = { 85 BPF_MOV64_REG(BPF_REG_6, BPF_REG_1), 86 BPF_LDX_MEM(BPF_W, BPF_REG_7, BPF_REG_6, 87 offsetof(struct bpf_sock, src_ip4)), 88 BPF_MOV64_IMM(BPF_REG_0, 1), 89 BPF_EXIT_INSN(), 90 }, 91 BPF_CGROUP_INET6_POST_BIND, 92 BPF_CGROUP_INET6_POST_BIND, 93 0, 94 0, 95 NULL, 96 0, 97 LOAD_REJECT, 98 }, 99 { 100 "sock_create load with invalid access: src_port", 101 .insns = { 102 BPF_MOV64_REG(BPF_REG_6, BPF_REG_1), 103 BPF_LDX_MEM(BPF_W, BPF_REG_7, BPF_REG_6, 104 offsetof(struct bpf_sock, src_port)), 105 BPF_MOV64_IMM(BPF_REG_0, 1), 106 BPF_EXIT_INSN(), 107 }, 108 BPF_CGROUP_INET_SOCK_CREATE, 109 BPF_CGROUP_INET_SOCK_CREATE, 110 0, 111 0, 112 NULL, 113 0, 114 LOAD_REJECT, 115 }, 116 { 117 "sock_create load w/o expected_attach_type (compat mode)", 118 .insns = { 119 BPF_MOV64_IMM(BPF_REG_0, 1), 120 BPF_EXIT_INSN(), 121 }, 122 0, 123 BPF_CGROUP_INET_SOCK_CREATE, 124 AF_INET, 125 SOCK_STREAM, 126 "127.0.0.1", 127 8097, 128 SUCCESS, 129 }, 130 { 131 "sock_create load w/ expected_attach_type", 132 .insns = { 133 BPF_MOV64_IMM(BPF_REG_0, 1), 134 BPF_EXIT_INSN(), 135 }, 136 BPF_CGROUP_INET_SOCK_CREATE, 137 BPF_CGROUP_INET_SOCK_CREATE, 138 AF_INET, 139 SOCK_STREAM, 140 "127.0.0.1", 141 8097, 142 SUCCESS, 143 }, 144 { 145 "attach type mismatch bind4 vs bind6", 146 .insns = { 147 BPF_MOV64_IMM(BPF_REG_0, 1), 148 BPF_EXIT_INSN(), 149 }, 150 BPF_CGROUP_INET4_POST_BIND, 151 BPF_CGROUP_INET6_POST_BIND, 152 0, 153 0, 154 NULL, 155 0, 156 ATTACH_REJECT, 157 }, 158 { 159 "attach type mismatch bind6 vs bind4", 160 .insns = { 161 BPF_MOV64_IMM(BPF_REG_0, 1), 162 BPF_EXIT_INSN(), 163 }, 164 BPF_CGROUP_INET6_POST_BIND, 165 BPF_CGROUP_INET4_POST_BIND, 166 0, 167 0, 168 NULL, 169 0, 170 ATTACH_REJECT, 171 }, 172 { 173 "attach type mismatch default vs bind4", 174 .insns = { 175 BPF_MOV64_IMM(BPF_REG_0, 1), 176 BPF_EXIT_INSN(), 177 }, 178 0, 179 BPF_CGROUP_INET4_POST_BIND, 180 0, 181 0, 182 NULL, 183 0, 184 ATTACH_REJECT, 185 }, 186 { 187 "attach type mismatch bind6 vs sock_create", 188 .insns = { 189 BPF_MOV64_IMM(BPF_REG_0, 1), 190 BPF_EXIT_INSN(), 191 }, 192 BPF_CGROUP_INET6_POST_BIND, 193 BPF_CGROUP_INET_SOCK_CREATE, 194 0, 195 0, 196 NULL, 197 0, 198 ATTACH_REJECT, 199 }, 200 { 201 "bind4 reject all", 202 .insns = { 203 BPF_MOV64_IMM(BPF_REG_0, 0), 204 BPF_EXIT_INSN(), 205 }, 206 BPF_CGROUP_INET4_POST_BIND, 207 BPF_CGROUP_INET4_POST_BIND, 208 AF_INET, 209 SOCK_STREAM, 210 "0.0.0.0", 211 0, 212 BIND_REJECT, 213 }, 214 { 215 "bind6 reject all", 216 .insns = { 217 BPF_MOV64_IMM(BPF_REG_0, 0), 218 BPF_EXIT_INSN(), 219 }, 220 BPF_CGROUP_INET6_POST_BIND, 221 BPF_CGROUP_INET6_POST_BIND, 222 AF_INET6, 223 SOCK_STREAM, 224 "::", 225 0, 226 BIND_REJECT, 227 }, 228 { 229 "bind6 deny specific IP & port", 230 .insns = { 231 BPF_MOV64_REG(BPF_REG_6, BPF_REG_1), 232 233 /* if (ip == expected && port == expected) */ 234 BPF_LDX_MEM(BPF_W, BPF_REG_7, BPF_REG_6, 235 offsetof(struct bpf_sock, src_ip6[3])), 236 BPF_JMP_IMM(BPF_JNE, BPF_REG_7, 237 __bpf_constant_ntohl(0x00000001), 4), 238 BPF_LDX_MEM(BPF_W, BPF_REG_7, BPF_REG_6, 239 offsetof(struct bpf_sock, src_port)), 240 BPF_JMP_IMM(BPF_JNE, BPF_REG_7, 0x2001, 2), 241 242 /* return DENY; */ 243 BPF_MOV64_IMM(BPF_REG_0, 0), 244 BPF_JMP_A(1), 245 246 /* else return ALLOW; */ 247 BPF_MOV64_IMM(BPF_REG_0, 1), 248 BPF_EXIT_INSN(), 249 }, 250 BPF_CGROUP_INET6_POST_BIND, 251 BPF_CGROUP_INET6_POST_BIND, 252 AF_INET6, 253 SOCK_STREAM, 254 "::1", 255 8193, 256 BIND_REJECT, 257 }, 258 { 259 "bind4 allow specific IP & port", 260 .insns = { 261 BPF_MOV64_REG(BPF_REG_6, BPF_REG_1), 262 263 /* if (ip == expected && port == expected) */ 264 BPF_LDX_MEM(BPF_W, BPF_REG_7, BPF_REG_6, 265 offsetof(struct bpf_sock, src_ip4)), 266 BPF_JMP_IMM(BPF_JNE, BPF_REG_7, 267 __bpf_constant_ntohl(0x7F000001), 4), 268 BPF_LDX_MEM(BPF_W, BPF_REG_7, BPF_REG_6, 269 offsetof(struct bpf_sock, src_port)), 270 BPF_JMP_IMM(BPF_JNE, BPF_REG_7, 0x1002, 2), 271 272 /* return ALLOW; */ 273 BPF_MOV64_IMM(BPF_REG_0, 1), 274 BPF_JMP_A(1), 275 276 /* else return DENY; */ 277 BPF_MOV64_IMM(BPF_REG_0, 0), 278 BPF_EXIT_INSN(), 279 }, 280 BPF_CGROUP_INET4_POST_BIND, 281 BPF_CGROUP_INET4_POST_BIND, 282 AF_INET, 283 SOCK_STREAM, 284 "127.0.0.1", 285 4098, 286 SUCCESS, 287 }, 288 { 289 "bind4 allow all", 290 .insns = { 291 BPF_MOV64_IMM(BPF_REG_0, 1), 292 BPF_EXIT_INSN(), 293 }, 294 BPF_CGROUP_INET4_POST_BIND, 295 BPF_CGROUP_INET4_POST_BIND, 296 AF_INET, 297 SOCK_STREAM, 298 "0.0.0.0", 299 0, 300 SUCCESS, 301 }, 302 { 303 "bind6 allow all", 304 .insns = { 305 BPF_MOV64_IMM(BPF_REG_0, 1), 306 BPF_EXIT_INSN(), 307 }, 308 BPF_CGROUP_INET6_POST_BIND, 309 BPF_CGROUP_INET6_POST_BIND, 310 AF_INET6, 311 SOCK_STREAM, 312 "::", 313 0, 314 SUCCESS, 315 }, 316 }; 317 318 static size_t probe_prog_length(const struct bpf_insn *fp) 319 { 320 size_t len; 321 322 for (len = MAX_INSNS - 1; len > 0; --len) 323 if (fp[len].code != 0 || fp[len].imm != 0) 324 break; 325 return len + 1; 326 } 327 328 static int load_sock_prog(const struct bpf_insn *prog, 329 enum bpf_attach_type attach_type) 330 { 331 LIBBPF_OPTS(bpf_prog_load_opts, opts); 332 int ret, insn_cnt; 333 334 insn_cnt = probe_prog_length(prog); 335 336 opts.expected_attach_type = attach_type; 337 opts.log_buf = bpf_log_buf; 338 opts.log_size = BPF_LOG_BUF_SIZE; 339 opts.log_level = 2; 340 341 ret = bpf_prog_load(BPF_PROG_TYPE_CGROUP_SOCK, NULL, "GPL", prog, insn_cnt, &opts); 342 if (verbose && ret < 0) 343 fprintf(stderr, "%s\n", bpf_log_buf); 344 345 return ret; 346 } 347 348 static int attach_sock_prog(int cgfd, int progfd, 349 enum bpf_attach_type attach_type) 350 { 351 return bpf_prog_attach(progfd, cgfd, attach_type, BPF_F_ALLOW_OVERRIDE); 352 } 353 354 static int bind_sock(int domain, int type, const char *ip, unsigned short port) 355 { 356 struct sockaddr_storage addr; 357 struct sockaddr_in6 *addr6; 358 struct sockaddr_in *addr4; 359 int sockfd = -1; 360 socklen_t len; 361 int err = 0; 362 363 sockfd = socket(domain, type, 0); 364 if (sockfd < 0) 365 goto err; 366 367 memset(&addr, 0, sizeof(addr)); 368 369 if (domain == AF_INET) { 370 len = sizeof(struct sockaddr_in); 371 addr4 = (struct sockaddr_in *)&addr; 372 addr4->sin_family = domain; 373 addr4->sin_port = htons(port); 374 if (inet_pton(domain, ip, (void *)&addr4->sin_addr) != 1) 375 goto err; 376 } else if (domain == AF_INET6) { 377 len = sizeof(struct sockaddr_in6); 378 addr6 = (struct sockaddr_in6 *)&addr; 379 addr6->sin6_family = domain; 380 addr6->sin6_port = htons(port); 381 if (inet_pton(domain, ip, (void *)&addr6->sin6_addr) != 1) 382 goto err; 383 } else { 384 goto err; 385 } 386 387 if (bind(sockfd, (const struct sockaddr *)&addr, len) == -1) 388 goto err; 389 390 goto out; 391 err: 392 err = -1; 393 out: 394 close(sockfd); 395 return err; 396 } 397 398 static int run_test_case(int cgfd, const struct sock_test *test) 399 { 400 int progfd = -1; 401 int err = 0; 402 403 printf("Test case: %s .. ", test->descr); 404 progfd = load_sock_prog(test->insns, test->expected_attach_type); 405 if (progfd < 0) { 406 if (test->result == LOAD_REJECT) 407 goto out; 408 else 409 goto err; 410 } 411 412 if (attach_sock_prog(cgfd, progfd, test->attach_type) == -1) { 413 if (test->result == ATTACH_REJECT) 414 goto out; 415 else 416 goto err; 417 } 418 419 if (bind_sock(test->domain, test->type, test->ip, test->port) == -1) { 420 /* sys_bind() may fail for different reasons, errno has to be 421 * checked to confirm that BPF program rejected it. 422 */ 423 if (test->result == BIND_REJECT && errno == EPERM) 424 goto out; 425 else 426 goto err; 427 } 428 429 430 if (test->result != SUCCESS) 431 goto err; 432 433 goto out; 434 err: 435 err = -1; 436 out: 437 /* Detaching w/o checking return code: best effort attempt. */ 438 if (progfd != -1) 439 bpf_prog_detach(cgfd, test->attach_type); 440 close(progfd); 441 printf("[%s]\n", err ? "FAIL" : "PASS"); 442 return err; 443 } 444 445 static int run_tests(int cgfd) 446 { 447 int passes = 0; 448 int fails = 0; 449 int i; 450 451 for (i = 0; i < ARRAY_SIZE(tests); ++i) { 452 if (run_test_case(cgfd, &tests[i])) 453 ++fails; 454 else 455 ++passes; 456 } 457 printf("Summary: %d PASSED, %d FAILED\n", passes, fails); 458 return fails ? -1 : 0; 459 } 460 461 int main(int argc, char **argv) 462 { 463 int cgfd = -1; 464 int err = 0; 465 466 cgfd = cgroup_setup_and_join(CG_PATH); 467 if (cgfd < 0) 468 goto err; 469 470 if (run_tests(cgfd)) 471 goto err; 472 473 goto out; 474 err: 475 err = -1; 476 out: 477 close(cgfd); 478 cleanup_cgroup_environment(); 479 return err; 480 } 481