1 // SPDX-License-Identifier: GPL-2.0-only
2 /*
3  * vsock_diag_test - vsock_diag.ko test suite
4  *
5  * Copyright (C) 2017 Red Hat, Inc.
6  *
7  * Author: Stefan Hajnoczi <stefanha@redhat.com>
8  */
9 
10 #include <getopt.h>
11 #include <stdio.h>
12 #include <stdlib.h>
13 #include <string.h>
14 #include <errno.h>
15 #include <unistd.h>
16 #include <sys/stat.h>
17 #include <sys/types.h>
18 #include <linux/list.h>
19 #include <linux/net.h>
20 #include <linux/netlink.h>
21 #include <linux/sock_diag.h>
22 #include <linux/vm_sockets_diag.h>
23 #include <netinet/tcp.h>
24 
25 #include "timeout.h"
26 #include "control.h"
27 #include "util.h"
28 
29 /* Per-socket status */
30 struct vsock_stat {
31 	struct list_head list;
32 	struct vsock_diag_msg msg;
33 };
34 
35 static const char *sock_type_str(int type)
36 {
37 	switch (type) {
38 	case SOCK_DGRAM:
39 		return "DGRAM";
40 	case SOCK_STREAM:
41 		return "STREAM";
42 	default:
43 		return "INVALID TYPE";
44 	}
45 }
46 
47 static const char *sock_state_str(int state)
48 {
49 	switch (state) {
50 	case TCP_CLOSE:
51 		return "UNCONNECTED";
52 	case TCP_SYN_SENT:
53 		return "CONNECTING";
54 	case TCP_ESTABLISHED:
55 		return "CONNECTED";
56 	case TCP_CLOSING:
57 		return "DISCONNECTING";
58 	case TCP_LISTEN:
59 		return "LISTEN";
60 	default:
61 		return "INVALID STATE";
62 	}
63 }
64 
65 static const char *sock_shutdown_str(int shutdown)
66 {
67 	switch (shutdown) {
68 	case 1:
69 		return "RCV_SHUTDOWN";
70 	case 2:
71 		return "SEND_SHUTDOWN";
72 	case 3:
73 		return "RCV_SHUTDOWN | SEND_SHUTDOWN";
74 	default:
75 		return "0";
76 	}
77 }
78 
79 static void print_vsock_addr(FILE *fp, unsigned int cid, unsigned int port)
80 {
81 	if (cid == VMADDR_CID_ANY)
82 		fprintf(fp, "*:");
83 	else
84 		fprintf(fp, "%u:", cid);
85 
86 	if (port == VMADDR_PORT_ANY)
87 		fprintf(fp, "*");
88 	else
89 		fprintf(fp, "%u", port);
90 }
91 
92 static void print_vsock_stat(FILE *fp, struct vsock_stat *st)
93 {
94 	print_vsock_addr(fp, st->msg.vdiag_src_cid, st->msg.vdiag_src_port);
95 	fprintf(fp, " ");
96 	print_vsock_addr(fp, st->msg.vdiag_dst_cid, st->msg.vdiag_dst_port);
97 	fprintf(fp, " %s %s %s %u\n",
98 		sock_type_str(st->msg.vdiag_type),
99 		sock_state_str(st->msg.vdiag_state),
100 		sock_shutdown_str(st->msg.vdiag_shutdown),
101 		st->msg.vdiag_ino);
102 }
103 
104 static void print_vsock_stats(FILE *fp, struct list_head *head)
105 {
106 	struct vsock_stat *st;
107 
108 	list_for_each_entry(st, head, list)
109 		print_vsock_stat(fp, st);
110 }
111 
112 static struct vsock_stat *find_vsock_stat(struct list_head *head, int fd)
113 {
114 	struct vsock_stat *st;
115 	struct stat stat;
116 
117 	if (fstat(fd, &stat) < 0) {
118 		perror("fstat");
119 		exit(EXIT_FAILURE);
120 	}
121 
122 	list_for_each_entry(st, head, list)
123 		if (st->msg.vdiag_ino == stat.st_ino)
124 			return st;
125 
126 	fprintf(stderr, "cannot find fd %d\n", fd);
127 	exit(EXIT_FAILURE);
128 }
129 
130 static void check_no_sockets(struct list_head *head)
131 {
132 	if (!list_empty(head)) {
133 		fprintf(stderr, "expected no sockets\n");
134 		print_vsock_stats(stderr, head);
135 		exit(1);
136 	}
137 }
138 
139 static void check_num_sockets(struct list_head *head, int expected)
140 {
141 	struct list_head *node;
142 	int n = 0;
143 
144 	list_for_each(node, head)
145 		n++;
146 
147 	if (n != expected) {
148 		fprintf(stderr, "expected %d sockets, found %d\n",
149 			expected, n);
150 		print_vsock_stats(stderr, head);
151 		exit(EXIT_FAILURE);
152 	}
153 }
154 
155 static void check_socket_state(struct vsock_stat *st, __u8 state)
156 {
157 	if (st->msg.vdiag_state != state) {
158 		fprintf(stderr, "expected socket state %#x, got %#x\n",
159 			state, st->msg.vdiag_state);
160 		exit(EXIT_FAILURE);
161 	}
162 }
163 
164 static void send_req(int fd)
165 {
166 	struct sockaddr_nl nladdr = {
167 		.nl_family = AF_NETLINK,
168 	};
169 	struct {
170 		struct nlmsghdr nlh;
171 		struct vsock_diag_req vreq;
172 	} req = {
173 		.nlh = {
174 			.nlmsg_len = sizeof(req),
175 			.nlmsg_type = SOCK_DIAG_BY_FAMILY,
176 			.nlmsg_flags = NLM_F_REQUEST | NLM_F_DUMP,
177 		},
178 		.vreq = {
179 			.sdiag_family = AF_VSOCK,
180 			.vdiag_states = ~(__u32)0,
181 		},
182 	};
183 	struct iovec iov = {
184 		.iov_base = &req,
185 		.iov_len = sizeof(req),
186 	};
187 	struct msghdr msg = {
188 		.msg_name = &nladdr,
189 		.msg_namelen = sizeof(nladdr),
190 		.msg_iov = &iov,
191 		.msg_iovlen = 1,
192 	};
193 
194 	for (;;) {
195 		if (sendmsg(fd, &msg, 0) < 0) {
196 			if (errno == EINTR)
197 				continue;
198 
199 			perror("sendmsg");
200 			exit(EXIT_FAILURE);
201 		}
202 
203 		return;
204 	}
205 }
206 
207 static ssize_t recv_resp(int fd, void *buf, size_t len)
208 {
209 	struct sockaddr_nl nladdr = {
210 		.nl_family = AF_NETLINK,
211 	};
212 	struct iovec iov = {
213 		.iov_base = buf,
214 		.iov_len = len,
215 	};
216 	struct msghdr msg = {
217 		.msg_name = &nladdr,
218 		.msg_namelen = sizeof(nladdr),
219 		.msg_iov = &iov,
220 		.msg_iovlen = 1,
221 	};
222 	ssize_t ret;
223 
224 	do {
225 		ret = recvmsg(fd, &msg, 0);
226 	} while (ret < 0 && errno == EINTR);
227 
228 	if (ret < 0) {
229 		perror("recvmsg");
230 		exit(EXIT_FAILURE);
231 	}
232 
233 	return ret;
234 }
235 
236 static void add_vsock_stat(struct list_head *sockets,
237 			   const struct vsock_diag_msg *resp)
238 {
239 	struct vsock_stat *st;
240 
241 	st = malloc(sizeof(*st));
242 	if (!st) {
243 		perror("malloc");
244 		exit(EXIT_FAILURE);
245 	}
246 
247 	st->msg = *resp;
248 	list_add_tail(&st->list, sockets);
249 }
250 
251 /*
252  * Read vsock stats into a list.
253  */
254 static void read_vsock_stat(struct list_head *sockets)
255 {
256 	long buf[8192 / sizeof(long)];
257 	int fd;
258 
259 	fd = socket(AF_NETLINK, SOCK_RAW, NETLINK_SOCK_DIAG);
260 	if (fd < 0) {
261 		perror("socket");
262 		exit(EXIT_FAILURE);
263 	}
264 
265 	send_req(fd);
266 
267 	for (;;) {
268 		const struct nlmsghdr *h;
269 		ssize_t ret;
270 
271 		ret = recv_resp(fd, buf, sizeof(buf));
272 		if (ret == 0)
273 			goto done;
274 		if (ret < sizeof(*h)) {
275 			fprintf(stderr, "short read of %zd bytes\n", ret);
276 			exit(EXIT_FAILURE);
277 		}
278 
279 		h = (struct nlmsghdr *)buf;
280 
281 		while (NLMSG_OK(h, ret)) {
282 			if (h->nlmsg_type == NLMSG_DONE)
283 				goto done;
284 
285 			if (h->nlmsg_type == NLMSG_ERROR) {
286 				const struct nlmsgerr *err = NLMSG_DATA(h);
287 
288 				if (h->nlmsg_len < NLMSG_LENGTH(sizeof(*err)))
289 					fprintf(stderr, "NLMSG_ERROR\n");
290 				else {
291 					errno = -err->error;
292 					perror("NLMSG_ERROR");
293 				}
294 
295 				exit(EXIT_FAILURE);
296 			}
297 
298 			if (h->nlmsg_type != SOCK_DIAG_BY_FAMILY) {
299 				fprintf(stderr, "unexpected nlmsg_type %#x\n",
300 					h->nlmsg_type);
301 				exit(EXIT_FAILURE);
302 			}
303 			if (h->nlmsg_len <
304 			    NLMSG_LENGTH(sizeof(struct vsock_diag_msg))) {
305 				fprintf(stderr, "short vsock_diag_msg\n");
306 				exit(EXIT_FAILURE);
307 			}
308 
309 			add_vsock_stat(sockets, NLMSG_DATA(h));
310 
311 			h = NLMSG_NEXT(h, ret);
312 		}
313 	}
314 
315 done:
316 	close(fd);
317 }
318 
319 static void free_sock_stat(struct list_head *sockets)
320 {
321 	struct vsock_stat *st;
322 	struct vsock_stat *next;
323 
324 	list_for_each_entry_safe(st, next, sockets, list)
325 		free(st);
326 }
327 
328 static void test_no_sockets(const struct test_opts *opts)
329 {
330 	LIST_HEAD(sockets);
331 
332 	read_vsock_stat(&sockets);
333 
334 	check_no_sockets(&sockets);
335 }
336 
337 static void test_listen_socket_server(const struct test_opts *opts)
338 {
339 	union {
340 		struct sockaddr sa;
341 		struct sockaddr_vm svm;
342 	} addr = {
343 		.svm = {
344 			.svm_family = AF_VSOCK,
345 			.svm_port = 1234,
346 			.svm_cid = VMADDR_CID_ANY,
347 		},
348 	};
349 	LIST_HEAD(sockets);
350 	struct vsock_stat *st;
351 	int fd;
352 
353 	fd = socket(AF_VSOCK, SOCK_STREAM, 0);
354 
355 	if (bind(fd, &addr.sa, sizeof(addr.svm)) < 0) {
356 		perror("bind");
357 		exit(EXIT_FAILURE);
358 	}
359 
360 	if (listen(fd, 1) < 0) {
361 		perror("listen");
362 		exit(EXIT_FAILURE);
363 	}
364 
365 	read_vsock_stat(&sockets);
366 
367 	check_num_sockets(&sockets, 1);
368 	st = find_vsock_stat(&sockets, fd);
369 	check_socket_state(st, TCP_LISTEN);
370 
371 	close(fd);
372 	free_sock_stat(&sockets);
373 }
374 
375 static void test_connect_client(const struct test_opts *opts)
376 {
377 	int fd;
378 	LIST_HEAD(sockets);
379 	struct vsock_stat *st;
380 
381 	fd = vsock_stream_connect(opts->peer_cid, 1234);
382 	if (fd < 0) {
383 		perror("connect");
384 		exit(EXIT_FAILURE);
385 	}
386 
387 	read_vsock_stat(&sockets);
388 
389 	check_num_sockets(&sockets, 1);
390 	st = find_vsock_stat(&sockets, fd);
391 	check_socket_state(st, TCP_ESTABLISHED);
392 
393 	control_expectln("DONE");
394 	control_writeln("DONE");
395 
396 	close(fd);
397 	free_sock_stat(&sockets);
398 }
399 
400 static void test_connect_server(const struct test_opts *opts)
401 {
402 	struct vsock_stat *st;
403 	LIST_HEAD(sockets);
404 	int client_fd;
405 
406 	client_fd = vsock_stream_accept(VMADDR_CID_ANY, 1234, NULL);
407 	if (client_fd < 0) {
408 		perror("accept");
409 		exit(EXIT_FAILURE);
410 	}
411 
412 	read_vsock_stat(&sockets);
413 
414 	check_num_sockets(&sockets, 1);
415 	st = find_vsock_stat(&sockets, client_fd);
416 	check_socket_state(st, TCP_ESTABLISHED);
417 
418 	control_writeln("DONE");
419 	control_expectln("DONE");
420 
421 	close(client_fd);
422 	free_sock_stat(&sockets);
423 }
424 
425 static struct test_case test_cases[] = {
426 	{
427 		.name = "No sockets",
428 		.run_server = test_no_sockets,
429 	},
430 	{
431 		.name = "Listen socket",
432 		.run_server = test_listen_socket_server,
433 	},
434 	{
435 		.name = "Connect",
436 		.run_client = test_connect_client,
437 		.run_server = test_connect_server,
438 	},
439 	{},
440 };
441 
442 static const char optstring[] = "";
443 static const struct option longopts[] = {
444 	{
445 		.name = "control-host",
446 		.has_arg = required_argument,
447 		.val = 'H',
448 	},
449 	{
450 		.name = "control-port",
451 		.has_arg = required_argument,
452 		.val = 'P',
453 	},
454 	{
455 		.name = "mode",
456 		.has_arg = required_argument,
457 		.val = 'm',
458 	},
459 	{
460 		.name = "peer-cid",
461 		.has_arg = required_argument,
462 		.val = 'p',
463 	},
464 	{
465 		.name = "list",
466 		.has_arg = no_argument,
467 		.val = 'l',
468 	},
469 	{
470 		.name = "skip",
471 		.has_arg = required_argument,
472 		.val = 's',
473 	},
474 	{
475 		.name = "help",
476 		.has_arg = no_argument,
477 		.val = '?',
478 	},
479 	{},
480 };
481 
482 static void usage(void)
483 {
484 	fprintf(stderr, "Usage: vsock_diag_test [--help] [--control-host=<host>] --control-port=<port> --mode=client|server --peer-cid=<cid> [--list] [--skip=<test_id>]\n"
485 		"\n"
486 		"  Server: vsock_diag_test --control-port=1234 --mode=server --peer-cid=3\n"
487 		"  Client: vsock_diag_test --control-host=192.168.0.1 --control-port=1234 --mode=client --peer-cid=2\n"
488 		"\n"
489 		"Run vsock_diag.ko tests.  Must be launched in both\n"
490 		"guest and host.  One side must use --mode=client and\n"
491 		"the other side must use --mode=server.\n"
492 		"\n"
493 		"A TCP control socket connection is used to coordinate tests\n"
494 		"between the client and the server.  The server requires a\n"
495 		"listen address and the client requires an address to\n"
496 		"connect to.\n"
497 		"\n"
498 		"The CID of the other side must be given with --peer-cid=<cid>.\n"
499 		"\n"
500 		"Options:\n"
501 		"  --help                 This help message\n"
502 		"  --control-host <host>  Server IP address to connect to\n"
503 		"  --control-port <port>  Server port to listen on/connect to\n"
504 		"  --mode client|server   Server or client mode\n"
505 		"  --peer-cid <cid>       CID of the other side\n"
506 		"  --list                 List of tests that will be executed\n"
507 		"  --skip <test_id>       Test ID to skip;\n"
508 		"                         use multiple --skip options to skip more tests\n"
509 		);
510 	exit(EXIT_FAILURE);
511 }
512 
513 int main(int argc, char **argv)
514 {
515 	const char *control_host = NULL;
516 	const char *control_port = NULL;
517 	struct test_opts opts = {
518 		.mode = TEST_MODE_UNSET,
519 		.peer_cid = VMADDR_CID_ANY,
520 	};
521 
522 	init_signals();
523 
524 	for (;;) {
525 		int opt = getopt_long(argc, argv, optstring, longopts, NULL);
526 
527 		if (opt == -1)
528 			break;
529 
530 		switch (opt) {
531 		case 'H':
532 			control_host = optarg;
533 			break;
534 		case 'm':
535 			if (strcmp(optarg, "client") == 0)
536 				opts.mode = TEST_MODE_CLIENT;
537 			else if (strcmp(optarg, "server") == 0)
538 				opts.mode = TEST_MODE_SERVER;
539 			else {
540 				fprintf(stderr, "--mode must be \"client\" or \"server\"\n");
541 				return EXIT_FAILURE;
542 			}
543 			break;
544 		case 'p':
545 			opts.peer_cid = parse_cid(optarg);
546 			break;
547 		case 'P':
548 			control_port = optarg;
549 			break;
550 		case 'l':
551 			list_tests(test_cases);
552 			break;
553 		case 's':
554 			skip_test(test_cases, ARRAY_SIZE(test_cases) - 1,
555 				  optarg);
556 			break;
557 		case '?':
558 		default:
559 			usage();
560 		}
561 	}
562 
563 	if (!control_port)
564 		usage();
565 	if (opts.mode == TEST_MODE_UNSET)
566 		usage();
567 	if (opts.peer_cid == VMADDR_CID_ANY)
568 		usage();
569 
570 	if (!control_host) {
571 		if (opts.mode != TEST_MODE_SERVER)
572 			usage();
573 		control_host = "0.0.0.0";
574 	}
575 
576 	control_init(control_host, control_port,
577 		     opts.mode == TEST_MODE_SERVER);
578 
579 	run_tests(test_cases, &opts);
580 
581 	control_cleanup();
582 	return EXIT_SUCCESS;
583 }
584