1 // SPDX-License-Identifier: GPL-2.0
2 
3 #define _GNU_SOURCE
4 
5 #include <assert.h>
6 #include <errno.h>
7 #include <limits.h>
8 #include <string.h>
9 #include <stdarg.h>
10 #include <stdbool.h>
11 #include <stdint.h>
12 #include <inttypes.h>
13 #include <stdio.h>
14 #include <stdlib.h>
15 #include <strings.h>
16 #include <unistd.h>
17 
18 #include <sys/socket.h>
19 #include <sys/types.h>
20 #include <sys/wait.h>
21 
22 #include <netdb.h>
23 #include <netinet/in.h>
24 
25 #include <linux/tcp.h>
26 
27 static int pf = AF_INET;
28 
29 #ifndef IPPROTO_MPTCP
30 #define IPPROTO_MPTCP 262
31 #endif
32 #ifndef SOL_MPTCP
33 #define SOL_MPTCP 284
34 #endif
35 
36 #ifndef MPTCP_INFO
37 struct mptcp_info {
38 	__u8	mptcpi_subflows;
39 	__u8	mptcpi_add_addr_signal;
40 	__u8	mptcpi_add_addr_accepted;
41 	__u8	mptcpi_subflows_max;
42 	__u8	mptcpi_add_addr_signal_max;
43 	__u8	mptcpi_add_addr_accepted_max;
44 	__u32	mptcpi_flags;
45 	__u32	mptcpi_token;
46 	__u64	mptcpi_write_seq;
47 	__u64	mptcpi_snd_una;
48 	__u64	mptcpi_rcv_nxt;
49 	__u8	mptcpi_local_addr_used;
50 	__u8	mptcpi_local_addr_max;
51 	__u8	mptcpi_csum_enabled;
52 };
53 
54 struct mptcp_subflow_data {
55 	__u32		size_subflow_data;		/* size of this structure in userspace */
56 	__u32		num_subflows;			/* must be 0, set by kernel */
57 	__u32		size_kernel;			/* must be 0, set by kernel */
58 	__u32		size_user;			/* size of one element in data[] */
59 } __attribute__((aligned(8)));
60 
61 struct mptcp_subflow_addrs {
62 	union {
63 		__kernel_sa_family_t sa_family;
64 		struct sockaddr sa_local;
65 		struct sockaddr_in sin_local;
66 		struct sockaddr_in6 sin6_local;
67 		struct __kernel_sockaddr_storage ss_local;
68 	};
69 	union {
70 		struct sockaddr sa_remote;
71 		struct sockaddr_in sin_remote;
72 		struct sockaddr_in6 sin6_remote;
73 		struct __kernel_sockaddr_storage ss_remote;
74 	};
75 };
76 
77 #define MPTCP_INFO		1
78 #define MPTCP_TCPINFO		2
79 #define MPTCP_SUBFLOW_ADDRS	3
80 #endif
81 
82 struct so_state {
83 	struct mptcp_info mi;
84 	uint64_t mptcpi_rcv_delta;
85 	uint64_t tcpi_rcv_delta;
86 };
87 
88 static void die_perror(const char *msg)
89 {
90 	perror(msg);
91 	exit(1);
92 }
93 
94 static void die_usage(int r)
95 {
96 	fprintf(stderr, "Usage: mptcp_sockopt [-6]\n");
97 	exit(r);
98 }
99 
100 static void xerror(const char *fmt, ...)
101 {
102 	va_list ap;
103 
104 	va_start(ap, fmt);
105 	vfprintf(stderr, fmt, ap);
106 	va_end(ap);
107 	fputc('\n', stderr);
108 	exit(1);
109 }
110 
111 static const char *getxinfo_strerr(int err)
112 {
113 	if (err == EAI_SYSTEM)
114 		return strerror(errno);
115 
116 	return gai_strerror(err);
117 }
118 
119 static void xgetaddrinfo(const char *node, const char *service,
120 			 const struct addrinfo *hints,
121 			 struct addrinfo **res)
122 {
123 	int err = getaddrinfo(node, service, hints, res);
124 
125 	if (err) {
126 		const char *errstr = getxinfo_strerr(err);
127 
128 		fprintf(stderr, "Fatal: getaddrinfo(%s:%s): %s\n",
129 			node ? node : "", service ? service : "", errstr);
130 		exit(1);
131 	}
132 }
133 
134 static int sock_listen_mptcp(const char * const listenaddr,
135 			     const char * const port)
136 {
137 	int sock;
138 	struct addrinfo hints = {
139 		.ai_protocol = IPPROTO_TCP,
140 		.ai_socktype = SOCK_STREAM,
141 		.ai_flags = AI_PASSIVE | AI_NUMERICHOST
142 	};
143 
144 	hints.ai_family = pf;
145 
146 	struct addrinfo *a, *addr;
147 	int one = 1;
148 
149 	xgetaddrinfo(listenaddr, port, &hints, &addr);
150 	hints.ai_family = pf;
151 
152 	for (a = addr; a; a = a->ai_next) {
153 		sock = socket(a->ai_family, a->ai_socktype, IPPROTO_MPTCP);
154 		if (sock < 0)
155 			continue;
156 
157 		if (-1 == setsockopt(sock, SOL_SOCKET, SO_REUSEADDR, &one,
158 				     sizeof(one)))
159 			perror("setsockopt");
160 
161 		if (bind(sock, a->ai_addr, a->ai_addrlen) == 0)
162 			break; /* success */
163 
164 		perror("bind");
165 		close(sock);
166 		sock = -1;
167 	}
168 
169 	freeaddrinfo(addr);
170 
171 	if (sock < 0)
172 		xerror("could not create listen socket");
173 
174 	if (listen(sock, 20))
175 		die_perror("listen");
176 
177 	return sock;
178 }
179 
180 static int sock_connect_mptcp(const char * const remoteaddr,
181 			      const char * const port, int proto)
182 {
183 	struct addrinfo hints = {
184 		.ai_protocol = IPPROTO_TCP,
185 		.ai_socktype = SOCK_STREAM,
186 	};
187 	struct addrinfo *a, *addr;
188 	int sock = -1;
189 
190 	hints.ai_family = pf;
191 
192 	xgetaddrinfo(remoteaddr, port, &hints, &addr);
193 	for (a = addr; a; a = a->ai_next) {
194 		sock = socket(a->ai_family, a->ai_socktype, proto);
195 		if (sock < 0)
196 			continue;
197 
198 		if (connect(sock, a->ai_addr, a->ai_addrlen) == 0)
199 			break; /* success */
200 
201 		die_perror("connect");
202 	}
203 
204 	if (sock < 0)
205 		xerror("could not create connect socket");
206 
207 	freeaddrinfo(addr);
208 	return sock;
209 }
210 
211 static void parse_opts(int argc, char **argv)
212 {
213 	int c;
214 
215 	while ((c = getopt(argc, argv, "h6")) != -1) {
216 		switch (c) {
217 		case 'h':
218 			die_usage(0);
219 			break;
220 		case '6':
221 			pf = AF_INET6;
222 			break;
223 		default:
224 			die_usage(1);
225 			break;
226 		}
227 	}
228 }
229 
230 static void do_getsockopt_bogus_sf_data(int fd, int optname)
231 {
232 	struct mptcp_subflow_data good_data;
233 	struct bogus_data {
234 		struct mptcp_subflow_data d;
235 		char buf[2];
236 	} bd;
237 	socklen_t olen, _olen;
238 	int ret;
239 
240 	memset(&bd, 0, sizeof(bd));
241 	memset(&good_data, 0, sizeof(good_data));
242 
243 	olen = sizeof(good_data);
244 	good_data.size_subflow_data = olen;
245 
246 	ret = getsockopt(fd, SOL_MPTCP, optname, &bd, &olen);
247 	assert(ret < 0); /* 0 size_subflow_data */
248 	assert(olen == sizeof(good_data));
249 
250 	bd.d = good_data;
251 
252 	ret = getsockopt(fd, SOL_MPTCP, optname, &bd, &olen);
253 	assert(ret == 0);
254 	assert(olen == sizeof(good_data));
255 	assert(bd.d.num_subflows == 1);
256 	assert(bd.d.size_kernel > 0);
257 	assert(bd.d.size_user == 0);
258 
259 	bd.d = good_data;
260 	_olen = rand() % olen;
261 	olen = _olen;
262 	ret = getsockopt(fd, SOL_MPTCP, optname, &bd, &olen);
263 	assert(ret < 0);	/* bogus olen */
264 	assert(olen == _olen);	/* must be unchanged */
265 
266 	bd.d = good_data;
267 	olen = sizeof(good_data);
268 	bd.d.size_kernel = 1;
269 	ret = getsockopt(fd, SOL_MPTCP, optname, &bd, &olen);
270 	assert(ret < 0); /* size_kernel not 0 */
271 
272 	bd.d = good_data;
273 	olen = sizeof(good_data);
274 	bd.d.num_subflows = 1;
275 	ret = getsockopt(fd, SOL_MPTCP, optname, &bd, &olen);
276 	assert(ret < 0); /* num_subflows not 0 */
277 
278 	/* forward compat check: larger struct mptcp_subflow_data on 'old' kernel */
279 	bd.d = good_data;
280 	olen = sizeof(bd);
281 	bd.d.size_subflow_data = sizeof(bd);
282 
283 	ret = getsockopt(fd, SOL_MPTCP, optname, &bd, &olen);
284 	assert(ret == 0);
285 
286 	/* olen must be truncated to real data size filled by kernel: */
287 	assert(olen == sizeof(good_data));
288 
289 	assert(bd.d.size_subflow_data == sizeof(bd));
290 
291 	bd.d = good_data;
292 	bd.d.size_subflow_data += 1;
293 	bd.d.size_user = 1;
294 	olen = bd.d.size_subflow_data + 1;
295 	_olen = olen;
296 
297 	ret = getsockopt(fd, SOL_MPTCP, optname, &bd, &_olen);
298 	assert(ret == 0);
299 
300 	/* no truncation, kernel should have filled 1 byte of optname payload in buf[1]: */
301 	assert(olen == _olen);
302 
303 	assert(bd.d.size_subflow_data == sizeof(good_data) + 1);
304 	assert(bd.buf[0] == 0);
305 }
306 
307 static void do_getsockopt_mptcp_info(struct so_state *s, int fd, size_t w)
308 {
309 	struct mptcp_info i;
310 	socklen_t olen;
311 	int ret;
312 
313 	olen = sizeof(i);
314 	ret = getsockopt(fd, SOL_MPTCP, MPTCP_INFO, &i, &olen);
315 
316 	if (ret < 0)
317 		die_perror("getsockopt MPTCP_INFO");
318 
319 	assert(olen == sizeof(i));
320 
321 	if (s->mi.mptcpi_write_seq == 0)
322 		s->mi = i;
323 
324 	assert(s->mi.mptcpi_write_seq + w == i.mptcpi_write_seq);
325 
326 	s->mptcpi_rcv_delta = i.mptcpi_rcv_nxt - s->mi.mptcpi_rcv_nxt;
327 }
328 
329 static void do_getsockopt_tcp_info(struct so_state *s, int fd, size_t r, size_t w)
330 {
331 	struct my_tcp_info {
332 		struct mptcp_subflow_data d;
333 		struct tcp_info ti[2];
334 	} ti;
335 	int ret, tries = 5;
336 	socklen_t olen;
337 
338 	do {
339 		memset(&ti, 0, sizeof(ti));
340 
341 		ti.d.size_subflow_data = sizeof(struct mptcp_subflow_data);
342 		ti.d.size_user = sizeof(struct tcp_info);
343 		olen = sizeof(ti);
344 
345 		ret = getsockopt(fd, SOL_MPTCP, MPTCP_TCPINFO, &ti, &olen);
346 		if (ret < 0)
347 			xerror("getsockopt MPTCP_TCPINFO (tries %d, %m)");
348 
349 		assert(olen <= sizeof(ti));
350 		assert(ti.d.size_user == ti.d.size_kernel);
351 		assert(ti.d.size_user == sizeof(struct tcp_info));
352 		assert(ti.d.num_subflows == 1);
353 
354 		assert(olen > (socklen_t)sizeof(struct mptcp_subflow_data));
355 		olen -= sizeof(struct mptcp_subflow_data);
356 		assert(olen == sizeof(struct tcp_info));
357 
358 		if (ti.ti[0].tcpi_bytes_sent == w &&
359 		    ti.ti[0].tcpi_bytes_received == r)
360 			goto done;
361 
362 		if (r == 0 && ti.ti[0].tcpi_bytes_sent == w &&
363 		    ti.ti[0].tcpi_bytes_received) {
364 			s->tcpi_rcv_delta = ti.ti[0].tcpi_bytes_received;
365 			goto done;
366 		}
367 
368 		/* wait and repeat, might be that tx is still ongoing */
369 		sleep(1);
370 	} while (tries-- > 0);
371 
372 	xerror("tcpi_bytes_sent %" PRIu64 ", want %zu. tcpi_bytes_received %" PRIu64 ", want %zu",
373 		ti.ti[0].tcpi_bytes_sent, w, ti.ti[0].tcpi_bytes_received, r);
374 
375 done:
376 	do_getsockopt_bogus_sf_data(fd, MPTCP_TCPINFO);
377 }
378 
379 static void do_getsockopt_subflow_addrs(int fd)
380 {
381 	struct sockaddr_storage remote, local;
382 	socklen_t olen, rlen, llen;
383 	int ret;
384 	struct my_addrs {
385 		struct mptcp_subflow_data d;
386 		struct mptcp_subflow_addrs addr[2];
387 	} addrs;
388 
389 	memset(&addrs, 0, sizeof(addrs));
390 	memset(&local, 0, sizeof(local));
391 	memset(&remote, 0, sizeof(remote));
392 
393 	addrs.d.size_subflow_data = sizeof(struct mptcp_subflow_data);
394 	addrs.d.size_user = sizeof(struct mptcp_subflow_addrs);
395 	olen = sizeof(addrs);
396 
397 	ret = getsockopt(fd, SOL_MPTCP, MPTCP_SUBFLOW_ADDRS, &addrs, &olen);
398 	if (ret < 0)
399 		die_perror("getsockopt MPTCP_SUBFLOW_ADDRS");
400 
401 	assert(olen <= sizeof(addrs));
402 	assert(addrs.d.size_user == addrs.d.size_kernel);
403 	assert(addrs.d.size_user == sizeof(struct mptcp_subflow_addrs));
404 	assert(addrs.d.num_subflows == 1);
405 
406 	assert(olen > (socklen_t)sizeof(struct mptcp_subflow_data));
407 	olen -= sizeof(struct mptcp_subflow_data);
408 	assert(olen == sizeof(struct mptcp_subflow_addrs));
409 
410 	llen = sizeof(local);
411 	ret = getsockname(fd, (struct sockaddr *)&local, &llen);
412 	if (ret < 0)
413 		die_perror("getsockname");
414 	rlen = sizeof(remote);
415 	ret = getpeername(fd, (struct sockaddr *)&remote, &rlen);
416 	if (ret < 0)
417 		die_perror("getpeername");
418 
419 	assert(rlen > 0);
420 	assert(rlen == llen);
421 
422 	assert(remote.ss_family == local.ss_family);
423 
424 	assert(memcmp(&local, &addrs.addr[0].ss_local, sizeof(local)) == 0);
425 	assert(memcmp(&remote, &addrs.addr[0].ss_remote, sizeof(remote)) == 0);
426 
427 	memset(&addrs, 0, sizeof(addrs));
428 
429 	addrs.d.size_subflow_data = sizeof(struct mptcp_subflow_data);
430 	addrs.d.size_user = sizeof(sa_family_t);
431 	olen = sizeof(addrs.d) + sizeof(sa_family_t);
432 
433 	ret = getsockopt(fd, SOL_MPTCP, MPTCP_SUBFLOW_ADDRS, &addrs, &olen);
434 	assert(ret == 0);
435 	assert(olen == sizeof(addrs.d) + sizeof(sa_family_t));
436 
437 	assert(addrs.addr[0].sa_family == pf);
438 	assert(addrs.addr[0].sa_family == local.ss_family);
439 
440 	assert(memcmp(&local, &addrs.addr[0].ss_local, sizeof(local)) != 0);
441 	assert(memcmp(&remote, &addrs.addr[0].ss_remote, sizeof(remote)) != 0);
442 
443 	do_getsockopt_bogus_sf_data(fd, MPTCP_SUBFLOW_ADDRS);
444 }
445 
446 static void do_getsockopts(struct so_state *s, int fd, size_t r, size_t w)
447 {
448 	do_getsockopt_mptcp_info(s, fd, w);
449 
450 	do_getsockopt_tcp_info(s, fd, r, w);
451 
452 	do_getsockopt_subflow_addrs(fd);
453 }
454 
455 static void connect_one_server(int fd, int pipefd)
456 {
457 	char buf[4096], buf2[4096];
458 	size_t len, i, total;
459 	struct so_state s;
460 	bool eof = false;
461 	ssize_t ret;
462 
463 	memset(&s, 0, sizeof(s));
464 
465 	len = rand() % (sizeof(buf) - 1);
466 
467 	if (len < 128)
468 		len = 128;
469 
470 	for (i = 0; i < len ; i++) {
471 		buf[i] = rand() % 26;
472 		buf[i] += 'A';
473 	}
474 
475 	buf[i] = '\n';
476 
477 	do_getsockopts(&s, fd, 0, 0);
478 
479 	/* un-block server */
480 	ret = read(pipefd, buf2, 4);
481 	assert(ret == 4);
482 	close(pipefd);
483 
484 	assert(strncmp(buf2, "xmit", 4) == 0);
485 
486 	ret = write(fd, buf, len);
487 	if (ret < 0)
488 		die_perror("write");
489 
490 	if (ret != (ssize_t)len)
491 		xerror("short write");
492 
493 	total = 0;
494 	do {
495 		ret = read(fd, buf2 + total, sizeof(buf2) - total);
496 		if (ret < 0)
497 			die_perror("read");
498 		if (ret == 0) {
499 			eof = true;
500 			break;
501 		}
502 
503 		total += ret;
504 	} while (total < len);
505 
506 	if (total != len)
507 		xerror("total %lu, len %lu eof %d\n", total, len, eof);
508 
509 	if (memcmp(buf, buf2, len))
510 		xerror("data corruption");
511 
512 	if (s.tcpi_rcv_delta)
513 		assert(s.tcpi_rcv_delta <= total);
514 
515 	do_getsockopts(&s, fd, ret, ret);
516 
517 	if (eof)
518 		total += 1; /* sequence advances due to FIN */
519 
520 	assert(s.mptcpi_rcv_delta == (uint64_t)total);
521 	close(fd);
522 }
523 
524 static void process_one_client(int fd, int pipefd)
525 {
526 	ssize_t ret, ret2, ret3;
527 	struct so_state s;
528 	char buf[4096];
529 
530 	memset(&s, 0, sizeof(s));
531 	do_getsockopts(&s, fd, 0, 0);
532 
533 	ret = write(pipefd, "xmit", 4);
534 	assert(ret == 4);
535 
536 	ret = read(fd, buf, sizeof(buf));
537 	if (ret < 0)
538 		die_perror("read");
539 
540 	assert(s.mptcpi_rcv_delta <= (uint64_t)ret);
541 
542 	if (s.tcpi_rcv_delta)
543 		assert(s.tcpi_rcv_delta == (uint64_t)ret);
544 
545 	ret2 = write(fd, buf, ret);
546 	if (ret2 < 0)
547 		die_perror("write");
548 
549 	/* wait for hangup */
550 	ret3 = read(fd, buf, 1);
551 	if (ret3 != 0)
552 		xerror("expected EOF, got %lu", ret3);
553 
554 	do_getsockopts(&s, fd, ret, ret2);
555 	if (s.mptcpi_rcv_delta != (uint64_t)ret + 1)
556 		xerror("mptcpi_rcv_delta %" PRIu64 ", expect %" PRIu64, s.mptcpi_rcv_delta, ret + 1, s.mptcpi_rcv_delta - ret);
557 	close(fd);
558 }
559 
560 static int xaccept(int s)
561 {
562 	int fd = accept(s, NULL, 0);
563 
564 	if (fd < 0)
565 		die_perror("accept");
566 
567 	return fd;
568 }
569 
570 static int server(int pipefd)
571 {
572 	int fd = -1, r;
573 
574 	switch (pf) {
575 	case AF_INET:
576 		fd = sock_listen_mptcp("127.0.0.1", "15432");
577 		break;
578 	case AF_INET6:
579 		fd = sock_listen_mptcp("::1", "15432");
580 		break;
581 	default:
582 		xerror("Unknown pf %d\n", pf);
583 		break;
584 	}
585 
586 	r = write(pipefd, "conn", 4);
587 	assert(r == 4);
588 
589 	alarm(15);
590 	r = xaccept(fd);
591 
592 	process_one_client(r, pipefd);
593 
594 	return 0;
595 }
596 
597 static int client(int pipefd)
598 {
599 	int fd = -1;
600 
601 	alarm(15);
602 
603 	switch (pf) {
604 	case AF_INET:
605 		fd = sock_connect_mptcp("127.0.0.1", "15432", IPPROTO_MPTCP);
606 		break;
607 	case AF_INET6:
608 		fd = sock_connect_mptcp("::1", "15432", IPPROTO_MPTCP);
609 		break;
610 	default:
611 		xerror("Unknown pf %d\n", pf);
612 	}
613 
614 	connect_one_server(fd, pipefd);
615 
616 	return 0;
617 }
618 
619 static pid_t xfork(void)
620 {
621 	pid_t p = fork();
622 
623 	if (p < 0)
624 		die_perror("fork");
625 
626 	return p;
627 }
628 
629 static int rcheck(int wstatus, const char *what)
630 {
631 	if (WIFEXITED(wstatus)) {
632 		if (WEXITSTATUS(wstatus) == 0)
633 			return 0;
634 		fprintf(stderr, "%s exited, status=%d\n", what, WEXITSTATUS(wstatus));
635 		return WEXITSTATUS(wstatus);
636 	} else if (WIFSIGNALED(wstatus)) {
637 		xerror("%s killed by signal %d\n", what, WTERMSIG(wstatus));
638 	} else if (WIFSTOPPED(wstatus)) {
639 		xerror("%s stopped by signal %d\n", what, WSTOPSIG(wstatus));
640 	}
641 
642 	return 111;
643 }
644 
645 int main(int argc, char *argv[])
646 {
647 	int e1, e2, wstatus;
648 	pid_t s, c, ret;
649 	int pipefds[2];
650 
651 	parse_opts(argc, argv);
652 
653 	e1 = pipe(pipefds);
654 	if (e1 < 0)
655 		die_perror("pipe");
656 
657 	s = xfork();
658 	if (s == 0)
659 		return server(pipefds[1]);
660 
661 	close(pipefds[1]);
662 
663 	/* wait until server bound a socket */
664 	e1 = read(pipefds[0], &e1, 4);
665 	assert(e1 == 4);
666 
667 	c = xfork();
668 	if (c == 0)
669 		return client(pipefds[0]);
670 
671 	close(pipefds[0]);
672 
673 	ret = waitpid(s, &wstatus, 0);
674 	if (ret == -1)
675 		die_perror("waitpid");
676 	e1 = rcheck(wstatus, "server");
677 	ret = waitpid(c, &wstatus, 0);
678 	if (ret == -1)
679 		die_perror("waitpid");
680 	e2 = rcheck(wstatus, "client");
681 
682 	return e1 ? e1 : e2;
683 }
684