xref: /openbmc/linux/tools/testing/vsock/vsock_test.c (revision 55b37d9c)
1 // SPDX-License-Identifier: GPL-2.0-only
2 /*
3  * vsock_test - vsock.ko test suite
4  *
5  * Copyright (C) 2017 Red Hat, Inc.
6  *
7  * Author: Stefan Hajnoczi <stefanha@redhat.com>
8  */
9 
10 #include <getopt.h>
11 #include <stdio.h>
12 #include <stdlib.h>
13 #include <string.h>
14 #include <errno.h>
15 #include <unistd.h>
16 #include <linux/kernel.h>
17 #include <sys/types.h>
18 #include <sys/socket.h>
19 #include <time.h>
20 #include <sys/mman.h>
21 #include <poll.h>
22 
23 #include "timeout.h"
24 #include "control.h"
25 #include "util.h"
26 
27 static void test_stream_connection_reset(const struct test_opts *opts)
28 {
29 	union {
30 		struct sockaddr sa;
31 		struct sockaddr_vm svm;
32 	} addr = {
33 		.svm = {
34 			.svm_family = AF_VSOCK,
35 			.svm_port = 1234,
36 			.svm_cid = opts->peer_cid,
37 		},
38 	};
39 	int ret;
40 	int fd;
41 
42 	fd = socket(AF_VSOCK, SOCK_STREAM, 0);
43 
44 	timeout_begin(TIMEOUT);
45 	do {
46 		ret = connect(fd, &addr.sa, sizeof(addr.svm));
47 		timeout_check("connect");
48 	} while (ret < 0 && errno == EINTR);
49 	timeout_end();
50 
51 	if (ret != -1) {
52 		fprintf(stderr, "expected connect(2) failure, got %d\n", ret);
53 		exit(EXIT_FAILURE);
54 	}
55 	if (errno != ECONNRESET) {
56 		fprintf(stderr, "unexpected connect(2) errno %d\n", errno);
57 		exit(EXIT_FAILURE);
58 	}
59 
60 	close(fd);
61 }
62 
63 static void test_stream_bind_only_client(const struct test_opts *opts)
64 {
65 	union {
66 		struct sockaddr sa;
67 		struct sockaddr_vm svm;
68 	} addr = {
69 		.svm = {
70 			.svm_family = AF_VSOCK,
71 			.svm_port = 1234,
72 			.svm_cid = opts->peer_cid,
73 		},
74 	};
75 	int ret;
76 	int fd;
77 
78 	/* Wait for the server to be ready */
79 	control_expectln("BIND");
80 
81 	fd = socket(AF_VSOCK, SOCK_STREAM, 0);
82 
83 	timeout_begin(TIMEOUT);
84 	do {
85 		ret = connect(fd, &addr.sa, sizeof(addr.svm));
86 		timeout_check("connect");
87 	} while (ret < 0 && errno == EINTR);
88 	timeout_end();
89 
90 	if (ret != -1) {
91 		fprintf(stderr, "expected connect(2) failure, got %d\n", ret);
92 		exit(EXIT_FAILURE);
93 	}
94 	if (errno != ECONNRESET) {
95 		fprintf(stderr, "unexpected connect(2) errno %d\n", errno);
96 		exit(EXIT_FAILURE);
97 	}
98 
99 	/* Notify the server that the client has finished */
100 	control_writeln("DONE");
101 
102 	close(fd);
103 }
104 
105 static void test_stream_bind_only_server(const struct test_opts *opts)
106 {
107 	union {
108 		struct sockaddr sa;
109 		struct sockaddr_vm svm;
110 	} addr = {
111 		.svm = {
112 			.svm_family = AF_VSOCK,
113 			.svm_port = 1234,
114 			.svm_cid = VMADDR_CID_ANY,
115 		},
116 	};
117 	int fd;
118 
119 	fd = socket(AF_VSOCK, SOCK_STREAM, 0);
120 
121 	if (bind(fd, &addr.sa, sizeof(addr.svm)) < 0) {
122 		perror("bind");
123 		exit(EXIT_FAILURE);
124 	}
125 
126 	/* Notify the client that the server is ready */
127 	control_writeln("BIND");
128 
129 	/* Wait for the client to finish */
130 	control_expectln("DONE");
131 
132 	close(fd);
133 }
134 
135 static void test_stream_client_close_client(const struct test_opts *opts)
136 {
137 	int fd;
138 
139 	fd = vsock_stream_connect(opts->peer_cid, 1234);
140 	if (fd < 0) {
141 		perror("connect");
142 		exit(EXIT_FAILURE);
143 	}
144 
145 	send_byte(fd, 1, 0);
146 	close(fd);
147 }
148 
149 static void test_stream_client_close_server(const struct test_opts *opts)
150 {
151 	int fd;
152 
153 	fd = vsock_stream_accept(VMADDR_CID_ANY, 1234, NULL);
154 	if (fd < 0) {
155 		perror("accept");
156 		exit(EXIT_FAILURE);
157 	}
158 
159 	/* Wait for the remote to close the connection, before check
160 	 * -EPIPE error on send.
161 	 */
162 	vsock_wait_remote_close(fd);
163 
164 	send_byte(fd, -EPIPE, 0);
165 	recv_byte(fd, 1, 0);
166 	recv_byte(fd, 0, 0);
167 	close(fd);
168 }
169 
170 static void test_stream_server_close_client(const struct test_opts *opts)
171 {
172 	int fd;
173 
174 	fd = vsock_stream_connect(opts->peer_cid, 1234);
175 	if (fd < 0) {
176 		perror("connect");
177 		exit(EXIT_FAILURE);
178 	}
179 
180 	/* Wait for the remote to close the connection, before check
181 	 * -EPIPE error on send.
182 	 */
183 	vsock_wait_remote_close(fd);
184 
185 	send_byte(fd, -EPIPE, 0);
186 	recv_byte(fd, 1, 0);
187 	recv_byte(fd, 0, 0);
188 	close(fd);
189 }
190 
191 static void test_stream_server_close_server(const struct test_opts *opts)
192 {
193 	int fd;
194 
195 	fd = vsock_stream_accept(VMADDR_CID_ANY, 1234, NULL);
196 	if (fd < 0) {
197 		perror("accept");
198 		exit(EXIT_FAILURE);
199 	}
200 
201 	send_byte(fd, 1, 0);
202 	close(fd);
203 }
204 
205 /* With the standard socket sizes, VMCI is able to support about 100
206  * concurrent stream connections.
207  */
208 #define MULTICONN_NFDS 100
209 
210 static void test_stream_multiconn_client(const struct test_opts *opts)
211 {
212 	int fds[MULTICONN_NFDS];
213 	int i;
214 
215 	for (i = 0; i < MULTICONN_NFDS; i++) {
216 		fds[i] = vsock_stream_connect(opts->peer_cid, 1234);
217 		if (fds[i] < 0) {
218 			perror("connect");
219 			exit(EXIT_FAILURE);
220 		}
221 	}
222 
223 	for (i = 0; i < MULTICONN_NFDS; i++) {
224 		if (i % 2)
225 			recv_byte(fds[i], 1, 0);
226 		else
227 			send_byte(fds[i], 1, 0);
228 	}
229 
230 	for (i = 0; i < MULTICONN_NFDS; i++)
231 		close(fds[i]);
232 }
233 
234 static void test_stream_multiconn_server(const struct test_opts *opts)
235 {
236 	int fds[MULTICONN_NFDS];
237 	int i;
238 
239 	for (i = 0; i < MULTICONN_NFDS; i++) {
240 		fds[i] = vsock_stream_accept(VMADDR_CID_ANY, 1234, NULL);
241 		if (fds[i] < 0) {
242 			perror("accept");
243 			exit(EXIT_FAILURE);
244 		}
245 	}
246 
247 	for (i = 0; i < MULTICONN_NFDS; i++) {
248 		if (i % 2)
249 			send_byte(fds[i], 1, 0);
250 		else
251 			recv_byte(fds[i], 1, 0);
252 	}
253 
254 	for (i = 0; i < MULTICONN_NFDS; i++)
255 		close(fds[i]);
256 }
257 
258 static void test_stream_msg_peek_client(const struct test_opts *opts)
259 {
260 	int fd;
261 
262 	fd = vsock_stream_connect(opts->peer_cid, 1234);
263 	if (fd < 0) {
264 		perror("connect");
265 		exit(EXIT_FAILURE);
266 	}
267 
268 	send_byte(fd, 1, 0);
269 	close(fd);
270 }
271 
272 static void test_stream_msg_peek_server(const struct test_opts *opts)
273 {
274 	int fd;
275 
276 	fd = vsock_stream_accept(VMADDR_CID_ANY, 1234, NULL);
277 	if (fd < 0) {
278 		perror("accept");
279 		exit(EXIT_FAILURE);
280 	}
281 
282 	recv_byte(fd, 1, MSG_PEEK);
283 	recv_byte(fd, 1, 0);
284 	close(fd);
285 }
286 
287 #define SOCK_BUF_SIZE (2 * 1024 * 1024)
288 #define MAX_MSG_SIZE (32 * 1024)
289 
290 static void test_seqpacket_msg_bounds_client(const struct test_opts *opts)
291 {
292 	unsigned long curr_hash;
293 	int page_size;
294 	int msg_count;
295 	int fd;
296 
297 	fd = vsock_seqpacket_connect(opts->peer_cid, 1234);
298 	if (fd < 0) {
299 		perror("connect");
300 		exit(EXIT_FAILURE);
301 	}
302 
303 	/* Wait, until receiver sets buffer size. */
304 	control_expectln("SRVREADY");
305 
306 	curr_hash = 0;
307 	page_size = getpagesize();
308 	msg_count = SOCK_BUF_SIZE / MAX_MSG_SIZE;
309 
310 	for (int i = 0; i < msg_count; i++) {
311 		ssize_t send_size;
312 		size_t buf_size;
313 		int flags;
314 		void *buf;
315 
316 		/* Use "small" buffers and "big" buffers. */
317 		if (i & 1)
318 			buf_size = page_size +
319 					(rand() % (MAX_MSG_SIZE - page_size));
320 		else
321 			buf_size = 1 + (rand() % page_size);
322 
323 		buf = malloc(buf_size);
324 
325 		if (!buf) {
326 			perror("malloc");
327 			exit(EXIT_FAILURE);
328 		}
329 
330 		memset(buf, rand() & 0xff, buf_size);
331 		/* Set at least one MSG_EOR + some random. */
332 		if (i == (msg_count / 2) || (rand() & 1)) {
333 			flags = MSG_EOR;
334 			curr_hash++;
335 		} else {
336 			flags = 0;
337 		}
338 
339 		send_size = send(fd, buf, buf_size, flags);
340 
341 		if (send_size < 0) {
342 			perror("send");
343 			exit(EXIT_FAILURE);
344 		}
345 
346 		if (send_size != buf_size) {
347 			fprintf(stderr, "Invalid send size\n");
348 			exit(EXIT_FAILURE);
349 		}
350 
351 		/*
352 		 * Hash sum is computed at both client and server in
353 		 * the same way:
354 		 * H += hash('message data')
355 		 * Such hash "controls" both data integrity and message
356 		 * bounds. After data exchange, both sums are compared
357 		 * using control socket, and if message bounds wasn't
358 		 * broken - two values must be equal.
359 		 */
360 		curr_hash += hash_djb2(buf, buf_size);
361 		free(buf);
362 	}
363 
364 	control_writeln("SENDDONE");
365 	control_writeulong(curr_hash);
366 	close(fd);
367 }
368 
369 static void test_seqpacket_msg_bounds_server(const struct test_opts *opts)
370 {
371 	unsigned long sock_buf_size;
372 	unsigned long remote_hash;
373 	unsigned long curr_hash;
374 	int fd;
375 	char buf[MAX_MSG_SIZE];
376 	struct msghdr msg = {0};
377 	struct iovec iov = {0};
378 
379 	fd = vsock_seqpacket_accept(VMADDR_CID_ANY, 1234, NULL);
380 	if (fd < 0) {
381 		perror("accept");
382 		exit(EXIT_FAILURE);
383 	}
384 
385 	sock_buf_size = SOCK_BUF_SIZE;
386 
387 	if (setsockopt(fd, AF_VSOCK, SO_VM_SOCKETS_BUFFER_MAX_SIZE,
388 		       &sock_buf_size, sizeof(sock_buf_size))) {
389 		perror("setsockopt(SO_VM_SOCKETS_BUFFER_MAX_SIZE)");
390 		exit(EXIT_FAILURE);
391 	}
392 
393 	if (setsockopt(fd, AF_VSOCK, SO_VM_SOCKETS_BUFFER_SIZE,
394 		       &sock_buf_size, sizeof(sock_buf_size))) {
395 		perror("setsockopt(SO_VM_SOCKETS_BUFFER_SIZE)");
396 		exit(EXIT_FAILURE);
397 	}
398 
399 	/* Ready to receive data. */
400 	control_writeln("SRVREADY");
401 	/* Wait, until peer sends whole data. */
402 	control_expectln("SENDDONE");
403 	iov.iov_base = buf;
404 	iov.iov_len = sizeof(buf);
405 	msg.msg_iov = &iov;
406 	msg.msg_iovlen = 1;
407 
408 	curr_hash = 0;
409 
410 	while (1) {
411 		ssize_t recv_size;
412 
413 		recv_size = recvmsg(fd, &msg, 0);
414 
415 		if (!recv_size)
416 			break;
417 
418 		if (recv_size < 0) {
419 			perror("recvmsg");
420 			exit(EXIT_FAILURE);
421 		}
422 
423 		if (msg.msg_flags & MSG_EOR)
424 			curr_hash++;
425 
426 		curr_hash += hash_djb2(msg.msg_iov[0].iov_base, recv_size);
427 	}
428 
429 	close(fd);
430 	remote_hash = control_readulong();
431 
432 	if (curr_hash != remote_hash) {
433 		fprintf(stderr, "Message bounds broken\n");
434 		exit(EXIT_FAILURE);
435 	}
436 }
437 
438 #define MESSAGE_TRUNC_SZ 32
439 static void test_seqpacket_msg_trunc_client(const struct test_opts *opts)
440 {
441 	int fd;
442 	char buf[MESSAGE_TRUNC_SZ];
443 
444 	fd = vsock_seqpacket_connect(opts->peer_cid, 1234);
445 	if (fd < 0) {
446 		perror("connect");
447 		exit(EXIT_FAILURE);
448 	}
449 
450 	if (send(fd, buf, sizeof(buf), 0) != sizeof(buf)) {
451 		perror("send failed");
452 		exit(EXIT_FAILURE);
453 	}
454 
455 	control_writeln("SENDDONE");
456 	close(fd);
457 }
458 
459 static void test_seqpacket_msg_trunc_server(const struct test_opts *opts)
460 {
461 	int fd;
462 	char buf[MESSAGE_TRUNC_SZ / 2];
463 	struct msghdr msg = {0};
464 	struct iovec iov = {0};
465 
466 	fd = vsock_seqpacket_accept(VMADDR_CID_ANY, 1234, NULL);
467 	if (fd < 0) {
468 		perror("accept");
469 		exit(EXIT_FAILURE);
470 	}
471 
472 	control_expectln("SENDDONE");
473 	iov.iov_base = buf;
474 	iov.iov_len = sizeof(buf);
475 	msg.msg_iov = &iov;
476 	msg.msg_iovlen = 1;
477 
478 	ssize_t ret = recvmsg(fd, &msg, MSG_TRUNC);
479 
480 	if (ret != MESSAGE_TRUNC_SZ) {
481 		printf("%zi\n", ret);
482 		perror("MSG_TRUNC doesn't work");
483 		exit(EXIT_FAILURE);
484 	}
485 
486 	if (!(msg.msg_flags & MSG_TRUNC)) {
487 		fprintf(stderr, "MSG_TRUNC expected\n");
488 		exit(EXIT_FAILURE);
489 	}
490 
491 	close(fd);
492 }
493 
494 static time_t current_nsec(void)
495 {
496 	struct timespec ts;
497 
498 	if (clock_gettime(CLOCK_REALTIME, &ts)) {
499 		perror("clock_gettime(3) failed");
500 		exit(EXIT_FAILURE);
501 	}
502 
503 	return (ts.tv_sec * 1000000000ULL) + ts.tv_nsec;
504 }
505 
506 #define RCVTIMEO_TIMEOUT_SEC 1
507 #define READ_OVERHEAD_NSEC 250000000 /* 0.25 sec */
508 
509 static void test_seqpacket_timeout_client(const struct test_opts *opts)
510 {
511 	int fd;
512 	struct timeval tv;
513 	char dummy;
514 	time_t read_enter_ns;
515 	time_t read_overhead_ns;
516 
517 	fd = vsock_seqpacket_connect(opts->peer_cid, 1234);
518 	if (fd < 0) {
519 		perror("connect");
520 		exit(EXIT_FAILURE);
521 	}
522 
523 	tv.tv_sec = RCVTIMEO_TIMEOUT_SEC;
524 	tv.tv_usec = 0;
525 
526 	if (setsockopt(fd, SOL_SOCKET, SO_RCVTIMEO, (void *)&tv, sizeof(tv)) == -1) {
527 		perror("setsockopt(SO_RCVTIMEO)");
528 		exit(EXIT_FAILURE);
529 	}
530 
531 	read_enter_ns = current_nsec();
532 
533 	if (read(fd, &dummy, sizeof(dummy)) != -1) {
534 		fprintf(stderr,
535 			"expected 'dummy' read(2) failure\n");
536 		exit(EXIT_FAILURE);
537 	}
538 
539 	if (errno != EAGAIN) {
540 		perror("EAGAIN expected");
541 		exit(EXIT_FAILURE);
542 	}
543 
544 	read_overhead_ns = current_nsec() - read_enter_ns -
545 			1000000000ULL * RCVTIMEO_TIMEOUT_SEC;
546 
547 	if (read_overhead_ns > READ_OVERHEAD_NSEC) {
548 		fprintf(stderr,
549 			"too much time in read(2), %lu > %i ns\n",
550 			read_overhead_ns, READ_OVERHEAD_NSEC);
551 		exit(EXIT_FAILURE);
552 	}
553 
554 	control_writeln("WAITDONE");
555 	close(fd);
556 }
557 
558 static void test_seqpacket_timeout_server(const struct test_opts *opts)
559 {
560 	int fd;
561 
562 	fd = vsock_seqpacket_accept(VMADDR_CID_ANY, 1234, NULL);
563 	if (fd < 0) {
564 		perror("accept");
565 		exit(EXIT_FAILURE);
566 	}
567 
568 	control_expectln("WAITDONE");
569 	close(fd);
570 }
571 
572 static void test_seqpacket_bigmsg_client(const struct test_opts *opts)
573 {
574 	unsigned long sock_buf_size;
575 	ssize_t send_size;
576 	socklen_t len;
577 	void *data;
578 	int fd;
579 
580 	len = sizeof(sock_buf_size);
581 
582 	fd = vsock_seqpacket_connect(opts->peer_cid, 1234);
583 	if (fd < 0) {
584 		perror("connect");
585 		exit(EXIT_FAILURE);
586 	}
587 
588 	if (getsockopt(fd, AF_VSOCK, SO_VM_SOCKETS_BUFFER_SIZE,
589 		       &sock_buf_size, &len)) {
590 		perror("getsockopt");
591 		exit(EXIT_FAILURE);
592 	}
593 
594 	sock_buf_size++;
595 
596 	data = malloc(sock_buf_size);
597 	if (!data) {
598 		perror("malloc");
599 		exit(EXIT_FAILURE);
600 	}
601 
602 	send_size = send(fd, data, sock_buf_size, 0);
603 	if (send_size != -1) {
604 		fprintf(stderr, "expected 'send(2)' failure, got %zi\n",
605 			send_size);
606 		exit(EXIT_FAILURE);
607 	}
608 
609 	if (errno != EMSGSIZE) {
610 		fprintf(stderr, "expected EMSGSIZE in 'errno', got %i\n",
611 			errno);
612 		exit(EXIT_FAILURE);
613 	}
614 
615 	control_writeln("CLISENT");
616 
617 	free(data);
618 	close(fd);
619 }
620 
621 static void test_seqpacket_bigmsg_server(const struct test_opts *opts)
622 {
623 	int fd;
624 
625 	fd = vsock_seqpacket_accept(VMADDR_CID_ANY, 1234, NULL);
626 	if (fd < 0) {
627 		perror("accept");
628 		exit(EXIT_FAILURE);
629 	}
630 
631 	control_expectln("CLISENT");
632 
633 	close(fd);
634 }
635 
636 #define BUF_PATTERN_1 'a'
637 #define BUF_PATTERN_2 'b'
638 
639 static void test_seqpacket_invalid_rec_buffer_client(const struct test_opts *opts)
640 {
641 	int fd;
642 	unsigned char *buf1;
643 	unsigned char *buf2;
644 	int buf_size = getpagesize() * 3;
645 
646 	fd = vsock_seqpacket_connect(opts->peer_cid, 1234);
647 	if (fd < 0) {
648 		perror("connect");
649 		exit(EXIT_FAILURE);
650 	}
651 
652 	buf1 = malloc(buf_size);
653 	if (!buf1) {
654 		perror("'malloc()' for 'buf1'");
655 		exit(EXIT_FAILURE);
656 	}
657 
658 	buf2 = malloc(buf_size);
659 	if (!buf2) {
660 		perror("'malloc()' for 'buf2'");
661 		exit(EXIT_FAILURE);
662 	}
663 
664 	memset(buf1, BUF_PATTERN_1, buf_size);
665 	memset(buf2, BUF_PATTERN_2, buf_size);
666 
667 	if (send(fd, buf1, buf_size, 0) != buf_size) {
668 		perror("send failed");
669 		exit(EXIT_FAILURE);
670 	}
671 
672 	if (send(fd, buf2, buf_size, 0) != buf_size) {
673 		perror("send failed");
674 		exit(EXIT_FAILURE);
675 	}
676 
677 	close(fd);
678 }
679 
680 static void test_seqpacket_invalid_rec_buffer_server(const struct test_opts *opts)
681 {
682 	int fd;
683 	unsigned char *broken_buf;
684 	unsigned char *valid_buf;
685 	int page_size = getpagesize();
686 	int buf_size = page_size * 3;
687 	ssize_t res;
688 	int prot = PROT_READ | PROT_WRITE;
689 	int flags = MAP_PRIVATE | MAP_ANONYMOUS;
690 	int i;
691 
692 	fd = vsock_seqpacket_accept(VMADDR_CID_ANY, 1234, NULL);
693 	if (fd < 0) {
694 		perror("accept");
695 		exit(EXIT_FAILURE);
696 	}
697 
698 	/* Setup first buffer. */
699 	broken_buf = mmap(NULL, buf_size, prot, flags, -1, 0);
700 	if (broken_buf == MAP_FAILED) {
701 		perror("mmap for 'broken_buf'");
702 		exit(EXIT_FAILURE);
703 	}
704 
705 	/* Unmap "hole" in buffer. */
706 	if (munmap(broken_buf + page_size, page_size)) {
707 		perror("'broken_buf' setup");
708 		exit(EXIT_FAILURE);
709 	}
710 
711 	valid_buf = mmap(NULL, buf_size, prot, flags, -1, 0);
712 	if (valid_buf == MAP_FAILED) {
713 		perror("mmap for 'valid_buf'");
714 		exit(EXIT_FAILURE);
715 	}
716 
717 	/* Try to fill buffer with unmapped middle. */
718 	res = read(fd, broken_buf, buf_size);
719 	if (res != -1) {
720 		fprintf(stderr,
721 			"expected 'broken_buf' read(2) failure, got %zi\n",
722 			res);
723 		exit(EXIT_FAILURE);
724 	}
725 
726 	if (errno != EFAULT) {
727 		perror("unexpected errno of 'broken_buf'");
728 		exit(EXIT_FAILURE);
729 	}
730 
731 	/* Try to fill valid buffer. */
732 	res = read(fd, valid_buf, buf_size);
733 	if (res < 0) {
734 		perror("unexpected 'valid_buf' read(2) failure");
735 		exit(EXIT_FAILURE);
736 	}
737 
738 	if (res != buf_size) {
739 		fprintf(stderr,
740 			"invalid 'valid_buf' read(2), expected %i, got %zi\n",
741 			buf_size, res);
742 		exit(EXIT_FAILURE);
743 	}
744 
745 	for (i = 0; i < buf_size; i++) {
746 		if (valid_buf[i] != BUF_PATTERN_2) {
747 			fprintf(stderr,
748 				"invalid pattern for 'valid_buf' at %i, expected %hhX, got %hhX\n",
749 				i, BUF_PATTERN_2, valid_buf[i]);
750 			exit(EXIT_FAILURE);
751 		}
752 	}
753 
754 	/* Unmap buffers. */
755 	munmap(broken_buf, page_size);
756 	munmap(broken_buf + page_size * 2, page_size);
757 	munmap(valid_buf, buf_size);
758 	close(fd);
759 }
760 
761 #define RCVLOWAT_BUF_SIZE 128
762 
763 static void test_stream_poll_rcvlowat_server(const struct test_opts *opts)
764 {
765 	int fd;
766 	int i;
767 
768 	fd = vsock_stream_accept(VMADDR_CID_ANY, 1234, NULL);
769 	if (fd < 0) {
770 		perror("accept");
771 		exit(EXIT_FAILURE);
772 	}
773 
774 	/* Send 1 byte. */
775 	send_byte(fd, 1, 0);
776 
777 	control_writeln("SRVSENT");
778 
779 	/* Wait until client is ready to receive rest of data. */
780 	control_expectln("CLNSENT");
781 
782 	for (i = 0; i < RCVLOWAT_BUF_SIZE - 1; i++)
783 		send_byte(fd, 1, 0);
784 
785 	/* Keep socket in active state. */
786 	control_expectln("POLLDONE");
787 
788 	close(fd);
789 }
790 
791 static void test_stream_poll_rcvlowat_client(const struct test_opts *opts)
792 {
793 	unsigned long lowat_val = RCVLOWAT_BUF_SIZE;
794 	char buf[RCVLOWAT_BUF_SIZE];
795 	struct pollfd fds;
796 	ssize_t read_res;
797 	short poll_flags;
798 	int fd;
799 
800 	fd = vsock_stream_connect(opts->peer_cid, 1234);
801 	if (fd < 0) {
802 		perror("connect");
803 		exit(EXIT_FAILURE);
804 	}
805 
806 	if (setsockopt(fd, SOL_SOCKET, SO_RCVLOWAT,
807 		       &lowat_val, sizeof(lowat_val))) {
808 		perror("setsockopt(SO_RCVLOWAT)");
809 		exit(EXIT_FAILURE);
810 	}
811 
812 	control_expectln("SRVSENT");
813 
814 	/* At this point, server sent 1 byte. */
815 	fds.fd = fd;
816 	poll_flags = POLLIN | POLLRDNORM;
817 	fds.events = poll_flags;
818 
819 	/* Try to wait for 1 sec. */
820 	if (poll(&fds, 1, 1000) < 0) {
821 		perror("poll");
822 		exit(EXIT_FAILURE);
823 	}
824 
825 	/* poll() must return nothing. */
826 	if (fds.revents) {
827 		fprintf(stderr, "Unexpected poll result %hx\n",
828 			fds.revents);
829 		exit(EXIT_FAILURE);
830 	}
831 
832 	/* Tell server to send rest of data. */
833 	control_writeln("CLNSENT");
834 
835 	/* Poll for data. */
836 	if (poll(&fds, 1, 10000) < 0) {
837 		perror("poll");
838 		exit(EXIT_FAILURE);
839 	}
840 
841 	/* Only these two bits are expected. */
842 	if (fds.revents != poll_flags) {
843 		fprintf(stderr, "Unexpected poll result %hx\n",
844 			fds.revents);
845 		exit(EXIT_FAILURE);
846 	}
847 
848 	/* Use MSG_DONTWAIT, if call is going to wait, EAGAIN
849 	 * will be returned.
850 	 */
851 	read_res = recv(fd, buf, sizeof(buf), MSG_DONTWAIT);
852 	if (read_res != RCVLOWAT_BUF_SIZE) {
853 		fprintf(stderr, "Unexpected recv result %zi\n",
854 			read_res);
855 		exit(EXIT_FAILURE);
856 	}
857 
858 	control_writeln("POLLDONE");
859 
860 	close(fd);
861 }
862 
863 #define INV_BUF_TEST_DATA_LEN 512
864 
865 static void test_inv_buf_client(const struct test_opts *opts, bool stream)
866 {
867 	unsigned char data[INV_BUF_TEST_DATA_LEN] = {0};
868 	ssize_t ret;
869 	int fd;
870 
871 	if (stream)
872 		fd = vsock_stream_connect(opts->peer_cid, 1234);
873 	else
874 		fd = vsock_seqpacket_connect(opts->peer_cid, 1234);
875 
876 	if (fd < 0) {
877 		perror("connect");
878 		exit(EXIT_FAILURE);
879 	}
880 
881 	control_expectln("SENDDONE");
882 
883 	/* Use invalid buffer here. */
884 	ret = recv(fd, NULL, sizeof(data), 0);
885 	if (ret != -1) {
886 		fprintf(stderr, "expected recv(2) failure, got %zi\n", ret);
887 		exit(EXIT_FAILURE);
888 	}
889 
890 	if (errno != EFAULT) {
891 		fprintf(stderr, "unexpected recv(2) errno %d\n", errno);
892 		exit(EXIT_FAILURE);
893 	}
894 
895 	ret = recv(fd, data, sizeof(data), MSG_DONTWAIT);
896 
897 	if (stream) {
898 		/* For SOCK_STREAM we must continue reading. */
899 		if (ret != sizeof(data)) {
900 			fprintf(stderr, "expected recv(2) success, got %zi\n", ret);
901 			exit(EXIT_FAILURE);
902 		}
903 		/* Don't check errno in case of success. */
904 	} else {
905 		/* For SOCK_SEQPACKET socket's queue must be empty. */
906 		if (ret != -1) {
907 			fprintf(stderr, "expected recv(2) failure, got %zi\n", ret);
908 			exit(EXIT_FAILURE);
909 		}
910 
911 		if (errno != EAGAIN) {
912 			fprintf(stderr, "unexpected recv(2) errno %d\n", errno);
913 			exit(EXIT_FAILURE);
914 		}
915 	}
916 
917 	control_writeln("DONE");
918 
919 	close(fd);
920 }
921 
922 static void test_inv_buf_server(const struct test_opts *opts, bool stream)
923 {
924 	unsigned char data[INV_BUF_TEST_DATA_LEN] = {0};
925 	ssize_t res;
926 	int fd;
927 
928 	if (stream)
929 		fd = vsock_stream_accept(VMADDR_CID_ANY, 1234, NULL);
930 	else
931 		fd = vsock_seqpacket_accept(VMADDR_CID_ANY, 1234, NULL);
932 
933 	if (fd < 0) {
934 		perror("accept");
935 		exit(EXIT_FAILURE);
936 	}
937 
938 	res = send(fd, data, sizeof(data), 0);
939 	if (res != sizeof(data)) {
940 		fprintf(stderr, "unexpected send(2) result %zi\n", res);
941 		exit(EXIT_FAILURE);
942 	}
943 
944 	control_writeln("SENDDONE");
945 
946 	control_expectln("DONE");
947 
948 	close(fd);
949 }
950 
951 static void test_stream_inv_buf_client(const struct test_opts *opts)
952 {
953 	test_inv_buf_client(opts, true);
954 }
955 
956 static void test_stream_inv_buf_server(const struct test_opts *opts)
957 {
958 	test_inv_buf_server(opts, true);
959 }
960 
961 static void test_seqpacket_inv_buf_client(const struct test_opts *opts)
962 {
963 	test_inv_buf_client(opts, false);
964 }
965 
966 static void test_seqpacket_inv_buf_server(const struct test_opts *opts)
967 {
968 	test_inv_buf_server(opts, false);
969 }
970 
971 #define HELLO_STR "HELLO"
972 #define WORLD_STR "WORLD"
973 
974 static void test_stream_virtio_skb_merge_client(const struct test_opts *opts)
975 {
976 	ssize_t res;
977 	int fd;
978 
979 	fd = vsock_stream_connect(opts->peer_cid, 1234);
980 	if (fd < 0) {
981 		perror("connect");
982 		exit(EXIT_FAILURE);
983 	}
984 
985 	/* Send first skbuff. */
986 	res = send(fd, HELLO_STR, strlen(HELLO_STR), 0);
987 	if (res != strlen(HELLO_STR)) {
988 		fprintf(stderr, "unexpected send(2) result %zi\n", res);
989 		exit(EXIT_FAILURE);
990 	}
991 
992 	control_writeln("SEND0");
993 	/* Peer reads part of first skbuff. */
994 	control_expectln("REPLY0");
995 
996 	/* Send second skbuff, it will be appended to the first. */
997 	res = send(fd, WORLD_STR, strlen(WORLD_STR), 0);
998 	if (res != strlen(WORLD_STR)) {
999 		fprintf(stderr, "unexpected send(2) result %zi\n", res);
1000 		exit(EXIT_FAILURE);
1001 	}
1002 
1003 	control_writeln("SEND1");
1004 	/* Peer reads merged skbuff packet. */
1005 	control_expectln("REPLY1");
1006 
1007 	close(fd);
1008 }
1009 
1010 static void test_stream_virtio_skb_merge_server(const struct test_opts *opts)
1011 {
1012 	unsigned char buf[64];
1013 	ssize_t res;
1014 	int fd;
1015 
1016 	fd = vsock_stream_accept(VMADDR_CID_ANY, 1234, NULL);
1017 	if (fd < 0) {
1018 		perror("accept");
1019 		exit(EXIT_FAILURE);
1020 	}
1021 
1022 	control_expectln("SEND0");
1023 
1024 	/* Read skbuff partially. */
1025 	res = recv(fd, buf, 2, 0);
1026 	if (res != 2) {
1027 		fprintf(stderr, "expected recv(2) returns 2 bytes, got %zi\n", res);
1028 		exit(EXIT_FAILURE);
1029 	}
1030 
1031 	control_writeln("REPLY0");
1032 	control_expectln("SEND1");
1033 
1034 	res = recv(fd, buf + 2, sizeof(buf) - 2, 0);
1035 	if (res != 8) {
1036 		fprintf(stderr, "expected recv(2) returns 8 bytes, got %zi\n", res);
1037 		exit(EXIT_FAILURE);
1038 	}
1039 
1040 	res = recv(fd, buf, sizeof(buf) - 8 - 2, MSG_DONTWAIT);
1041 	if (res != -1) {
1042 		fprintf(stderr, "expected recv(2) failure, got %zi\n", res);
1043 		exit(EXIT_FAILURE);
1044 	}
1045 
1046 	if (memcmp(buf, HELLO_STR WORLD_STR, strlen(HELLO_STR WORLD_STR))) {
1047 		fprintf(stderr, "pattern mismatch\n");
1048 		exit(EXIT_FAILURE);
1049 	}
1050 
1051 	control_writeln("REPLY1");
1052 
1053 	close(fd);
1054 }
1055 
1056 static struct test_case test_cases[] = {
1057 	{
1058 		.name = "SOCK_STREAM connection reset",
1059 		.run_client = test_stream_connection_reset,
1060 	},
1061 	{
1062 		.name = "SOCK_STREAM bind only",
1063 		.run_client = test_stream_bind_only_client,
1064 		.run_server = test_stream_bind_only_server,
1065 	},
1066 	{
1067 		.name = "SOCK_STREAM client close",
1068 		.run_client = test_stream_client_close_client,
1069 		.run_server = test_stream_client_close_server,
1070 	},
1071 	{
1072 		.name = "SOCK_STREAM server close",
1073 		.run_client = test_stream_server_close_client,
1074 		.run_server = test_stream_server_close_server,
1075 	},
1076 	{
1077 		.name = "SOCK_STREAM multiple connections",
1078 		.run_client = test_stream_multiconn_client,
1079 		.run_server = test_stream_multiconn_server,
1080 	},
1081 	{
1082 		.name = "SOCK_STREAM MSG_PEEK",
1083 		.run_client = test_stream_msg_peek_client,
1084 		.run_server = test_stream_msg_peek_server,
1085 	},
1086 	{
1087 		.name = "SOCK_SEQPACKET msg bounds",
1088 		.run_client = test_seqpacket_msg_bounds_client,
1089 		.run_server = test_seqpacket_msg_bounds_server,
1090 	},
1091 	{
1092 		.name = "SOCK_SEQPACKET MSG_TRUNC flag",
1093 		.run_client = test_seqpacket_msg_trunc_client,
1094 		.run_server = test_seqpacket_msg_trunc_server,
1095 	},
1096 	{
1097 		.name = "SOCK_SEQPACKET timeout",
1098 		.run_client = test_seqpacket_timeout_client,
1099 		.run_server = test_seqpacket_timeout_server,
1100 	},
1101 	{
1102 		.name = "SOCK_SEQPACKET invalid receive buffer",
1103 		.run_client = test_seqpacket_invalid_rec_buffer_client,
1104 		.run_server = test_seqpacket_invalid_rec_buffer_server,
1105 	},
1106 	{
1107 		.name = "SOCK_STREAM poll() + SO_RCVLOWAT",
1108 		.run_client = test_stream_poll_rcvlowat_client,
1109 		.run_server = test_stream_poll_rcvlowat_server,
1110 	},
1111 	{
1112 		.name = "SOCK_SEQPACKET big message",
1113 		.run_client = test_seqpacket_bigmsg_client,
1114 		.run_server = test_seqpacket_bigmsg_server,
1115 	},
1116 	{
1117 		.name = "SOCK_STREAM test invalid buffer",
1118 		.run_client = test_stream_inv_buf_client,
1119 		.run_server = test_stream_inv_buf_server,
1120 	},
1121 	{
1122 		.name = "SOCK_SEQPACKET test invalid buffer",
1123 		.run_client = test_seqpacket_inv_buf_client,
1124 		.run_server = test_seqpacket_inv_buf_server,
1125 	},
1126 	{
1127 		.name = "SOCK_STREAM virtio skb merge",
1128 		.run_client = test_stream_virtio_skb_merge_client,
1129 		.run_server = test_stream_virtio_skb_merge_server,
1130 	},
1131 	{},
1132 };
1133 
1134 static const char optstring[] = "";
1135 static const struct option longopts[] = {
1136 	{
1137 		.name = "control-host",
1138 		.has_arg = required_argument,
1139 		.val = 'H',
1140 	},
1141 	{
1142 		.name = "control-port",
1143 		.has_arg = required_argument,
1144 		.val = 'P',
1145 	},
1146 	{
1147 		.name = "mode",
1148 		.has_arg = required_argument,
1149 		.val = 'm',
1150 	},
1151 	{
1152 		.name = "peer-cid",
1153 		.has_arg = required_argument,
1154 		.val = 'p',
1155 	},
1156 	{
1157 		.name = "list",
1158 		.has_arg = no_argument,
1159 		.val = 'l',
1160 	},
1161 	{
1162 		.name = "skip",
1163 		.has_arg = required_argument,
1164 		.val = 's',
1165 	},
1166 	{
1167 		.name = "help",
1168 		.has_arg = no_argument,
1169 		.val = '?',
1170 	},
1171 	{},
1172 };
1173 
1174 static void usage(void)
1175 {
1176 	fprintf(stderr, "Usage: vsock_test [--help] [--control-host=<host>] --control-port=<port> --mode=client|server --peer-cid=<cid> [--list] [--skip=<test_id>]\n"
1177 		"\n"
1178 		"  Server: vsock_test --control-port=1234 --mode=server --peer-cid=3\n"
1179 		"  Client: vsock_test --control-host=192.168.0.1 --control-port=1234 --mode=client --peer-cid=2\n"
1180 		"\n"
1181 		"Run vsock.ko tests.  Must be launched in both guest\n"
1182 		"and host.  One side must use --mode=client and\n"
1183 		"the other side must use --mode=server.\n"
1184 		"\n"
1185 		"A TCP control socket connection is used to coordinate tests\n"
1186 		"between the client and the server.  The server requires a\n"
1187 		"listen address and the client requires an address to\n"
1188 		"connect to.\n"
1189 		"\n"
1190 		"The CID of the other side must be given with --peer-cid=<cid>.\n"
1191 		"\n"
1192 		"Options:\n"
1193 		"  --help                 This help message\n"
1194 		"  --control-host <host>  Server IP address to connect to\n"
1195 		"  --control-port <port>  Server port to listen on/connect to\n"
1196 		"  --mode client|server   Server or client mode\n"
1197 		"  --peer-cid <cid>       CID of the other side\n"
1198 		"  --list                 List of tests that will be executed\n"
1199 		"  --skip <test_id>       Test ID to skip;\n"
1200 		"                         use multiple --skip options to skip more tests\n"
1201 		);
1202 	exit(EXIT_FAILURE);
1203 }
1204 
1205 int main(int argc, char **argv)
1206 {
1207 	const char *control_host = NULL;
1208 	const char *control_port = NULL;
1209 	struct test_opts opts = {
1210 		.mode = TEST_MODE_UNSET,
1211 		.peer_cid = VMADDR_CID_ANY,
1212 	};
1213 
1214 	srand(time(NULL));
1215 	init_signals();
1216 
1217 	for (;;) {
1218 		int opt = getopt_long(argc, argv, optstring, longopts, NULL);
1219 
1220 		if (opt == -1)
1221 			break;
1222 
1223 		switch (opt) {
1224 		case 'H':
1225 			control_host = optarg;
1226 			break;
1227 		case 'm':
1228 			if (strcmp(optarg, "client") == 0)
1229 				opts.mode = TEST_MODE_CLIENT;
1230 			else if (strcmp(optarg, "server") == 0)
1231 				opts.mode = TEST_MODE_SERVER;
1232 			else {
1233 				fprintf(stderr, "--mode must be \"client\" or \"server\"\n");
1234 				return EXIT_FAILURE;
1235 			}
1236 			break;
1237 		case 'p':
1238 			opts.peer_cid = parse_cid(optarg);
1239 			break;
1240 		case 'P':
1241 			control_port = optarg;
1242 			break;
1243 		case 'l':
1244 			list_tests(test_cases);
1245 			break;
1246 		case 's':
1247 			skip_test(test_cases, ARRAY_SIZE(test_cases) - 1,
1248 				  optarg);
1249 			break;
1250 		case '?':
1251 		default:
1252 			usage();
1253 		}
1254 	}
1255 
1256 	if (!control_port)
1257 		usage();
1258 	if (opts.mode == TEST_MODE_UNSET)
1259 		usage();
1260 	if (opts.peer_cid == VMADDR_CID_ANY)
1261 		usage();
1262 
1263 	if (!control_host) {
1264 		if (opts.mode != TEST_MODE_SERVER)
1265 			usage();
1266 		control_host = "0.0.0.0";
1267 	}
1268 
1269 	control_init(control_host, control_port,
1270 		     opts.mode == TEST_MODE_SERVER);
1271 
1272 	run_tests(test_cases, &opts);
1273 
1274 	control_cleanup();
1275 	return EXIT_SUCCESS;
1276 }
1277