1 // SPDX-License-Identifier: GPL-2.0 2 3 #define _GNU_SOURCE 4 5 #include <errno.h> 6 #include <limits.h> 7 #include <fcntl.h> 8 #include <string.h> 9 #include <stdbool.h> 10 #include <stdint.h> 11 #include <stdio.h> 12 #include <stdlib.h> 13 #include <strings.h> 14 #include <unistd.h> 15 16 #include <sys/poll.h> 17 #include <sys/sendfile.h> 18 #include <sys/stat.h> 19 #include <sys/socket.h> 20 #include <sys/types.h> 21 #include <sys/mman.h> 22 23 #include <netdb.h> 24 #include <netinet/in.h> 25 26 #include <linux/tcp.h> 27 28 extern int optind; 29 30 #ifndef IPPROTO_MPTCP 31 #define IPPROTO_MPTCP 262 32 #endif 33 #ifndef TCP_ULP 34 #define TCP_ULP 31 35 #endif 36 37 static int poll_timeout = 10 * 1000; 38 static bool listen_mode; 39 40 enum cfg_mode { 41 CFG_MODE_POLL, 42 CFG_MODE_MMAP, 43 CFG_MODE_SENDFILE, 44 }; 45 46 static enum cfg_mode cfg_mode = CFG_MODE_POLL; 47 static const char *cfg_host; 48 static const char *cfg_port = "12000"; 49 static int cfg_sock_proto = IPPROTO_MPTCP; 50 static bool tcpulp_audit; 51 static int pf = AF_INET; 52 static int cfg_sndbuf; 53 static int cfg_rcvbuf; 54 55 static void die_usage(void) 56 { 57 fprintf(stderr, "Usage: mptcp_connect [-6] [-u] [-s MPTCP|TCP] [-p port] [-m mode]" 58 "[-l] connect_address\n"); 59 fprintf(stderr, "\t-6 use ipv6\n"); 60 fprintf(stderr, "\t-t num -- set poll timeout to num\n"); 61 fprintf(stderr, "\t-S num -- set SO_SNDBUF to num\n"); 62 fprintf(stderr, "\t-R num -- set SO_RCVBUF to num\n"); 63 fprintf(stderr, "\t-p num -- use port num\n"); 64 fprintf(stderr, "\t-m [MPTCP|TCP] -- use tcp or mptcp sockets\n"); 65 fprintf(stderr, "\t-s [mmap|poll] -- use poll (default) or mmap\n"); 66 fprintf(stderr, "\t-u -- check mptcp ulp\n"); 67 exit(1); 68 } 69 70 static const char *getxinfo_strerr(int err) 71 { 72 if (err == EAI_SYSTEM) 73 return strerror(errno); 74 75 return gai_strerror(err); 76 } 77 78 static void xgetnameinfo(const struct sockaddr *addr, socklen_t addrlen, 79 char *host, socklen_t hostlen, 80 char *serv, socklen_t servlen) 81 { 82 int flags = NI_NUMERICHOST | NI_NUMERICSERV; 83 int err = getnameinfo(addr, addrlen, host, hostlen, serv, servlen, 84 flags); 85 86 if (err) { 87 const char *errstr = getxinfo_strerr(err); 88 89 fprintf(stderr, "Fatal: getnameinfo: %s\n", errstr); 90 exit(1); 91 } 92 } 93 94 static void xgetaddrinfo(const char *node, const char *service, 95 const struct addrinfo *hints, 96 struct addrinfo **res) 97 { 98 int err = getaddrinfo(node, service, hints, res); 99 100 if (err) { 101 const char *errstr = getxinfo_strerr(err); 102 103 fprintf(stderr, "Fatal: getaddrinfo(%s:%s): %s\n", 104 node ? node : "", service ? service : "", errstr); 105 exit(1); 106 } 107 } 108 109 static void set_rcvbuf(int fd, unsigned int size) 110 { 111 int err; 112 113 err = setsockopt(fd, SOL_SOCKET, SO_RCVBUF, &size, sizeof(size)); 114 if (err) { 115 perror("set SO_RCVBUF"); 116 exit(1); 117 } 118 } 119 120 static void set_sndbuf(int fd, unsigned int size) 121 { 122 int err; 123 124 err = setsockopt(fd, SOL_SOCKET, SO_SNDBUF, &size, sizeof(size)); 125 if (err) { 126 perror("set SO_SNDBUF"); 127 exit(1); 128 } 129 } 130 131 static int sock_listen_mptcp(const char * const listenaddr, 132 const char * const port) 133 { 134 int sock; 135 struct addrinfo hints = { 136 .ai_protocol = IPPROTO_TCP, 137 .ai_socktype = SOCK_STREAM, 138 .ai_flags = AI_PASSIVE | AI_NUMERICHOST 139 }; 140 141 hints.ai_family = pf; 142 143 struct addrinfo *a, *addr; 144 int one = 1; 145 146 xgetaddrinfo(listenaddr, port, &hints, &addr); 147 hints.ai_family = pf; 148 149 for (a = addr; a; a = a->ai_next) { 150 sock = socket(a->ai_family, a->ai_socktype, cfg_sock_proto); 151 if (sock < 0) 152 continue; 153 154 if (-1 == setsockopt(sock, SOL_SOCKET, SO_REUSEADDR, &one, 155 sizeof(one))) 156 perror("setsockopt"); 157 158 if (bind(sock, a->ai_addr, a->ai_addrlen) == 0) 159 break; /* success */ 160 161 perror("bind"); 162 close(sock); 163 sock = -1; 164 } 165 166 freeaddrinfo(addr); 167 168 if (sock < 0) { 169 fprintf(stderr, "Could not create listen socket\n"); 170 return sock; 171 } 172 173 if (listen(sock, 20)) { 174 perror("listen"); 175 close(sock); 176 return -1; 177 } 178 179 return sock; 180 } 181 182 static bool sock_test_tcpulp(const char * const remoteaddr, 183 const char * const port) 184 { 185 struct addrinfo hints = { 186 .ai_protocol = IPPROTO_TCP, 187 .ai_socktype = SOCK_STREAM, 188 }; 189 struct addrinfo *a, *addr; 190 int sock = -1, ret = 0; 191 bool test_pass = false; 192 193 hints.ai_family = AF_INET; 194 195 xgetaddrinfo(remoteaddr, port, &hints, &addr); 196 for (a = addr; a; a = a->ai_next) { 197 sock = socket(a->ai_family, a->ai_socktype, IPPROTO_TCP); 198 if (sock < 0) { 199 perror("socket"); 200 continue; 201 } 202 ret = setsockopt(sock, IPPROTO_TCP, TCP_ULP, "mptcp", 203 sizeof("mptcp")); 204 if (ret == -1 && errno == EOPNOTSUPP) 205 test_pass = true; 206 close(sock); 207 208 if (test_pass) 209 break; 210 if (!ret) 211 fprintf(stderr, 212 "setsockopt(TCP_ULP) returned 0\n"); 213 else 214 perror("setsockopt(TCP_ULP)"); 215 } 216 return test_pass; 217 } 218 219 static int sock_connect_mptcp(const char * const remoteaddr, 220 const char * const port, int proto) 221 { 222 struct addrinfo hints = { 223 .ai_protocol = IPPROTO_TCP, 224 .ai_socktype = SOCK_STREAM, 225 }; 226 struct addrinfo *a, *addr; 227 int sock = -1; 228 229 hints.ai_family = pf; 230 231 xgetaddrinfo(remoteaddr, port, &hints, &addr); 232 for (a = addr; a; a = a->ai_next) { 233 sock = socket(a->ai_family, a->ai_socktype, proto); 234 if (sock < 0) { 235 perror("socket"); 236 continue; 237 } 238 239 if (connect(sock, a->ai_addr, a->ai_addrlen) == 0) 240 break; /* success */ 241 242 perror("connect()"); 243 close(sock); 244 sock = -1; 245 } 246 247 freeaddrinfo(addr); 248 return sock; 249 } 250 251 static size_t do_rnd_write(const int fd, char *buf, const size_t len) 252 { 253 unsigned int do_w; 254 ssize_t bw; 255 256 do_w = rand() & 0xffff; 257 if (do_w == 0 || do_w > len) 258 do_w = len; 259 260 bw = write(fd, buf, do_w); 261 if (bw < 0) 262 perror("write"); 263 264 return bw; 265 } 266 267 static size_t do_write(const int fd, char *buf, const size_t len) 268 { 269 size_t offset = 0; 270 271 while (offset < len) { 272 size_t written; 273 ssize_t bw; 274 275 bw = write(fd, buf + offset, len - offset); 276 if (bw < 0) { 277 perror("write"); 278 return 0; 279 } 280 281 written = (size_t)bw; 282 offset += written; 283 } 284 285 return offset; 286 } 287 288 static ssize_t do_rnd_read(const int fd, char *buf, const size_t len) 289 { 290 size_t cap = rand(); 291 292 cap &= 0xffff; 293 294 if (cap == 0) 295 cap = 1; 296 else if (cap > len) 297 cap = len; 298 299 return read(fd, buf, cap); 300 } 301 302 static void set_nonblock(int fd) 303 { 304 int flags = fcntl(fd, F_GETFL); 305 306 if (flags == -1) 307 return; 308 309 fcntl(fd, F_SETFL, flags | O_NONBLOCK); 310 } 311 312 static int copyfd_io_poll(int infd, int peerfd, int outfd) 313 { 314 struct pollfd fds = { 315 .fd = peerfd, 316 .events = POLLIN | POLLOUT, 317 }; 318 unsigned int woff = 0, wlen = 0; 319 char wbuf[8192]; 320 321 set_nonblock(peerfd); 322 323 for (;;) { 324 char rbuf[8192]; 325 ssize_t len; 326 327 if (fds.events == 0) 328 break; 329 330 switch (poll(&fds, 1, poll_timeout)) { 331 case -1: 332 if (errno == EINTR) 333 continue; 334 perror("poll"); 335 return 1; 336 case 0: 337 fprintf(stderr, "%s: poll timed out (events: " 338 "POLLIN %u, POLLOUT %u)\n", __func__, 339 fds.events & POLLIN, fds.events & POLLOUT); 340 return 2; 341 } 342 343 if (fds.revents & POLLIN) { 344 len = do_rnd_read(peerfd, rbuf, sizeof(rbuf)); 345 if (len == 0) { 346 /* no more data to receive: 347 * peer has closed its write side 348 */ 349 fds.events &= ~POLLIN; 350 351 if ((fds.events & POLLOUT) == 0) 352 /* and nothing more to send */ 353 break; 354 355 /* Else, still have data to transmit */ 356 } else if (len < 0) { 357 perror("read"); 358 return 3; 359 } 360 361 do_write(outfd, rbuf, len); 362 } 363 364 if (fds.revents & POLLOUT) { 365 if (wlen == 0) { 366 woff = 0; 367 wlen = read(infd, wbuf, sizeof(wbuf)); 368 } 369 370 if (wlen > 0) { 371 ssize_t bw; 372 373 bw = do_rnd_write(peerfd, wbuf + woff, wlen); 374 if (bw < 0) 375 return 111; 376 377 woff += bw; 378 wlen -= bw; 379 } else if (wlen == 0) { 380 /* We have no more data to send. */ 381 fds.events &= ~POLLOUT; 382 383 if ((fds.events & POLLIN) == 0) 384 /* ... and peer also closed already */ 385 break; 386 387 /* ... but we still receive. 388 * Close our write side. 389 */ 390 shutdown(peerfd, SHUT_WR); 391 } else { 392 if (errno == EINTR) 393 continue; 394 perror("read"); 395 return 4; 396 } 397 } 398 399 if (fds.revents & (POLLERR | POLLNVAL)) { 400 fprintf(stderr, "Unexpected revents: " 401 "POLLERR/POLLNVAL(%x)\n", fds.revents); 402 return 5; 403 } 404 } 405 406 close(peerfd); 407 return 0; 408 } 409 410 static int do_recvfile(int infd, int outfd) 411 { 412 ssize_t r; 413 414 do { 415 char buf[16384]; 416 417 r = do_rnd_read(infd, buf, sizeof(buf)); 418 if (r > 0) { 419 if (write(outfd, buf, r) != r) 420 break; 421 } else if (r < 0) { 422 perror("read"); 423 } 424 } while (r > 0); 425 426 return (int)r; 427 } 428 429 static int do_mmap(int infd, int outfd, unsigned int size) 430 { 431 char *inbuf = mmap(NULL, size, PROT_READ, MAP_SHARED, infd, 0); 432 ssize_t ret = 0, off = 0; 433 size_t rem; 434 435 if (inbuf == MAP_FAILED) { 436 perror("mmap"); 437 return 1; 438 } 439 440 rem = size; 441 442 while (rem > 0) { 443 ret = write(outfd, inbuf + off, rem); 444 445 if (ret < 0) { 446 perror("write"); 447 break; 448 } 449 450 off += ret; 451 rem -= ret; 452 } 453 454 munmap(inbuf, size); 455 return rem; 456 } 457 458 static int get_infd_size(int fd) 459 { 460 struct stat sb; 461 ssize_t count; 462 int err; 463 464 err = fstat(fd, &sb); 465 if (err < 0) { 466 perror("fstat"); 467 return -1; 468 } 469 470 if ((sb.st_mode & S_IFMT) != S_IFREG) { 471 fprintf(stderr, "%s: stdin is not a regular file\n", __func__); 472 return -2; 473 } 474 475 count = sb.st_size; 476 if (count > INT_MAX) { 477 fprintf(stderr, "File too large: %zu\n", count); 478 return -3; 479 } 480 481 return (int)count; 482 } 483 484 static int do_sendfile(int infd, int outfd, unsigned int count) 485 { 486 while (count > 0) { 487 ssize_t r; 488 489 r = sendfile(outfd, infd, NULL, count); 490 if (r < 0) { 491 perror("sendfile"); 492 return 3; 493 } 494 495 count -= r; 496 } 497 498 return 0; 499 } 500 501 static int copyfd_io_mmap(int infd, int peerfd, int outfd, 502 unsigned int size) 503 { 504 int err; 505 506 if (listen_mode) { 507 err = do_recvfile(peerfd, outfd); 508 if (err) 509 return err; 510 511 err = do_mmap(infd, peerfd, size); 512 } else { 513 err = do_mmap(infd, peerfd, size); 514 if (err) 515 return err; 516 517 shutdown(peerfd, SHUT_WR); 518 519 err = do_recvfile(peerfd, outfd); 520 } 521 522 return err; 523 } 524 525 static int copyfd_io_sendfile(int infd, int peerfd, int outfd, 526 unsigned int size) 527 { 528 int err; 529 530 if (listen_mode) { 531 err = do_recvfile(peerfd, outfd); 532 if (err) 533 return err; 534 535 err = do_sendfile(infd, peerfd, size); 536 } else { 537 err = do_sendfile(infd, peerfd, size); 538 if (err) 539 return err; 540 err = do_recvfile(peerfd, outfd); 541 } 542 543 return err; 544 } 545 546 static int copyfd_io(int infd, int peerfd, int outfd) 547 { 548 int file_size; 549 550 switch (cfg_mode) { 551 case CFG_MODE_POLL: 552 return copyfd_io_poll(infd, peerfd, outfd); 553 case CFG_MODE_MMAP: 554 file_size = get_infd_size(infd); 555 if (file_size < 0) 556 return file_size; 557 return copyfd_io_mmap(infd, peerfd, outfd, file_size); 558 case CFG_MODE_SENDFILE: 559 file_size = get_infd_size(infd); 560 if (file_size < 0) 561 return file_size; 562 return copyfd_io_sendfile(infd, peerfd, outfd, file_size); 563 } 564 565 fprintf(stderr, "Invalid mode %d\n", cfg_mode); 566 567 die_usage(); 568 return 1; 569 } 570 571 static void check_sockaddr(int pf, struct sockaddr_storage *ss, 572 socklen_t salen) 573 { 574 struct sockaddr_in6 *sin6; 575 struct sockaddr_in *sin; 576 socklen_t wanted_size = 0; 577 578 switch (pf) { 579 case AF_INET: 580 wanted_size = sizeof(*sin); 581 sin = (void *)ss; 582 if (!sin->sin_port) 583 fprintf(stderr, "accept: something wrong: ip connection from port 0"); 584 break; 585 case AF_INET6: 586 wanted_size = sizeof(*sin6); 587 sin6 = (void *)ss; 588 if (!sin6->sin6_port) 589 fprintf(stderr, "accept: something wrong: ipv6 connection from port 0"); 590 break; 591 default: 592 fprintf(stderr, "accept: Unknown pf %d, salen %u\n", pf, salen); 593 return; 594 } 595 596 if (salen != wanted_size) 597 fprintf(stderr, "accept: size mismatch, got %d expected %d\n", 598 (int)salen, wanted_size); 599 600 if (ss->ss_family != pf) 601 fprintf(stderr, "accept: pf mismatch, expect %d, ss_family is %d\n", 602 (int)ss->ss_family, pf); 603 } 604 605 static void check_getpeername(int fd, struct sockaddr_storage *ss, socklen_t salen) 606 { 607 struct sockaddr_storage peerss; 608 socklen_t peersalen = sizeof(peerss); 609 610 if (getpeername(fd, (struct sockaddr *)&peerss, &peersalen) < 0) { 611 perror("getpeername"); 612 return; 613 } 614 615 if (peersalen != salen) { 616 fprintf(stderr, "%s: %d vs %d\n", __func__, peersalen, salen); 617 return; 618 } 619 620 if (memcmp(ss, &peerss, peersalen)) { 621 char a[INET6_ADDRSTRLEN]; 622 char b[INET6_ADDRSTRLEN]; 623 char c[INET6_ADDRSTRLEN]; 624 char d[INET6_ADDRSTRLEN]; 625 626 xgetnameinfo((struct sockaddr *)ss, salen, 627 a, sizeof(a), b, sizeof(b)); 628 629 xgetnameinfo((struct sockaddr *)&peerss, peersalen, 630 c, sizeof(c), d, sizeof(d)); 631 632 fprintf(stderr, "%s: memcmp failure: accept %s vs peername %s, %s vs %s salen %d vs %d\n", 633 __func__, a, c, b, d, peersalen, salen); 634 } 635 } 636 637 static void check_getpeername_connect(int fd) 638 { 639 struct sockaddr_storage ss; 640 socklen_t salen = sizeof(ss); 641 char a[INET6_ADDRSTRLEN]; 642 char b[INET6_ADDRSTRLEN]; 643 644 if (getpeername(fd, (struct sockaddr *)&ss, &salen) < 0) { 645 perror("getpeername"); 646 return; 647 } 648 649 xgetnameinfo((struct sockaddr *)&ss, salen, 650 a, sizeof(a), b, sizeof(b)); 651 652 if (strcmp(cfg_host, a) || strcmp(cfg_port, b)) 653 fprintf(stderr, "%s: %s vs %s, %s vs %s\n", __func__, 654 cfg_host, a, cfg_port, b); 655 } 656 657 static void maybe_close(int fd) 658 { 659 unsigned int r = rand(); 660 661 if (r & 1) 662 close(fd); 663 } 664 665 int main_loop_s(int listensock) 666 { 667 struct sockaddr_storage ss; 668 struct pollfd polls; 669 socklen_t salen; 670 int remotesock; 671 672 polls.fd = listensock; 673 polls.events = POLLIN; 674 675 switch (poll(&polls, 1, poll_timeout)) { 676 case -1: 677 perror("poll"); 678 return 1; 679 case 0: 680 fprintf(stderr, "%s: timed out\n", __func__); 681 close(listensock); 682 return 2; 683 } 684 685 salen = sizeof(ss); 686 remotesock = accept(listensock, (struct sockaddr *)&ss, &salen); 687 if (remotesock >= 0) { 688 maybe_close(listensock); 689 check_sockaddr(pf, &ss, salen); 690 check_getpeername(remotesock, &ss, salen); 691 692 return copyfd_io(0, remotesock, 1); 693 } 694 695 perror("accept"); 696 697 return 1; 698 } 699 700 static void init_rng(void) 701 { 702 int fd = open("/dev/urandom", O_RDONLY); 703 unsigned int foo; 704 705 if (fd > 0) { 706 int ret = read(fd, &foo, sizeof(foo)); 707 708 if (ret < 0) 709 srand(fd + foo); 710 close(fd); 711 } 712 713 srand(foo); 714 } 715 716 int main_loop(void) 717 { 718 int fd; 719 720 /* listener is ready. */ 721 fd = sock_connect_mptcp(cfg_host, cfg_port, cfg_sock_proto); 722 if (fd < 0) 723 return 2; 724 725 check_getpeername_connect(fd); 726 727 if (cfg_rcvbuf) 728 set_rcvbuf(fd, cfg_rcvbuf); 729 if (cfg_sndbuf) 730 set_sndbuf(fd, cfg_sndbuf); 731 732 return copyfd_io(0, fd, 1); 733 } 734 735 int parse_proto(const char *proto) 736 { 737 if (!strcasecmp(proto, "MPTCP")) 738 return IPPROTO_MPTCP; 739 if (!strcasecmp(proto, "TCP")) 740 return IPPROTO_TCP; 741 742 fprintf(stderr, "Unknown protocol: %s\n.", proto); 743 die_usage(); 744 745 /* silence compiler warning */ 746 return 0; 747 } 748 749 int parse_mode(const char *mode) 750 { 751 if (!strcasecmp(mode, "poll")) 752 return CFG_MODE_POLL; 753 if (!strcasecmp(mode, "mmap")) 754 return CFG_MODE_MMAP; 755 if (!strcasecmp(mode, "sendfile")) 756 return CFG_MODE_SENDFILE; 757 758 fprintf(stderr, "Unknown test mode: %s\n", mode); 759 fprintf(stderr, "Supported modes are:\n"); 760 fprintf(stderr, "\t\t\"poll\" - interleaved read/write using poll()\n"); 761 fprintf(stderr, "\t\t\"mmap\" - send entire input file (mmap+write), then read response (-l will read input first)\n"); 762 fprintf(stderr, "\t\t\"sendfile\" - send entire input file (sendfile), then read response (-l will read input first)\n"); 763 764 die_usage(); 765 766 /* silence compiler warning */ 767 return 0; 768 } 769 770 static int parse_int(const char *size) 771 { 772 unsigned long s; 773 774 errno = 0; 775 776 s = strtoul(size, NULL, 0); 777 778 if (errno) { 779 fprintf(stderr, "Invalid sndbuf size %s (%s)\n", 780 size, strerror(errno)); 781 die_usage(); 782 } 783 784 if (s > INT_MAX) { 785 fprintf(stderr, "Invalid sndbuf size %s (%s)\n", 786 size, strerror(ERANGE)); 787 die_usage(); 788 } 789 790 return (int)s; 791 } 792 793 static void parse_opts(int argc, char **argv) 794 { 795 int c; 796 797 while ((c = getopt(argc, argv, "6lp:s:hut:m:S:R:")) != -1) { 798 switch (c) { 799 case 'l': 800 listen_mode = true; 801 break; 802 case 'p': 803 cfg_port = optarg; 804 break; 805 case 's': 806 cfg_sock_proto = parse_proto(optarg); 807 break; 808 case 'h': 809 die_usage(); 810 break; 811 case 'u': 812 tcpulp_audit = true; 813 break; 814 case '6': 815 pf = AF_INET6; 816 break; 817 case 't': 818 poll_timeout = atoi(optarg) * 1000; 819 if (poll_timeout <= 0) 820 poll_timeout = -1; 821 break; 822 case 'm': 823 cfg_mode = parse_mode(optarg); 824 break; 825 case 'S': 826 cfg_sndbuf = parse_int(optarg); 827 break; 828 case 'R': 829 cfg_rcvbuf = parse_int(optarg); 830 break; 831 } 832 } 833 834 if (optind + 1 != argc) 835 die_usage(); 836 cfg_host = argv[optind]; 837 838 if (strchr(cfg_host, ':')) 839 pf = AF_INET6; 840 } 841 842 int main(int argc, char *argv[]) 843 { 844 init_rng(); 845 846 parse_opts(argc, argv); 847 848 if (tcpulp_audit) 849 return sock_test_tcpulp(cfg_host, cfg_port) ? 0 : 1; 850 851 if (listen_mode) { 852 int fd = sock_listen_mptcp(cfg_host, cfg_port); 853 854 if (fd < 0) 855 return 1; 856 857 if (cfg_rcvbuf) 858 set_rcvbuf(fd, cfg_rcvbuf); 859 if (cfg_sndbuf) 860 set_sndbuf(fd, cfg_sndbuf); 861 862 return main_loop_s(fd); 863 } 864 865 return main_loop(); 866 } 867