1 /*
2  * vsock_diag_test - vsock_diag.ko test suite
3  *
4  * Copyright (C) 2017 Red Hat, Inc.
5  *
6  * Author: Stefan Hajnoczi <stefanha@redhat.com>
7  *
8  * This program is free software; you can redistribute it and/or
9  * modify it under the terms of the GNU General Public License
10  * as published by the Free Software Foundation; version 2
11  * of the License.
12  */
13 
14 #include <getopt.h>
15 #include <stdio.h>
16 #include <stdbool.h>
17 #include <stdlib.h>
18 #include <string.h>
19 #include <errno.h>
20 #include <unistd.h>
21 #include <signal.h>
22 #include <sys/socket.h>
23 #include <sys/stat.h>
24 #include <sys/types.h>
25 #include <linux/list.h>
26 #include <linux/net.h>
27 #include <linux/netlink.h>
28 #include <linux/sock_diag.h>
29 #include <netinet/tcp.h>
30 
31 #include "../../../include/uapi/linux/vm_sockets.h"
32 #include "../../../include/uapi/linux/vm_sockets_diag.h"
33 
34 #include "timeout.h"
35 #include "control.h"
36 
37 enum test_mode {
38 	TEST_MODE_UNSET,
39 	TEST_MODE_CLIENT,
40 	TEST_MODE_SERVER
41 };
42 
43 /* Per-socket status */
44 struct vsock_stat {
45 	struct list_head list;
46 	struct vsock_diag_msg msg;
47 };
48 
49 static const char *sock_type_str(int type)
50 {
51 	switch (type) {
52 	case SOCK_DGRAM:
53 		return "DGRAM";
54 	case SOCK_STREAM:
55 		return "STREAM";
56 	default:
57 		return "INVALID TYPE";
58 	}
59 }
60 
61 static const char *sock_state_str(int state)
62 {
63 	switch (state) {
64 	case TCP_CLOSE:
65 		return "UNCONNECTED";
66 	case TCP_SYN_SENT:
67 		return "CONNECTING";
68 	case TCP_ESTABLISHED:
69 		return "CONNECTED";
70 	case TCP_CLOSING:
71 		return "DISCONNECTING";
72 	case TCP_LISTEN:
73 		return "LISTEN";
74 	default:
75 		return "INVALID STATE";
76 	}
77 }
78 
79 static const char *sock_shutdown_str(int shutdown)
80 {
81 	switch (shutdown) {
82 	case 1:
83 		return "RCV_SHUTDOWN";
84 	case 2:
85 		return "SEND_SHUTDOWN";
86 	case 3:
87 		return "RCV_SHUTDOWN | SEND_SHUTDOWN";
88 	default:
89 		return "0";
90 	}
91 }
92 
93 static void print_vsock_addr(FILE *fp, unsigned int cid, unsigned int port)
94 {
95 	if (cid == VMADDR_CID_ANY)
96 		fprintf(fp, "*:");
97 	else
98 		fprintf(fp, "%u:", cid);
99 
100 	if (port == VMADDR_PORT_ANY)
101 		fprintf(fp, "*");
102 	else
103 		fprintf(fp, "%u", port);
104 }
105 
106 static void print_vsock_stat(FILE *fp, struct vsock_stat *st)
107 {
108 	print_vsock_addr(fp, st->msg.vdiag_src_cid, st->msg.vdiag_src_port);
109 	fprintf(fp, " ");
110 	print_vsock_addr(fp, st->msg.vdiag_dst_cid, st->msg.vdiag_dst_port);
111 	fprintf(fp, " %s %s %s %u\n",
112 		sock_type_str(st->msg.vdiag_type),
113 		sock_state_str(st->msg.vdiag_state),
114 		sock_shutdown_str(st->msg.vdiag_shutdown),
115 		st->msg.vdiag_ino);
116 }
117 
118 static void print_vsock_stats(FILE *fp, struct list_head *head)
119 {
120 	struct vsock_stat *st;
121 
122 	list_for_each_entry(st, head, list)
123 		print_vsock_stat(fp, st);
124 }
125 
126 static struct vsock_stat *find_vsock_stat(struct list_head *head, int fd)
127 {
128 	struct vsock_stat *st;
129 	struct stat stat;
130 
131 	if (fstat(fd, &stat) < 0) {
132 		perror("fstat");
133 		exit(EXIT_FAILURE);
134 	}
135 
136 	list_for_each_entry(st, head, list)
137 		if (st->msg.vdiag_ino == stat.st_ino)
138 			return st;
139 
140 	fprintf(stderr, "cannot find fd %d\n", fd);
141 	exit(EXIT_FAILURE);
142 }
143 
144 static void check_no_sockets(struct list_head *head)
145 {
146 	if (!list_empty(head)) {
147 		fprintf(stderr, "expected no sockets\n");
148 		print_vsock_stats(stderr, head);
149 		exit(1);
150 	}
151 }
152 
153 static void check_num_sockets(struct list_head *head, int expected)
154 {
155 	struct list_head *node;
156 	int n = 0;
157 
158 	list_for_each(node, head)
159 		n++;
160 
161 	if (n != expected) {
162 		fprintf(stderr, "expected %d sockets, found %d\n",
163 			expected, n);
164 		print_vsock_stats(stderr, head);
165 		exit(EXIT_FAILURE);
166 	}
167 }
168 
169 static void check_socket_state(struct vsock_stat *st, __u8 state)
170 {
171 	if (st->msg.vdiag_state != state) {
172 		fprintf(stderr, "expected socket state %#x, got %#x\n",
173 			state, st->msg.vdiag_state);
174 		exit(EXIT_FAILURE);
175 	}
176 }
177 
178 static void send_req(int fd)
179 {
180 	struct sockaddr_nl nladdr = {
181 		.nl_family = AF_NETLINK,
182 	};
183 	struct {
184 		struct nlmsghdr nlh;
185 		struct vsock_diag_req vreq;
186 	} req = {
187 		.nlh = {
188 			.nlmsg_len = sizeof(req),
189 			.nlmsg_type = SOCK_DIAG_BY_FAMILY,
190 			.nlmsg_flags = NLM_F_REQUEST | NLM_F_DUMP,
191 		},
192 		.vreq = {
193 			.sdiag_family = AF_VSOCK,
194 			.vdiag_states = ~(__u32)0,
195 		},
196 	};
197 	struct iovec iov = {
198 		.iov_base = &req,
199 		.iov_len = sizeof(req),
200 	};
201 	struct msghdr msg = {
202 		.msg_name = &nladdr,
203 		.msg_namelen = sizeof(nladdr),
204 		.msg_iov = &iov,
205 		.msg_iovlen = 1,
206 	};
207 
208 	for (;;) {
209 		if (sendmsg(fd, &msg, 0) < 0) {
210 			if (errno == EINTR)
211 				continue;
212 
213 			perror("sendmsg");
214 			exit(EXIT_FAILURE);
215 		}
216 
217 		return;
218 	}
219 }
220 
221 static ssize_t recv_resp(int fd, void *buf, size_t len)
222 {
223 	struct sockaddr_nl nladdr = {
224 		.nl_family = AF_NETLINK,
225 	};
226 	struct iovec iov = {
227 		.iov_base = buf,
228 		.iov_len = len,
229 	};
230 	struct msghdr msg = {
231 		.msg_name = &nladdr,
232 		.msg_namelen = sizeof(nladdr),
233 		.msg_iov = &iov,
234 		.msg_iovlen = 1,
235 	};
236 	ssize_t ret;
237 
238 	do {
239 		ret = recvmsg(fd, &msg, 0);
240 	} while (ret < 0 && errno == EINTR);
241 
242 	if (ret < 0) {
243 		perror("recvmsg");
244 		exit(EXIT_FAILURE);
245 	}
246 
247 	return ret;
248 }
249 
250 static void add_vsock_stat(struct list_head *sockets,
251 			   const struct vsock_diag_msg *resp)
252 {
253 	struct vsock_stat *st;
254 
255 	st = malloc(sizeof(*st));
256 	if (!st) {
257 		perror("malloc");
258 		exit(EXIT_FAILURE);
259 	}
260 
261 	st->msg = *resp;
262 	list_add_tail(&st->list, sockets);
263 }
264 
265 /*
266  * Read vsock stats into a list.
267  */
268 static void read_vsock_stat(struct list_head *sockets)
269 {
270 	long buf[8192 / sizeof(long)];
271 	int fd;
272 
273 	fd = socket(AF_NETLINK, SOCK_RAW, NETLINK_SOCK_DIAG);
274 	if (fd < 0) {
275 		perror("socket");
276 		exit(EXIT_FAILURE);
277 	}
278 
279 	send_req(fd);
280 
281 	for (;;) {
282 		const struct nlmsghdr *h;
283 		ssize_t ret;
284 
285 		ret = recv_resp(fd, buf, sizeof(buf));
286 		if (ret == 0)
287 			goto done;
288 		if (ret < sizeof(*h)) {
289 			fprintf(stderr, "short read of %zd bytes\n", ret);
290 			exit(EXIT_FAILURE);
291 		}
292 
293 		h = (struct nlmsghdr *)buf;
294 
295 		while (NLMSG_OK(h, ret)) {
296 			if (h->nlmsg_type == NLMSG_DONE)
297 				goto done;
298 
299 			if (h->nlmsg_type == NLMSG_ERROR) {
300 				const struct nlmsgerr *err = NLMSG_DATA(h);
301 
302 				if (h->nlmsg_len < NLMSG_LENGTH(sizeof(*err)))
303 					fprintf(stderr, "NLMSG_ERROR\n");
304 				else {
305 					errno = -err->error;
306 					perror("NLMSG_ERROR");
307 				}
308 
309 				exit(EXIT_FAILURE);
310 			}
311 
312 			if (h->nlmsg_type != SOCK_DIAG_BY_FAMILY) {
313 				fprintf(stderr, "unexpected nlmsg_type %#x\n",
314 					h->nlmsg_type);
315 				exit(EXIT_FAILURE);
316 			}
317 			if (h->nlmsg_len <
318 			    NLMSG_LENGTH(sizeof(struct vsock_diag_msg))) {
319 				fprintf(stderr, "short vsock_diag_msg\n");
320 				exit(EXIT_FAILURE);
321 			}
322 
323 			add_vsock_stat(sockets, NLMSG_DATA(h));
324 
325 			h = NLMSG_NEXT(h, ret);
326 		}
327 	}
328 
329 done:
330 	close(fd);
331 }
332 
333 static void free_sock_stat(struct list_head *sockets)
334 {
335 	struct vsock_stat *st;
336 	struct vsock_stat *next;
337 
338 	list_for_each_entry_safe(st, next, sockets, list)
339 		free(st);
340 }
341 
342 static void test_no_sockets(unsigned int peer_cid)
343 {
344 	LIST_HEAD(sockets);
345 
346 	read_vsock_stat(&sockets);
347 
348 	check_no_sockets(&sockets);
349 
350 	free_sock_stat(&sockets);
351 }
352 
353 static void test_listen_socket_server(unsigned int peer_cid)
354 {
355 	union {
356 		struct sockaddr sa;
357 		struct sockaddr_vm svm;
358 	} addr = {
359 		.svm = {
360 			.svm_family = AF_VSOCK,
361 			.svm_port = 1234,
362 			.svm_cid = VMADDR_CID_ANY,
363 		},
364 	};
365 	LIST_HEAD(sockets);
366 	struct vsock_stat *st;
367 	int fd;
368 
369 	fd = socket(AF_VSOCK, SOCK_STREAM, 0);
370 
371 	if (bind(fd, &addr.sa, sizeof(addr.svm)) < 0) {
372 		perror("bind");
373 		exit(EXIT_FAILURE);
374 	}
375 
376 	if (listen(fd, 1) < 0) {
377 		perror("listen");
378 		exit(EXIT_FAILURE);
379 	}
380 
381 	read_vsock_stat(&sockets);
382 
383 	check_num_sockets(&sockets, 1);
384 	st = find_vsock_stat(&sockets, fd);
385 	check_socket_state(st, TCP_LISTEN);
386 
387 	close(fd);
388 	free_sock_stat(&sockets);
389 }
390 
391 static void test_connect_client(unsigned int peer_cid)
392 {
393 	union {
394 		struct sockaddr sa;
395 		struct sockaddr_vm svm;
396 	} addr = {
397 		.svm = {
398 			.svm_family = AF_VSOCK,
399 			.svm_port = 1234,
400 			.svm_cid = peer_cid,
401 		},
402 	};
403 	int fd;
404 	int ret;
405 	LIST_HEAD(sockets);
406 	struct vsock_stat *st;
407 
408 	control_expectln("LISTENING");
409 
410 	fd = socket(AF_VSOCK, SOCK_STREAM, 0);
411 
412 	timeout_begin(TIMEOUT);
413 	do {
414 		ret = connect(fd, &addr.sa, sizeof(addr.svm));
415 		timeout_check("connect");
416 	} while (ret < 0 && errno == EINTR);
417 	timeout_end();
418 
419 	if (ret < 0) {
420 		perror("connect");
421 		exit(EXIT_FAILURE);
422 	}
423 
424 	read_vsock_stat(&sockets);
425 
426 	check_num_sockets(&sockets, 1);
427 	st = find_vsock_stat(&sockets, fd);
428 	check_socket_state(st, TCP_ESTABLISHED);
429 
430 	control_expectln("DONE");
431 	control_writeln("DONE");
432 
433 	close(fd);
434 	free_sock_stat(&sockets);
435 }
436 
437 static void test_connect_server(unsigned int peer_cid)
438 {
439 	union {
440 		struct sockaddr sa;
441 		struct sockaddr_vm svm;
442 	} addr = {
443 		.svm = {
444 			.svm_family = AF_VSOCK,
445 			.svm_port = 1234,
446 			.svm_cid = VMADDR_CID_ANY,
447 		},
448 	};
449 	union {
450 		struct sockaddr sa;
451 		struct sockaddr_vm svm;
452 	} clientaddr;
453 	socklen_t clientaddr_len = sizeof(clientaddr.svm);
454 	LIST_HEAD(sockets);
455 	struct vsock_stat *st;
456 	int fd;
457 	int client_fd;
458 
459 	fd = socket(AF_VSOCK, SOCK_STREAM, 0);
460 
461 	if (bind(fd, &addr.sa, sizeof(addr.svm)) < 0) {
462 		perror("bind");
463 		exit(EXIT_FAILURE);
464 	}
465 
466 	if (listen(fd, 1) < 0) {
467 		perror("listen");
468 		exit(EXIT_FAILURE);
469 	}
470 
471 	control_writeln("LISTENING");
472 
473 	timeout_begin(TIMEOUT);
474 	do {
475 		client_fd = accept(fd, &clientaddr.sa, &clientaddr_len);
476 		timeout_check("accept");
477 	} while (client_fd < 0 && errno == EINTR);
478 	timeout_end();
479 
480 	if (client_fd < 0) {
481 		perror("accept");
482 		exit(EXIT_FAILURE);
483 	}
484 	if (clientaddr.sa.sa_family != AF_VSOCK) {
485 		fprintf(stderr, "expected AF_VSOCK from accept(2), got %d\n",
486 			clientaddr.sa.sa_family);
487 		exit(EXIT_FAILURE);
488 	}
489 	if (clientaddr.svm.svm_cid != peer_cid) {
490 		fprintf(stderr, "expected peer CID %u from accept(2), got %u\n",
491 			peer_cid, clientaddr.svm.svm_cid);
492 		exit(EXIT_FAILURE);
493 	}
494 
495 	read_vsock_stat(&sockets);
496 
497 	check_num_sockets(&sockets, 2);
498 	find_vsock_stat(&sockets, fd);
499 	st = find_vsock_stat(&sockets, client_fd);
500 	check_socket_state(st, TCP_ESTABLISHED);
501 
502 	control_writeln("DONE");
503 	control_expectln("DONE");
504 
505 	close(client_fd);
506 	close(fd);
507 	free_sock_stat(&sockets);
508 }
509 
510 static struct {
511 	const char *name;
512 	void (*run_client)(unsigned int peer_cid);
513 	void (*run_server)(unsigned int peer_cid);
514 } test_cases[] = {
515 	{
516 		.name = "No sockets",
517 		.run_server = test_no_sockets,
518 	},
519 	{
520 		.name = "Listen socket",
521 		.run_server = test_listen_socket_server,
522 	},
523 	{
524 		.name = "Connect",
525 		.run_client = test_connect_client,
526 		.run_server = test_connect_server,
527 	},
528 	{},
529 };
530 
531 static void init_signals(void)
532 {
533 	struct sigaction act = {
534 		.sa_handler = sigalrm,
535 	};
536 
537 	sigaction(SIGALRM, &act, NULL);
538 	signal(SIGPIPE, SIG_IGN);
539 }
540 
541 static unsigned int parse_cid(const char *str)
542 {
543 	char *endptr = NULL;
544 	unsigned long int n;
545 
546 	errno = 0;
547 	n = strtoul(str, &endptr, 10);
548 	if (errno || *endptr != '\0') {
549 		fprintf(stderr, "malformed CID \"%s\"\n", str);
550 		exit(EXIT_FAILURE);
551 	}
552 	return n;
553 }
554 
555 static const char optstring[] = "";
556 static const struct option longopts[] = {
557 	{
558 		.name = "control-host",
559 		.has_arg = required_argument,
560 		.val = 'H',
561 	},
562 	{
563 		.name = "control-port",
564 		.has_arg = required_argument,
565 		.val = 'P',
566 	},
567 	{
568 		.name = "mode",
569 		.has_arg = required_argument,
570 		.val = 'm',
571 	},
572 	{
573 		.name = "peer-cid",
574 		.has_arg = required_argument,
575 		.val = 'p',
576 	},
577 	{
578 		.name = "help",
579 		.has_arg = no_argument,
580 		.val = '?',
581 	},
582 	{},
583 };
584 
585 static void usage(void)
586 {
587 	fprintf(stderr, "Usage: vsock_diag_test [--help] [--control-host=<host>] --control-port=<port> --mode=client|server --peer-cid=<cid>\n"
588 		"\n"
589 		"  Server: vsock_diag_test --control-port=1234 --mode=server --peer-cid=3\n"
590 		"  Client: vsock_diag_test --control-host=192.168.0.1 --control-port=1234 --mode=client --peer-cid=2\n"
591 		"\n"
592 		"Run vsock_diag.ko tests.  Must be launched in both\n"
593 		"guest and host.  One side must use --mode=client and\n"
594 		"the other side must use --mode=server.\n"
595 		"\n"
596 		"A TCP control socket connection is used to coordinate tests\n"
597 		"between the client and the server.  The server requires a\n"
598 		"listen address and the client requires an address to\n"
599 		"connect to.\n"
600 		"\n"
601 		"The CID of the other side must be given with --peer-cid=<cid>.\n");
602 	exit(EXIT_FAILURE);
603 }
604 
605 int main(int argc, char **argv)
606 {
607 	const char *control_host = NULL;
608 	const char *control_port = NULL;
609 	int mode = TEST_MODE_UNSET;
610 	unsigned int peer_cid = VMADDR_CID_ANY;
611 	int i;
612 
613 	init_signals();
614 
615 	for (;;) {
616 		int opt = getopt_long(argc, argv, optstring, longopts, NULL);
617 
618 		if (opt == -1)
619 			break;
620 
621 		switch (opt) {
622 		case 'H':
623 			control_host = optarg;
624 			break;
625 		case 'm':
626 			if (strcmp(optarg, "client") == 0)
627 				mode = TEST_MODE_CLIENT;
628 			else if (strcmp(optarg, "server") == 0)
629 				mode = TEST_MODE_SERVER;
630 			else {
631 				fprintf(stderr, "--mode must be \"client\" or \"server\"\n");
632 				return EXIT_FAILURE;
633 			}
634 			break;
635 		case 'p':
636 			peer_cid = parse_cid(optarg);
637 			break;
638 		case 'P':
639 			control_port = optarg;
640 			break;
641 		case '?':
642 		default:
643 			usage();
644 		}
645 	}
646 
647 	if (!control_port)
648 		usage();
649 	if (mode == TEST_MODE_UNSET)
650 		usage();
651 	if (peer_cid == VMADDR_CID_ANY)
652 		usage();
653 
654 	if (!control_host) {
655 		if (mode != TEST_MODE_SERVER)
656 			usage();
657 		control_host = "0.0.0.0";
658 	}
659 
660 	control_init(control_host, control_port, mode == TEST_MODE_SERVER);
661 
662 	for (i = 0; test_cases[i].name; i++) {
663 		void (*run)(unsigned int peer_cid);
664 
665 		printf("%s...", test_cases[i].name);
666 		fflush(stdout);
667 
668 		if (mode == TEST_MODE_CLIENT)
669 			run = test_cases[i].run_client;
670 		else
671 			run = test_cases[i].run_server;
672 
673 		if (run)
674 			run(peer_cid);
675 
676 		printf("ok\n");
677 	}
678 
679 	control_cleanup();
680 	return EXIT_SUCCESS;
681 }
682