1 // SPDX-License-Identifier: GPL-2.0 2 // Copyright (c) 2018 Facebook 3 4 #include <stdio.h> 5 #include <stdlib.h> 6 #include <unistd.h> 7 8 #include <arpa/inet.h> 9 #include <sys/types.h> 10 #include <sys/socket.h> 11 12 #include <linux/filter.h> 13 14 #include <bpf/bpf.h> 15 #include <bpf/libbpf.h> 16 17 #include "cgroup_helpers.h" 18 #include "bpf_rlimit.h" 19 20 #define CG_PATH "/foo" 21 #define CONNECT4_PROG_PATH "./connect4_prog.o" 22 #define CONNECT6_PROG_PATH "./connect6_prog.o" 23 24 #define SERV4_IP "192.168.1.254" 25 #define SERV4_REWRITE_IP "127.0.0.1" 26 #define SERV4_PORT 4040 27 #define SERV4_REWRITE_PORT 4444 28 29 #define SERV6_IP "face:b00c:1234:5678::abcd" 30 #define SERV6_REWRITE_IP "::1" 31 #define SERV6_PORT 6060 32 #define SERV6_REWRITE_PORT 6666 33 34 #define INET_NTOP_BUF 40 35 36 typedef int (*load_fn)(enum bpf_attach_type, const char *comment); 37 typedef int (*info_fn)(int, struct sockaddr *, socklen_t *); 38 39 struct program { 40 enum bpf_attach_type type; 41 load_fn loadfn; 42 int fd; 43 const char *name; 44 enum bpf_attach_type invalid_type; 45 }; 46 47 char bpf_log_buf[BPF_LOG_BUF_SIZE]; 48 49 static int mk_sockaddr(int domain, const char *ip, unsigned short port, 50 struct sockaddr *addr, socklen_t addr_len) 51 { 52 struct sockaddr_in6 *addr6; 53 struct sockaddr_in *addr4; 54 55 if (domain != AF_INET && domain != AF_INET6) { 56 log_err("Unsupported address family"); 57 return -1; 58 } 59 60 memset(addr, 0, addr_len); 61 62 if (domain == AF_INET) { 63 if (addr_len < sizeof(struct sockaddr_in)) 64 return -1; 65 addr4 = (struct sockaddr_in *)addr; 66 addr4->sin_family = domain; 67 addr4->sin_port = htons(port); 68 if (inet_pton(domain, ip, (void *)&addr4->sin_addr) != 1) { 69 log_err("Invalid IPv4: %s", ip); 70 return -1; 71 } 72 } else if (domain == AF_INET6) { 73 if (addr_len < sizeof(struct sockaddr_in6)) 74 return -1; 75 addr6 = (struct sockaddr_in6 *)addr; 76 addr6->sin6_family = domain; 77 addr6->sin6_port = htons(port); 78 if (inet_pton(domain, ip, (void *)&addr6->sin6_addr) != 1) { 79 log_err("Invalid IPv6: %s", ip); 80 return -1; 81 } 82 } 83 84 return 0; 85 } 86 87 static int load_insns(enum bpf_attach_type attach_type, 88 const struct bpf_insn *insns, size_t insns_cnt, 89 const char *comment) 90 { 91 struct bpf_load_program_attr load_attr; 92 int ret; 93 94 memset(&load_attr, 0, sizeof(struct bpf_load_program_attr)); 95 load_attr.prog_type = BPF_PROG_TYPE_CGROUP_SOCK_ADDR; 96 load_attr.expected_attach_type = attach_type; 97 load_attr.insns = insns; 98 load_attr.insns_cnt = insns_cnt; 99 load_attr.license = "GPL"; 100 101 ret = bpf_load_program_xattr(&load_attr, bpf_log_buf, BPF_LOG_BUF_SIZE); 102 if (ret < 0 && comment) { 103 log_err(">>> Loading %s program error.\n" 104 ">>> Output from verifier:\n%s\n-------\n", 105 comment, bpf_log_buf); 106 } 107 108 return ret; 109 } 110 111 /* [1] These testing programs try to read different context fields, including 112 * narrow loads of different sizes from user_ip4 and user_ip6, and write to 113 * those allowed to be overridden. 114 * 115 * [2] BPF_LD_IMM64 & BPF_JMP_REG are used below whenever there is a need to 116 * compare a register with unsigned 32bit integer. BPF_JMP_IMM can't be used 117 * in such cases since it accepts only _signed_ 32bit integer as IMM 118 * argument. Also note that BPF_LD_IMM64 contains 2 instructions what matters 119 * to count jumps properly. 120 */ 121 122 static int bind4_prog_load(enum bpf_attach_type attach_type, 123 const char *comment) 124 { 125 union { 126 uint8_t u4_addr8[4]; 127 uint16_t u4_addr16[2]; 128 uint32_t u4_addr32; 129 } ip4; 130 struct sockaddr_in addr4_rw; 131 132 if (inet_pton(AF_INET, SERV4_IP, (void *)&ip4) != 1) { 133 log_err("Invalid IPv4: %s", SERV4_IP); 134 return -1; 135 } 136 137 if (mk_sockaddr(AF_INET, SERV4_REWRITE_IP, SERV4_REWRITE_PORT, 138 (struct sockaddr *)&addr4_rw, sizeof(addr4_rw)) == -1) 139 return -1; 140 141 /* See [1]. */ 142 struct bpf_insn insns[] = { 143 BPF_MOV64_REG(BPF_REG_6, BPF_REG_1), 144 145 /* if (sk.family == AF_INET && */ 146 BPF_LDX_MEM(BPF_W, BPF_REG_7, BPF_REG_6, 147 offsetof(struct bpf_sock_addr, family)), 148 BPF_JMP_IMM(BPF_JNE, BPF_REG_7, AF_INET, 16), 149 150 /* (sk.type == SOCK_DGRAM || sk.type == SOCK_STREAM) && */ 151 BPF_LDX_MEM(BPF_W, BPF_REG_7, BPF_REG_6, 152 offsetof(struct bpf_sock_addr, type)), 153 BPF_JMP_IMM(BPF_JNE, BPF_REG_7, SOCK_DGRAM, 1), 154 BPF_JMP_A(1), 155 BPF_JMP_IMM(BPF_JNE, BPF_REG_7, SOCK_STREAM, 12), 156 157 /* 1st_byte_of_user_ip4 == expected && */ 158 BPF_LDX_MEM(BPF_B, BPF_REG_7, BPF_REG_6, 159 offsetof(struct bpf_sock_addr, user_ip4)), 160 BPF_JMP_IMM(BPF_JNE, BPF_REG_7, ip4.u4_addr8[0], 10), 161 162 /* 1st_half_of_user_ip4 == expected && */ 163 BPF_LDX_MEM(BPF_H, BPF_REG_7, BPF_REG_6, 164 offsetof(struct bpf_sock_addr, user_ip4)), 165 BPF_JMP_IMM(BPF_JNE, BPF_REG_7, ip4.u4_addr16[0], 8), 166 167 /* whole_user_ip4 == expected) { */ 168 BPF_LDX_MEM(BPF_W, BPF_REG_7, BPF_REG_6, 169 offsetof(struct bpf_sock_addr, user_ip4)), 170 BPF_LD_IMM64(BPF_REG_8, ip4.u4_addr32), /* See [2]. */ 171 BPF_JMP_REG(BPF_JNE, BPF_REG_7, BPF_REG_8, 4), 172 173 /* user_ip4 = addr4_rw.sin_addr */ 174 BPF_MOV32_IMM(BPF_REG_7, addr4_rw.sin_addr.s_addr), 175 BPF_STX_MEM(BPF_W, BPF_REG_6, BPF_REG_7, 176 offsetof(struct bpf_sock_addr, user_ip4)), 177 178 /* user_port = addr4_rw.sin_port */ 179 BPF_MOV32_IMM(BPF_REG_7, addr4_rw.sin_port), 180 BPF_STX_MEM(BPF_W, BPF_REG_6, BPF_REG_7, 181 offsetof(struct bpf_sock_addr, user_port)), 182 /* } */ 183 184 /* return 1 */ 185 BPF_MOV64_IMM(BPF_REG_0, 1), 186 BPF_EXIT_INSN(), 187 }; 188 189 return load_insns(attach_type, insns, 190 sizeof(insns) / sizeof(struct bpf_insn), comment); 191 } 192 193 static int bind6_prog_load(enum bpf_attach_type attach_type, 194 const char *comment) 195 { 196 struct sockaddr_in6 addr6_rw; 197 struct in6_addr ip6; 198 199 if (inet_pton(AF_INET6, SERV6_IP, (void *)&ip6) != 1) { 200 log_err("Invalid IPv6: %s", SERV6_IP); 201 return -1; 202 } 203 204 if (mk_sockaddr(AF_INET6, SERV6_REWRITE_IP, SERV6_REWRITE_PORT, 205 (struct sockaddr *)&addr6_rw, sizeof(addr6_rw)) == -1) 206 return -1; 207 208 /* See [1]. */ 209 struct bpf_insn insns[] = { 210 BPF_MOV64_REG(BPF_REG_6, BPF_REG_1), 211 212 /* if (sk.family == AF_INET6 && */ 213 BPF_LDX_MEM(BPF_W, BPF_REG_7, BPF_REG_6, 214 offsetof(struct bpf_sock_addr, family)), 215 BPF_JMP_IMM(BPF_JNE, BPF_REG_7, AF_INET6, 18), 216 217 /* 5th_byte_of_user_ip6 == expected && */ 218 BPF_LDX_MEM(BPF_B, BPF_REG_7, BPF_REG_6, 219 offsetof(struct bpf_sock_addr, user_ip6[1])), 220 BPF_JMP_IMM(BPF_JNE, BPF_REG_7, ip6.s6_addr[4], 16), 221 222 /* 3rd_half_of_user_ip6 == expected && */ 223 BPF_LDX_MEM(BPF_H, BPF_REG_7, BPF_REG_6, 224 offsetof(struct bpf_sock_addr, user_ip6[1])), 225 BPF_JMP_IMM(BPF_JNE, BPF_REG_7, ip6.s6_addr16[2], 14), 226 227 /* last_word_of_user_ip6 == expected) { */ 228 BPF_LDX_MEM(BPF_W, BPF_REG_7, BPF_REG_6, 229 offsetof(struct bpf_sock_addr, user_ip6[3])), 230 BPF_LD_IMM64(BPF_REG_8, ip6.s6_addr32[3]), /* See [2]. */ 231 BPF_JMP_REG(BPF_JNE, BPF_REG_7, BPF_REG_8, 10), 232 233 234 #define STORE_IPV6_WORD(N) \ 235 BPF_MOV32_IMM(BPF_REG_7, addr6_rw.sin6_addr.s6_addr32[N]), \ 236 BPF_STX_MEM(BPF_W, BPF_REG_6, BPF_REG_7, \ 237 offsetof(struct bpf_sock_addr, user_ip6[N])) 238 239 /* user_ip6 = addr6_rw.sin6_addr */ 240 STORE_IPV6_WORD(0), 241 STORE_IPV6_WORD(1), 242 STORE_IPV6_WORD(2), 243 STORE_IPV6_WORD(3), 244 245 /* user_port = addr6_rw.sin6_port */ 246 BPF_MOV32_IMM(BPF_REG_7, addr6_rw.sin6_port), 247 BPF_STX_MEM(BPF_W, BPF_REG_6, BPF_REG_7, 248 offsetof(struct bpf_sock_addr, user_port)), 249 250 /* } */ 251 252 /* return 1 */ 253 BPF_MOV64_IMM(BPF_REG_0, 1), 254 BPF_EXIT_INSN(), 255 }; 256 257 return load_insns(attach_type, insns, 258 sizeof(insns) / sizeof(struct bpf_insn), comment); 259 } 260 261 static int connect_prog_load_path(const char *path, 262 enum bpf_attach_type attach_type, 263 const char *comment) 264 { 265 struct bpf_prog_load_attr attr; 266 struct bpf_object *obj; 267 int prog_fd; 268 269 memset(&attr, 0, sizeof(struct bpf_prog_load_attr)); 270 attr.file = path; 271 attr.prog_type = BPF_PROG_TYPE_CGROUP_SOCK_ADDR; 272 attr.expected_attach_type = attach_type; 273 274 if (bpf_prog_load_xattr(&attr, &obj, &prog_fd)) { 275 if (comment) 276 log_err(">>> Loading %s program at %s error.\n", 277 comment, path); 278 return -1; 279 } 280 281 return prog_fd; 282 } 283 284 static int connect4_prog_load(enum bpf_attach_type attach_type, 285 const char *comment) 286 { 287 return connect_prog_load_path(CONNECT4_PROG_PATH, attach_type, comment); 288 } 289 290 static int connect6_prog_load(enum bpf_attach_type attach_type, 291 const char *comment) 292 { 293 return connect_prog_load_path(CONNECT6_PROG_PATH, attach_type, comment); 294 } 295 296 static void print_ip_port(int sockfd, info_fn fn, const char *fmt) 297 { 298 char addr_buf[INET_NTOP_BUF]; 299 struct sockaddr_storage addr; 300 struct sockaddr_in6 *addr6; 301 struct sockaddr_in *addr4; 302 socklen_t addr_len; 303 unsigned short port; 304 void *nip; 305 306 addr_len = sizeof(struct sockaddr_storage); 307 memset(&addr, 0, addr_len); 308 309 if (fn(sockfd, (struct sockaddr *)&addr, (socklen_t *)&addr_len) == 0) { 310 if (addr.ss_family == AF_INET) { 311 addr4 = (struct sockaddr_in *)&addr; 312 nip = (void *)&addr4->sin_addr; 313 port = ntohs(addr4->sin_port); 314 } else if (addr.ss_family == AF_INET6) { 315 addr6 = (struct sockaddr_in6 *)&addr; 316 nip = (void *)&addr6->sin6_addr; 317 port = ntohs(addr6->sin6_port); 318 } else { 319 return; 320 } 321 const char *addr_str = 322 inet_ntop(addr.ss_family, nip, addr_buf, INET_NTOP_BUF); 323 printf(fmt, addr_str ? addr_str : "??", port); 324 } 325 } 326 327 static void print_local_ip_port(int sockfd, const char *fmt) 328 { 329 print_ip_port(sockfd, getsockname, fmt); 330 } 331 332 static void print_remote_ip_port(int sockfd, const char *fmt) 333 { 334 print_ip_port(sockfd, getpeername, fmt); 335 } 336 337 static int start_server(int type, const struct sockaddr_storage *addr, 338 socklen_t addr_len) 339 { 340 341 int fd; 342 343 fd = socket(addr->ss_family, type, 0); 344 if (fd == -1) { 345 log_err("Failed to create server socket"); 346 goto out; 347 } 348 349 if (bind(fd, (const struct sockaddr *)addr, addr_len) == -1) { 350 log_err("Failed to bind server socket"); 351 goto close_out; 352 } 353 354 if (type == SOCK_STREAM) { 355 if (listen(fd, 128) == -1) { 356 log_err("Failed to listen on server socket"); 357 goto close_out; 358 } 359 } 360 361 print_local_ip_port(fd, "\t Actual: bind(%s, %d)\n"); 362 363 goto out; 364 close_out: 365 close(fd); 366 fd = -1; 367 out: 368 return fd; 369 } 370 371 static int connect_to_server(int type, const struct sockaddr_storage *addr, 372 socklen_t addr_len) 373 { 374 int domain; 375 int fd; 376 377 domain = addr->ss_family; 378 379 if (domain != AF_INET && domain != AF_INET6) { 380 log_err("Unsupported address family"); 381 return -1; 382 } 383 384 fd = socket(domain, type, 0); 385 if (fd == -1) { 386 log_err("Failed to creating client socket"); 387 return -1; 388 } 389 390 if (connect(fd, (const struct sockaddr *)addr, addr_len) == -1) { 391 log_err("Fail to connect to server"); 392 goto err; 393 } 394 395 print_remote_ip_port(fd, "\t Actual: connect(%s, %d)"); 396 print_local_ip_port(fd, " from (%s, %d)\n"); 397 398 return 0; 399 err: 400 close(fd); 401 return -1; 402 } 403 404 static void print_test_case_num(int domain, int type) 405 { 406 static int test_num; 407 408 printf("Test case #%d (%s/%s):\n", ++test_num, 409 (domain == AF_INET ? "IPv4" : 410 domain == AF_INET6 ? "IPv6" : 411 "unknown_domain"), 412 (type == SOCK_STREAM ? "TCP" : 413 type == SOCK_DGRAM ? "UDP" : 414 "unknown_type")); 415 } 416 417 static int run_test_case(int domain, int type, const char *ip, 418 unsigned short port) 419 { 420 struct sockaddr_storage addr; 421 socklen_t addr_len = sizeof(addr); 422 int servfd = -1; 423 int err = 0; 424 425 print_test_case_num(domain, type); 426 427 if (mk_sockaddr(domain, ip, port, (struct sockaddr *)&addr, 428 addr_len) == -1) 429 return -1; 430 431 printf("\tRequested: bind(%s, %d) ..\n", ip, port); 432 servfd = start_server(type, &addr, addr_len); 433 if (servfd == -1) 434 goto err; 435 436 printf("\tRequested: connect(%s, %d) from (*, *) ..\n", ip, port); 437 if (connect_to_server(type, &addr, addr_len)) 438 goto err; 439 440 goto out; 441 err: 442 err = -1; 443 out: 444 close(servfd); 445 return err; 446 } 447 448 static void close_progs_fds(struct program *progs, size_t prog_cnt) 449 { 450 size_t i; 451 452 for (i = 0; i < prog_cnt; ++i) { 453 close(progs[i].fd); 454 progs[i].fd = -1; 455 } 456 } 457 458 static int load_and_attach_progs(int cgfd, struct program *progs, 459 size_t prog_cnt) 460 { 461 size_t i; 462 463 for (i = 0; i < prog_cnt; ++i) { 464 printf("Load %s with invalid type (can pollute stderr) ", 465 progs[i].name); 466 fflush(stdout); 467 progs[i].fd = progs[i].loadfn(progs[i].invalid_type, NULL); 468 if (progs[i].fd != -1) { 469 log_err("Load with invalid type accepted for %s", 470 progs[i].name); 471 goto err; 472 } 473 printf("... REJECTED\n"); 474 475 printf("Load %s with valid type", progs[i].name); 476 progs[i].fd = progs[i].loadfn(progs[i].type, progs[i].name); 477 if (progs[i].fd == -1) { 478 log_err("Failed to load program %s", progs[i].name); 479 goto err; 480 } 481 printf(" ... OK\n"); 482 483 printf("Attach %s with invalid type", progs[i].name); 484 if (bpf_prog_attach(progs[i].fd, cgfd, progs[i].invalid_type, 485 BPF_F_ALLOW_OVERRIDE) != -1) { 486 log_err("Attach with invalid type accepted for %s", 487 progs[i].name); 488 goto err; 489 } 490 printf(" ... REJECTED\n"); 491 492 printf("Attach %s with valid type", progs[i].name); 493 if (bpf_prog_attach(progs[i].fd, cgfd, progs[i].type, 494 BPF_F_ALLOW_OVERRIDE) == -1) { 495 log_err("Failed to attach program %s", progs[i].name); 496 goto err; 497 } 498 printf(" ... OK\n"); 499 } 500 501 return 0; 502 err: 503 close_progs_fds(progs, prog_cnt); 504 return -1; 505 } 506 507 static int run_domain_test(int domain, int cgfd, struct program *progs, 508 size_t prog_cnt, const char *ip, unsigned short port) 509 { 510 int err = 0; 511 512 if (load_and_attach_progs(cgfd, progs, prog_cnt) == -1) 513 goto err; 514 515 if (run_test_case(domain, SOCK_STREAM, ip, port) == -1) 516 goto err; 517 518 if (run_test_case(domain, SOCK_DGRAM, ip, port) == -1) 519 goto err; 520 521 goto out; 522 err: 523 err = -1; 524 out: 525 close_progs_fds(progs, prog_cnt); 526 return err; 527 } 528 529 static int run_test(void) 530 { 531 size_t inet6_prog_cnt; 532 size_t inet_prog_cnt; 533 int cgfd = -1; 534 int err = 0; 535 536 struct program inet6_progs[] = { 537 {BPF_CGROUP_INET6_BIND, bind6_prog_load, -1, "bind6", 538 BPF_CGROUP_INET4_BIND}, 539 {BPF_CGROUP_INET6_CONNECT, connect6_prog_load, -1, "connect6", 540 BPF_CGROUP_INET4_CONNECT}, 541 }; 542 inet6_prog_cnt = sizeof(inet6_progs) / sizeof(struct program); 543 544 struct program inet_progs[] = { 545 {BPF_CGROUP_INET4_BIND, bind4_prog_load, -1, "bind4", 546 BPF_CGROUP_INET6_BIND}, 547 {BPF_CGROUP_INET4_CONNECT, connect4_prog_load, -1, "connect4", 548 BPF_CGROUP_INET6_CONNECT}, 549 }; 550 inet_prog_cnt = sizeof(inet_progs) / sizeof(struct program); 551 552 if (setup_cgroup_environment()) 553 goto err; 554 555 cgfd = create_and_get_cgroup(CG_PATH); 556 if (!cgfd) 557 goto err; 558 559 if (join_cgroup(CG_PATH)) 560 goto err; 561 562 if (run_domain_test(AF_INET, cgfd, inet_progs, inet_prog_cnt, SERV4_IP, 563 SERV4_PORT) == -1) 564 goto err; 565 566 if (run_domain_test(AF_INET6, cgfd, inet6_progs, inet6_prog_cnt, 567 SERV6_IP, SERV6_PORT) == -1) 568 goto err; 569 570 goto out; 571 err: 572 err = -1; 573 out: 574 close(cgfd); 575 cleanup_cgroup_environment(); 576 printf(err ? "### FAIL\n" : "### SUCCESS\n"); 577 return err; 578 } 579 580 int main(int argc, char **argv) 581 { 582 if (argc < 2) { 583 fprintf(stderr, 584 "%s has to be run via %s.sh. Skip direct run.\n", 585 argv[0], argv[0]); 586 exit(0); 587 } 588 return run_test(); 589 } 590