1 // SPDX-License-Identifier: GPL-2.0
2 
3 #define _GNU_SOURCE
4 
5 #include <assert.h>
6 #include <errno.h>
7 #include <fcntl.h>
8 #include <limits.h>
9 #include <string.h>
10 #include <stdarg.h>
11 #include <stdbool.h>
12 #include <stdint.h>
13 #include <inttypes.h>
14 #include <stdio.h>
15 #include <stdlib.h>
16 #include <strings.h>
17 #include <unistd.h>
18 #include <time.h>
19 
20 #include <sys/ioctl.h>
21 #include <sys/socket.h>
22 #include <sys/types.h>
23 #include <sys/wait.h>
24 
25 #include <netdb.h>
26 #include <netinet/in.h>
27 
28 #include <linux/tcp.h>
29 #include <linux/sockios.h>
30 
31 #ifndef IPPROTO_MPTCP
32 #define IPPROTO_MPTCP 262
33 #endif
34 #ifndef SOL_MPTCP
35 #define SOL_MPTCP 284
36 #endif
37 
38 static int pf = AF_INET;
39 static int proto_tx = IPPROTO_MPTCP;
40 static int proto_rx = IPPROTO_MPTCP;
41 
42 static void die_perror(const char *msg)
43 {
44 	perror(msg);
45 	exit(1);
46 }
47 
48 static void die_usage(int r)
49 {
50 	fprintf(stderr, "Usage: mptcp_inq [-6] [ -t tcp|mptcp ] [ -r tcp|mptcp]\n");
51 	exit(r);
52 }
53 
54 static void xerror(const char *fmt, ...)
55 {
56 	va_list ap;
57 
58 	va_start(ap, fmt);
59 	vfprintf(stderr, fmt, ap);
60 	va_end(ap);
61 	fputc('\n', stderr);
62 	exit(1);
63 }
64 
65 static const char *getxinfo_strerr(int err)
66 {
67 	if (err == EAI_SYSTEM)
68 		return strerror(errno);
69 
70 	return gai_strerror(err);
71 }
72 
73 static void xgetaddrinfo(const char *node, const char *service,
74 			 const struct addrinfo *hints,
75 			 struct addrinfo **res)
76 {
77 	int err = getaddrinfo(node, service, hints, res);
78 
79 	if (err) {
80 		const char *errstr = getxinfo_strerr(err);
81 
82 		fprintf(stderr, "Fatal: getaddrinfo(%s:%s): %s\n",
83 			node ? node : "", service ? service : "", errstr);
84 		exit(1);
85 	}
86 }
87 
88 static int sock_listen_mptcp(const char * const listenaddr,
89 			     const char * const port)
90 {
91 	int sock;
92 	struct addrinfo hints = {
93 		.ai_protocol = IPPROTO_TCP,
94 		.ai_socktype = SOCK_STREAM,
95 		.ai_flags = AI_PASSIVE | AI_NUMERICHOST
96 	};
97 
98 	hints.ai_family = pf;
99 
100 	struct addrinfo *a, *addr;
101 	int one = 1;
102 
103 	xgetaddrinfo(listenaddr, port, &hints, &addr);
104 	hints.ai_family = pf;
105 
106 	for (a = addr; a; a = a->ai_next) {
107 		sock = socket(a->ai_family, a->ai_socktype, proto_rx);
108 		if (sock < 0)
109 			continue;
110 
111 		if (-1 == setsockopt(sock, SOL_SOCKET, SO_REUSEADDR, &one,
112 				     sizeof(one)))
113 			perror("setsockopt");
114 
115 		if (bind(sock, a->ai_addr, a->ai_addrlen) == 0)
116 			break; /* success */
117 
118 		perror("bind");
119 		close(sock);
120 		sock = -1;
121 	}
122 
123 	freeaddrinfo(addr);
124 
125 	if (sock < 0)
126 		xerror("could not create listen socket");
127 
128 	if (listen(sock, 20))
129 		die_perror("listen");
130 
131 	return sock;
132 }
133 
134 static int sock_connect_mptcp(const char * const remoteaddr,
135 			      const char * const port, int proto)
136 {
137 	struct addrinfo hints = {
138 		.ai_protocol = IPPROTO_TCP,
139 		.ai_socktype = SOCK_STREAM,
140 	};
141 	struct addrinfo *a, *addr;
142 	int sock = -1;
143 
144 	hints.ai_family = pf;
145 
146 	xgetaddrinfo(remoteaddr, port, &hints, &addr);
147 	for (a = addr; a; a = a->ai_next) {
148 		sock = socket(a->ai_family, a->ai_socktype, proto);
149 		if (sock < 0)
150 			continue;
151 
152 		if (connect(sock, a->ai_addr, a->ai_addrlen) == 0)
153 			break; /* success */
154 
155 		die_perror("connect");
156 	}
157 
158 	if (sock < 0)
159 		xerror("could not create connect socket");
160 
161 	freeaddrinfo(addr);
162 	return sock;
163 }
164 
165 static int protostr_to_num(const char *s)
166 {
167 	if (strcasecmp(s, "tcp") == 0)
168 		return IPPROTO_TCP;
169 	if (strcasecmp(s, "mptcp") == 0)
170 		return IPPROTO_MPTCP;
171 
172 	die_usage(1);
173 	return 0;
174 }
175 
176 static void parse_opts(int argc, char **argv)
177 {
178 	int c;
179 
180 	while ((c = getopt(argc, argv, "h6t:r:")) != -1) {
181 		switch (c) {
182 		case 'h':
183 			die_usage(0);
184 			break;
185 		case '6':
186 			pf = AF_INET6;
187 			break;
188 		case 't':
189 			proto_tx = protostr_to_num(optarg);
190 			break;
191 		case 'r':
192 			proto_rx = protostr_to_num(optarg);
193 			break;
194 		default:
195 			die_usage(1);
196 			break;
197 		}
198 	}
199 }
200 
201 /* wait up to timeout milliseconds */
202 static void wait_for_ack(int fd, int timeout, size_t total)
203 {
204 	int i;
205 
206 	for (i = 0; i < timeout; i++) {
207 		int nsd, ret, queued = -1;
208 		struct timespec req;
209 
210 		ret = ioctl(fd, TIOCOUTQ, &queued);
211 		if (ret < 0)
212 			die_perror("TIOCOUTQ");
213 
214 		ret = ioctl(fd, SIOCOUTQNSD, &nsd);
215 		if (ret < 0)
216 			die_perror("SIOCOUTQNSD");
217 
218 		if ((size_t)queued > total)
219 			xerror("TIOCOUTQ %u, but only %zu expected\n", queued, total);
220 		assert(nsd <= queued);
221 
222 		if (queued == 0)
223 			return;
224 
225 		/* wait for peer to ack rx of all data */
226 		req.tv_sec = 0;
227 		req.tv_nsec = 1 * 1000 * 1000ul; /* 1ms */
228 		nanosleep(&req, NULL);
229 	}
230 
231 	xerror("still tx data queued after %u ms\n", timeout);
232 }
233 
234 static void connect_one_server(int fd, int unixfd)
235 {
236 	size_t len, i, total, sent;
237 	char buf[4096], buf2[4096];
238 	ssize_t ret;
239 
240 	len = rand() % (sizeof(buf) - 1);
241 
242 	if (len < 128)
243 		len = 128;
244 
245 	for (i = 0; i < len ; i++) {
246 		buf[i] = rand() % 26;
247 		buf[i] += 'A';
248 	}
249 
250 	buf[i] = '\n';
251 
252 	/* un-block server */
253 	ret = read(unixfd, buf2, 4);
254 	assert(ret == 4);
255 
256 	assert(strncmp(buf2, "xmit", 4) == 0);
257 
258 	ret = write(unixfd, &len, sizeof(len));
259 	assert(ret == (ssize_t)sizeof(len));
260 
261 	ret = write(fd, buf, len);
262 	if (ret < 0)
263 		die_perror("write");
264 
265 	if (ret != (ssize_t)len)
266 		xerror("short write");
267 
268 	ret = read(unixfd, buf2, 4);
269 	assert(strncmp(buf2, "huge", 4) == 0);
270 
271 	total = rand() % (16 * 1024 * 1024);
272 	total += (1 * 1024 * 1024);
273 	sent = total;
274 
275 	ret = write(unixfd, &total, sizeof(total));
276 	assert(ret == (ssize_t)sizeof(total));
277 
278 	wait_for_ack(fd, 5000, len);
279 
280 	while (total > 0) {
281 		if (total > sizeof(buf))
282 			len = sizeof(buf);
283 		else
284 			len = total;
285 
286 		ret = write(fd, buf, len);
287 		if (ret < 0)
288 			die_perror("write");
289 		total -= ret;
290 
291 		/* we don't have to care about buf content, only
292 		 * number of total bytes sent
293 		 */
294 	}
295 
296 	ret = read(unixfd, buf2, 4);
297 	assert(ret == 4);
298 	assert(strncmp(buf2, "shut", 4) == 0);
299 
300 	wait_for_ack(fd, 5000, sent);
301 
302 	ret = write(fd, buf, 1);
303 	assert(ret == 1);
304 	close(fd);
305 	ret = write(unixfd, "closed", 6);
306 	assert(ret == 6);
307 
308 	close(unixfd);
309 }
310 
311 static void get_tcp_inq(struct msghdr *msgh, unsigned int *inqv)
312 {
313 	struct cmsghdr *cmsg;
314 
315 	for (cmsg = CMSG_FIRSTHDR(msgh); cmsg ; cmsg = CMSG_NXTHDR(msgh, cmsg)) {
316 		if (cmsg->cmsg_level == IPPROTO_TCP && cmsg->cmsg_type == TCP_CM_INQ) {
317 			memcpy(inqv, CMSG_DATA(cmsg), sizeof(*inqv));
318 			return;
319 		}
320 	}
321 
322 	xerror("could not find TCP_CM_INQ cmsg type");
323 }
324 
325 static void process_one_client(int fd, int unixfd)
326 {
327 	unsigned int tcp_inq;
328 	size_t expect_len;
329 	char msg_buf[4096];
330 	char buf[4096];
331 	char tmp[16];
332 	struct iovec iov = {
333 		.iov_base = buf,
334 		.iov_len = 1,
335 	};
336 	struct msghdr msg = {
337 		.msg_iov = &iov,
338 		.msg_iovlen = 1,
339 		.msg_control = msg_buf,
340 		.msg_controllen = sizeof(msg_buf),
341 	};
342 	ssize_t ret, tot;
343 
344 	ret = write(unixfd, "xmit", 4);
345 	assert(ret == 4);
346 
347 	ret = read(unixfd, &expect_len, sizeof(expect_len));
348 	assert(ret == (ssize_t)sizeof(expect_len));
349 
350 	if (expect_len > sizeof(buf))
351 		xerror("expect len %zu exceeds buffer size", expect_len);
352 
353 	for (;;) {
354 		struct timespec req;
355 		unsigned int queued;
356 
357 		ret = ioctl(fd, FIONREAD, &queued);
358 		if (ret < 0)
359 			die_perror("FIONREAD");
360 		if (queued > expect_len)
361 			xerror("FIONREAD returned %u, but only %zu expected\n",
362 			       queued, expect_len);
363 		if (queued == expect_len)
364 			break;
365 
366 		req.tv_sec = 0;
367 		req.tv_nsec = 1000 * 1000ul;
368 		nanosleep(&req, NULL);
369 	}
370 
371 	/* read one byte, expect cmsg to return expected - 1 */
372 	ret = recvmsg(fd, &msg, 0);
373 	if (ret < 0)
374 		die_perror("recvmsg");
375 
376 	if (msg.msg_controllen == 0)
377 		xerror("msg_controllen is 0");
378 
379 	get_tcp_inq(&msg, &tcp_inq);
380 
381 	assert((size_t)tcp_inq == (expect_len - 1));
382 
383 	iov.iov_len = sizeof(buf);
384 	ret = recvmsg(fd, &msg, 0);
385 	if (ret < 0)
386 		die_perror("recvmsg");
387 
388 	/* should have gotten exact remainder of all pending data */
389 	assert(ret == (ssize_t)tcp_inq);
390 
391 	/* should be 0, all drained */
392 	get_tcp_inq(&msg, &tcp_inq);
393 	assert(tcp_inq == 0);
394 
395 	/* request a large swath of data. */
396 	ret = write(unixfd, "huge", 4);
397 	assert(ret == 4);
398 
399 	ret = read(unixfd, &expect_len, sizeof(expect_len));
400 	assert(ret == (ssize_t)sizeof(expect_len));
401 
402 	/* peer should send us a few mb of data */
403 	if (expect_len <= sizeof(buf))
404 		xerror("expect len %zu too small\n", expect_len);
405 
406 	tot = 0;
407 	do {
408 		iov.iov_len = sizeof(buf);
409 		ret = recvmsg(fd, &msg, 0);
410 		if (ret < 0)
411 			die_perror("recvmsg");
412 
413 		tot += ret;
414 
415 		get_tcp_inq(&msg, &tcp_inq);
416 
417 		if (tcp_inq > expect_len - tot)
418 			xerror("inq %d, remaining %d total_len %d\n",
419 			       tcp_inq, expect_len - tot, (int)expect_len);
420 
421 		assert(tcp_inq <= expect_len - tot);
422 	} while ((size_t)tot < expect_len);
423 
424 	ret = write(unixfd, "shut", 4);
425 	assert(ret == 4);
426 
427 	/* wait for hangup. Should have received one more byte of data. */
428 	ret = read(unixfd, tmp, sizeof(tmp));
429 	assert(ret == 6);
430 	assert(strncmp(tmp, "closed", 6) == 0);
431 
432 	sleep(1);
433 
434 	iov.iov_len = 1;
435 	ret = recvmsg(fd, &msg, 0);
436 	if (ret < 0)
437 		die_perror("recvmsg");
438 	assert(ret == 1);
439 
440 	get_tcp_inq(&msg, &tcp_inq);
441 
442 	/* tcp_inq should be 1 due to received fin. */
443 	assert(tcp_inq == 1);
444 
445 	iov.iov_len = 1;
446 	ret = recvmsg(fd, &msg, 0);
447 	if (ret < 0)
448 		die_perror("recvmsg");
449 
450 	/* expect EOF */
451 	assert(ret == 0);
452 	get_tcp_inq(&msg, &tcp_inq);
453 	assert(tcp_inq == 1);
454 
455 	close(fd);
456 }
457 
458 static int xaccept(int s)
459 {
460 	int fd = accept(s, NULL, 0);
461 
462 	if (fd < 0)
463 		die_perror("accept");
464 
465 	return fd;
466 }
467 
468 static int server(int unixfd)
469 {
470 	int fd = -1, r, on = 1;
471 
472 	switch (pf) {
473 	case AF_INET:
474 		fd = sock_listen_mptcp("127.0.0.1", "15432");
475 		break;
476 	case AF_INET6:
477 		fd = sock_listen_mptcp("::1", "15432");
478 		break;
479 	default:
480 		xerror("Unknown pf %d\n", pf);
481 		break;
482 	}
483 
484 	r = write(unixfd, "conn", 4);
485 	assert(r == 4);
486 
487 	alarm(15);
488 	r = xaccept(fd);
489 
490 	if (-1 == setsockopt(r, IPPROTO_TCP, TCP_INQ, &on, sizeof(on)))
491 		die_perror("setsockopt");
492 
493 	process_one_client(r, unixfd);
494 
495 	return 0;
496 }
497 
498 static int client(int unixfd)
499 {
500 	int fd = -1;
501 
502 	alarm(15);
503 
504 	switch (pf) {
505 	case AF_INET:
506 		fd = sock_connect_mptcp("127.0.0.1", "15432", proto_tx);
507 		break;
508 	case AF_INET6:
509 		fd = sock_connect_mptcp("::1", "15432", proto_tx);
510 		break;
511 	default:
512 		xerror("Unknown pf %d\n", pf);
513 	}
514 
515 	connect_one_server(fd, unixfd);
516 
517 	return 0;
518 }
519 
520 static void init_rng(void)
521 {
522 	int fd = open("/dev/urandom", O_RDONLY);
523 	unsigned int foo;
524 
525 	if (fd > 0) {
526 		int ret = read(fd, &foo, sizeof(foo));
527 
528 		if (ret < 0)
529 			srand(fd + foo);
530 		close(fd);
531 	}
532 
533 	srand(foo);
534 }
535 
536 static pid_t xfork(void)
537 {
538 	pid_t p = fork();
539 
540 	if (p < 0)
541 		die_perror("fork");
542 	else if (p == 0)
543 		init_rng();
544 
545 	return p;
546 }
547 
548 static int rcheck(int wstatus, const char *what)
549 {
550 	if (WIFEXITED(wstatus)) {
551 		if (WEXITSTATUS(wstatus) == 0)
552 			return 0;
553 		fprintf(stderr, "%s exited, status=%d\n", what, WEXITSTATUS(wstatus));
554 		return WEXITSTATUS(wstatus);
555 	} else if (WIFSIGNALED(wstatus)) {
556 		xerror("%s killed by signal %d\n", what, WTERMSIG(wstatus));
557 	} else if (WIFSTOPPED(wstatus)) {
558 		xerror("%s stopped by signal %d\n", what, WSTOPSIG(wstatus));
559 	}
560 
561 	return 111;
562 }
563 
564 int main(int argc, char *argv[])
565 {
566 	int e1, e2, wstatus;
567 	pid_t s, c, ret;
568 	int unixfds[2];
569 
570 	parse_opts(argc, argv);
571 
572 	e1 = socketpair(AF_UNIX, SOCK_DGRAM, 0, unixfds);
573 	if (e1 < 0)
574 		die_perror("pipe");
575 
576 	s = xfork();
577 	if (s == 0)
578 		return server(unixfds[1]);
579 
580 	close(unixfds[1]);
581 
582 	/* wait until server bound a socket */
583 	e1 = read(unixfds[0], &e1, 4);
584 	assert(e1 == 4);
585 
586 	c = xfork();
587 	if (c == 0)
588 		return client(unixfds[0]);
589 
590 	close(unixfds[0]);
591 
592 	ret = waitpid(s, &wstatus, 0);
593 	if (ret == -1)
594 		die_perror("waitpid");
595 	e1 = rcheck(wstatus, "server");
596 	ret = waitpid(c, &wstatus, 0);
597 	if (ret == -1)
598 		die_perror("waitpid");
599 	e2 = rcheck(wstatus, "client");
600 
601 	return e1 ? e1 : e2;
602 }
603