1 // SPDX-License-Identifier: GPL-2.0 2 // Copyright (c) 2018 Facebook 3 4 #include <stdio.h> 5 #include <stdlib.h> 6 #include <unistd.h> 7 8 #include <arpa/inet.h> 9 #include <sys/types.h> 10 #include <sys/socket.h> 11 12 #include <linux/filter.h> 13 14 #include <bpf/bpf.h> 15 #include <bpf/libbpf.h> 16 17 #include "cgroup_helpers.h" 18 19 #define CG_PATH "/foo" 20 #define CONNECT4_PROG_PATH "./connect4_prog.o" 21 #define CONNECT6_PROG_PATH "./connect6_prog.o" 22 23 #define SERV4_IP "192.168.1.254" 24 #define SERV4_REWRITE_IP "127.0.0.1" 25 #define SERV4_PORT 4040 26 #define SERV4_REWRITE_PORT 4444 27 28 #define SERV6_IP "face:b00c:1234:5678::abcd" 29 #define SERV6_REWRITE_IP "::1" 30 #define SERV6_PORT 6060 31 #define SERV6_REWRITE_PORT 6666 32 33 #define INET_NTOP_BUF 40 34 35 typedef int (*load_fn)(enum bpf_attach_type, const char *comment); 36 typedef int (*info_fn)(int, struct sockaddr *, socklen_t *); 37 38 struct program { 39 enum bpf_attach_type type; 40 load_fn loadfn; 41 int fd; 42 const char *name; 43 enum bpf_attach_type invalid_type; 44 }; 45 46 char bpf_log_buf[BPF_LOG_BUF_SIZE]; 47 48 static int mk_sockaddr(int domain, const char *ip, unsigned short port, 49 struct sockaddr *addr, socklen_t addr_len) 50 { 51 struct sockaddr_in6 *addr6; 52 struct sockaddr_in *addr4; 53 54 if (domain != AF_INET && domain != AF_INET6) { 55 log_err("Unsupported address family"); 56 return -1; 57 } 58 59 memset(addr, 0, addr_len); 60 61 if (domain == AF_INET) { 62 if (addr_len < sizeof(struct sockaddr_in)) 63 return -1; 64 addr4 = (struct sockaddr_in *)addr; 65 addr4->sin_family = domain; 66 addr4->sin_port = htons(port); 67 if (inet_pton(domain, ip, (void *)&addr4->sin_addr) != 1) { 68 log_err("Invalid IPv4: %s", ip); 69 return -1; 70 } 71 } else if (domain == AF_INET6) { 72 if (addr_len < sizeof(struct sockaddr_in6)) 73 return -1; 74 addr6 = (struct sockaddr_in6 *)addr; 75 addr6->sin6_family = domain; 76 addr6->sin6_port = htons(port); 77 if (inet_pton(domain, ip, (void *)&addr6->sin6_addr) != 1) { 78 log_err("Invalid IPv6: %s", ip); 79 return -1; 80 } 81 } 82 83 return 0; 84 } 85 86 static int load_insns(enum bpf_attach_type attach_type, 87 const struct bpf_insn *insns, size_t insns_cnt, 88 const char *comment) 89 { 90 struct bpf_load_program_attr load_attr; 91 int ret; 92 93 memset(&load_attr, 0, sizeof(struct bpf_load_program_attr)); 94 load_attr.prog_type = BPF_PROG_TYPE_CGROUP_SOCK_ADDR; 95 load_attr.expected_attach_type = attach_type; 96 load_attr.insns = insns; 97 load_attr.insns_cnt = insns_cnt; 98 load_attr.license = "GPL"; 99 100 ret = bpf_load_program_xattr(&load_attr, bpf_log_buf, BPF_LOG_BUF_SIZE); 101 if (ret < 0 && comment) { 102 log_err(">>> Loading %s program error.\n" 103 ">>> Output from verifier:\n%s\n-------\n", 104 comment, bpf_log_buf); 105 } 106 107 return ret; 108 } 109 110 /* [1] These testing programs try to read different context fields, including 111 * narrow loads of different sizes from user_ip4 and user_ip6, and write to 112 * those allowed to be overridden. 113 * 114 * [2] BPF_LD_IMM64 & BPF_JMP_REG are used below whenever there is a need to 115 * compare a register with unsigned 32bit integer. BPF_JMP_IMM can't be used 116 * in such cases since it accepts only _signed_ 32bit integer as IMM 117 * argument. Also note that BPF_LD_IMM64 contains 2 instructions what matters 118 * to count jumps properly. 119 */ 120 121 static int bind4_prog_load(enum bpf_attach_type attach_type, 122 const char *comment) 123 { 124 union { 125 uint8_t u4_addr8[4]; 126 uint16_t u4_addr16[2]; 127 uint32_t u4_addr32; 128 } ip4; 129 struct sockaddr_in addr4_rw; 130 131 if (inet_pton(AF_INET, SERV4_IP, (void *)&ip4) != 1) { 132 log_err("Invalid IPv4: %s", SERV4_IP); 133 return -1; 134 } 135 136 if (mk_sockaddr(AF_INET, SERV4_REWRITE_IP, SERV4_REWRITE_PORT, 137 (struct sockaddr *)&addr4_rw, sizeof(addr4_rw)) == -1) 138 return -1; 139 140 /* See [1]. */ 141 struct bpf_insn insns[] = { 142 BPF_MOV64_REG(BPF_REG_6, BPF_REG_1), 143 144 /* if (sk.family == AF_INET && */ 145 BPF_LDX_MEM(BPF_W, BPF_REG_7, BPF_REG_6, 146 offsetof(struct bpf_sock_addr, family)), 147 BPF_JMP_IMM(BPF_JNE, BPF_REG_7, AF_INET, 16), 148 149 /* (sk.type == SOCK_DGRAM || sk.type == SOCK_STREAM) && */ 150 BPF_LDX_MEM(BPF_W, BPF_REG_7, BPF_REG_6, 151 offsetof(struct bpf_sock_addr, type)), 152 BPF_JMP_IMM(BPF_JNE, BPF_REG_7, SOCK_DGRAM, 1), 153 BPF_JMP_A(1), 154 BPF_JMP_IMM(BPF_JNE, BPF_REG_7, SOCK_STREAM, 12), 155 156 /* 1st_byte_of_user_ip4 == expected && */ 157 BPF_LDX_MEM(BPF_B, BPF_REG_7, BPF_REG_6, 158 offsetof(struct bpf_sock_addr, user_ip4)), 159 BPF_JMP_IMM(BPF_JNE, BPF_REG_7, ip4.u4_addr8[0], 10), 160 161 /* 1st_half_of_user_ip4 == expected && */ 162 BPF_LDX_MEM(BPF_H, BPF_REG_7, BPF_REG_6, 163 offsetof(struct bpf_sock_addr, user_ip4)), 164 BPF_JMP_IMM(BPF_JNE, BPF_REG_7, ip4.u4_addr16[0], 8), 165 166 /* whole_user_ip4 == expected) { */ 167 BPF_LDX_MEM(BPF_W, BPF_REG_7, BPF_REG_6, 168 offsetof(struct bpf_sock_addr, user_ip4)), 169 BPF_LD_IMM64(BPF_REG_8, ip4.u4_addr32), /* See [2]. */ 170 BPF_JMP_REG(BPF_JNE, BPF_REG_7, BPF_REG_8, 4), 171 172 /* user_ip4 = addr4_rw.sin_addr */ 173 BPF_MOV32_IMM(BPF_REG_7, addr4_rw.sin_addr.s_addr), 174 BPF_STX_MEM(BPF_W, BPF_REG_6, BPF_REG_7, 175 offsetof(struct bpf_sock_addr, user_ip4)), 176 177 /* user_port = addr4_rw.sin_port */ 178 BPF_MOV32_IMM(BPF_REG_7, addr4_rw.sin_port), 179 BPF_STX_MEM(BPF_W, BPF_REG_6, BPF_REG_7, 180 offsetof(struct bpf_sock_addr, user_port)), 181 /* } */ 182 183 /* return 1 */ 184 BPF_MOV64_IMM(BPF_REG_0, 1), 185 BPF_EXIT_INSN(), 186 }; 187 188 return load_insns(attach_type, insns, 189 sizeof(insns) / sizeof(struct bpf_insn), comment); 190 } 191 192 static int bind6_prog_load(enum bpf_attach_type attach_type, 193 const char *comment) 194 { 195 struct sockaddr_in6 addr6_rw; 196 struct in6_addr ip6; 197 198 if (inet_pton(AF_INET6, SERV6_IP, (void *)&ip6) != 1) { 199 log_err("Invalid IPv6: %s", SERV6_IP); 200 return -1; 201 } 202 203 if (mk_sockaddr(AF_INET6, SERV6_REWRITE_IP, SERV6_REWRITE_PORT, 204 (struct sockaddr *)&addr6_rw, sizeof(addr6_rw)) == -1) 205 return -1; 206 207 /* See [1]. */ 208 struct bpf_insn insns[] = { 209 BPF_MOV64_REG(BPF_REG_6, BPF_REG_1), 210 211 /* if (sk.family == AF_INET6 && */ 212 BPF_LDX_MEM(BPF_W, BPF_REG_7, BPF_REG_6, 213 offsetof(struct bpf_sock_addr, family)), 214 BPF_JMP_IMM(BPF_JNE, BPF_REG_7, AF_INET6, 18), 215 216 /* 5th_byte_of_user_ip6 == expected && */ 217 BPF_LDX_MEM(BPF_B, BPF_REG_7, BPF_REG_6, 218 offsetof(struct bpf_sock_addr, user_ip6[1])), 219 BPF_JMP_IMM(BPF_JNE, BPF_REG_7, ip6.s6_addr[4], 16), 220 221 /* 3rd_half_of_user_ip6 == expected && */ 222 BPF_LDX_MEM(BPF_H, BPF_REG_7, BPF_REG_6, 223 offsetof(struct bpf_sock_addr, user_ip6[1])), 224 BPF_JMP_IMM(BPF_JNE, BPF_REG_7, ip6.s6_addr16[2], 14), 225 226 /* last_word_of_user_ip6 == expected) { */ 227 BPF_LDX_MEM(BPF_W, BPF_REG_7, BPF_REG_6, 228 offsetof(struct bpf_sock_addr, user_ip6[3])), 229 BPF_LD_IMM64(BPF_REG_8, ip6.s6_addr32[3]), /* See [2]. */ 230 BPF_JMP_REG(BPF_JNE, BPF_REG_7, BPF_REG_8, 10), 231 232 233 #define STORE_IPV6_WORD(N) \ 234 BPF_MOV32_IMM(BPF_REG_7, addr6_rw.sin6_addr.s6_addr32[N]), \ 235 BPF_STX_MEM(BPF_W, BPF_REG_6, BPF_REG_7, \ 236 offsetof(struct bpf_sock_addr, user_ip6[N])) 237 238 /* user_ip6 = addr6_rw.sin6_addr */ 239 STORE_IPV6_WORD(0), 240 STORE_IPV6_WORD(1), 241 STORE_IPV6_WORD(2), 242 STORE_IPV6_WORD(3), 243 244 /* user_port = addr6_rw.sin6_port */ 245 BPF_MOV32_IMM(BPF_REG_7, addr6_rw.sin6_port), 246 BPF_STX_MEM(BPF_W, BPF_REG_6, BPF_REG_7, 247 offsetof(struct bpf_sock_addr, user_port)), 248 249 /* } */ 250 251 /* return 1 */ 252 BPF_MOV64_IMM(BPF_REG_0, 1), 253 BPF_EXIT_INSN(), 254 }; 255 256 return load_insns(attach_type, insns, 257 sizeof(insns) / sizeof(struct bpf_insn), comment); 258 } 259 260 static int connect_prog_load_path(const char *path, 261 enum bpf_attach_type attach_type, 262 const char *comment) 263 { 264 struct bpf_prog_load_attr attr; 265 struct bpf_object *obj; 266 int prog_fd; 267 268 memset(&attr, 0, sizeof(struct bpf_prog_load_attr)); 269 attr.file = path; 270 attr.prog_type = BPF_PROG_TYPE_CGROUP_SOCK_ADDR; 271 attr.expected_attach_type = attach_type; 272 273 if (bpf_prog_load_xattr(&attr, &obj, &prog_fd)) { 274 if (comment) 275 log_err(">>> Loading %s program at %s error.\n", 276 comment, path); 277 return -1; 278 } 279 280 return prog_fd; 281 } 282 283 static int connect4_prog_load(enum bpf_attach_type attach_type, 284 const char *comment) 285 { 286 return connect_prog_load_path(CONNECT4_PROG_PATH, attach_type, comment); 287 } 288 289 static int connect6_prog_load(enum bpf_attach_type attach_type, 290 const char *comment) 291 { 292 return connect_prog_load_path(CONNECT6_PROG_PATH, attach_type, comment); 293 } 294 295 static void print_ip_port(int sockfd, info_fn fn, const char *fmt) 296 { 297 char addr_buf[INET_NTOP_BUF]; 298 struct sockaddr_storage addr; 299 struct sockaddr_in6 *addr6; 300 struct sockaddr_in *addr4; 301 socklen_t addr_len; 302 unsigned short port; 303 void *nip; 304 305 addr_len = sizeof(struct sockaddr_storage); 306 memset(&addr, 0, addr_len); 307 308 if (fn(sockfd, (struct sockaddr *)&addr, (socklen_t *)&addr_len) == 0) { 309 if (addr.ss_family == AF_INET) { 310 addr4 = (struct sockaddr_in *)&addr; 311 nip = (void *)&addr4->sin_addr; 312 port = ntohs(addr4->sin_port); 313 } else if (addr.ss_family == AF_INET6) { 314 addr6 = (struct sockaddr_in6 *)&addr; 315 nip = (void *)&addr6->sin6_addr; 316 port = ntohs(addr6->sin6_port); 317 } else { 318 return; 319 } 320 const char *addr_str = 321 inet_ntop(addr.ss_family, nip, addr_buf, INET_NTOP_BUF); 322 printf(fmt, addr_str ? addr_str : "??", port); 323 } 324 } 325 326 static void print_local_ip_port(int sockfd, const char *fmt) 327 { 328 print_ip_port(sockfd, getsockname, fmt); 329 } 330 331 static void print_remote_ip_port(int sockfd, const char *fmt) 332 { 333 print_ip_port(sockfd, getpeername, fmt); 334 } 335 336 static int start_server(int type, const struct sockaddr_storage *addr, 337 socklen_t addr_len) 338 { 339 340 int fd; 341 342 fd = socket(addr->ss_family, type, 0); 343 if (fd == -1) { 344 log_err("Failed to create server socket"); 345 goto out; 346 } 347 348 if (bind(fd, (const struct sockaddr *)addr, addr_len) == -1) { 349 log_err("Failed to bind server socket"); 350 goto close_out; 351 } 352 353 if (type == SOCK_STREAM) { 354 if (listen(fd, 128) == -1) { 355 log_err("Failed to listen on server socket"); 356 goto close_out; 357 } 358 } 359 360 print_local_ip_port(fd, "\t Actual: bind(%s, %d)\n"); 361 362 goto out; 363 close_out: 364 close(fd); 365 fd = -1; 366 out: 367 return fd; 368 } 369 370 static int connect_to_server(int type, const struct sockaddr_storage *addr, 371 socklen_t addr_len) 372 { 373 int domain; 374 int fd; 375 376 domain = addr->ss_family; 377 378 if (domain != AF_INET && domain != AF_INET6) { 379 log_err("Unsupported address family"); 380 return -1; 381 } 382 383 fd = socket(domain, type, 0); 384 if (fd == -1) { 385 log_err("Failed to creating client socket"); 386 return -1; 387 } 388 389 if (connect(fd, (const struct sockaddr *)addr, addr_len) == -1) { 390 log_err("Fail to connect to server"); 391 goto err; 392 } 393 394 print_remote_ip_port(fd, "\t Actual: connect(%s, %d)"); 395 print_local_ip_port(fd, " from (%s, %d)\n"); 396 397 return 0; 398 err: 399 close(fd); 400 return -1; 401 } 402 403 static void print_test_case_num(int domain, int type) 404 { 405 static int test_num; 406 407 printf("Test case #%d (%s/%s):\n", ++test_num, 408 (domain == AF_INET ? "IPv4" : 409 domain == AF_INET6 ? "IPv6" : 410 "unknown_domain"), 411 (type == SOCK_STREAM ? "TCP" : 412 type == SOCK_DGRAM ? "UDP" : 413 "unknown_type")); 414 } 415 416 static int run_test_case(int domain, int type, const char *ip, 417 unsigned short port) 418 { 419 struct sockaddr_storage addr; 420 socklen_t addr_len = sizeof(addr); 421 int servfd = -1; 422 int err = 0; 423 424 print_test_case_num(domain, type); 425 426 if (mk_sockaddr(domain, ip, port, (struct sockaddr *)&addr, 427 addr_len) == -1) 428 return -1; 429 430 printf("\tRequested: bind(%s, %d) ..\n", ip, port); 431 servfd = start_server(type, &addr, addr_len); 432 if (servfd == -1) 433 goto err; 434 435 printf("\tRequested: connect(%s, %d) from (*, *) ..\n", ip, port); 436 if (connect_to_server(type, &addr, addr_len)) 437 goto err; 438 439 goto out; 440 err: 441 err = -1; 442 out: 443 close(servfd); 444 return err; 445 } 446 447 static void close_progs_fds(struct program *progs, size_t prog_cnt) 448 { 449 size_t i; 450 451 for (i = 0; i < prog_cnt; ++i) { 452 close(progs[i].fd); 453 progs[i].fd = -1; 454 } 455 } 456 457 static int load_and_attach_progs(int cgfd, struct program *progs, 458 size_t prog_cnt) 459 { 460 size_t i; 461 462 for (i = 0; i < prog_cnt; ++i) { 463 printf("Load %s with invalid type (can pollute stderr) ", 464 progs[i].name); 465 fflush(stdout); 466 progs[i].fd = progs[i].loadfn(progs[i].invalid_type, NULL); 467 if (progs[i].fd != -1) { 468 log_err("Load with invalid type accepted for %s", 469 progs[i].name); 470 goto err; 471 } 472 printf("... REJECTED\n"); 473 474 printf("Load %s with valid type", progs[i].name); 475 progs[i].fd = progs[i].loadfn(progs[i].type, progs[i].name); 476 if (progs[i].fd == -1) { 477 log_err("Failed to load program %s", progs[i].name); 478 goto err; 479 } 480 printf(" ... OK\n"); 481 482 printf("Attach %s with invalid type", progs[i].name); 483 if (bpf_prog_attach(progs[i].fd, cgfd, progs[i].invalid_type, 484 BPF_F_ALLOW_OVERRIDE) != -1) { 485 log_err("Attach with invalid type accepted for %s", 486 progs[i].name); 487 goto err; 488 } 489 printf(" ... REJECTED\n"); 490 491 printf("Attach %s with valid type", progs[i].name); 492 if (bpf_prog_attach(progs[i].fd, cgfd, progs[i].type, 493 BPF_F_ALLOW_OVERRIDE) == -1) { 494 log_err("Failed to attach program %s", progs[i].name); 495 goto err; 496 } 497 printf(" ... OK\n"); 498 } 499 500 return 0; 501 err: 502 close_progs_fds(progs, prog_cnt); 503 return -1; 504 } 505 506 static int run_domain_test(int domain, int cgfd, struct program *progs, 507 size_t prog_cnt, const char *ip, unsigned short port) 508 { 509 int err = 0; 510 511 if (load_and_attach_progs(cgfd, progs, prog_cnt) == -1) 512 goto err; 513 514 if (run_test_case(domain, SOCK_STREAM, ip, port) == -1) 515 goto err; 516 517 if (run_test_case(domain, SOCK_DGRAM, ip, port) == -1) 518 goto err; 519 520 goto out; 521 err: 522 err = -1; 523 out: 524 close_progs_fds(progs, prog_cnt); 525 return err; 526 } 527 528 static int run_test(void) 529 { 530 size_t inet6_prog_cnt; 531 size_t inet_prog_cnt; 532 int cgfd = -1; 533 int err = 0; 534 535 struct program inet6_progs[] = { 536 {BPF_CGROUP_INET6_BIND, bind6_prog_load, -1, "bind6", 537 BPF_CGROUP_INET4_BIND}, 538 {BPF_CGROUP_INET6_CONNECT, connect6_prog_load, -1, "connect6", 539 BPF_CGROUP_INET4_CONNECT}, 540 }; 541 inet6_prog_cnt = sizeof(inet6_progs) / sizeof(struct program); 542 543 struct program inet_progs[] = { 544 {BPF_CGROUP_INET4_BIND, bind4_prog_load, -1, "bind4", 545 BPF_CGROUP_INET6_BIND}, 546 {BPF_CGROUP_INET4_CONNECT, connect4_prog_load, -1, "connect4", 547 BPF_CGROUP_INET6_CONNECT}, 548 }; 549 inet_prog_cnt = sizeof(inet_progs) / sizeof(struct program); 550 551 if (setup_cgroup_environment()) 552 goto err; 553 554 cgfd = create_and_get_cgroup(CG_PATH); 555 if (!cgfd) 556 goto err; 557 558 if (join_cgroup(CG_PATH)) 559 goto err; 560 561 if (run_domain_test(AF_INET, cgfd, inet_progs, inet_prog_cnt, SERV4_IP, 562 SERV4_PORT) == -1) 563 goto err; 564 565 if (run_domain_test(AF_INET6, cgfd, inet6_progs, inet6_prog_cnt, 566 SERV6_IP, SERV6_PORT) == -1) 567 goto err; 568 569 goto out; 570 err: 571 err = -1; 572 out: 573 close(cgfd); 574 cleanup_cgroup_environment(); 575 printf(err ? "### FAIL\n" : "### SUCCESS\n"); 576 return err; 577 } 578 579 int main(int argc, char **argv) 580 { 581 if (argc < 2) { 582 fprintf(stderr, 583 "%s has to be run via %s.sh. Skip direct run.\n", 584 argv[0], argv[0]); 585 exit(0); 586 } 587 return run_test(); 588 } 589