1 // SPDX-License-Identifier: GPL-2.0
2 
3 #define _GNU_SOURCE
4 
5 #include <errno.h>
6 #include <limits.h>
7 #include <fcntl.h>
8 #include <string.h>
9 #include <stdarg.h>
10 #include <stdbool.h>
11 #include <stdint.h>
12 #include <stdio.h>
13 #include <stdlib.h>
14 #include <strings.h>
15 #include <signal.h>
16 #include <unistd.h>
17 
18 #include <sys/poll.h>
19 #include <sys/sendfile.h>
20 #include <sys/stat.h>
21 #include <sys/socket.h>
22 #include <sys/types.h>
23 #include <sys/mman.h>
24 
25 #include <netdb.h>
26 #include <netinet/in.h>
27 
28 #include <linux/tcp.h>
29 #include <linux/time_types.h>
30 
31 extern int optind;
32 
33 #ifndef IPPROTO_MPTCP
34 #define IPPROTO_MPTCP 262
35 #endif
36 #ifndef TCP_ULP
37 #define TCP_ULP 31
38 #endif
39 
40 static int  poll_timeout = 10 * 1000;
41 static bool listen_mode;
42 static bool quit;
43 
44 enum cfg_mode {
45 	CFG_MODE_POLL,
46 	CFG_MODE_MMAP,
47 	CFG_MODE_SENDFILE,
48 };
49 
50 enum cfg_peek {
51 	CFG_NONE_PEEK,
52 	CFG_WITH_PEEK,
53 	CFG_AFTER_PEEK,
54 };
55 
56 static enum cfg_mode cfg_mode = CFG_MODE_POLL;
57 static enum cfg_peek cfg_peek = CFG_NONE_PEEK;
58 static const char *cfg_host;
59 static const char *cfg_port	= "12000";
60 static int cfg_sock_proto	= IPPROTO_MPTCP;
61 static bool tcpulp_audit;
62 static int pf = AF_INET;
63 static int cfg_sndbuf;
64 static int cfg_rcvbuf;
65 static bool cfg_join;
66 static bool cfg_remove;
67 static unsigned int cfg_do_w;
68 static int cfg_wait;
69 static uint32_t cfg_mark;
70 
71 struct cfg_cmsg_types {
72 	unsigned int cmsg_enabled:1;
73 	unsigned int timestampns:1;
74 };
75 
76 static struct cfg_cmsg_types cfg_cmsg_types;
77 
78 static void die_usage(void)
79 {
80 	fprintf(stderr, "Usage: mptcp_connect [-6] [-u] [-s MPTCP|TCP] [-p port] [-m mode]"
81 		"[-l] [-w sec] connect_address\n");
82 	fprintf(stderr, "\t-6 use ipv6\n");
83 	fprintf(stderr, "\t-t num -- set poll timeout to num\n");
84 	fprintf(stderr, "\t-S num -- set SO_SNDBUF to num\n");
85 	fprintf(stderr, "\t-R num -- set SO_RCVBUF to num\n");
86 	fprintf(stderr, "\t-p num -- use port num\n");
87 	fprintf(stderr, "\t-s [MPTCP|TCP] -- use mptcp(default) or tcp sockets\n");
88 	fprintf(stderr, "\t-m [poll|mmap|sendfile] -- use poll(default)/mmap+write/sendfile\n");
89 	fprintf(stderr, "\t-M mark -- set socket packet mark\n");
90 	fprintf(stderr, "\t-u -- check mptcp ulp\n");
91 	fprintf(stderr, "\t-w num -- wait num sec before closing the socket\n");
92 	fprintf(stderr, "\t-c cmsg -- test cmsg type <cmsg>\n");
93 	fprintf(stderr,
94 		"\t-P [saveWithPeek|saveAfterPeek] -- save data with/after MSG_PEEK form tcp socket\n");
95 	exit(1);
96 }
97 
98 static void xerror(const char *fmt, ...)
99 {
100 	va_list ap;
101 
102 	va_start(ap, fmt);
103 	vfprintf(stderr, fmt, ap);
104 	va_end(ap);
105 	exit(1);
106 }
107 
108 static void handle_signal(int nr)
109 {
110 	quit = true;
111 }
112 
113 static const char *getxinfo_strerr(int err)
114 {
115 	if (err == EAI_SYSTEM)
116 		return strerror(errno);
117 
118 	return gai_strerror(err);
119 }
120 
121 static void xgetnameinfo(const struct sockaddr *addr, socklen_t addrlen,
122 			 char *host, socklen_t hostlen,
123 			 char *serv, socklen_t servlen)
124 {
125 	int flags = NI_NUMERICHOST | NI_NUMERICSERV;
126 	int err = getnameinfo(addr, addrlen, host, hostlen, serv, servlen,
127 			      flags);
128 
129 	if (err) {
130 		const char *errstr = getxinfo_strerr(err);
131 
132 		fprintf(stderr, "Fatal: getnameinfo: %s\n", errstr);
133 		exit(1);
134 	}
135 }
136 
137 static void xgetaddrinfo(const char *node, const char *service,
138 			 const struct addrinfo *hints,
139 			 struct addrinfo **res)
140 {
141 	int err = getaddrinfo(node, service, hints, res);
142 
143 	if (err) {
144 		const char *errstr = getxinfo_strerr(err);
145 
146 		fprintf(stderr, "Fatal: getaddrinfo(%s:%s): %s\n",
147 			node ? node : "", service ? service : "", errstr);
148 		exit(1);
149 	}
150 }
151 
152 static void set_rcvbuf(int fd, unsigned int size)
153 {
154 	int err;
155 
156 	err = setsockopt(fd, SOL_SOCKET, SO_RCVBUF, &size, sizeof(size));
157 	if (err) {
158 		perror("set SO_RCVBUF");
159 		exit(1);
160 	}
161 }
162 
163 static void set_sndbuf(int fd, unsigned int size)
164 {
165 	int err;
166 
167 	err = setsockopt(fd, SOL_SOCKET, SO_SNDBUF, &size, sizeof(size));
168 	if (err) {
169 		perror("set SO_SNDBUF");
170 		exit(1);
171 	}
172 }
173 
174 static void set_mark(int fd, uint32_t mark)
175 {
176 	int err;
177 
178 	err = setsockopt(fd, SOL_SOCKET, SO_MARK, &mark, sizeof(mark));
179 	if (err) {
180 		perror("set SO_MARK");
181 		exit(1);
182 	}
183 }
184 
185 static int sock_listen_mptcp(const char * const listenaddr,
186 			     const char * const port)
187 {
188 	int sock;
189 	struct addrinfo hints = {
190 		.ai_protocol = IPPROTO_TCP,
191 		.ai_socktype = SOCK_STREAM,
192 		.ai_flags = AI_PASSIVE | AI_NUMERICHOST
193 	};
194 
195 	hints.ai_family = pf;
196 
197 	struct addrinfo *a, *addr;
198 	int one = 1;
199 
200 	xgetaddrinfo(listenaddr, port, &hints, &addr);
201 	hints.ai_family = pf;
202 
203 	for (a = addr; a; a = a->ai_next) {
204 		sock = socket(a->ai_family, a->ai_socktype, cfg_sock_proto);
205 		if (sock < 0)
206 			continue;
207 
208 		if (-1 == setsockopt(sock, SOL_SOCKET, SO_REUSEADDR, &one,
209 				     sizeof(one)))
210 			perror("setsockopt");
211 
212 		if (bind(sock, a->ai_addr, a->ai_addrlen) == 0)
213 			break; /* success */
214 
215 		perror("bind");
216 		close(sock);
217 		sock = -1;
218 	}
219 
220 	freeaddrinfo(addr);
221 
222 	if (sock < 0) {
223 		fprintf(stderr, "Could not create listen socket\n");
224 		return sock;
225 	}
226 
227 	if (listen(sock, 20)) {
228 		perror("listen");
229 		close(sock);
230 		return -1;
231 	}
232 
233 	return sock;
234 }
235 
236 static bool sock_test_tcpulp(const char * const remoteaddr,
237 			     const char * const port)
238 {
239 	struct addrinfo hints = {
240 		.ai_protocol = IPPROTO_TCP,
241 		.ai_socktype = SOCK_STREAM,
242 	};
243 	struct addrinfo *a, *addr;
244 	int sock = -1, ret = 0;
245 	bool test_pass = false;
246 
247 	hints.ai_family = AF_INET;
248 
249 	xgetaddrinfo(remoteaddr, port, &hints, &addr);
250 	for (a = addr; a; a = a->ai_next) {
251 		sock = socket(a->ai_family, a->ai_socktype, IPPROTO_TCP);
252 		if (sock < 0) {
253 			perror("socket");
254 			continue;
255 		}
256 		ret = setsockopt(sock, IPPROTO_TCP, TCP_ULP, "mptcp",
257 				 sizeof("mptcp"));
258 		if (ret == -1 && errno == EOPNOTSUPP)
259 			test_pass = true;
260 		close(sock);
261 
262 		if (test_pass)
263 			break;
264 		if (!ret)
265 			fprintf(stderr,
266 				"setsockopt(TCP_ULP) returned 0\n");
267 		else
268 			perror("setsockopt(TCP_ULP)");
269 	}
270 	return test_pass;
271 }
272 
273 static int sock_connect_mptcp(const char * const remoteaddr,
274 			      const char * const port, int proto)
275 {
276 	struct addrinfo hints = {
277 		.ai_protocol = IPPROTO_TCP,
278 		.ai_socktype = SOCK_STREAM,
279 	};
280 	struct addrinfo *a, *addr;
281 	int sock = -1;
282 
283 	hints.ai_family = pf;
284 
285 	xgetaddrinfo(remoteaddr, port, &hints, &addr);
286 	for (a = addr; a; a = a->ai_next) {
287 		sock = socket(a->ai_family, a->ai_socktype, proto);
288 		if (sock < 0) {
289 			perror("socket");
290 			continue;
291 		}
292 
293 		if (cfg_mark)
294 			set_mark(sock, cfg_mark);
295 
296 		if (connect(sock, a->ai_addr, a->ai_addrlen) == 0)
297 			break; /* success */
298 
299 		perror("connect()");
300 		close(sock);
301 		sock = -1;
302 	}
303 
304 	freeaddrinfo(addr);
305 	return sock;
306 }
307 
308 static size_t do_rnd_write(const int fd, char *buf, const size_t len)
309 {
310 	static bool first = true;
311 	unsigned int do_w;
312 	ssize_t bw;
313 
314 	do_w = rand() & 0xffff;
315 	if (do_w == 0 || do_w > len)
316 		do_w = len;
317 
318 	if (cfg_join && first && do_w > 100)
319 		do_w = 100;
320 
321 	if (cfg_remove && do_w > cfg_do_w)
322 		do_w = cfg_do_w;
323 
324 	bw = write(fd, buf, do_w);
325 	if (bw < 0)
326 		perror("write");
327 
328 	/* let the join handshake complete, before going on */
329 	if (cfg_join && first) {
330 		usleep(200000);
331 		first = false;
332 	}
333 
334 	if (cfg_remove)
335 		usleep(200000);
336 
337 	return bw;
338 }
339 
340 static size_t do_write(const int fd, char *buf, const size_t len)
341 {
342 	size_t offset = 0;
343 
344 	while (offset < len) {
345 		size_t written;
346 		ssize_t bw;
347 
348 		bw = write(fd, buf + offset, len - offset);
349 		if (bw < 0) {
350 			perror("write");
351 			return 0;
352 		}
353 
354 		written = (size_t)bw;
355 		offset += written;
356 	}
357 
358 	return offset;
359 }
360 
361 static void process_cmsg(struct msghdr *msgh)
362 {
363 	struct __kernel_timespec ts;
364 	bool ts_found = false;
365 	struct cmsghdr *cmsg;
366 
367 	for (cmsg = CMSG_FIRSTHDR(msgh); cmsg ; cmsg = CMSG_NXTHDR(msgh, cmsg)) {
368 		if (cmsg->cmsg_level == SOL_SOCKET && cmsg->cmsg_type == SO_TIMESTAMPNS_NEW) {
369 			memcpy(&ts, CMSG_DATA(cmsg), sizeof(ts));
370 			ts_found = true;
371 			continue;
372 		}
373 	}
374 
375 	if (cfg_cmsg_types.timestampns) {
376 		if (!ts_found)
377 			xerror("TIMESTAMPNS not present\n");
378 	}
379 }
380 
381 static ssize_t do_recvmsg_cmsg(const int fd, char *buf, const size_t len)
382 {
383 	char msg_buf[8192];
384 	struct iovec iov = {
385 		.iov_base = buf,
386 		.iov_len = len,
387 	};
388 	struct msghdr msg = {
389 		.msg_iov = &iov,
390 		.msg_iovlen = 1,
391 		.msg_control = msg_buf,
392 		.msg_controllen = sizeof(msg_buf),
393 	};
394 	int flags = 0;
395 	int ret = recvmsg(fd, &msg, flags);
396 
397 	if (ret <= 0)
398 		return ret;
399 
400 	if (msg.msg_controllen && !cfg_cmsg_types.cmsg_enabled)
401 		xerror("got %lu bytes of cmsg data, expected 0\n",
402 		       (unsigned long)msg.msg_controllen);
403 
404 	if (msg.msg_controllen == 0 && cfg_cmsg_types.cmsg_enabled)
405 		xerror("%s\n", "got no cmsg data");
406 
407 	if (msg.msg_controllen)
408 		process_cmsg(&msg);
409 
410 	return ret;
411 }
412 
413 static ssize_t do_rnd_read(const int fd, char *buf, const size_t len)
414 {
415 	int ret = 0;
416 	char tmp[16384];
417 	size_t cap = rand();
418 
419 	cap &= 0xffff;
420 
421 	if (cap == 0)
422 		cap = 1;
423 	else if (cap > len)
424 		cap = len;
425 
426 	if (cfg_peek == CFG_WITH_PEEK) {
427 		ret = recv(fd, buf, cap, MSG_PEEK);
428 		ret = (ret < 0) ? ret : read(fd, tmp, ret);
429 	} else if (cfg_peek == CFG_AFTER_PEEK) {
430 		ret = recv(fd, buf, cap, MSG_PEEK);
431 		ret = (ret < 0) ? ret : read(fd, buf, cap);
432 	} else if (cfg_cmsg_types.cmsg_enabled) {
433 		ret = do_recvmsg_cmsg(fd, buf, cap);
434 	} else {
435 		ret = read(fd, buf, cap);
436 	}
437 
438 	return ret;
439 }
440 
441 static void set_nonblock(int fd)
442 {
443 	int flags = fcntl(fd, F_GETFL);
444 
445 	if (flags == -1)
446 		return;
447 
448 	fcntl(fd, F_SETFL, flags | O_NONBLOCK);
449 }
450 
451 static int copyfd_io_poll(int infd, int peerfd, int outfd)
452 {
453 	struct pollfd fds = {
454 		.fd = peerfd,
455 		.events = POLLIN | POLLOUT,
456 	};
457 	unsigned int woff = 0, wlen = 0;
458 	char wbuf[8192];
459 
460 	set_nonblock(peerfd);
461 
462 	for (;;) {
463 		char rbuf[8192];
464 		ssize_t len;
465 
466 		if (fds.events == 0)
467 			break;
468 
469 		switch (poll(&fds, 1, poll_timeout)) {
470 		case -1:
471 			if (errno == EINTR)
472 				continue;
473 			perror("poll");
474 			return 1;
475 		case 0:
476 			fprintf(stderr, "%s: poll timed out (events: "
477 				"POLLIN %u, POLLOUT %u)\n", __func__,
478 				fds.events & POLLIN, fds.events & POLLOUT);
479 			return 2;
480 		}
481 
482 		if (fds.revents & POLLIN) {
483 			len = do_rnd_read(peerfd, rbuf, sizeof(rbuf));
484 			if (len == 0) {
485 				/* no more data to receive:
486 				 * peer has closed its write side
487 				 */
488 				fds.events &= ~POLLIN;
489 
490 				if ((fds.events & POLLOUT) == 0)
491 					/* and nothing more to send */
492 					break;
493 
494 			/* Else, still have data to transmit */
495 			} else if (len < 0) {
496 				perror("read");
497 				return 3;
498 			}
499 
500 			do_write(outfd, rbuf, len);
501 		}
502 
503 		if (fds.revents & POLLOUT) {
504 			if (wlen == 0) {
505 				woff = 0;
506 				wlen = read(infd, wbuf, sizeof(wbuf));
507 			}
508 
509 			if (wlen > 0) {
510 				ssize_t bw;
511 
512 				bw = do_rnd_write(peerfd, wbuf + woff, wlen);
513 				if (bw < 0)
514 					return 111;
515 
516 				woff += bw;
517 				wlen -= bw;
518 			} else if (wlen == 0) {
519 				/* We have no more data to send. */
520 				fds.events &= ~POLLOUT;
521 
522 				if ((fds.events & POLLIN) == 0)
523 					/* ... and peer also closed already */
524 					break;
525 
526 				/* ... but we still receive.
527 				 * Close our write side, ev. give some time
528 				 * for address notification and/or checking
529 				 * the current status
530 				 */
531 				if (cfg_wait)
532 					usleep(cfg_wait);
533 				shutdown(peerfd, SHUT_WR);
534 			} else {
535 				if (errno == EINTR)
536 					continue;
537 				perror("read");
538 				return 4;
539 			}
540 		}
541 
542 		if (fds.revents & (POLLERR | POLLNVAL)) {
543 			fprintf(stderr, "Unexpected revents: "
544 				"POLLERR/POLLNVAL(%x)\n", fds.revents);
545 			return 5;
546 		}
547 	}
548 
549 	/* leave some time for late join/announce */
550 	if (cfg_join || cfg_remove)
551 		usleep(cfg_wait);
552 
553 	close(peerfd);
554 	return 0;
555 }
556 
557 static int do_recvfile(int infd, int outfd)
558 {
559 	ssize_t r;
560 
561 	do {
562 		char buf[16384];
563 
564 		r = do_rnd_read(infd, buf, sizeof(buf));
565 		if (r > 0) {
566 			if (write(outfd, buf, r) != r)
567 				break;
568 		} else if (r < 0) {
569 			perror("read");
570 		}
571 	} while (r > 0);
572 
573 	return (int)r;
574 }
575 
576 static int do_mmap(int infd, int outfd, unsigned int size)
577 {
578 	char *inbuf = mmap(NULL, size, PROT_READ, MAP_SHARED, infd, 0);
579 	ssize_t ret = 0, off = 0;
580 	size_t rem;
581 
582 	if (inbuf == MAP_FAILED) {
583 		perror("mmap");
584 		return 1;
585 	}
586 
587 	rem = size;
588 
589 	while (rem > 0) {
590 		ret = write(outfd, inbuf + off, rem);
591 
592 		if (ret < 0) {
593 			perror("write");
594 			break;
595 		}
596 
597 		off += ret;
598 		rem -= ret;
599 	}
600 
601 	munmap(inbuf, size);
602 	return rem;
603 }
604 
605 static int get_infd_size(int fd)
606 {
607 	struct stat sb;
608 	ssize_t count;
609 	int err;
610 
611 	err = fstat(fd, &sb);
612 	if (err < 0) {
613 		perror("fstat");
614 		return -1;
615 	}
616 
617 	if ((sb.st_mode & S_IFMT) != S_IFREG) {
618 		fprintf(stderr, "%s: stdin is not a regular file\n", __func__);
619 		return -2;
620 	}
621 
622 	count = sb.st_size;
623 	if (count > INT_MAX) {
624 		fprintf(stderr, "File too large: %zu\n", count);
625 		return -3;
626 	}
627 
628 	return (int)count;
629 }
630 
631 static int do_sendfile(int infd, int outfd, unsigned int count)
632 {
633 	while (count > 0) {
634 		ssize_t r;
635 
636 		r = sendfile(outfd, infd, NULL, count);
637 		if (r < 0) {
638 			perror("sendfile");
639 			return 3;
640 		}
641 
642 		count -= r;
643 	}
644 
645 	return 0;
646 }
647 
648 static int copyfd_io_mmap(int infd, int peerfd, int outfd,
649 			  unsigned int size)
650 {
651 	int err;
652 
653 	if (listen_mode) {
654 		err = do_recvfile(peerfd, outfd);
655 		if (err)
656 			return err;
657 
658 		err = do_mmap(infd, peerfd, size);
659 	} else {
660 		err = do_mmap(infd, peerfd, size);
661 		if (err)
662 			return err;
663 
664 		shutdown(peerfd, SHUT_WR);
665 
666 		err = do_recvfile(peerfd, outfd);
667 	}
668 
669 	return err;
670 }
671 
672 static int copyfd_io_sendfile(int infd, int peerfd, int outfd,
673 			      unsigned int size)
674 {
675 	int err;
676 
677 	if (listen_mode) {
678 		err = do_recvfile(peerfd, outfd);
679 		if (err)
680 			return err;
681 
682 		err = do_sendfile(infd, peerfd, size);
683 	} else {
684 		err = do_sendfile(infd, peerfd, size);
685 		if (err)
686 			return err;
687 		err = do_recvfile(peerfd, outfd);
688 	}
689 
690 	return err;
691 }
692 
693 static int copyfd_io(int infd, int peerfd, int outfd)
694 {
695 	int file_size;
696 
697 	switch (cfg_mode) {
698 	case CFG_MODE_POLL:
699 		return copyfd_io_poll(infd, peerfd, outfd);
700 	case CFG_MODE_MMAP:
701 		file_size = get_infd_size(infd);
702 		if (file_size < 0)
703 			return file_size;
704 		return copyfd_io_mmap(infd, peerfd, outfd, file_size);
705 	case CFG_MODE_SENDFILE:
706 		file_size = get_infd_size(infd);
707 		if (file_size < 0)
708 			return file_size;
709 		return copyfd_io_sendfile(infd, peerfd, outfd, file_size);
710 	}
711 
712 	fprintf(stderr, "Invalid mode %d\n", cfg_mode);
713 
714 	die_usage();
715 	return 1;
716 }
717 
718 static void check_sockaddr(int pf, struct sockaddr_storage *ss,
719 			   socklen_t salen)
720 {
721 	struct sockaddr_in6 *sin6;
722 	struct sockaddr_in *sin;
723 	socklen_t wanted_size = 0;
724 
725 	switch (pf) {
726 	case AF_INET:
727 		wanted_size = sizeof(*sin);
728 		sin = (void *)ss;
729 		if (!sin->sin_port)
730 			fprintf(stderr, "accept: something wrong: ip connection from port 0");
731 		break;
732 	case AF_INET6:
733 		wanted_size = sizeof(*sin6);
734 		sin6 = (void *)ss;
735 		if (!sin6->sin6_port)
736 			fprintf(stderr, "accept: something wrong: ipv6 connection from port 0");
737 		break;
738 	default:
739 		fprintf(stderr, "accept: Unknown pf %d, salen %u\n", pf, salen);
740 		return;
741 	}
742 
743 	if (salen != wanted_size)
744 		fprintf(stderr, "accept: size mismatch, got %d expected %d\n",
745 			(int)salen, wanted_size);
746 
747 	if (ss->ss_family != pf)
748 		fprintf(stderr, "accept: pf mismatch, expect %d, ss_family is %d\n",
749 			(int)ss->ss_family, pf);
750 }
751 
752 static void check_getpeername(int fd, struct sockaddr_storage *ss, socklen_t salen)
753 {
754 	struct sockaddr_storage peerss;
755 	socklen_t peersalen = sizeof(peerss);
756 
757 	if (getpeername(fd, (struct sockaddr *)&peerss, &peersalen) < 0) {
758 		perror("getpeername");
759 		return;
760 	}
761 
762 	if (peersalen != salen) {
763 		fprintf(stderr, "%s: %d vs %d\n", __func__, peersalen, salen);
764 		return;
765 	}
766 
767 	if (memcmp(ss, &peerss, peersalen)) {
768 		char a[INET6_ADDRSTRLEN];
769 		char b[INET6_ADDRSTRLEN];
770 		char c[INET6_ADDRSTRLEN];
771 		char d[INET6_ADDRSTRLEN];
772 
773 		xgetnameinfo((struct sockaddr *)ss, salen,
774 			     a, sizeof(a), b, sizeof(b));
775 
776 		xgetnameinfo((struct sockaddr *)&peerss, peersalen,
777 			     c, sizeof(c), d, sizeof(d));
778 
779 		fprintf(stderr, "%s: memcmp failure: accept %s vs peername %s, %s vs %s salen %d vs %d\n",
780 			__func__, a, c, b, d, peersalen, salen);
781 	}
782 }
783 
784 static void check_getpeername_connect(int fd)
785 {
786 	struct sockaddr_storage ss;
787 	socklen_t salen = sizeof(ss);
788 	char a[INET6_ADDRSTRLEN];
789 	char b[INET6_ADDRSTRLEN];
790 
791 	if (getpeername(fd, (struct sockaddr *)&ss, &salen) < 0) {
792 		perror("getpeername");
793 		return;
794 	}
795 
796 	xgetnameinfo((struct sockaddr *)&ss, salen,
797 		     a, sizeof(a), b, sizeof(b));
798 
799 	if (strcmp(cfg_host, a) || strcmp(cfg_port, b))
800 		fprintf(stderr, "%s: %s vs %s, %s vs %s\n", __func__,
801 			cfg_host, a, cfg_port, b);
802 }
803 
804 static void maybe_close(int fd)
805 {
806 	unsigned int r = rand();
807 
808 	if (!(cfg_join || cfg_remove) && (r & 1))
809 		close(fd);
810 }
811 
812 int main_loop_s(int listensock)
813 {
814 	struct sockaddr_storage ss;
815 	struct pollfd polls;
816 	socklen_t salen;
817 	int remotesock;
818 
819 	polls.fd = listensock;
820 	polls.events = POLLIN;
821 
822 	switch (poll(&polls, 1, poll_timeout)) {
823 	case -1:
824 		perror("poll");
825 		return 1;
826 	case 0:
827 		fprintf(stderr, "%s: timed out\n", __func__);
828 		close(listensock);
829 		return 2;
830 	}
831 
832 	salen = sizeof(ss);
833 	remotesock = accept(listensock, (struct sockaddr *)&ss, &salen);
834 	if (remotesock >= 0) {
835 		maybe_close(listensock);
836 		check_sockaddr(pf, &ss, salen);
837 		check_getpeername(remotesock, &ss, salen);
838 
839 		return copyfd_io(0, remotesock, 1);
840 	}
841 
842 	perror("accept");
843 
844 	return 1;
845 }
846 
847 static void init_rng(void)
848 {
849 	int fd = open("/dev/urandom", O_RDONLY);
850 	unsigned int foo;
851 
852 	if (fd > 0) {
853 		int ret = read(fd, &foo, sizeof(foo));
854 
855 		if (ret < 0)
856 			srand(fd + foo);
857 		close(fd);
858 	}
859 
860 	srand(foo);
861 }
862 
863 static void xsetsockopt(int fd, int level, int optname, const void *optval, socklen_t optlen)
864 {
865 	int err;
866 
867 	err = setsockopt(fd, level, optname, optval, optlen);
868 	if (err) {
869 		perror("setsockopt");
870 		exit(1);
871 	}
872 }
873 
874 static void apply_cmsg_types(int fd, const struct cfg_cmsg_types *cmsg)
875 {
876 	static const unsigned int on = 1;
877 
878 	if (cmsg->timestampns)
879 		xsetsockopt(fd, SOL_SOCKET, SO_TIMESTAMPNS_NEW, &on, sizeof(on));
880 }
881 
882 static void parse_cmsg_types(const char *type)
883 {
884 	char *next = strchr(type, ',');
885 	unsigned int len = 0;
886 
887 	cfg_cmsg_types.cmsg_enabled = 1;
888 
889 	if (next) {
890 		parse_cmsg_types(next + 1);
891 		len = next - type;
892 	} else {
893 		len = strlen(type);
894 	}
895 
896 	if (strncmp(type, "TIMESTAMPNS", len) == 0) {
897 		cfg_cmsg_types.timestampns = 1;
898 		return;
899 	}
900 
901 	fprintf(stderr, "Unrecognized cmsg option %s\n", type);
902 	exit(1);
903 }
904 
905 int main_loop(void)
906 {
907 	int fd;
908 
909 	/* listener is ready. */
910 	fd = sock_connect_mptcp(cfg_host, cfg_port, cfg_sock_proto);
911 	if (fd < 0)
912 		return 2;
913 
914 	check_getpeername_connect(fd);
915 
916 	if (cfg_rcvbuf)
917 		set_rcvbuf(fd, cfg_rcvbuf);
918 	if (cfg_sndbuf)
919 		set_sndbuf(fd, cfg_sndbuf);
920 	if (cfg_cmsg_types.cmsg_enabled)
921 		apply_cmsg_types(fd, &cfg_cmsg_types);
922 
923 	return copyfd_io(0, fd, 1);
924 }
925 
926 int parse_proto(const char *proto)
927 {
928 	if (!strcasecmp(proto, "MPTCP"))
929 		return IPPROTO_MPTCP;
930 	if (!strcasecmp(proto, "TCP"))
931 		return IPPROTO_TCP;
932 
933 	fprintf(stderr, "Unknown protocol: %s\n.", proto);
934 	die_usage();
935 
936 	/* silence compiler warning */
937 	return 0;
938 }
939 
940 int parse_mode(const char *mode)
941 {
942 	if (!strcasecmp(mode, "poll"))
943 		return CFG_MODE_POLL;
944 	if (!strcasecmp(mode, "mmap"))
945 		return CFG_MODE_MMAP;
946 	if (!strcasecmp(mode, "sendfile"))
947 		return CFG_MODE_SENDFILE;
948 
949 	fprintf(stderr, "Unknown test mode: %s\n", mode);
950 	fprintf(stderr, "Supported modes are:\n");
951 	fprintf(stderr, "\t\t\"poll\" - interleaved read/write using poll()\n");
952 	fprintf(stderr, "\t\t\"mmap\" - send entire input file (mmap+write), then read response (-l will read input first)\n");
953 	fprintf(stderr, "\t\t\"sendfile\" - send entire input file (sendfile), then read response (-l will read input first)\n");
954 
955 	die_usage();
956 
957 	/* silence compiler warning */
958 	return 0;
959 }
960 
961 int parse_peek(const char *mode)
962 {
963 	if (!strcasecmp(mode, "saveWithPeek"))
964 		return CFG_WITH_PEEK;
965 	if (!strcasecmp(mode, "saveAfterPeek"))
966 		return CFG_AFTER_PEEK;
967 
968 	fprintf(stderr, "Unknown: %s\n", mode);
969 	fprintf(stderr, "Supported MSG_PEEK mode are:\n");
970 	fprintf(stderr,
971 		"\t\t\"saveWithPeek\" - recv data with flags 'MSG_PEEK' and save the peek data into file\n");
972 	fprintf(stderr,
973 		"\t\t\"saveAfterPeek\" - read and save data into file after recv with flags 'MSG_PEEK'\n");
974 
975 	die_usage();
976 
977 	/* silence compiler warning */
978 	return 0;
979 }
980 
981 static int parse_int(const char *size)
982 {
983 	unsigned long s;
984 
985 	errno = 0;
986 
987 	s = strtoul(size, NULL, 0);
988 
989 	if (errno) {
990 		fprintf(stderr, "Invalid sndbuf size %s (%s)\n",
991 			size, strerror(errno));
992 		die_usage();
993 	}
994 
995 	if (s > INT_MAX) {
996 		fprintf(stderr, "Invalid sndbuf size %s (%s)\n",
997 			size, strerror(ERANGE));
998 		die_usage();
999 	}
1000 
1001 	return (int)s;
1002 }
1003 
1004 static void parse_opts(int argc, char **argv)
1005 {
1006 	int c;
1007 
1008 	while ((c = getopt(argc, argv, "6jr:lp:s:hut:m:S:R:w:M:P:c:")) != -1) {
1009 		switch (c) {
1010 		case 'j':
1011 			cfg_join = true;
1012 			cfg_mode = CFG_MODE_POLL;
1013 			cfg_wait = 400000;
1014 			break;
1015 		case 'r':
1016 			cfg_remove = true;
1017 			cfg_mode = CFG_MODE_POLL;
1018 			cfg_wait = 400000;
1019 			cfg_do_w = atoi(optarg);
1020 			if (cfg_do_w <= 0)
1021 				cfg_do_w = 50;
1022 			break;
1023 		case 'l':
1024 			listen_mode = true;
1025 			break;
1026 		case 'p':
1027 			cfg_port = optarg;
1028 			break;
1029 		case 's':
1030 			cfg_sock_proto = parse_proto(optarg);
1031 			break;
1032 		case 'h':
1033 			die_usage();
1034 			break;
1035 		case 'u':
1036 			tcpulp_audit = true;
1037 			break;
1038 		case '6':
1039 			pf = AF_INET6;
1040 			break;
1041 		case 't':
1042 			poll_timeout = atoi(optarg) * 1000;
1043 			if (poll_timeout <= 0)
1044 				poll_timeout = -1;
1045 			break;
1046 		case 'm':
1047 			cfg_mode = parse_mode(optarg);
1048 			break;
1049 		case 'S':
1050 			cfg_sndbuf = parse_int(optarg);
1051 			break;
1052 		case 'R':
1053 			cfg_rcvbuf = parse_int(optarg);
1054 			break;
1055 		case 'w':
1056 			cfg_wait = atoi(optarg)*1000000;
1057 			break;
1058 		case 'M':
1059 			cfg_mark = strtol(optarg, NULL, 0);
1060 			break;
1061 		case 'P':
1062 			cfg_peek = parse_peek(optarg);
1063 			break;
1064 		case 'c':
1065 			parse_cmsg_types(optarg);
1066 			break;
1067 		}
1068 	}
1069 
1070 	if (optind + 1 != argc)
1071 		die_usage();
1072 	cfg_host = argv[optind];
1073 
1074 	if (strchr(cfg_host, ':'))
1075 		pf = AF_INET6;
1076 }
1077 
1078 int main(int argc, char *argv[])
1079 {
1080 	init_rng();
1081 
1082 	signal(SIGUSR1, handle_signal);
1083 	parse_opts(argc, argv);
1084 
1085 	if (tcpulp_audit)
1086 		return sock_test_tcpulp(cfg_host, cfg_port) ? 0 : 1;
1087 
1088 	if (listen_mode) {
1089 		int fd = sock_listen_mptcp(cfg_host, cfg_port);
1090 
1091 		if (fd < 0)
1092 			return 1;
1093 
1094 		if (cfg_rcvbuf)
1095 			set_rcvbuf(fd, cfg_rcvbuf);
1096 		if (cfg_sndbuf)
1097 			set_sndbuf(fd, cfg_sndbuf);
1098 		if (cfg_mark)
1099 			set_mark(fd, cfg_mark);
1100 		if (cfg_cmsg_types.cmsg_enabled)
1101 			apply_cmsg_types(fd, &cfg_cmsg_types);
1102 
1103 		return main_loop_s(fd);
1104 	}
1105 
1106 	return main_loop();
1107 }
1108