1 // SPDX-License-Identifier: GPL-2.0 2 3 #define _GNU_SOURCE 4 5 #include <assert.h> 6 #include <errno.h> 7 #include <fcntl.h> 8 #include <limits.h> 9 #include <string.h> 10 #include <stdarg.h> 11 #include <stdbool.h> 12 #include <stdint.h> 13 #include <inttypes.h> 14 #include <stdio.h> 15 #include <stdlib.h> 16 #include <strings.h> 17 #include <unistd.h> 18 #include <time.h> 19 20 #include <sys/ioctl.h> 21 #include <sys/socket.h> 22 #include <sys/types.h> 23 #include <sys/wait.h> 24 25 #include <netdb.h> 26 #include <netinet/in.h> 27 28 #include <linux/tcp.h> 29 #include <linux/sockios.h> 30 31 #ifndef IPPROTO_MPTCP 32 #define IPPROTO_MPTCP 262 33 #endif 34 #ifndef SOL_MPTCP 35 #define SOL_MPTCP 284 36 #endif 37 38 static int pf = AF_INET; 39 static int proto_tx = IPPROTO_MPTCP; 40 static int proto_rx = IPPROTO_MPTCP; 41 42 static void die_perror(const char *msg) 43 { 44 perror(msg); 45 exit(1); 46 } 47 48 static void die_usage(int r) 49 { 50 fprintf(stderr, "Usage: mptcp_inq [-6] [ -t tcp|mptcp ] [ -r tcp|mptcp]\n"); 51 exit(r); 52 } 53 54 static void xerror(const char *fmt, ...) 55 { 56 va_list ap; 57 58 va_start(ap, fmt); 59 vfprintf(stderr, fmt, ap); 60 va_end(ap); 61 fputc('\n', stderr); 62 exit(1); 63 } 64 65 static const char *getxinfo_strerr(int err) 66 { 67 if (err == EAI_SYSTEM) 68 return strerror(errno); 69 70 return gai_strerror(err); 71 } 72 73 static void xgetaddrinfo(const char *node, const char *service, 74 const struct addrinfo *hints, 75 struct addrinfo **res) 76 { 77 int err = getaddrinfo(node, service, hints, res); 78 79 if (err) { 80 const char *errstr = getxinfo_strerr(err); 81 82 fprintf(stderr, "Fatal: getaddrinfo(%s:%s): %s\n", 83 node ? node : "", service ? service : "", errstr); 84 exit(1); 85 } 86 } 87 88 static int sock_listen_mptcp(const char * const listenaddr, 89 const char * const port) 90 { 91 int sock = -1; 92 struct addrinfo hints = { 93 .ai_protocol = IPPROTO_TCP, 94 .ai_socktype = SOCK_STREAM, 95 .ai_flags = AI_PASSIVE | AI_NUMERICHOST 96 }; 97 98 hints.ai_family = pf; 99 100 struct addrinfo *a, *addr; 101 int one = 1; 102 103 xgetaddrinfo(listenaddr, port, &hints, &addr); 104 hints.ai_family = pf; 105 106 for (a = addr; a; a = a->ai_next) { 107 sock = socket(a->ai_family, a->ai_socktype, proto_rx); 108 if (sock < 0) 109 continue; 110 111 if (-1 == setsockopt(sock, SOL_SOCKET, SO_REUSEADDR, &one, 112 sizeof(one))) 113 perror("setsockopt"); 114 115 if (bind(sock, a->ai_addr, a->ai_addrlen) == 0) 116 break; /* success */ 117 118 perror("bind"); 119 close(sock); 120 sock = -1; 121 } 122 123 freeaddrinfo(addr); 124 125 if (sock < 0) 126 xerror("could not create listen socket"); 127 128 if (listen(sock, 20)) 129 die_perror("listen"); 130 131 return sock; 132 } 133 134 static int sock_connect_mptcp(const char * const remoteaddr, 135 const char * const port, int proto) 136 { 137 struct addrinfo hints = { 138 .ai_protocol = IPPROTO_TCP, 139 .ai_socktype = SOCK_STREAM, 140 }; 141 struct addrinfo *a, *addr; 142 int sock = -1; 143 144 hints.ai_family = pf; 145 146 xgetaddrinfo(remoteaddr, port, &hints, &addr); 147 for (a = addr; a; a = a->ai_next) { 148 sock = socket(a->ai_family, a->ai_socktype, proto); 149 if (sock < 0) 150 continue; 151 152 if (connect(sock, a->ai_addr, a->ai_addrlen) == 0) 153 break; /* success */ 154 155 die_perror("connect"); 156 } 157 158 if (sock < 0) 159 xerror("could not create connect socket"); 160 161 freeaddrinfo(addr); 162 return sock; 163 } 164 165 static int protostr_to_num(const char *s) 166 { 167 if (strcasecmp(s, "tcp") == 0) 168 return IPPROTO_TCP; 169 if (strcasecmp(s, "mptcp") == 0) 170 return IPPROTO_MPTCP; 171 172 die_usage(1); 173 return 0; 174 } 175 176 static void parse_opts(int argc, char **argv) 177 { 178 int c; 179 180 while ((c = getopt(argc, argv, "h6t:r:")) != -1) { 181 switch (c) { 182 case 'h': 183 die_usage(0); 184 break; 185 case '6': 186 pf = AF_INET6; 187 break; 188 case 't': 189 proto_tx = protostr_to_num(optarg); 190 break; 191 case 'r': 192 proto_rx = protostr_to_num(optarg); 193 break; 194 default: 195 die_usage(1); 196 break; 197 } 198 } 199 } 200 201 /* wait up to timeout milliseconds */ 202 static void wait_for_ack(int fd, int timeout, size_t total) 203 { 204 int i; 205 206 for (i = 0; i < timeout; i++) { 207 int nsd, ret, queued = -1; 208 struct timespec req; 209 210 ret = ioctl(fd, TIOCOUTQ, &queued); 211 if (ret < 0) 212 die_perror("TIOCOUTQ"); 213 214 ret = ioctl(fd, SIOCOUTQNSD, &nsd); 215 if (ret < 0) 216 die_perror("SIOCOUTQNSD"); 217 218 if ((size_t)queued > total) 219 xerror("TIOCOUTQ %u, but only %zu expected\n", queued, total); 220 assert(nsd <= queued); 221 222 if (queued == 0) 223 return; 224 225 /* wait for peer to ack rx of all data */ 226 req.tv_sec = 0; 227 req.tv_nsec = 1 * 1000 * 1000ul; /* 1ms */ 228 nanosleep(&req, NULL); 229 } 230 231 xerror("still tx data queued after %u ms\n", timeout); 232 } 233 234 static void connect_one_server(int fd, int unixfd) 235 { 236 size_t len, i, total, sent; 237 char buf[4096], buf2[4096]; 238 ssize_t ret; 239 240 len = rand() % (sizeof(buf) - 1); 241 242 if (len < 128) 243 len = 128; 244 245 for (i = 0; i < len ; i++) { 246 buf[i] = rand() % 26; 247 buf[i] += 'A'; 248 } 249 250 buf[i] = '\n'; 251 252 /* un-block server */ 253 ret = read(unixfd, buf2, 4); 254 assert(ret == 4); 255 256 assert(strncmp(buf2, "xmit", 4) == 0); 257 258 ret = write(unixfd, &len, sizeof(len)); 259 assert(ret == (ssize_t)sizeof(len)); 260 261 ret = write(fd, buf, len); 262 if (ret < 0) 263 die_perror("write"); 264 265 if (ret != (ssize_t)len) 266 xerror("short write"); 267 268 ret = read(unixfd, buf2, 4); 269 assert(strncmp(buf2, "huge", 4) == 0); 270 271 total = rand() % (16 * 1024 * 1024); 272 total += (1 * 1024 * 1024); 273 sent = total; 274 275 ret = write(unixfd, &total, sizeof(total)); 276 assert(ret == (ssize_t)sizeof(total)); 277 278 wait_for_ack(fd, 5000, len); 279 280 while (total > 0) { 281 if (total > sizeof(buf)) 282 len = sizeof(buf); 283 else 284 len = total; 285 286 ret = write(fd, buf, len); 287 if (ret < 0) 288 die_perror("write"); 289 total -= ret; 290 291 /* we don't have to care about buf content, only 292 * number of total bytes sent 293 */ 294 } 295 296 ret = read(unixfd, buf2, 4); 297 assert(ret == 4); 298 assert(strncmp(buf2, "shut", 4) == 0); 299 300 wait_for_ack(fd, 5000, sent); 301 302 ret = write(fd, buf, 1); 303 assert(ret == 1); 304 close(fd); 305 ret = write(unixfd, "closed", 6); 306 assert(ret == 6); 307 308 close(unixfd); 309 } 310 311 static void get_tcp_inq(struct msghdr *msgh, unsigned int *inqv) 312 { 313 struct cmsghdr *cmsg; 314 315 for (cmsg = CMSG_FIRSTHDR(msgh); cmsg ; cmsg = CMSG_NXTHDR(msgh, cmsg)) { 316 if (cmsg->cmsg_level == IPPROTO_TCP && cmsg->cmsg_type == TCP_CM_INQ) { 317 memcpy(inqv, CMSG_DATA(cmsg), sizeof(*inqv)); 318 return; 319 } 320 } 321 322 xerror("could not find TCP_CM_INQ cmsg type"); 323 } 324 325 static void process_one_client(int fd, int unixfd) 326 { 327 unsigned int tcp_inq; 328 size_t expect_len; 329 char msg_buf[4096]; 330 char buf[4096]; 331 char tmp[16]; 332 struct iovec iov = { 333 .iov_base = buf, 334 .iov_len = 1, 335 }; 336 struct msghdr msg = { 337 .msg_iov = &iov, 338 .msg_iovlen = 1, 339 .msg_control = msg_buf, 340 .msg_controllen = sizeof(msg_buf), 341 }; 342 ssize_t ret, tot; 343 344 ret = write(unixfd, "xmit", 4); 345 assert(ret == 4); 346 347 ret = read(unixfd, &expect_len, sizeof(expect_len)); 348 assert(ret == (ssize_t)sizeof(expect_len)); 349 350 if (expect_len > sizeof(buf)) 351 xerror("expect len %zu exceeds buffer size", expect_len); 352 353 for (;;) { 354 struct timespec req; 355 unsigned int queued; 356 357 ret = ioctl(fd, FIONREAD, &queued); 358 if (ret < 0) 359 die_perror("FIONREAD"); 360 if (queued > expect_len) 361 xerror("FIONREAD returned %u, but only %zu expected\n", 362 queued, expect_len); 363 if (queued == expect_len) 364 break; 365 366 req.tv_sec = 0; 367 req.tv_nsec = 1000 * 1000ul; 368 nanosleep(&req, NULL); 369 } 370 371 /* read one byte, expect cmsg to return expected - 1 */ 372 ret = recvmsg(fd, &msg, 0); 373 if (ret < 0) 374 die_perror("recvmsg"); 375 376 if (msg.msg_controllen == 0) 377 xerror("msg_controllen is 0"); 378 379 get_tcp_inq(&msg, &tcp_inq); 380 381 assert((size_t)tcp_inq == (expect_len - 1)); 382 383 iov.iov_len = sizeof(buf); 384 ret = recvmsg(fd, &msg, 0); 385 if (ret < 0) 386 die_perror("recvmsg"); 387 388 /* should have gotten exact remainder of all pending data */ 389 assert(ret == (ssize_t)tcp_inq); 390 391 /* should be 0, all drained */ 392 get_tcp_inq(&msg, &tcp_inq); 393 assert(tcp_inq == 0); 394 395 /* request a large swath of data. */ 396 ret = write(unixfd, "huge", 4); 397 assert(ret == 4); 398 399 ret = read(unixfd, &expect_len, sizeof(expect_len)); 400 assert(ret == (ssize_t)sizeof(expect_len)); 401 402 /* peer should send us a few mb of data */ 403 if (expect_len <= sizeof(buf)) 404 xerror("expect len %zu too small\n", expect_len); 405 406 tot = 0; 407 do { 408 iov.iov_len = sizeof(buf); 409 ret = recvmsg(fd, &msg, 0); 410 if (ret < 0) 411 die_perror("recvmsg"); 412 413 tot += ret; 414 415 get_tcp_inq(&msg, &tcp_inq); 416 417 if (tcp_inq > expect_len - tot) 418 xerror("inq %d, remaining %d total_len %d\n", 419 tcp_inq, expect_len - tot, (int)expect_len); 420 421 assert(tcp_inq <= expect_len - tot); 422 } while ((size_t)tot < expect_len); 423 424 ret = write(unixfd, "shut", 4); 425 assert(ret == 4); 426 427 /* wait for hangup. Should have received one more byte of data. */ 428 ret = read(unixfd, tmp, sizeof(tmp)); 429 assert(ret == 6); 430 assert(strncmp(tmp, "closed", 6) == 0); 431 432 sleep(1); 433 434 iov.iov_len = 1; 435 ret = recvmsg(fd, &msg, 0); 436 if (ret < 0) 437 die_perror("recvmsg"); 438 assert(ret == 1); 439 440 get_tcp_inq(&msg, &tcp_inq); 441 442 /* tcp_inq should be 1 due to received fin. */ 443 assert(tcp_inq == 1); 444 445 iov.iov_len = 1; 446 ret = recvmsg(fd, &msg, 0); 447 if (ret < 0) 448 die_perror("recvmsg"); 449 450 /* expect EOF */ 451 assert(ret == 0); 452 get_tcp_inq(&msg, &tcp_inq); 453 assert(tcp_inq == 1); 454 455 close(fd); 456 } 457 458 static int xaccept(int s) 459 { 460 int fd = accept(s, NULL, 0); 461 462 if (fd < 0) 463 die_perror("accept"); 464 465 return fd; 466 } 467 468 static int server(int unixfd) 469 { 470 int fd = -1, r, on = 1; 471 472 switch (pf) { 473 case AF_INET: 474 fd = sock_listen_mptcp("127.0.0.1", "15432"); 475 break; 476 case AF_INET6: 477 fd = sock_listen_mptcp("::1", "15432"); 478 break; 479 default: 480 xerror("Unknown pf %d\n", pf); 481 break; 482 } 483 484 r = write(unixfd, "conn", 4); 485 assert(r == 4); 486 487 alarm(15); 488 r = xaccept(fd); 489 490 if (-1 == setsockopt(r, IPPROTO_TCP, TCP_INQ, &on, sizeof(on))) 491 die_perror("setsockopt"); 492 493 process_one_client(r, unixfd); 494 495 return 0; 496 } 497 498 static int client(int unixfd) 499 { 500 int fd = -1; 501 502 alarm(15); 503 504 switch (pf) { 505 case AF_INET: 506 fd = sock_connect_mptcp("127.0.0.1", "15432", proto_tx); 507 break; 508 case AF_INET6: 509 fd = sock_connect_mptcp("::1", "15432", proto_tx); 510 break; 511 default: 512 xerror("Unknown pf %d\n", pf); 513 } 514 515 connect_one_server(fd, unixfd); 516 517 return 0; 518 } 519 520 static void init_rng(void) 521 { 522 int fd = open("/dev/urandom", O_RDONLY); 523 unsigned int foo; 524 525 if (fd > 0) { 526 int ret = read(fd, &foo, sizeof(foo)); 527 528 if (ret < 0) 529 srand(fd + foo); 530 close(fd); 531 } 532 533 srand(foo); 534 } 535 536 static pid_t xfork(void) 537 { 538 pid_t p = fork(); 539 540 if (p < 0) 541 die_perror("fork"); 542 else if (p == 0) 543 init_rng(); 544 545 return p; 546 } 547 548 static int rcheck(int wstatus, const char *what) 549 { 550 if (WIFEXITED(wstatus)) { 551 if (WEXITSTATUS(wstatus) == 0) 552 return 0; 553 fprintf(stderr, "%s exited, status=%d\n", what, WEXITSTATUS(wstatus)); 554 return WEXITSTATUS(wstatus); 555 } else if (WIFSIGNALED(wstatus)) { 556 xerror("%s killed by signal %d\n", what, WTERMSIG(wstatus)); 557 } else if (WIFSTOPPED(wstatus)) { 558 xerror("%s stopped by signal %d\n", what, WSTOPSIG(wstatus)); 559 } 560 561 return 111; 562 } 563 564 int main(int argc, char *argv[]) 565 { 566 int e1, e2, wstatus; 567 pid_t s, c, ret; 568 int unixfds[2]; 569 570 parse_opts(argc, argv); 571 572 e1 = socketpair(AF_UNIX, SOCK_DGRAM, 0, unixfds); 573 if (e1 < 0) 574 die_perror("pipe"); 575 576 s = xfork(); 577 if (s == 0) 578 return server(unixfds[1]); 579 580 close(unixfds[1]); 581 582 /* wait until server bound a socket */ 583 e1 = read(unixfds[0], &e1, 4); 584 assert(e1 == 4); 585 586 c = xfork(); 587 if (c == 0) 588 return client(unixfds[0]); 589 590 close(unixfds[0]); 591 592 ret = waitpid(s, &wstatus, 0); 593 if (ret == -1) 594 die_perror("waitpid"); 595 e1 = rcheck(wstatus, "server"); 596 ret = waitpid(c, &wstatus, 0); 597 if (ret == -1) 598 die_perror("waitpid"); 599 e2 = rcheck(wstatus, "client"); 600 601 return e1 ? e1 : e2; 602 } 603