1 // SPDX-License-Identifier: GPL-2.0 2 3 #define _GNU_SOURCE 4 5 #include <errno.h> 6 #include <limits.h> 7 #include <fcntl.h> 8 #include <string.h> 9 #include <stdbool.h> 10 #include <stdint.h> 11 #include <stdio.h> 12 #include <stdlib.h> 13 #include <strings.h> 14 #include <unistd.h> 15 16 #include <sys/poll.h> 17 #include <sys/sendfile.h> 18 #include <sys/stat.h> 19 #include <sys/socket.h> 20 #include <sys/types.h> 21 #include <sys/mman.h> 22 23 #include <netdb.h> 24 #include <netinet/in.h> 25 26 #include <linux/tcp.h> 27 28 extern int optind; 29 30 #ifndef IPPROTO_MPTCP 31 #define IPPROTO_MPTCP 262 32 #endif 33 #ifndef TCP_ULP 34 #define TCP_ULP 31 35 #endif 36 37 static bool listen_mode; 38 static int poll_timeout; 39 40 enum cfg_mode { 41 CFG_MODE_POLL, 42 CFG_MODE_MMAP, 43 CFG_MODE_SENDFILE, 44 }; 45 46 static enum cfg_mode cfg_mode = CFG_MODE_POLL; 47 static const char *cfg_host; 48 static const char *cfg_port = "12000"; 49 static int cfg_sock_proto = IPPROTO_MPTCP; 50 static bool tcpulp_audit; 51 static int pf = AF_INET; 52 static int cfg_sndbuf; 53 54 static void die_usage(void) 55 { 56 fprintf(stderr, "Usage: mptcp_connect [-6] [-u] [-s MPTCP|TCP] [-p port] -m mode]" 57 "[ -l ] [ -t timeout ] connect_address\n"); 58 exit(1); 59 } 60 61 static const char *getxinfo_strerr(int err) 62 { 63 if (err == EAI_SYSTEM) 64 return strerror(errno); 65 66 return gai_strerror(err); 67 } 68 69 static void xgetnameinfo(const struct sockaddr *addr, socklen_t addrlen, 70 char *host, socklen_t hostlen, 71 char *serv, socklen_t servlen) 72 { 73 int flags = NI_NUMERICHOST | NI_NUMERICSERV; 74 int err = getnameinfo(addr, addrlen, host, hostlen, serv, servlen, 75 flags); 76 77 if (err) { 78 const char *errstr = getxinfo_strerr(err); 79 80 fprintf(stderr, "Fatal: getnameinfo: %s\n", errstr); 81 exit(1); 82 } 83 } 84 85 static void xgetaddrinfo(const char *node, const char *service, 86 const struct addrinfo *hints, 87 struct addrinfo **res) 88 { 89 int err = getaddrinfo(node, service, hints, res); 90 91 if (err) { 92 const char *errstr = getxinfo_strerr(err); 93 94 fprintf(stderr, "Fatal: getaddrinfo(%s:%s): %s\n", 95 node ? node : "", service ? service : "", errstr); 96 exit(1); 97 } 98 } 99 100 static void set_sndbuf(int fd, unsigned int size) 101 { 102 int err; 103 104 err = setsockopt(fd, SOL_SOCKET, SO_SNDBUF, &size, sizeof(size)); 105 if (err) { 106 perror("set SO_SNDBUF"); 107 exit(1); 108 } 109 } 110 111 static int sock_listen_mptcp(const char * const listenaddr, 112 const char * const port) 113 { 114 int sock; 115 struct addrinfo hints = { 116 .ai_protocol = IPPROTO_TCP, 117 .ai_socktype = SOCK_STREAM, 118 .ai_flags = AI_PASSIVE | AI_NUMERICHOST 119 }; 120 121 hints.ai_family = pf; 122 123 struct addrinfo *a, *addr; 124 int one = 1; 125 126 xgetaddrinfo(listenaddr, port, &hints, &addr); 127 hints.ai_family = pf; 128 129 for (a = addr; a; a = a->ai_next) { 130 sock = socket(a->ai_family, a->ai_socktype, cfg_sock_proto); 131 if (sock < 0) 132 continue; 133 134 if (-1 == setsockopt(sock, SOL_SOCKET, SO_REUSEADDR, &one, 135 sizeof(one))) 136 perror("setsockopt"); 137 138 if (bind(sock, a->ai_addr, a->ai_addrlen) == 0) 139 break; /* success */ 140 141 perror("bind"); 142 close(sock); 143 sock = -1; 144 } 145 146 freeaddrinfo(addr); 147 148 if (sock < 0) { 149 fprintf(stderr, "Could not create listen socket\n"); 150 return sock; 151 } 152 153 if (listen(sock, 20)) { 154 perror("listen"); 155 close(sock); 156 return -1; 157 } 158 159 return sock; 160 } 161 162 static bool sock_test_tcpulp(const char * const remoteaddr, 163 const char * const port) 164 { 165 struct addrinfo hints = { 166 .ai_protocol = IPPROTO_TCP, 167 .ai_socktype = SOCK_STREAM, 168 }; 169 struct addrinfo *a, *addr; 170 int sock = -1, ret = 0; 171 bool test_pass = false; 172 173 hints.ai_family = AF_INET; 174 175 xgetaddrinfo(remoteaddr, port, &hints, &addr); 176 for (a = addr; a; a = a->ai_next) { 177 sock = socket(a->ai_family, a->ai_socktype, IPPROTO_TCP); 178 if (sock < 0) { 179 perror("socket"); 180 continue; 181 } 182 ret = setsockopt(sock, IPPROTO_TCP, TCP_ULP, "mptcp", 183 sizeof("mptcp")); 184 if (ret == -1 && errno == EOPNOTSUPP) 185 test_pass = true; 186 close(sock); 187 188 if (test_pass) 189 break; 190 if (!ret) 191 fprintf(stderr, 192 "setsockopt(TCP_ULP) returned 0\n"); 193 else 194 perror("setsockopt(TCP_ULP)"); 195 } 196 return test_pass; 197 } 198 199 static int sock_connect_mptcp(const char * const remoteaddr, 200 const char * const port, int proto) 201 { 202 struct addrinfo hints = { 203 .ai_protocol = IPPROTO_TCP, 204 .ai_socktype = SOCK_STREAM, 205 }; 206 struct addrinfo *a, *addr; 207 int sock = -1; 208 209 hints.ai_family = pf; 210 211 xgetaddrinfo(remoteaddr, port, &hints, &addr); 212 for (a = addr; a; a = a->ai_next) { 213 sock = socket(a->ai_family, a->ai_socktype, proto); 214 if (sock < 0) { 215 perror("socket"); 216 continue; 217 } 218 219 if (connect(sock, a->ai_addr, a->ai_addrlen) == 0) 220 break; /* success */ 221 222 perror("connect()"); 223 close(sock); 224 sock = -1; 225 } 226 227 freeaddrinfo(addr); 228 return sock; 229 } 230 231 static size_t do_rnd_write(const int fd, char *buf, const size_t len) 232 { 233 unsigned int do_w; 234 ssize_t bw; 235 236 do_w = rand() & 0xffff; 237 if (do_w == 0 || do_w > len) 238 do_w = len; 239 240 bw = write(fd, buf, do_w); 241 if (bw < 0) 242 perror("write"); 243 244 return bw; 245 } 246 247 static size_t do_write(const int fd, char *buf, const size_t len) 248 { 249 size_t offset = 0; 250 251 while (offset < len) { 252 size_t written; 253 ssize_t bw; 254 255 bw = write(fd, buf + offset, len - offset); 256 if (bw < 0) { 257 perror("write"); 258 return 0; 259 } 260 261 written = (size_t)bw; 262 offset += written; 263 } 264 265 return offset; 266 } 267 268 static ssize_t do_rnd_read(const int fd, char *buf, const size_t len) 269 { 270 size_t cap = rand(); 271 272 cap &= 0xffff; 273 274 if (cap == 0) 275 cap = 1; 276 else if (cap > len) 277 cap = len; 278 279 return read(fd, buf, cap); 280 } 281 282 static void set_nonblock(int fd) 283 { 284 int flags = fcntl(fd, F_GETFL); 285 286 if (flags == -1) 287 return; 288 289 fcntl(fd, F_SETFL, flags | O_NONBLOCK); 290 } 291 292 static int copyfd_io_poll(int infd, int peerfd, int outfd) 293 { 294 struct pollfd fds = { 295 .fd = peerfd, 296 .events = POLLIN | POLLOUT, 297 }; 298 unsigned int woff = 0, wlen = 0; 299 char wbuf[8192]; 300 301 set_nonblock(peerfd); 302 303 for (;;) { 304 char rbuf[8192]; 305 ssize_t len; 306 307 if (fds.events == 0) 308 break; 309 310 switch (poll(&fds, 1, poll_timeout)) { 311 case -1: 312 if (errno == EINTR) 313 continue; 314 perror("poll"); 315 return 1; 316 case 0: 317 fprintf(stderr, "%s: poll timed out (events: " 318 "POLLIN %u, POLLOUT %u)\n", __func__, 319 fds.events & POLLIN, fds.events & POLLOUT); 320 return 2; 321 } 322 323 if (fds.revents & POLLIN) { 324 len = do_rnd_read(peerfd, rbuf, sizeof(rbuf)); 325 if (len == 0) { 326 /* no more data to receive: 327 * peer has closed its write side 328 */ 329 fds.events &= ~POLLIN; 330 331 if ((fds.events & POLLOUT) == 0) 332 /* and nothing more to send */ 333 break; 334 335 /* Else, still have data to transmit */ 336 } else if (len < 0) { 337 perror("read"); 338 return 3; 339 } 340 341 do_write(outfd, rbuf, len); 342 } 343 344 if (fds.revents & POLLOUT) { 345 if (wlen == 0) { 346 woff = 0; 347 wlen = read(infd, wbuf, sizeof(wbuf)); 348 } 349 350 if (wlen > 0) { 351 ssize_t bw; 352 353 bw = do_rnd_write(peerfd, wbuf + woff, wlen); 354 if (bw < 0) 355 return 111; 356 357 woff += bw; 358 wlen -= bw; 359 } else if (wlen == 0) { 360 /* We have no more data to send. */ 361 fds.events &= ~POLLOUT; 362 363 if ((fds.events & POLLIN) == 0) 364 /* ... and peer also closed already */ 365 break; 366 367 /* ... but we still receive. 368 * Close our write side. 369 */ 370 shutdown(peerfd, SHUT_WR); 371 } else { 372 if (errno == EINTR) 373 continue; 374 perror("read"); 375 return 4; 376 } 377 } 378 379 if (fds.revents & (POLLERR | POLLNVAL)) { 380 fprintf(stderr, "Unexpected revents: " 381 "POLLERR/POLLNVAL(%x)\n", fds.revents); 382 return 5; 383 } 384 } 385 386 close(peerfd); 387 return 0; 388 } 389 390 static int do_recvfile(int infd, int outfd) 391 { 392 ssize_t r; 393 394 do { 395 char buf[16384]; 396 397 r = do_rnd_read(infd, buf, sizeof(buf)); 398 if (r > 0) { 399 if (write(outfd, buf, r) != r) 400 break; 401 } else if (r < 0) { 402 perror("read"); 403 } 404 } while (r > 0); 405 406 return (int)r; 407 } 408 409 static int do_mmap(int infd, int outfd, unsigned int size) 410 { 411 char *inbuf = mmap(NULL, size, PROT_READ, MAP_SHARED, infd, 0); 412 ssize_t ret = 0, off = 0; 413 size_t rem; 414 415 if (inbuf == MAP_FAILED) { 416 perror("mmap"); 417 return 1; 418 } 419 420 rem = size; 421 422 while (rem > 0) { 423 ret = write(outfd, inbuf + off, rem); 424 425 if (ret < 0) { 426 perror("write"); 427 break; 428 } 429 430 off += ret; 431 rem -= ret; 432 } 433 434 munmap(inbuf, size); 435 return rem; 436 } 437 438 static int get_infd_size(int fd) 439 { 440 struct stat sb; 441 ssize_t count; 442 int err; 443 444 err = fstat(fd, &sb); 445 if (err < 0) { 446 perror("fstat"); 447 return -1; 448 } 449 450 if ((sb.st_mode & S_IFMT) != S_IFREG) { 451 fprintf(stderr, "%s: stdin is not a regular file\n", __func__); 452 return -2; 453 } 454 455 count = sb.st_size; 456 if (count > INT_MAX) { 457 fprintf(stderr, "File too large: %zu\n", count); 458 return -3; 459 } 460 461 return (int)count; 462 } 463 464 static int do_sendfile(int infd, int outfd, unsigned int count) 465 { 466 while (count > 0) { 467 ssize_t r; 468 469 r = sendfile(outfd, infd, NULL, count); 470 if (r < 0) { 471 perror("sendfile"); 472 return 3; 473 } 474 475 count -= r; 476 } 477 478 return 0; 479 } 480 481 static int copyfd_io_mmap(int infd, int peerfd, int outfd, 482 unsigned int size) 483 { 484 int err; 485 486 if (listen_mode) { 487 err = do_recvfile(peerfd, outfd); 488 if (err) 489 return err; 490 491 err = do_mmap(infd, peerfd, size); 492 } else { 493 err = do_mmap(infd, peerfd, size); 494 if (err) 495 return err; 496 497 shutdown(peerfd, SHUT_WR); 498 499 err = do_recvfile(peerfd, outfd); 500 } 501 502 return err; 503 } 504 505 static int copyfd_io_sendfile(int infd, int peerfd, int outfd, 506 unsigned int size) 507 { 508 int err; 509 510 if (listen_mode) { 511 err = do_recvfile(peerfd, outfd); 512 if (err) 513 return err; 514 515 err = do_sendfile(infd, peerfd, size); 516 } else { 517 err = do_sendfile(infd, peerfd, size); 518 if (err) 519 return err; 520 err = do_recvfile(peerfd, outfd); 521 } 522 523 return err; 524 } 525 526 static int copyfd_io(int infd, int peerfd, int outfd) 527 { 528 int file_size; 529 530 switch (cfg_mode) { 531 case CFG_MODE_POLL: 532 return copyfd_io_poll(infd, peerfd, outfd); 533 case CFG_MODE_MMAP: 534 file_size = get_infd_size(infd); 535 if (file_size < 0) 536 return file_size; 537 return copyfd_io_mmap(infd, peerfd, outfd, file_size); 538 case CFG_MODE_SENDFILE: 539 file_size = get_infd_size(infd); 540 if (file_size < 0) 541 return file_size; 542 return copyfd_io_sendfile(infd, peerfd, outfd, file_size); 543 } 544 545 fprintf(stderr, "Invalid mode %d\n", cfg_mode); 546 547 die_usage(); 548 return 1; 549 } 550 551 static void check_sockaddr(int pf, struct sockaddr_storage *ss, 552 socklen_t salen) 553 { 554 struct sockaddr_in6 *sin6; 555 struct sockaddr_in *sin; 556 socklen_t wanted_size = 0; 557 558 switch (pf) { 559 case AF_INET: 560 wanted_size = sizeof(*sin); 561 sin = (void *)ss; 562 if (!sin->sin_port) 563 fprintf(stderr, "accept: something wrong: ip connection from port 0"); 564 break; 565 case AF_INET6: 566 wanted_size = sizeof(*sin6); 567 sin6 = (void *)ss; 568 if (!sin6->sin6_port) 569 fprintf(stderr, "accept: something wrong: ipv6 connection from port 0"); 570 break; 571 default: 572 fprintf(stderr, "accept: Unknown pf %d, salen %u\n", pf, salen); 573 return; 574 } 575 576 if (salen != wanted_size) 577 fprintf(stderr, "accept: size mismatch, got %d expected %d\n", 578 (int)salen, wanted_size); 579 580 if (ss->ss_family != pf) 581 fprintf(stderr, "accept: pf mismatch, expect %d, ss_family is %d\n", 582 (int)ss->ss_family, pf); 583 } 584 585 static void check_getpeername(int fd, struct sockaddr_storage *ss, socklen_t salen) 586 { 587 struct sockaddr_storage peerss; 588 socklen_t peersalen = sizeof(peerss); 589 590 if (getpeername(fd, (struct sockaddr *)&peerss, &peersalen) < 0) { 591 perror("getpeername"); 592 return; 593 } 594 595 if (peersalen != salen) { 596 fprintf(stderr, "%s: %d vs %d\n", __func__, peersalen, salen); 597 return; 598 } 599 600 if (memcmp(ss, &peerss, peersalen)) { 601 char a[INET6_ADDRSTRLEN]; 602 char b[INET6_ADDRSTRLEN]; 603 char c[INET6_ADDRSTRLEN]; 604 char d[INET6_ADDRSTRLEN]; 605 606 xgetnameinfo((struct sockaddr *)ss, salen, 607 a, sizeof(a), b, sizeof(b)); 608 609 xgetnameinfo((struct sockaddr *)&peerss, peersalen, 610 c, sizeof(c), d, sizeof(d)); 611 612 fprintf(stderr, "%s: memcmp failure: accept %s vs peername %s, %s vs %s salen %d vs %d\n", 613 __func__, a, c, b, d, peersalen, salen); 614 } 615 } 616 617 static void check_getpeername_connect(int fd) 618 { 619 struct sockaddr_storage ss; 620 socklen_t salen = sizeof(ss); 621 char a[INET6_ADDRSTRLEN]; 622 char b[INET6_ADDRSTRLEN]; 623 624 if (getpeername(fd, (struct sockaddr *)&ss, &salen) < 0) { 625 perror("getpeername"); 626 return; 627 } 628 629 xgetnameinfo((struct sockaddr *)&ss, salen, 630 a, sizeof(a), b, sizeof(b)); 631 632 if (strcmp(cfg_host, a) || strcmp(cfg_port, b)) 633 fprintf(stderr, "%s: %s vs %s, %s vs %s\n", __func__, 634 cfg_host, a, cfg_port, b); 635 } 636 637 static void maybe_close(int fd) 638 { 639 unsigned int r = rand(); 640 641 if (r & 1) 642 close(fd); 643 } 644 645 int main_loop_s(int listensock) 646 { 647 struct sockaddr_storage ss; 648 struct pollfd polls; 649 socklen_t salen; 650 int remotesock; 651 652 polls.fd = listensock; 653 polls.events = POLLIN; 654 655 switch (poll(&polls, 1, poll_timeout)) { 656 case -1: 657 perror("poll"); 658 return 1; 659 case 0: 660 fprintf(stderr, "%s: timed out\n", __func__); 661 close(listensock); 662 return 2; 663 } 664 665 salen = sizeof(ss); 666 remotesock = accept(listensock, (struct sockaddr *)&ss, &salen); 667 if (remotesock >= 0) { 668 maybe_close(listensock); 669 check_sockaddr(pf, &ss, salen); 670 check_getpeername(remotesock, &ss, salen); 671 672 return copyfd_io(0, remotesock, 1); 673 } 674 675 perror("accept"); 676 677 return 1; 678 } 679 680 static void init_rng(void) 681 { 682 int fd = open("/dev/urandom", O_RDONLY); 683 unsigned int foo; 684 685 if (fd > 0) { 686 int ret = read(fd, &foo, sizeof(foo)); 687 688 if (ret < 0) 689 srand(fd + foo); 690 close(fd); 691 } 692 693 srand(foo); 694 } 695 696 int main_loop(void) 697 { 698 int fd; 699 700 /* listener is ready. */ 701 fd = sock_connect_mptcp(cfg_host, cfg_port, cfg_sock_proto); 702 if (fd < 0) 703 return 2; 704 705 check_getpeername_connect(fd); 706 707 if (cfg_sndbuf) 708 set_sndbuf(fd, cfg_sndbuf); 709 710 return copyfd_io(0, fd, 1); 711 } 712 713 int parse_proto(const char *proto) 714 { 715 if (!strcasecmp(proto, "MPTCP")) 716 return IPPROTO_MPTCP; 717 if (!strcasecmp(proto, "TCP")) 718 return IPPROTO_TCP; 719 720 fprintf(stderr, "Unknown protocol: %s\n.", proto); 721 die_usage(); 722 723 /* silence compiler warning */ 724 return 0; 725 } 726 727 int parse_mode(const char *mode) 728 { 729 if (!strcasecmp(mode, "poll")) 730 return CFG_MODE_POLL; 731 if (!strcasecmp(mode, "mmap")) 732 return CFG_MODE_MMAP; 733 if (!strcasecmp(mode, "sendfile")) 734 return CFG_MODE_SENDFILE; 735 736 fprintf(stderr, "Unknown test mode: %s\n", mode); 737 fprintf(stderr, "Supported modes are:\n"); 738 fprintf(stderr, "\t\t\"poll\" - interleaved read/write using poll()\n"); 739 fprintf(stderr, "\t\t\"mmap\" - send entire input file (mmap+write), then read response (-l will read input first)\n"); 740 fprintf(stderr, "\t\t\"sendfile\" - send entire input file (sendfile), then read response (-l will read input first)\n"); 741 742 die_usage(); 743 744 /* silence compiler warning */ 745 return 0; 746 } 747 748 int parse_sndbuf(const char *size) 749 { 750 unsigned long s; 751 752 errno = 0; 753 754 s = strtoul(size, NULL, 0); 755 756 if (errno) { 757 fprintf(stderr, "Invalid sndbuf size %s (%s)\n", 758 size, strerror(errno)); 759 die_usage(); 760 } 761 762 if (s > INT_MAX) { 763 fprintf(stderr, "Invalid sndbuf size %s (%s)\n", 764 size, strerror(ERANGE)); 765 die_usage(); 766 } 767 768 cfg_sndbuf = s; 769 770 return 0; 771 } 772 773 static void parse_opts(int argc, char **argv) 774 { 775 int c; 776 777 while ((c = getopt(argc, argv, "6lp:s:hut:m:b:")) != -1) { 778 switch (c) { 779 case 'l': 780 listen_mode = true; 781 break; 782 case 'p': 783 cfg_port = optarg; 784 break; 785 case 's': 786 cfg_sock_proto = parse_proto(optarg); 787 break; 788 case 'h': 789 die_usage(); 790 break; 791 case 'u': 792 tcpulp_audit = true; 793 break; 794 case '6': 795 pf = AF_INET6; 796 break; 797 case 't': 798 poll_timeout = atoi(optarg) * 1000; 799 if (poll_timeout <= 0) 800 poll_timeout = -1; 801 break; 802 case 'm': 803 cfg_mode = parse_mode(optarg); 804 break; 805 case 'b': 806 cfg_sndbuf = parse_sndbuf(optarg); 807 break; 808 } 809 } 810 811 if (optind + 1 != argc) 812 die_usage(); 813 cfg_host = argv[optind]; 814 815 if (strchr(cfg_host, ':')) 816 pf = AF_INET6; 817 } 818 819 int main(int argc, char *argv[]) 820 { 821 init_rng(); 822 823 parse_opts(argc, argv); 824 825 if (tcpulp_audit) 826 return sock_test_tcpulp(cfg_host, cfg_port) ? 0 : 1; 827 828 if (listen_mode) { 829 int fd = sock_listen_mptcp(cfg_host, cfg_port); 830 831 if (fd < 0) 832 return 1; 833 834 if (cfg_sndbuf) 835 set_sndbuf(fd, cfg_sndbuf); 836 837 return main_loop_s(fd); 838 } 839 840 return main_loop(); 841 } 842