1 // SPDX-License-Identifier: GPL-2.0-only
2 /*
3  * vsock_diag_test - vsock_diag.ko test suite
4  *
5  * Copyright (C) 2017 Red Hat, Inc.
6  *
7  * Author: Stefan Hajnoczi <stefanha@redhat.com>
8  */
9 
10 #include <getopt.h>
11 #include <stdio.h>
12 #include <stdbool.h>
13 #include <stdlib.h>
14 #include <string.h>
15 #include <errno.h>
16 #include <unistd.h>
17 #include <signal.h>
18 #include <sys/socket.h>
19 #include <sys/stat.h>
20 #include <sys/types.h>
21 #include <linux/list.h>
22 #include <linux/net.h>
23 #include <linux/netlink.h>
24 #include <linux/sock_diag.h>
25 #include <netinet/tcp.h>
26 
27 #include "../../../include/uapi/linux/vm_sockets.h"
28 #include "../../../include/uapi/linux/vm_sockets_diag.h"
29 
30 #include "timeout.h"
31 #include "control.h"
32 
33 enum test_mode {
34 	TEST_MODE_UNSET,
35 	TEST_MODE_CLIENT,
36 	TEST_MODE_SERVER
37 };
38 
39 /* Per-socket status */
40 struct vsock_stat {
41 	struct list_head list;
42 	struct vsock_diag_msg msg;
43 };
44 
45 static const char *sock_type_str(int type)
46 {
47 	switch (type) {
48 	case SOCK_DGRAM:
49 		return "DGRAM";
50 	case SOCK_STREAM:
51 		return "STREAM";
52 	default:
53 		return "INVALID TYPE";
54 	}
55 }
56 
57 static const char *sock_state_str(int state)
58 {
59 	switch (state) {
60 	case TCP_CLOSE:
61 		return "UNCONNECTED";
62 	case TCP_SYN_SENT:
63 		return "CONNECTING";
64 	case TCP_ESTABLISHED:
65 		return "CONNECTED";
66 	case TCP_CLOSING:
67 		return "DISCONNECTING";
68 	case TCP_LISTEN:
69 		return "LISTEN";
70 	default:
71 		return "INVALID STATE";
72 	}
73 }
74 
75 static const char *sock_shutdown_str(int shutdown)
76 {
77 	switch (shutdown) {
78 	case 1:
79 		return "RCV_SHUTDOWN";
80 	case 2:
81 		return "SEND_SHUTDOWN";
82 	case 3:
83 		return "RCV_SHUTDOWN | SEND_SHUTDOWN";
84 	default:
85 		return "0";
86 	}
87 }
88 
89 static void print_vsock_addr(FILE *fp, unsigned int cid, unsigned int port)
90 {
91 	if (cid == VMADDR_CID_ANY)
92 		fprintf(fp, "*:");
93 	else
94 		fprintf(fp, "%u:", cid);
95 
96 	if (port == VMADDR_PORT_ANY)
97 		fprintf(fp, "*");
98 	else
99 		fprintf(fp, "%u", port);
100 }
101 
102 static void print_vsock_stat(FILE *fp, struct vsock_stat *st)
103 {
104 	print_vsock_addr(fp, st->msg.vdiag_src_cid, st->msg.vdiag_src_port);
105 	fprintf(fp, " ");
106 	print_vsock_addr(fp, st->msg.vdiag_dst_cid, st->msg.vdiag_dst_port);
107 	fprintf(fp, " %s %s %s %u\n",
108 		sock_type_str(st->msg.vdiag_type),
109 		sock_state_str(st->msg.vdiag_state),
110 		sock_shutdown_str(st->msg.vdiag_shutdown),
111 		st->msg.vdiag_ino);
112 }
113 
114 static void print_vsock_stats(FILE *fp, struct list_head *head)
115 {
116 	struct vsock_stat *st;
117 
118 	list_for_each_entry(st, head, list)
119 		print_vsock_stat(fp, st);
120 }
121 
122 static struct vsock_stat *find_vsock_stat(struct list_head *head, int fd)
123 {
124 	struct vsock_stat *st;
125 	struct stat stat;
126 
127 	if (fstat(fd, &stat) < 0) {
128 		perror("fstat");
129 		exit(EXIT_FAILURE);
130 	}
131 
132 	list_for_each_entry(st, head, list)
133 		if (st->msg.vdiag_ino == stat.st_ino)
134 			return st;
135 
136 	fprintf(stderr, "cannot find fd %d\n", fd);
137 	exit(EXIT_FAILURE);
138 }
139 
140 static void check_no_sockets(struct list_head *head)
141 {
142 	if (!list_empty(head)) {
143 		fprintf(stderr, "expected no sockets\n");
144 		print_vsock_stats(stderr, head);
145 		exit(1);
146 	}
147 }
148 
149 static void check_num_sockets(struct list_head *head, int expected)
150 {
151 	struct list_head *node;
152 	int n = 0;
153 
154 	list_for_each(node, head)
155 		n++;
156 
157 	if (n != expected) {
158 		fprintf(stderr, "expected %d sockets, found %d\n",
159 			expected, n);
160 		print_vsock_stats(stderr, head);
161 		exit(EXIT_FAILURE);
162 	}
163 }
164 
165 static void check_socket_state(struct vsock_stat *st, __u8 state)
166 {
167 	if (st->msg.vdiag_state != state) {
168 		fprintf(stderr, "expected socket state %#x, got %#x\n",
169 			state, st->msg.vdiag_state);
170 		exit(EXIT_FAILURE);
171 	}
172 }
173 
174 static void send_req(int fd)
175 {
176 	struct sockaddr_nl nladdr = {
177 		.nl_family = AF_NETLINK,
178 	};
179 	struct {
180 		struct nlmsghdr nlh;
181 		struct vsock_diag_req vreq;
182 	} req = {
183 		.nlh = {
184 			.nlmsg_len = sizeof(req),
185 			.nlmsg_type = SOCK_DIAG_BY_FAMILY,
186 			.nlmsg_flags = NLM_F_REQUEST | NLM_F_DUMP,
187 		},
188 		.vreq = {
189 			.sdiag_family = AF_VSOCK,
190 			.vdiag_states = ~(__u32)0,
191 		},
192 	};
193 	struct iovec iov = {
194 		.iov_base = &req,
195 		.iov_len = sizeof(req),
196 	};
197 	struct msghdr msg = {
198 		.msg_name = &nladdr,
199 		.msg_namelen = sizeof(nladdr),
200 		.msg_iov = &iov,
201 		.msg_iovlen = 1,
202 	};
203 
204 	for (;;) {
205 		if (sendmsg(fd, &msg, 0) < 0) {
206 			if (errno == EINTR)
207 				continue;
208 
209 			perror("sendmsg");
210 			exit(EXIT_FAILURE);
211 		}
212 
213 		return;
214 	}
215 }
216 
217 static ssize_t recv_resp(int fd, void *buf, size_t len)
218 {
219 	struct sockaddr_nl nladdr = {
220 		.nl_family = AF_NETLINK,
221 	};
222 	struct iovec iov = {
223 		.iov_base = buf,
224 		.iov_len = len,
225 	};
226 	struct msghdr msg = {
227 		.msg_name = &nladdr,
228 		.msg_namelen = sizeof(nladdr),
229 		.msg_iov = &iov,
230 		.msg_iovlen = 1,
231 	};
232 	ssize_t ret;
233 
234 	do {
235 		ret = recvmsg(fd, &msg, 0);
236 	} while (ret < 0 && errno == EINTR);
237 
238 	if (ret < 0) {
239 		perror("recvmsg");
240 		exit(EXIT_FAILURE);
241 	}
242 
243 	return ret;
244 }
245 
246 static void add_vsock_stat(struct list_head *sockets,
247 			   const struct vsock_diag_msg *resp)
248 {
249 	struct vsock_stat *st;
250 
251 	st = malloc(sizeof(*st));
252 	if (!st) {
253 		perror("malloc");
254 		exit(EXIT_FAILURE);
255 	}
256 
257 	st->msg = *resp;
258 	list_add_tail(&st->list, sockets);
259 }
260 
261 /*
262  * Read vsock stats into a list.
263  */
264 static void read_vsock_stat(struct list_head *sockets)
265 {
266 	long buf[8192 / sizeof(long)];
267 	int fd;
268 
269 	fd = socket(AF_NETLINK, SOCK_RAW, NETLINK_SOCK_DIAG);
270 	if (fd < 0) {
271 		perror("socket");
272 		exit(EXIT_FAILURE);
273 	}
274 
275 	send_req(fd);
276 
277 	for (;;) {
278 		const struct nlmsghdr *h;
279 		ssize_t ret;
280 
281 		ret = recv_resp(fd, buf, sizeof(buf));
282 		if (ret == 0)
283 			goto done;
284 		if (ret < sizeof(*h)) {
285 			fprintf(stderr, "short read of %zd bytes\n", ret);
286 			exit(EXIT_FAILURE);
287 		}
288 
289 		h = (struct nlmsghdr *)buf;
290 
291 		while (NLMSG_OK(h, ret)) {
292 			if (h->nlmsg_type == NLMSG_DONE)
293 				goto done;
294 
295 			if (h->nlmsg_type == NLMSG_ERROR) {
296 				const struct nlmsgerr *err = NLMSG_DATA(h);
297 
298 				if (h->nlmsg_len < NLMSG_LENGTH(sizeof(*err)))
299 					fprintf(stderr, "NLMSG_ERROR\n");
300 				else {
301 					errno = -err->error;
302 					perror("NLMSG_ERROR");
303 				}
304 
305 				exit(EXIT_FAILURE);
306 			}
307 
308 			if (h->nlmsg_type != SOCK_DIAG_BY_FAMILY) {
309 				fprintf(stderr, "unexpected nlmsg_type %#x\n",
310 					h->nlmsg_type);
311 				exit(EXIT_FAILURE);
312 			}
313 			if (h->nlmsg_len <
314 			    NLMSG_LENGTH(sizeof(struct vsock_diag_msg))) {
315 				fprintf(stderr, "short vsock_diag_msg\n");
316 				exit(EXIT_FAILURE);
317 			}
318 
319 			add_vsock_stat(sockets, NLMSG_DATA(h));
320 
321 			h = NLMSG_NEXT(h, ret);
322 		}
323 	}
324 
325 done:
326 	close(fd);
327 }
328 
329 static void free_sock_stat(struct list_head *sockets)
330 {
331 	struct vsock_stat *st;
332 	struct vsock_stat *next;
333 
334 	list_for_each_entry_safe(st, next, sockets, list)
335 		free(st);
336 }
337 
338 static void test_no_sockets(unsigned int peer_cid)
339 {
340 	LIST_HEAD(sockets);
341 
342 	read_vsock_stat(&sockets);
343 
344 	check_no_sockets(&sockets);
345 
346 	free_sock_stat(&sockets);
347 }
348 
349 static void test_listen_socket_server(unsigned int peer_cid)
350 {
351 	union {
352 		struct sockaddr sa;
353 		struct sockaddr_vm svm;
354 	} addr = {
355 		.svm = {
356 			.svm_family = AF_VSOCK,
357 			.svm_port = 1234,
358 			.svm_cid = VMADDR_CID_ANY,
359 		},
360 	};
361 	LIST_HEAD(sockets);
362 	struct vsock_stat *st;
363 	int fd;
364 
365 	fd = socket(AF_VSOCK, SOCK_STREAM, 0);
366 
367 	if (bind(fd, &addr.sa, sizeof(addr.svm)) < 0) {
368 		perror("bind");
369 		exit(EXIT_FAILURE);
370 	}
371 
372 	if (listen(fd, 1) < 0) {
373 		perror("listen");
374 		exit(EXIT_FAILURE);
375 	}
376 
377 	read_vsock_stat(&sockets);
378 
379 	check_num_sockets(&sockets, 1);
380 	st = find_vsock_stat(&sockets, fd);
381 	check_socket_state(st, TCP_LISTEN);
382 
383 	close(fd);
384 	free_sock_stat(&sockets);
385 }
386 
387 static void test_connect_client(unsigned int peer_cid)
388 {
389 	union {
390 		struct sockaddr sa;
391 		struct sockaddr_vm svm;
392 	} addr = {
393 		.svm = {
394 			.svm_family = AF_VSOCK,
395 			.svm_port = 1234,
396 			.svm_cid = peer_cid,
397 		},
398 	};
399 	int fd;
400 	int ret;
401 	LIST_HEAD(sockets);
402 	struct vsock_stat *st;
403 
404 	control_expectln("LISTENING");
405 
406 	fd = socket(AF_VSOCK, SOCK_STREAM, 0);
407 
408 	timeout_begin(TIMEOUT);
409 	do {
410 		ret = connect(fd, &addr.sa, sizeof(addr.svm));
411 		timeout_check("connect");
412 	} while (ret < 0 && errno == EINTR);
413 	timeout_end();
414 
415 	if (ret < 0) {
416 		perror("connect");
417 		exit(EXIT_FAILURE);
418 	}
419 
420 	read_vsock_stat(&sockets);
421 
422 	check_num_sockets(&sockets, 1);
423 	st = find_vsock_stat(&sockets, fd);
424 	check_socket_state(st, TCP_ESTABLISHED);
425 
426 	control_expectln("DONE");
427 	control_writeln("DONE");
428 
429 	close(fd);
430 	free_sock_stat(&sockets);
431 }
432 
433 static void test_connect_server(unsigned int peer_cid)
434 {
435 	union {
436 		struct sockaddr sa;
437 		struct sockaddr_vm svm;
438 	} addr = {
439 		.svm = {
440 			.svm_family = AF_VSOCK,
441 			.svm_port = 1234,
442 			.svm_cid = VMADDR_CID_ANY,
443 		},
444 	};
445 	union {
446 		struct sockaddr sa;
447 		struct sockaddr_vm svm;
448 	} clientaddr;
449 	socklen_t clientaddr_len = sizeof(clientaddr.svm);
450 	LIST_HEAD(sockets);
451 	struct vsock_stat *st;
452 	int fd;
453 	int client_fd;
454 
455 	fd = socket(AF_VSOCK, SOCK_STREAM, 0);
456 
457 	if (bind(fd, &addr.sa, sizeof(addr.svm)) < 0) {
458 		perror("bind");
459 		exit(EXIT_FAILURE);
460 	}
461 
462 	if (listen(fd, 1) < 0) {
463 		perror("listen");
464 		exit(EXIT_FAILURE);
465 	}
466 
467 	control_writeln("LISTENING");
468 
469 	timeout_begin(TIMEOUT);
470 	do {
471 		client_fd = accept(fd, &clientaddr.sa, &clientaddr_len);
472 		timeout_check("accept");
473 	} while (client_fd < 0 && errno == EINTR);
474 	timeout_end();
475 
476 	if (client_fd < 0) {
477 		perror("accept");
478 		exit(EXIT_FAILURE);
479 	}
480 	if (clientaddr.sa.sa_family != AF_VSOCK) {
481 		fprintf(stderr, "expected AF_VSOCK from accept(2), got %d\n",
482 			clientaddr.sa.sa_family);
483 		exit(EXIT_FAILURE);
484 	}
485 	if (clientaddr.svm.svm_cid != peer_cid) {
486 		fprintf(stderr, "expected peer CID %u from accept(2), got %u\n",
487 			peer_cid, clientaddr.svm.svm_cid);
488 		exit(EXIT_FAILURE);
489 	}
490 
491 	read_vsock_stat(&sockets);
492 
493 	check_num_sockets(&sockets, 2);
494 	find_vsock_stat(&sockets, fd);
495 	st = find_vsock_stat(&sockets, client_fd);
496 	check_socket_state(st, TCP_ESTABLISHED);
497 
498 	control_writeln("DONE");
499 	control_expectln("DONE");
500 
501 	close(client_fd);
502 	close(fd);
503 	free_sock_stat(&sockets);
504 }
505 
506 static struct {
507 	const char *name;
508 	void (*run_client)(unsigned int peer_cid);
509 	void (*run_server)(unsigned int peer_cid);
510 } test_cases[] = {
511 	{
512 		.name = "No sockets",
513 		.run_server = test_no_sockets,
514 	},
515 	{
516 		.name = "Listen socket",
517 		.run_server = test_listen_socket_server,
518 	},
519 	{
520 		.name = "Connect",
521 		.run_client = test_connect_client,
522 		.run_server = test_connect_server,
523 	},
524 	{},
525 };
526 
527 static void init_signals(void)
528 {
529 	struct sigaction act = {
530 		.sa_handler = sigalrm,
531 	};
532 
533 	sigaction(SIGALRM, &act, NULL);
534 	signal(SIGPIPE, SIG_IGN);
535 }
536 
537 static unsigned int parse_cid(const char *str)
538 {
539 	char *endptr = NULL;
540 	unsigned long int n;
541 
542 	errno = 0;
543 	n = strtoul(str, &endptr, 10);
544 	if (errno || *endptr != '\0') {
545 		fprintf(stderr, "malformed CID \"%s\"\n", str);
546 		exit(EXIT_FAILURE);
547 	}
548 	return n;
549 }
550 
551 static const char optstring[] = "";
552 static const struct option longopts[] = {
553 	{
554 		.name = "control-host",
555 		.has_arg = required_argument,
556 		.val = 'H',
557 	},
558 	{
559 		.name = "control-port",
560 		.has_arg = required_argument,
561 		.val = 'P',
562 	},
563 	{
564 		.name = "mode",
565 		.has_arg = required_argument,
566 		.val = 'm',
567 	},
568 	{
569 		.name = "peer-cid",
570 		.has_arg = required_argument,
571 		.val = 'p',
572 	},
573 	{
574 		.name = "help",
575 		.has_arg = no_argument,
576 		.val = '?',
577 	},
578 	{},
579 };
580 
581 static void usage(void)
582 {
583 	fprintf(stderr, "Usage: vsock_diag_test [--help] [--control-host=<host>] --control-port=<port> --mode=client|server --peer-cid=<cid>\n"
584 		"\n"
585 		"  Server: vsock_diag_test --control-port=1234 --mode=server --peer-cid=3\n"
586 		"  Client: vsock_diag_test --control-host=192.168.0.1 --control-port=1234 --mode=client --peer-cid=2\n"
587 		"\n"
588 		"Run vsock_diag.ko tests.  Must be launched in both\n"
589 		"guest and host.  One side must use --mode=client and\n"
590 		"the other side must use --mode=server.\n"
591 		"\n"
592 		"A TCP control socket connection is used to coordinate tests\n"
593 		"between the client and the server.  The server requires a\n"
594 		"listen address and the client requires an address to\n"
595 		"connect to.\n"
596 		"\n"
597 		"The CID of the other side must be given with --peer-cid=<cid>.\n");
598 	exit(EXIT_FAILURE);
599 }
600 
601 int main(int argc, char **argv)
602 {
603 	const char *control_host = NULL;
604 	const char *control_port = NULL;
605 	int mode = TEST_MODE_UNSET;
606 	unsigned int peer_cid = VMADDR_CID_ANY;
607 	int i;
608 
609 	init_signals();
610 
611 	for (;;) {
612 		int opt = getopt_long(argc, argv, optstring, longopts, NULL);
613 
614 		if (opt == -1)
615 			break;
616 
617 		switch (opt) {
618 		case 'H':
619 			control_host = optarg;
620 			break;
621 		case 'm':
622 			if (strcmp(optarg, "client") == 0)
623 				mode = TEST_MODE_CLIENT;
624 			else if (strcmp(optarg, "server") == 0)
625 				mode = TEST_MODE_SERVER;
626 			else {
627 				fprintf(stderr, "--mode must be \"client\" or \"server\"\n");
628 				return EXIT_FAILURE;
629 			}
630 			break;
631 		case 'p':
632 			peer_cid = parse_cid(optarg);
633 			break;
634 		case 'P':
635 			control_port = optarg;
636 			break;
637 		case '?':
638 		default:
639 			usage();
640 		}
641 	}
642 
643 	if (!control_port)
644 		usage();
645 	if (mode == TEST_MODE_UNSET)
646 		usage();
647 	if (peer_cid == VMADDR_CID_ANY)
648 		usage();
649 
650 	if (!control_host) {
651 		if (mode != TEST_MODE_SERVER)
652 			usage();
653 		control_host = "0.0.0.0";
654 	}
655 
656 	control_init(control_host, control_port, mode == TEST_MODE_SERVER);
657 
658 	for (i = 0; test_cases[i].name; i++) {
659 		void (*run)(unsigned int peer_cid);
660 
661 		printf("%s...", test_cases[i].name);
662 		fflush(stdout);
663 
664 		if (mode == TEST_MODE_CLIENT)
665 			run = test_cases[i].run_client;
666 		else
667 			run = test_cases[i].run_server;
668 
669 		if (run)
670 			run(peer_cid);
671 
672 		printf("ok\n");
673 	}
674 
675 	control_cleanup();
676 	return EXIT_SUCCESS;
677 }
678