10b025033SStefan Hajnoczi /*
20b025033SStefan Hajnoczi  * vsock_diag_test - vsock_diag.ko test suite
30b025033SStefan Hajnoczi  *
40b025033SStefan Hajnoczi  * Copyright (C) 2017 Red Hat, Inc.
50b025033SStefan Hajnoczi  *
60b025033SStefan Hajnoczi  * Author: Stefan Hajnoczi <stefanha@redhat.com>
70b025033SStefan Hajnoczi  *
80b025033SStefan Hajnoczi  * This program is free software; you can redistribute it and/or
90b025033SStefan Hajnoczi  * modify it under the terms of the GNU General Public License
100b025033SStefan Hajnoczi  * as published by the Free Software Foundation; version 2
110b025033SStefan Hajnoczi  * of the License.
120b025033SStefan Hajnoczi  */
130b025033SStefan Hajnoczi 
140b025033SStefan Hajnoczi #include <getopt.h>
150b025033SStefan Hajnoczi #include <stdio.h>
160b025033SStefan Hajnoczi #include <stdbool.h>
170b025033SStefan Hajnoczi #include <stdlib.h>
180b025033SStefan Hajnoczi #include <string.h>
190b025033SStefan Hajnoczi #include <errno.h>
200b025033SStefan Hajnoczi #include <unistd.h>
210b025033SStefan Hajnoczi #include <signal.h>
220b025033SStefan Hajnoczi #include <sys/socket.h>
230b025033SStefan Hajnoczi #include <sys/stat.h>
240b025033SStefan Hajnoczi #include <sys/types.h>
250b025033SStefan Hajnoczi #include <linux/list.h>
260b025033SStefan Hajnoczi #include <linux/net.h>
270b025033SStefan Hajnoczi #include <linux/netlink.h>
280b025033SStefan Hajnoczi #include <linux/sock_diag.h>
290b025033SStefan Hajnoczi #include <netinet/tcp.h>
300b025033SStefan Hajnoczi 
310b025033SStefan Hajnoczi #include "../../../include/uapi/linux/vm_sockets.h"
320b025033SStefan Hajnoczi #include "../../../include/uapi/linux/vm_sockets_diag.h"
330b025033SStefan Hajnoczi 
340b025033SStefan Hajnoczi #include "timeout.h"
350b025033SStefan Hajnoczi #include "control.h"
360b025033SStefan Hajnoczi 
370b025033SStefan Hajnoczi enum test_mode {
380b025033SStefan Hajnoczi 	TEST_MODE_UNSET,
390b025033SStefan Hajnoczi 	TEST_MODE_CLIENT,
400b025033SStefan Hajnoczi 	TEST_MODE_SERVER
410b025033SStefan Hajnoczi };
420b025033SStefan Hajnoczi 
430b025033SStefan Hajnoczi /* Per-socket status */
440b025033SStefan Hajnoczi struct vsock_stat {
450b025033SStefan Hajnoczi 	struct list_head list;
460b025033SStefan Hajnoczi 	struct vsock_diag_msg msg;
470b025033SStefan Hajnoczi };
480b025033SStefan Hajnoczi 
490b025033SStefan Hajnoczi static const char *sock_type_str(int type)
500b025033SStefan Hajnoczi {
510b025033SStefan Hajnoczi 	switch (type) {
520b025033SStefan Hajnoczi 	case SOCK_DGRAM:
530b025033SStefan Hajnoczi 		return "DGRAM";
540b025033SStefan Hajnoczi 	case SOCK_STREAM:
550b025033SStefan Hajnoczi 		return "STREAM";
560b025033SStefan Hajnoczi 	default:
570b025033SStefan Hajnoczi 		return "INVALID TYPE";
580b025033SStefan Hajnoczi 	}
590b025033SStefan Hajnoczi }
600b025033SStefan Hajnoczi 
610b025033SStefan Hajnoczi static const char *sock_state_str(int state)
620b025033SStefan Hajnoczi {
630b025033SStefan Hajnoczi 	switch (state) {
640b025033SStefan Hajnoczi 	case TCP_CLOSE:
650b025033SStefan Hajnoczi 		return "UNCONNECTED";
660b025033SStefan Hajnoczi 	case TCP_SYN_SENT:
670b025033SStefan Hajnoczi 		return "CONNECTING";
680b025033SStefan Hajnoczi 	case TCP_ESTABLISHED:
690b025033SStefan Hajnoczi 		return "CONNECTED";
700b025033SStefan Hajnoczi 	case TCP_CLOSING:
710b025033SStefan Hajnoczi 		return "DISCONNECTING";
720b025033SStefan Hajnoczi 	case TCP_LISTEN:
730b025033SStefan Hajnoczi 		return "LISTEN";
740b025033SStefan Hajnoczi 	default:
750b025033SStefan Hajnoczi 		return "INVALID STATE";
760b025033SStefan Hajnoczi 	}
770b025033SStefan Hajnoczi }
780b025033SStefan Hajnoczi 
790b025033SStefan Hajnoczi static const char *sock_shutdown_str(int shutdown)
800b025033SStefan Hajnoczi {
810b025033SStefan Hajnoczi 	switch (shutdown) {
820b025033SStefan Hajnoczi 	case 1:
830b025033SStefan Hajnoczi 		return "RCV_SHUTDOWN";
840b025033SStefan Hajnoczi 	case 2:
850b025033SStefan Hajnoczi 		return "SEND_SHUTDOWN";
860b025033SStefan Hajnoczi 	case 3:
870b025033SStefan Hajnoczi 		return "RCV_SHUTDOWN | SEND_SHUTDOWN";
880b025033SStefan Hajnoczi 	default:
890b025033SStefan Hajnoczi 		return "0";
900b025033SStefan Hajnoczi 	}
910b025033SStefan Hajnoczi }
920b025033SStefan Hajnoczi 
930b025033SStefan Hajnoczi static void print_vsock_addr(FILE *fp, unsigned int cid, unsigned int port)
940b025033SStefan Hajnoczi {
950b025033SStefan Hajnoczi 	if (cid == VMADDR_CID_ANY)
960b025033SStefan Hajnoczi 		fprintf(fp, "*:");
970b025033SStefan Hajnoczi 	else
980b025033SStefan Hajnoczi 		fprintf(fp, "%u:", cid);
990b025033SStefan Hajnoczi 
1000b025033SStefan Hajnoczi 	if (port == VMADDR_PORT_ANY)
1010b025033SStefan Hajnoczi 		fprintf(fp, "*");
1020b025033SStefan Hajnoczi 	else
1030b025033SStefan Hajnoczi 		fprintf(fp, "%u", port);
1040b025033SStefan Hajnoczi }
1050b025033SStefan Hajnoczi 
1060b025033SStefan Hajnoczi static void print_vsock_stat(FILE *fp, struct vsock_stat *st)
1070b025033SStefan Hajnoczi {
1080b025033SStefan Hajnoczi 	print_vsock_addr(fp, st->msg.vdiag_src_cid, st->msg.vdiag_src_port);
1090b025033SStefan Hajnoczi 	fprintf(fp, " ");
1100b025033SStefan Hajnoczi 	print_vsock_addr(fp, st->msg.vdiag_dst_cid, st->msg.vdiag_dst_port);
1110b025033SStefan Hajnoczi 	fprintf(fp, " %s %s %s %u\n",
1120b025033SStefan Hajnoczi 		sock_type_str(st->msg.vdiag_type),
1130b025033SStefan Hajnoczi 		sock_state_str(st->msg.vdiag_state),
1140b025033SStefan Hajnoczi 		sock_shutdown_str(st->msg.vdiag_shutdown),
1150b025033SStefan Hajnoczi 		st->msg.vdiag_ino);
1160b025033SStefan Hajnoczi }
1170b025033SStefan Hajnoczi 
1180b025033SStefan Hajnoczi static void print_vsock_stats(FILE *fp, struct list_head *head)
1190b025033SStefan Hajnoczi {
1200b025033SStefan Hajnoczi 	struct vsock_stat *st;
1210b025033SStefan Hajnoczi 
1220b025033SStefan Hajnoczi 	list_for_each_entry(st, head, list)
1230b025033SStefan Hajnoczi 		print_vsock_stat(fp, st);
1240b025033SStefan Hajnoczi }
1250b025033SStefan Hajnoczi 
1260b025033SStefan Hajnoczi static struct vsock_stat *find_vsock_stat(struct list_head *head, int fd)
1270b025033SStefan Hajnoczi {
1280b025033SStefan Hajnoczi 	struct vsock_stat *st;
1290b025033SStefan Hajnoczi 	struct stat stat;
1300b025033SStefan Hajnoczi 
1310b025033SStefan Hajnoczi 	if (fstat(fd, &stat) < 0) {
1320b025033SStefan Hajnoczi 		perror("fstat");
1330b025033SStefan Hajnoczi 		exit(EXIT_FAILURE);
1340b025033SStefan Hajnoczi 	}
1350b025033SStefan Hajnoczi 
1360b025033SStefan Hajnoczi 	list_for_each_entry(st, head, list)
1370b025033SStefan Hajnoczi 		if (st->msg.vdiag_ino == stat.st_ino)
1380b025033SStefan Hajnoczi 			return st;
1390b025033SStefan Hajnoczi 
1400b025033SStefan Hajnoczi 	fprintf(stderr, "cannot find fd %d\n", fd);
1410b025033SStefan Hajnoczi 	exit(EXIT_FAILURE);
1420b025033SStefan Hajnoczi }
1430b025033SStefan Hajnoczi 
1440b025033SStefan Hajnoczi static void check_no_sockets(struct list_head *head)
1450b025033SStefan Hajnoczi {
1460b025033SStefan Hajnoczi 	if (!list_empty(head)) {
1470b025033SStefan Hajnoczi 		fprintf(stderr, "expected no sockets\n");
1480b025033SStefan Hajnoczi 		print_vsock_stats(stderr, head);
1490b025033SStefan Hajnoczi 		exit(1);
1500b025033SStefan Hajnoczi 	}
1510b025033SStefan Hajnoczi }
1520b025033SStefan Hajnoczi 
1530b025033SStefan Hajnoczi static void check_num_sockets(struct list_head *head, int expected)
1540b025033SStefan Hajnoczi {
1550b025033SStefan Hajnoczi 	struct list_head *node;
1560b025033SStefan Hajnoczi 	int n = 0;
1570b025033SStefan Hajnoczi 
1580b025033SStefan Hajnoczi 	list_for_each(node, head)
1590b025033SStefan Hajnoczi 		n++;
1600b025033SStefan Hajnoczi 
1610b025033SStefan Hajnoczi 	if (n != expected) {
1620b025033SStefan Hajnoczi 		fprintf(stderr, "expected %d sockets, found %d\n",
1630b025033SStefan Hajnoczi 			expected, n);
1640b025033SStefan Hajnoczi 		print_vsock_stats(stderr, head);
1650b025033SStefan Hajnoczi 		exit(EXIT_FAILURE);
1660b025033SStefan Hajnoczi 	}
1670b025033SStefan Hajnoczi }
1680b025033SStefan Hajnoczi 
1690b025033SStefan Hajnoczi static void check_socket_state(struct vsock_stat *st, __u8 state)
1700b025033SStefan Hajnoczi {
1710b025033SStefan Hajnoczi 	if (st->msg.vdiag_state != state) {
1720b025033SStefan Hajnoczi 		fprintf(stderr, "expected socket state %#x, got %#x\n",
1730b025033SStefan Hajnoczi 			state, st->msg.vdiag_state);
1740b025033SStefan Hajnoczi 		exit(EXIT_FAILURE);
1750b025033SStefan Hajnoczi 	}
1760b025033SStefan Hajnoczi }
1770b025033SStefan Hajnoczi 
1780b025033SStefan Hajnoczi static void send_req(int fd)
1790b025033SStefan Hajnoczi {
1800b025033SStefan Hajnoczi 	struct sockaddr_nl nladdr = {
1810b025033SStefan Hajnoczi 		.nl_family = AF_NETLINK,
1820b025033SStefan Hajnoczi 	};
1830b025033SStefan Hajnoczi 	struct {
1840b025033SStefan Hajnoczi 		struct nlmsghdr nlh;
1850b025033SStefan Hajnoczi 		struct vsock_diag_req vreq;
1860b025033SStefan Hajnoczi 	} req = {
1870b025033SStefan Hajnoczi 		.nlh = {
1880b025033SStefan Hajnoczi 			.nlmsg_len = sizeof(req),
1890b025033SStefan Hajnoczi 			.nlmsg_type = SOCK_DIAG_BY_FAMILY,
1900b025033SStefan Hajnoczi 			.nlmsg_flags = NLM_F_REQUEST | NLM_F_DUMP,
1910b025033SStefan Hajnoczi 		},
1920b025033SStefan Hajnoczi 		.vreq = {
1930b025033SStefan Hajnoczi 			.sdiag_family = AF_VSOCK,
1940b025033SStefan Hajnoczi 			.vdiag_states = ~(__u32)0,
1950b025033SStefan Hajnoczi 		},
1960b025033SStefan Hajnoczi 	};
1970b025033SStefan Hajnoczi 	struct iovec iov = {
1980b025033SStefan Hajnoczi 		.iov_base = &req,
1990b025033SStefan Hajnoczi 		.iov_len = sizeof(req),
2000b025033SStefan Hajnoczi 	};
2010b025033SStefan Hajnoczi 	struct msghdr msg = {
2020b025033SStefan Hajnoczi 		.msg_name = &nladdr,
2030b025033SStefan Hajnoczi 		.msg_namelen = sizeof(nladdr),
2040b025033SStefan Hajnoczi 		.msg_iov = &iov,
2050b025033SStefan Hajnoczi 		.msg_iovlen = 1,
2060b025033SStefan Hajnoczi 	};
2070b025033SStefan Hajnoczi 
2080b025033SStefan Hajnoczi 	for (;;) {
2090b025033SStefan Hajnoczi 		if (sendmsg(fd, &msg, 0) < 0) {
2100b025033SStefan Hajnoczi 			if (errno == EINTR)
2110b025033SStefan Hajnoczi 				continue;
2120b025033SStefan Hajnoczi 
2130b025033SStefan Hajnoczi 			perror("sendmsg");
2140b025033SStefan Hajnoczi 			exit(EXIT_FAILURE);
2150b025033SStefan Hajnoczi 		}
2160b025033SStefan Hajnoczi 
2170b025033SStefan Hajnoczi 		return;
2180b025033SStefan Hajnoczi 	}
2190b025033SStefan Hajnoczi }
2200b025033SStefan Hajnoczi 
2210b025033SStefan Hajnoczi static ssize_t recv_resp(int fd, void *buf, size_t len)
2220b025033SStefan Hajnoczi {
2230b025033SStefan Hajnoczi 	struct sockaddr_nl nladdr = {
2240b025033SStefan Hajnoczi 		.nl_family = AF_NETLINK,
2250b025033SStefan Hajnoczi 	};
2260b025033SStefan Hajnoczi 	struct iovec iov = {
2270b025033SStefan Hajnoczi 		.iov_base = buf,
2280b025033SStefan Hajnoczi 		.iov_len = len,
2290b025033SStefan Hajnoczi 	};
2300b025033SStefan Hajnoczi 	struct msghdr msg = {
2310b025033SStefan Hajnoczi 		.msg_name = &nladdr,
2320b025033SStefan Hajnoczi 		.msg_namelen = sizeof(nladdr),
2330b025033SStefan Hajnoczi 		.msg_iov = &iov,
2340b025033SStefan Hajnoczi 		.msg_iovlen = 1,
2350b025033SStefan Hajnoczi 	};
2360b025033SStefan Hajnoczi 	ssize_t ret;
2370b025033SStefan Hajnoczi 
2380b025033SStefan Hajnoczi 	do {
2390b025033SStefan Hajnoczi 		ret = recvmsg(fd, &msg, 0);
2400b025033SStefan Hajnoczi 	} while (ret < 0 && errno == EINTR);
2410b025033SStefan Hajnoczi 
2420b025033SStefan Hajnoczi 	if (ret < 0) {
2430b025033SStefan Hajnoczi 		perror("recvmsg");
2440b025033SStefan Hajnoczi 		exit(EXIT_FAILURE);
2450b025033SStefan Hajnoczi 	}
2460b025033SStefan Hajnoczi 
2470b025033SStefan Hajnoczi 	return ret;
2480b025033SStefan Hajnoczi }
2490b025033SStefan Hajnoczi 
2500b025033SStefan Hajnoczi static void add_vsock_stat(struct list_head *sockets,
2510b025033SStefan Hajnoczi 			   const struct vsock_diag_msg *resp)
2520b025033SStefan Hajnoczi {
2530b025033SStefan Hajnoczi 	struct vsock_stat *st;
2540b025033SStefan Hajnoczi 
2550b025033SStefan Hajnoczi 	st = malloc(sizeof(*st));
2560b025033SStefan Hajnoczi 	if (!st) {
2570b025033SStefan Hajnoczi 		perror("malloc");
2580b025033SStefan Hajnoczi 		exit(EXIT_FAILURE);
2590b025033SStefan Hajnoczi 	}
2600b025033SStefan Hajnoczi 
2610b025033SStefan Hajnoczi 	st->msg = *resp;
2620b025033SStefan Hajnoczi 	list_add_tail(&st->list, sockets);
2630b025033SStefan Hajnoczi }
2640b025033SStefan Hajnoczi 
2650b025033SStefan Hajnoczi /*
2660b025033SStefan Hajnoczi  * Read vsock stats into a list.
2670b025033SStefan Hajnoczi  */
2680b025033SStefan Hajnoczi static void read_vsock_stat(struct list_head *sockets)
2690b025033SStefan Hajnoczi {
2700b025033SStefan Hajnoczi 	long buf[8192 / sizeof(long)];
2710b025033SStefan Hajnoczi 	int fd;
2720b025033SStefan Hajnoczi 
2730b025033SStefan Hajnoczi 	fd = socket(AF_NETLINK, SOCK_RAW, NETLINK_SOCK_DIAG);
2740b025033SStefan Hajnoczi 	if (fd < 0) {
2750b025033SStefan Hajnoczi 		perror("socket");
2760b025033SStefan Hajnoczi 		exit(EXIT_FAILURE);
2770b025033SStefan Hajnoczi 	}
2780b025033SStefan Hajnoczi 
2790b025033SStefan Hajnoczi 	send_req(fd);
2800b025033SStefan Hajnoczi 
2810b025033SStefan Hajnoczi 	for (;;) {
2820b025033SStefan Hajnoczi 		const struct nlmsghdr *h;
2830b025033SStefan Hajnoczi 		ssize_t ret;
2840b025033SStefan Hajnoczi 
2850b025033SStefan Hajnoczi 		ret = recv_resp(fd, buf, sizeof(buf));
2860b025033SStefan Hajnoczi 		if (ret == 0)
2870b025033SStefan Hajnoczi 			goto done;
2880b025033SStefan Hajnoczi 		if (ret < sizeof(*h)) {
2890b025033SStefan Hajnoczi 			fprintf(stderr, "short read of %zd bytes\n", ret);
2900b025033SStefan Hajnoczi 			exit(EXIT_FAILURE);
2910b025033SStefan Hajnoczi 		}
2920b025033SStefan Hajnoczi 
2930b025033SStefan Hajnoczi 		h = (struct nlmsghdr *)buf;
2940b025033SStefan Hajnoczi 
2950b025033SStefan Hajnoczi 		while (NLMSG_OK(h, ret)) {
2960b025033SStefan Hajnoczi 			if (h->nlmsg_type == NLMSG_DONE)
2970b025033SStefan Hajnoczi 				goto done;
2980b025033SStefan Hajnoczi 
2990b025033SStefan Hajnoczi 			if (h->nlmsg_type == NLMSG_ERROR) {
3000b025033SStefan Hajnoczi 				const struct nlmsgerr *err = NLMSG_DATA(h);
3010b025033SStefan Hajnoczi 
3020b025033SStefan Hajnoczi 				if (h->nlmsg_len < NLMSG_LENGTH(sizeof(*err)))
3030b025033SStefan Hajnoczi 					fprintf(stderr, "NLMSG_ERROR\n");
3040b025033SStefan Hajnoczi 				else {
3050b025033SStefan Hajnoczi 					errno = -err->error;
3060b025033SStefan Hajnoczi 					perror("NLMSG_ERROR");
3070b025033SStefan Hajnoczi 				}
3080b025033SStefan Hajnoczi 
3090b025033SStefan Hajnoczi 				exit(EXIT_FAILURE);
3100b025033SStefan Hajnoczi 			}
3110b025033SStefan Hajnoczi 
3120b025033SStefan Hajnoczi 			if (h->nlmsg_type != SOCK_DIAG_BY_FAMILY) {
3130b025033SStefan Hajnoczi 				fprintf(stderr, "unexpected nlmsg_type %#x\n",
3140b025033SStefan Hajnoczi 					h->nlmsg_type);
3150b025033SStefan Hajnoczi 				exit(EXIT_FAILURE);
3160b025033SStefan Hajnoczi 			}
3170b025033SStefan Hajnoczi 			if (h->nlmsg_len <
3180b025033SStefan Hajnoczi 			    NLMSG_LENGTH(sizeof(struct vsock_diag_msg))) {
3190b025033SStefan Hajnoczi 				fprintf(stderr, "short vsock_diag_msg\n");
3200b025033SStefan Hajnoczi 				exit(EXIT_FAILURE);
3210b025033SStefan Hajnoczi 			}
3220b025033SStefan Hajnoczi 
3230b025033SStefan Hajnoczi 			add_vsock_stat(sockets, NLMSG_DATA(h));
3240b025033SStefan Hajnoczi 
3250b025033SStefan Hajnoczi 			h = NLMSG_NEXT(h, ret);
3260b025033SStefan Hajnoczi 		}
3270b025033SStefan Hajnoczi 	}
3280b025033SStefan Hajnoczi 
3290b025033SStefan Hajnoczi done:
3300b025033SStefan Hajnoczi 	close(fd);
3310b025033SStefan Hajnoczi }
3320b025033SStefan Hajnoczi 
3330b025033SStefan Hajnoczi static void free_sock_stat(struct list_head *sockets)
3340b025033SStefan Hajnoczi {
3350b025033SStefan Hajnoczi 	struct vsock_stat *st;
3360b025033SStefan Hajnoczi 	struct vsock_stat *next;
3370b025033SStefan Hajnoczi 
3380b025033SStefan Hajnoczi 	list_for_each_entry_safe(st, next, sockets, list)
3390b025033SStefan Hajnoczi 		free(st);
3400b025033SStefan Hajnoczi }
3410b025033SStefan Hajnoczi 
3420b025033SStefan Hajnoczi static void test_no_sockets(unsigned int peer_cid)
3430b025033SStefan Hajnoczi {
3440b025033SStefan Hajnoczi 	LIST_HEAD(sockets);
3450b025033SStefan Hajnoczi 
3460b025033SStefan Hajnoczi 	read_vsock_stat(&sockets);
3470b025033SStefan Hajnoczi 
3480b025033SStefan Hajnoczi 	check_no_sockets(&sockets);
3490b025033SStefan Hajnoczi 
3500b025033SStefan Hajnoczi 	free_sock_stat(&sockets);
3510b025033SStefan Hajnoczi }
3520b025033SStefan Hajnoczi 
3530b025033SStefan Hajnoczi static void test_listen_socket_server(unsigned int peer_cid)
3540b025033SStefan Hajnoczi {
3550b025033SStefan Hajnoczi 	union {
3560b025033SStefan Hajnoczi 		struct sockaddr sa;
3570b025033SStefan Hajnoczi 		struct sockaddr_vm svm;
3580b025033SStefan Hajnoczi 	} addr = {
3590b025033SStefan Hajnoczi 		.svm = {
3600b025033SStefan Hajnoczi 			.svm_family = AF_VSOCK,
3610b025033SStefan Hajnoczi 			.svm_port = 1234,
3620b025033SStefan Hajnoczi 			.svm_cid = VMADDR_CID_ANY,
3630b025033SStefan Hajnoczi 		},
3640b025033SStefan Hajnoczi 	};
3650b025033SStefan Hajnoczi 	LIST_HEAD(sockets);
3660b025033SStefan Hajnoczi 	struct vsock_stat *st;
3670b025033SStefan Hajnoczi 	int fd;
3680b025033SStefan Hajnoczi 
3690b025033SStefan Hajnoczi 	fd = socket(AF_VSOCK, SOCK_STREAM, 0);
3700b025033SStefan Hajnoczi 
3710b025033SStefan Hajnoczi 	if (bind(fd, &addr.sa, sizeof(addr.svm)) < 0) {
3720b025033SStefan Hajnoczi 		perror("bind");
3730b025033SStefan Hajnoczi 		exit(EXIT_FAILURE);
3740b025033SStefan Hajnoczi 	}
3750b025033SStefan Hajnoczi 
3760b025033SStefan Hajnoczi 	if (listen(fd, 1) < 0) {
3770b025033SStefan Hajnoczi 		perror("listen");
3780b025033SStefan Hajnoczi 		exit(EXIT_FAILURE);
3790b025033SStefan Hajnoczi 	}
3800b025033SStefan Hajnoczi 
3810b025033SStefan Hajnoczi 	read_vsock_stat(&sockets);
3820b025033SStefan Hajnoczi 
3830b025033SStefan Hajnoczi 	check_num_sockets(&sockets, 1);
3840b025033SStefan Hajnoczi 	st = find_vsock_stat(&sockets, fd);
3850b025033SStefan Hajnoczi 	check_socket_state(st, TCP_LISTEN);
3860b025033SStefan Hajnoczi 
3870b025033SStefan Hajnoczi 	close(fd);
3880b025033SStefan Hajnoczi 	free_sock_stat(&sockets);
3890b025033SStefan Hajnoczi }
3900b025033SStefan Hajnoczi 
3910b025033SStefan Hajnoczi static void test_connect_client(unsigned int peer_cid)
3920b025033SStefan Hajnoczi {
3930b025033SStefan Hajnoczi 	union {
3940b025033SStefan Hajnoczi 		struct sockaddr sa;
3950b025033SStefan Hajnoczi 		struct sockaddr_vm svm;
3960b025033SStefan Hajnoczi 	} addr = {
3970b025033SStefan Hajnoczi 		.svm = {
3980b025033SStefan Hajnoczi 			.svm_family = AF_VSOCK,
3990b025033SStefan Hajnoczi 			.svm_port = 1234,
4000b025033SStefan Hajnoczi 			.svm_cid = peer_cid,
4010b025033SStefan Hajnoczi 		},
4020b025033SStefan Hajnoczi 	};
4030b025033SStefan Hajnoczi 	int fd;
4040b025033SStefan Hajnoczi 	int ret;
4050b025033SStefan Hajnoczi 	LIST_HEAD(sockets);
4060b025033SStefan Hajnoczi 	struct vsock_stat *st;
4070b025033SStefan Hajnoczi 
4080b025033SStefan Hajnoczi 	control_expectln("LISTENING");
4090b025033SStefan Hajnoczi 
4100b025033SStefan Hajnoczi 	fd = socket(AF_VSOCK, SOCK_STREAM, 0);
4110b025033SStefan Hajnoczi 
4120b025033SStefan Hajnoczi 	timeout_begin(TIMEOUT);
4130b025033SStefan Hajnoczi 	do {
4140b025033SStefan Hajnoczi 		ret = connect(fd, &addr.sa, sizeof(addr.svm));
4150b025033SStefan Hajnoczi 		timeout_check("connect");
4160b025033SStefan Hajnoczi 	} while (ret < 0 && errno == EINTR);
4170b025033SStefan Hajnoczi 	timeout_end();
4180b025033SStefan Hajnoczi 
4190b025033SStefan Hajnoczi 	if (ret < 0) {
4200b025033SStefan Hajnoczi 		perror("connect");
4210b025033SStefan Hajnoczi 		exit(EXIT_FAILURE);
4220b025033SStefan Hajnoczi 	}
4230b025033SStefan Hajnoczi 
4240b025033SStefan Hajnoczi 	read_vsock_stat(&sockets);
4250b025033SStefan Hajnoczi 
4260b025033SStefan Hajnoczi 	check_num_sockets(&sockets, 1);
4270b025033SStefan Hajnoczi 	st = find_vsock_stat(&sockets, fd);
4280b025033SStefan Hajnoczi 	check_socket_state(st, TCP_ESTABLISHED);
4290b025033SStefan Hajnoczi 
4300b025033SStefan Hajnoczi 	control_expectln("DONE");
4310b025033SStefan Hajnoczi 	control_writeln("DONE");
4320b025033SStefan Hajnoczi 
4330b025033SStefan Hajnoczi 	close(fd);
4340b025033SStefan Hajnoczi 	free_sock_stat(&sockets);
4350b025033SStefan Hajnoczi }
4360b025033SStefan Hajnoczi 
4370b025033SStefan Hajnoczi static void test_connect_server(unsigned int peer_cid)
4380b025033SStefan Hajnoczi {
4390b025033SStefan Hajnoczi 	union {
4400b025033SStefan Hajnoczi 		struct sockaddr sa;
4410b025033SStefan Hajnoczi 		struct sockaddr_vm svm;
4420b025033SStefan Hajnoczi 	} addr = {
4430b025033SStefan Hajnoczi 		.svm = {
4440b025033SStefan Hajnoczi 			.svm_family = AF_VSOCK,
4450b025033SStefan Hajnoczi 			.svm_port = 1234,
4460b025033SStefan Hajnoczi 			.svm_cid = VMADDR_CID_ANY,
4470b025033SStefan Hajnoczi 		},
4480b025033SStefan Hajnoczi 	};
4490b025033SStefan Hajnoczi 	union {
4500b025033SStefan Hajnoczi 		struct sockaddr sa;
4510b025033SStefan Hajnoczi 		struct sockaddr_vm svm;
4520b025033SStefan Hajnoczi 	} clientaddr;
4530b025033SStefan Hajnoczi 	socklen_t clientaddr_len = sizeof(clientaddr.svm);
4540b025033SStefan Hajnoczi 	LIST_HEAD(sockets);
4550b025033SStefan Hajnoczi 	struct vsock_stat *st;
4560b025033SStefan Hajnoczi 	int fd;
4570b025033SStefan Hajnoczi 	int client_fd;
4580b025033SStefan Hajnoczi 
4590b025033SStefan Hajnoczi 	fd = socket(AF_VSOCK, SOCK_STREAM, 0);
4600b025033SStefan Hajnoczi 
4610b025033SStefan Hajnoczi 	if (bind(fd, &addr.sa, sizeof(addr.svm)) < 0) {
4620b025033SStefan Hajnoczi 		perror("bind");
4630b025033SStefan Hajnoczi 		exit(EXIT_FAILURE);
4640b025033SStefan Hajnoczi 	}
4650b025033SStefan Hajnoczi 
4660b025033SStefan Hajnoczi 	if (listen(fd, 1) < 0) {
4670b025033SStefan Hajnoczi 		perror("listen");
4680b025033SStefan Hajnoczi 		exit(EXIT_FAILURE);
4690b025033SStefan Hajnoczi 	}
4700b025033SStefan Hajnoczi 
4710b025033SStefan Hajnoczi 	control_writeln("LISTENING");
4720b025033SStefan Hajnoczi 
4730b025033SStefan Hajnoczi 	timeout_begin(TIMEOUT);
4740b025033SStefan Hajnoczi 	do {
4750b025033SStefan Hajnoczi 		client_fd = accept(fd, &clientaddr.sa, &clientaddr_len);
4760b025033SStefan Hajnoczi 		timeout_check("accept");
4770b025033SStefan Hajnoczi 	} while (client_fd < 0 && errno == EINTR);
4780b025033SStefan Hajnoczi 	timeout_end();
4790b025033SStefan Hajnoczi 
4800b025033SStefan Hajnoczi 	if (client_fd < 0) {
4810b025033SStefan Hajnoczi 		perror("accept");
4820b025033SStefan Hajnoczi 		exit(EXIT_FAILURE);
4830b025033SStefan Hajnoczi 	}
4840b025033SStefan Hajnoczi 	if (clientaddr.sa.sa_family != AF_VSOCK) {
4850b025033SStefan Hajnoczi 		fprintf(stderr, "expected AF_VSOCK from accept(2), got %d\n",
4860b025033SStefan Hajnoczi 			clientaddr.sa.sa_family);
4870b025033SStefan Hajnoczi 		exit(EXIT_FAILURE);
4880b025033SStefan Hajnoczi 	}
4890b025033SStefan Hajnoczi 	if (clientaddr.svm.svm_cid != peer_cid) {
4900b025033SStefan Hajnoczi 		fprintf(stderr, "expected peer CID %u from accept(2), got %u\n",
4910b025033SStefan Hajnoczi 			peer_cid, clientaddr.svm.svm_cid);
4920b025033SStefan Hajnoczi 		exit(EXIT_FAILURE);
4930b025033SStefan Hajnoczi 	}
4940b025033SStefan Hajnoczi 
4950b025033SStefan Hajnoczi 	read_vsock_stat(&sockets);
4960b025033SStefan Hajnoczi 
4970b025033SStefan Hajnoczi 	check_num_sockets(&sockets, 2);
4980b025033SStefan Hajnoczi 	find_vsock_stat(&sockets, fd);
4990b025033SStefan Hajnoczi 	st = find_vsock_stat(&sockets, client_fd);
5000b025033SStefan Hajnoczi 	check_socket_state(st, TCP_ESTABLISHED);
5010b025033SStefan Hajnoczi 
5020b025033SStefan Hajnoczi 	control_writeln("DONE");
5030b025033SStefan Hajnoczi 	control_expectln("DONE");
5040b025033SStefan Hajnoczi 
5050b025033SStefan Hajnoczi 	close(client_fd);
5060b025033SStefan Hajnoczi 	close(fd);
5070b025033SStefan Hajnoczi 	free_sock_stat(&sockets);
5080b025033SStefan Hajnoczi }
5090b025033SStefan Hajnoczi 
5100b025033SStefan Hajnoczi static struct {
5110b025033SStefan Hajnoczi 	const char *name;
5120b025033SStefan Hajnoczi 	void (*run_client)(unsigned int peer_cid);
5130b025033SStefan Hajnoczi 	void (*run_server)(unsigned int peer_cid);
5140b025033SStefan Hajnoczi } test_cases[] = {
5150b025033SStefan Hajnoczi 	{
5160b025033SStefan Hajnoczi 		.name = "No sockets",
5170b025033SStefan Hajnoczi 		.run_server = test_no_sockets,
5180b025033SStefan Hajnoczi 	},
5190b025033SStefan Hajnoczi 	{
5200b025033SStefan Hajnoczi 		.name = "Listen socket",
5210b025033SStefan Hajnoczi 		.run_server = test_listen_socket_server,
5220b025033SStefan Hajnoczi 	},
5230b025033SStefan Hajnoczi 	{
5240b025033SStefan Hajnoczi 		.name = "Connect",
5250b025033SStefan Hajnoczi 		.run_client = test_connect_client,
5260b025033SStefan Hajnoczi 		.run_server = test_connect_server,
5270b025033SStefan Hajnoczi 	},
5280b025033SStefan Hajnoczi 	{},
5290b025033SStefan Hajnoczi };
5300b025033SStefan Hajnoczi 
5310b025033SStefan Hajnoczi static void init_signals(void)
5320b025033SStefan Hajnoczi {
5330b025033SStefan Hajnoczi 	struct sigaction act = {
5340b025033SStefan Hajnoczi 		.sa_handler = sigalrm,
5350b025033SStefan Hajnoczi 	};
5360b025033SStefan Hajnoczi 
5370b025033SStefan Hajnoczi 	sigaction(SIGALRM, &act, NULL);
5380b025033SStefan Hajnoczi 	signal(SIGPIPE, SIG_IGN);
5390b025033SStefan Hajnoczi }
5400b025033SStefan Hajnoczi 
5410b025033SStefan Hajnoczi static unsigned int parse_cid(const char *str)
5420b025033SStefan Hajnoczi {
5430b025033SStefan Hajnoczi 	char *endptr = NULL;
5440b025033SStefan Hajnoczi 	unsigned long int n;
5450b025033SStefan Hajnoczi 
5460b025033SStefan Hajnoczi 	errno = 0;
5470b025033SStefan Hajnoczi 	n = strtoul(str, &endptr, 10);
5480b025033SStefan Hajnoczi 	if (errno || *endptr != '\0') {
5490b025033SStefan Hajnoczi 		fprintf(stderr, "malformed CID \"%s\"\n", str);
5500b025033SStefan Hajnoczi 		exit(EXIT_FAILURE);
5510b025033SStefan Hajnoczi 	}
5520b025033SStefan Hajnoczi 	return n;
5530b025033SStefan Hajnoczi }
5540b025033SStefan Hajnoczi 
5550b025033SStefan Hajnoczi static const char optstring[] = "";
5560b025033SStefan Hajnoczi static const struct option longopts[] = {
5570b025033SStefan Hajnoczi 	{
5580b025033SStefan Hajnoczi 		.name = "control-host",
5590b025033SStefan Hajnoczi 		.has_arg = required_argument,
5600b025033SStefan Hajnoczi 		.val = 'H',
5610b025033SStefan Hajnoczi 	},
5620b025033SStefan Hajnoczi 	{
5630b025033SStefan Hajnoczi 		.name = "control-port",
5640b025033SStefan Hajnoczi 		.has_arg = required_argument,
5650b025033SStefan Hajnoczi 		.val = 'P',
5660b025033SStefan Hajnoczi 	},
5670b025033SStefan Hajnoczi 	{
5680b025033SStefan Hajnoczi 		.name = "mode",
5690b025033SStefan Hajnoczi 		.has_arg = required_argument,
5700b025033SStefan Hajnoczi 		.val = 'm',
5710b025033SStefan Hajnoczi 	},
5720b025033SStefan Hajnoczi 	{
5730b025033SStefan Hajnoczi 		.name = "peer-cid",
5740b025033SStefan Hajnoczi 		.has_arg = required_argument,
5750b025033SStefan Hajnoczi 		.val = 'p',
5760b025033SStefan Hajnoczi 	},
5770b025033SStefan Hajnoczi 	{
5780b025033SStefan Hajnoczi 		.name = "help",
5790b025033SStefan Hajnoczi 		.has_arg = no_argument,
5800b025033SStefan Hajnoczi 		.val = '?',
5810b025033SStefan Hajnoczi 	},
5820b025033SStefan Hajnoczi 	{},
5830b025033SStefan Hajnoczi };
5840b025033SStefan Hajnoczi 
5850b025033SStefan Hajnoczi static void usage(void)
5860b025033SStefan Hajnoczi {
5870b025033SStefan Hajnoczi 	fprintf(stderr, "Usage: vsock_diag_test [--help] [--control-host=<host>] --control-port=<port> --mode=client|server --peer-cid=<cid>\n"
5880b025033SStefan Hajnoczi 		"\n"
5890b025033SStefan Hajnoczi 		"  Server: vsock_diag_test --control-port=1234 --mode=server --peer-cid=3\n"
5900b025033SStefan Hajnoczi 		"  Client: vsock_diag_test --control-host=192.168.0.1 --control-port=1234 --mode=client --peer-cid=2\n"
5910b025033SStefan Hajnoczi 		"\n"
5920b025033SStefan Hajnoczi 		"Run vsock_diag.ko tests.  Must be launched in both\n"
5930b025033SStefan Hajnoczi 		"guest and host.  One side must use --mode=client and\n"
5940b025033SStefan Hajnoczi 		"the other side must use --mode=server.\n"
5950b025033SStefan Hajnoczi 		"\n"
5960b025033SStefan Hajnoczi 		"A TCP control socket connection is used to coordinate tests\n"
5970b025033SStefan Hajnoczi 		"between the client and the server.  The server requires a\n"
5980b025033SStefan Hajnoczi 		"listen address and the client requires an address to\n"
5990b025033SStefan Hajnoczi 		"connect to.\n"
6000b025033SStefan Hajnoczi 		"\n"
6010b025033SStefan Hajnoczi 		"The CID of the other side must be given with --peer-cid=<cid>.\n");
6020b025033SStefan Hajnoczi 	exit(EXIT_FAILURE);
6030b025033SStefan Hajnoczi }
6040b025033SStefan Hajnoczi 
6050b025033SStefan Hajnoczi int main(int argc, char **argv)
6060b025033SStefan Hajnoczi {
6070b025033SStefan Hajnoczi 	const char *control_host = NULL;
6080b025033SStefan Hajnoczi 	const char *control_port = NULL;
6090b025033SStefan Hajnoczi 	int mode = TEST_MODE_UNSET;
6100b025033SStefan Hajnoczi 	unsigned int peer_cid = VMADDR_CID_ANY;
6110b025033SStefan Hajnoczi 	int i;
6120b025033SStefan Hajnoczi 
6130b025033SStefan Hajnoczi 	init_signals();
6140b025033SStefan Hajnoczi 
6150b025033SStefan Hajnoczi 	for (;;) {
6160b025033SStefan Hajnoczi 		int opt = getopt_long(argc, argv, optstring, longopts, NULL);
6170b025033SStefan Hajnoczi 
6180b025033SStefan Hajnoczi 		if (opt == -1)
6190b025033SStefan Hajnoczi 			break;
6200b025033SStefan Hajnoczi 
6210b025033SStefan Hajnoczi 		switch (opt) {
6220b025033SStefan Hajnoczi 		case 'H':
6230b025033SStefan Hajnoczi 			control_host = optarg;
6240b025033SStefan Hajnoczi 			break;
6250b025033SStefan Hajnoczi 		case 'm':
6260b025033SStefan Hajnoczi 			if (strcmp(optarg, "client") == 0)
6270b025033SStefan Hajnoczi 				mode = TEST_MODE_CLIENT;
6280b025033SStefan Hajnoczi 			else if (strcmp(optarg, "server") == 0)
6290b025033SStefan Hajnoczi 				mode = TEST_MODE_SERVER;
6300b025033SStefan Hajnoczi 			else {
6310b025033SStefan Hajnoczi 				fprintf(stderr, "--mode must be \"client\" or \"server\"\n");
6320b025033SStefan Hajnoczi 				return EXIT_FAILURE;
6330b025033SStefan Hajnoczi 			}
6340b025033SStefan Hajnoczi 			break;
6350b025033SStefan Hajnoczi 		case 'p':
6360b025033SStefan Hajnoczi 			peer_cid = parse_cid(optarg);
6370b025033SStefan Hajnoczi 			break;
6380b025033SStefan Hajnoczi 		case 'P':
6390b025033SStefan Hajnoczi 			control_port = optarg;
6400b025033SStefan Hajnoczi 			break;
6410b025033SStefan Hajnoczi 		case '?':
6420b025033SStefan Hajnoczi 		default:
6430b025033SStefan Hajnoczi 			usage();
6440b025033SStefan Hajnoczi 		}
6450b025033SStefan Hajnoczi 	}
6460b025033SStefan Hajnoczi 
6470b025033SStefan Hajnoczi 	if (!control_port)
6480b025033SStefan Hajnoczi 		usage();
6490b025033SStefan Hajnoczi 	if (mode == TEST_MODE_UNSET)
6500b025033SStefan Hajnoczi 		usage();
6510b025033SStefan Hajnoczi 	if (peer_cid == VMADDR_CID_ANY)
6520b025033SStefan Hajnoczi 		usage();
6530b025033SStefan Hajnoczi 
6540b025033SStefan Hajnoczi 	if (!control_host) {
6550b025033SStefan Hajnoczi 		if (mode != TEST_MODE_SERVER)
6560b025033SStefan Hajnoczi 			usage();
6570b025033SStefan Hajnoczi 		control_host = "0.0.0.0";
6580b025033SStefan Hajnoczi 	}
6590b025033SStefan Hajnoczi 
6600b025033SStefan Hajnoczi 	control_init(control_host, control_port, mode == TEST_MODE_SERVER);
6610b025033SStefan Hajnoczi 
6620b025033SStefan Hajnoczi 	for (i = 0; test_cases[i].name; i++) {
6630b025033SStefan Hajnoczi 		void (*run)(unsigned int peer_cid);
6640b025033SStefan Hajnoczi 
6650b025033SStefan Hajnoczi 		printf("%s...", test_cases[i].name);
6660b025033SStefan Hajnoczi 		fflush(stdout);
6670b025033SStefan Hajnoczi 
6680b025033SStefan Hajnoczi 		if (mode == TEST_MODE_CLIENT)
6690b025033SStefan Hajnoczi 			run = test_cases[i].run_client;
6700b025033SStefan Hajnoczi 		else
6710b025033SStefan Hajnoczi 			run = test_cases[i].run_server;
6720b025033SStefan Hajnoczi 
6730b025033SStefan Hajnoczi 		if (run)
6740b025033SStefan Hajnoczi 			run(peer_cid);
6750b025033SStefan Hajnoczi 
6760b025033SStefan Hajnoczi 		printf("ok\n");
6770b025033SStefan Hajnoczi 	}
6780b025033SStefan Hajnoczi 
6790b025033SStefan Hajnoczi 	control_cleanup();
6800b025033SStefan Hajnoczi 	return EXIT_SUCCESS;
6810b025033SStefan Hajnoczi }
682