xref: /openbmc/linux/tools/testing/vsock/vsock_test.c (revision 22b6e7f3)
1 // SPDX-License-Identifier: GPL-2.0-only
2 /*
3  * vsock_test - vsock.ko test suite
4  *
5  * Copyright (C) 2017 Red Hat, Inc.
6  *
7  * Author: Stefan Hajnoczi <stefanha@redhat.com>
8  */
9 
10 #include <getopt.h>
11 #include <stdio.h>
12 #include <stdlib.h>
13 #include <string.h>
14 #include <errno.h>
15 #include <unistd.h>
16 #include <linux/kernel.h>
17 #include <sys/types.h>
18 #include <sys/socket.h>
19 #include <time.h>
20 #include <sys/mman.h>
21 #include <poll.h>
22 
23 #include "timeout.h"
24 #include "control.h"
25 #include "util.h"
26 
27 static void test_stream_connection_reset(const struct test_opts *opts)
28 {
29 	union {
30 		struct sockaddr sa;
31 		struct sockaddr_vm svm;
32 	} addr = {
33 		.svm = {
34 			.svm_family = AF_VSOCK,
35 			.svm_port = 1234,
36 			.svm_cid = opts->peer_cid,
37 		},
38 	};
39 	int ret;
40 	int fd;
41 
42 	fd = socket(AF_VSOCK, SOCK_STREAM, 0);
43 
44 	timeout_begin(TIMEOUT);
45 	do {
46 		ret = connect(fd, &addr.sa, sizeof(addr.svm));
47 		timeout_check("connect");
48 	} while (ret < 0 && errno == EINTR);
49 	timeout_end();
50 
51 	if (ret != -1) {
52 		fprintf(stderr, "expected connect(2) failure, got %d\n", ret);
53 		exit(EXIT_FAILURE);
54 	}
55 	if (errno != ECONNRESET) {
56 		fprintf(stderr, "unexpected connect(2) errno %d\n", errno);
57 		exit(EXIT_FAILURE);
58 	}
59 
60 	close(fd);
61 }
62 
63 static void test_stream_bind_only_client(const struct test_opts *opts)
64 {
65 	union {
66 		struct sockaddr sa;
67 		struct sockaddr_vm svm;
68 	} addr = {
69 		.svm = {
70 			.svm_family = AF_VSOCK,
71 			.svm_port = 1234,
72 			.svm_cid = opts->peer_cid,
73 		},
74 	};
75 	int ret;
76 	int fd;
77 
78 	/* Wait for the server to be ready */
79 	control_expectln("BIND");
80 
81 	fd = socket(AF_VSOCK, SOCK_STREAM, 0);
82 
83 	timeout_begin(TIMEOUT);
84 	do {
85 		ret = connect(fd, &addr.sa, sizeof(addr.svm));
86 		timeout_check("connect");
87 	} while (ret < 0 && errno == EINTR);
88 	timeout_end();
89 
90 	if (ret != -1) {
91 		fprintf(stderr, "expected connect(2) failure, got %d\n", ret);
92 		exit(EXIT_FAILURE);
93 	}
94 	if (errno != ECONNRESET) {
95 		fprintf(stderr, "unexpected connect(2) errno %d\n", errno);
96 		exit(EXIT_FAILURE);
97 	}
98 
99 	/* Notify the server that the client has finished */
100 	control_writeln("DONE");
101 
102 	close(fd);
103 }
104 
105 static void test_stream_bind_only_server(const struct test_opts *opts)
106 {
107 	union {
108 		struct sockaddr sa;
109 		struct sockaddr_vm svm;
110 	} addr = {
111 		.svm = {
112 			.svm_family = AF_VSOCK,
113 			.svm_port = 1234,
114 			.svm_cid = VMADDR_CID_ANY,
115 		},
116 	};
117 	int fd;
118 
119 	fd = socket(AF_VSOCK, SOCK_STREAM, 0);
120 
121 	if (bind(fd, &addr.sa, sizeof(addr.svm)) < 0) {
122 		perror("bind");
123 		exit(EXIT_FAILURE);
124 	}
125 
126 	/* Notify the client that the server is ready */
127 	control_writeln("BIND");
128 
129 	/* Wait for the client to finish */
130 	control_expectln("DONE");
131 
132 	close(fd);
133 }
134 
135 static void test_stream_client_close_client(const struct test_opts *opts)
136 {
137 	int fd;
138 
139 	fd = vsock_stream_connect(opts->peer_cid, 1234);
140 	if (fd < 0) {
141 		perror("connect");
142 		exit(EXIT_FAILURE);
143 	}
144 
145 	send_byte(fd, 1, 0);
146 	close(fd);
147 }
148 
149 static void test_stream_client_close_server(const struct test_opts *opts)
150 {
151 	int fd;
152 
153 	fd = vsock_stream_accept(VMADDR_CID_ANY, 1234, NULL);
154 	if (fd < 0) {
155 		perror("accept");
156 		exit(EXIT_FAILURE);
157 	}
158 
159 	/* Wait for the remote to close the connection, before check
160 	 * -EPIPE error on send.
161 	 */
162 	vsock_wait_remote_close(fd);
163 
164 	send_byte(fd, -EPIPE, 0);
165 	recv_byte(fd, 1, 0);
166 	recv_byte(fd, 0, 0);
167 	close(fd);
168 }
169 
170 static void test_stream_server_close_client(const struct test_opts *opts)
171 {
172 	int fd;
173 
174 	fd = vsock_stream_connect(opts->peer_cid, 1234);
175 	if (fd < 0) {
176 		perror("connect");
177 		exit(EXIT_FAILURE);
178 	}
179 
180 	/* Wait for the remote to close the connection, before check
181 	 * -EPIPE error on send.
182 	 */
183 	vsock_wait_remote_close(fd);
184 
185 	send_byte(fd, -EPIPE, 0);
186 	recv_byte(fd, 1, 0);
187 	recv_byte(fd, 0, 0);
188 	close(fd);
189 }
190 
191 static void test_stream_server_close_server(const struct test_opts *opts)
192 {
193 	int fd;
194 
195 	fd = vsock_stream_accept(VMADDR_CID_ANY, 1234, NULL);
196 	if (fd < 0) {
197 		perror("accept");
198 		exit(EXIT_FAILURE);
199 	}
200 
201 	send_byte(fd, 1, 0);
202 	close(fd);
203 }
204 
205 /* With the standard socket sizes, VMCI is able to support about 100
206  * concurrent stream connections.
207  */
208 #define MULTICONN_NFDS 100
209 
210 static void test_stream_multiconn_client(const struct test_opts *opts)
211 {
212 	int fds[MULTICONN_NFDS];
213 	int i;
214 
215 	for (i = 0; i < MULTICONN_NFDS; i++) {
216 		fds[i] = vsock_stream_connect(opts->peer_cid, 1234);
217 		if (fds[i] < 0) {
218 			perror("connect");
219 			exit(EXIT_FAILURE);
220 		}
221 	}
222 
223 	for (i = 0; i < MULTICONN_NFDS; i++) {
224 		if (i % 2)
225 			recv_byte(fds[i], 1, 0);
226 		else
227 			send_byte(fds[i], 1, 0);
228 	}
229 
230 	for (i = 0; i < MULTICONN_NFDS; i++)
231 		close(fds[i]);
232 }
233 
234 static void test_stream_multiconn_server(const struct test_opts *opts)
235 {
236 	int fds[MULTICONN_NFDS];
237 	int i;
238 
239 	for (i = 0; i < MULTICONN_NFDS; i++) {
240 		fds[i] = vsock_stream_accept(VMADDR_CID_ANY, 1234, NULL);
241 		if (fds[i] < 0) {
242 			perror("accept");
243 			exit(EXIT_FAILURE);
244 		}
245 	}
246 
247 	for (i = 0; i < MULTICONN_NFDS; i++) {
248 		if (i % 2)
249 			send_byte(fds[i], 1, 0);
250 		else
251 			recv_byte(fds[i], 1, 0);
252 	}
253 
254 	for (i = 0; i < MULTICONN_NFDS; i++)
255 		close(fds[i]);
256 }
257 
258 #define MSG_PEEK_BUF_LEN 64
259 
260 static void test_msg_peek_client(const struct test_opts *opts,
261 				 bool seqpacket)
262 {
263 	unsigned char buf[MSG_PEEK_BUF_LEN];
264 	ssize_t send_size;
265 	int fd;
266 	int i;
267 
268 	if (seqpacket)
269 		fd = vsock_seqpacket_connect(opts->peer_cid, 1234);
270 	else
271 		fd = vsock_stream_connect(opts->peer_cid, 1234);
272 
273 	if (fd < 0) {
274 		perror("connect");
275 		exit(EXIT_FAILURE);
276 	}
277 
278 	for (i = 0; i < sizeof(buf); i++)
279 		buf[i] = rand() & 0xFF;
280 
281 	control_expectln("SRVREADY");
282 
283 	send_size = send(fd, buf, sizeof(buf), 0);
284 
285 	if (send_size < 0) {
286 		perror("send");
287 		exit(EXIT_FAILURE);
288 	}
289 
290 	if (send_size != sizeof(buf)) {
291 		fprintf(stderr, "Invalid send size %zi\n", send_size);
292 		exit(EXIT_FAILURE);
293 	}
294 
295 	close(fd);
296 }
297 
298 static void test_msg_peek_server(const struct test_opts *opts,
299 				 bool seqpacket)
300 {
301 	unsigned char buf_half[MSG_PEEK_BUF_LEN / 2];
302 	unsigned char buf_normal[MSG_PEEK_BUF_LEN];
303 	unsigned char buf_peek[MSG_PEEK_BUF_LEN];
304 	ssize_t res;
305 	int fd;
306 
307 	if (seqpacket)
308 		fd = vsock_seqpacket_accept(VMADDR_CID_ANY, 1234, NULL);
309 	else
310 		fd = vsock_stream_accept(VMADDR_CID_ANY, 1234, NULL);
311 
312 	if (fd < 0) {
313 		perror("accept");
314 		exit(EXIT_FAILURE);
315 	}
316 
317 	/* Peek from empty socket. */
318 	res = recv(fd, buf_peek, sizeof(buf_peek), MSG_PEEK | MSG_DONTWAIT);
319 	if (res != -1) {
320 		fprintf(stderr, "expected recv(2) failure, got %zi\n", res);
321 		exit(EXIT_FAILURE);
322 	}
323 
324 	if (errno != EAGAIN) {
325 		perror("EAGAIN expected");
326 		exit(EXIT_FAILURE);
327 	}
328 
329 	control_writeln("SRVREADY");
330 
331 	/* Peek part of data. */
332 	res = recv(fd, buf_half, sizeof(buf_half), MSG_PEEK);
333 	if (res != sizeof(buf_half)) {
334 		fprintf(stderr, "recv(2) + MSG_PEEK, expected %zu, got %zi\n",
335 			sizeof(buf_half), res);
336 		exit(EXIT_FAILURE);
337 	}
338 
339 	/* Peek whole data. */
340 	res = recv(fd, buf_peek, sizeof(buf_peek), MSG_PEEK);
341 	if (res != sizeof(buf_peek)) {
342 		fprintf(stderr, "recv(2) + MSG_PEEK, expected %zu, got %zi\n",
343 			sizeof(buf_peek), res);
344 		exit(EXIT_FAILURE);
345 	}
346 
347 	/* Compare partial and full peek. */
348 	if (memcmp(buf_half, buf_peek, sizeof(buf_half))) {
349 		fprintf(stderr, "Partial peek data mismatch\n");
350 		exit(EXIT_FAILURE);
351 	}
352 
353 	if (seqpacket) {
354 		/* This type of socket supports MSG_TRUNC flag,
355 		 * so check it with MSG_PEEK. We must get length
356 		 * of the message.
357 		 */
358 		res = recv(fd, buf_half, sizeof(buf_half), MSG_PEEK |
359 			   MSG_TRUNC);
360 		if (res != sizeof(buf_peek)) {
361 			fprintf(stderr,
362 				"recv(2) + MSG_PEEK | MSG_TRUNC, exp %zu, got %zi\n",
363 				sizeof(buf_half), res);
364 			exit(EXIT_FAILURE);
365 		}
366 	}
367 
368 	res = recv(fd, buf_normal, sizeof(buf_normal), 0);
369 	if (res != sizeof(buf_normal)) {
370 		fprintf(stderr, "recv(2), expected %zu, got %zi\n",
371 			sizeof(buf_normal), res);
372 		exit(EXIT_FAILURE);
373 	}
374 
375 	/* Compare full peek and normal read. */
376 	if (memcmp(buf_peek, buf_normal, sizeof(buf_peek))) {
377 		fprintf(stderr, "Full peek data mismatch\n");
378 		exit(EXIT_FAILURE);
379 	}
380 
381 	close(fd);
382 }
383 
384 static void test_stream_msg_peek_client(const struct test_opts *opts)
385 {
386 	return test_msg_peek_client(opts, false);
387 }
388 
389 static void test_stream_msg_peek_server(const struct test_opts *opts)
390 {
391 	return test_msg_peek_server(opts, false);
392 }
393 
394 #define SOCK_BUF_SIZE (2 * 1024 * 1024)
395 #define MAX_MSG_SIZE (32 * 1024)
396 
397 static void test_seqpacket_msg_bounds_client(const struct test_opts *opts)
398 {
399 	unsigned long curr_hash;
400 	int page_size;
401 	int msg_count;
402 	int fd;
403 
404 	fd = vsock_seqpacket_connect(opts->peer_cid, 1234);
405 	if (fd < 0) {
406 		perror("connect");
407 		exit(EXIT_FAILURE);
408 	}
409 
410 	/* Wait, until receiver sets buffer size. */
411 	control_expectln("SRVREADY");
412 
413 	curr_hash = 0;
414 	page_size = getpagesize();
415 	msg_count = SOCK_BUF_SIZE / MAX_MSG_SIZE;
416 
417 	for (int i = 0; i < msg_count; i++) {
418 		ssize_t send_size;
419 		size_t buf_size;
420 		int flags;
421 		void *buf;
422 
423 		/* Use "small" buffers and "big" buffers. */
424 		if (i & 1)
425 			buf_size = page_size +
426 					(rand() % (MAX_MSG_SIZE - page_size));
427 		else
428 			buf_size = 1 + (rand() % page_size);
429 
430 		buf = malloc(buf_size);
431 
432 		if (!buf) {
433 			perror("malloc");
434 			exit(EXIT_FAILURE);
435 		}
436 
437 		memset(buf, rand() & 0xff, buf_size);
438 		/* Set at least one MSG_EOR + some random. */
439 		if (i == (msg_count / 2) || (rand() & 1)) {
440 			flags = MSG_EOR;
441 			curr_hash++;
442 		} else {
443 			flags = 0;
444 		}
445 
446 		send_size = send(fd, buf, buf_size, flags);
447 
448 		if (send_size < 0) {
449 			perror("send");
450 			exit(EXIT_FAILURE);
451 		}
452 
453 		if (send_size != buf_size) {
454 			fprintf(stderr, "Invalid send size\n");
455 			exit(EXIT_FAILURE);
456 		}
457 
458 		/*
459 		 * Hash sum is computed at both client and server in
460 		 * the same way:
461 		 * H += hash('message data')
462 		 * Such hash "controls" both data integrity and message
463 		 * bounds. After data exchange, both sums are compared
464 		 * using control socket, and if message bounds wasn't
465 		 * broken - two values must be equal.
466 		 */
467 		curr_hash += hash_djb2(buf, buf_size);
468 		free(buf);
469 	}
470 
471 	control_writeln("SENDDONE");
472 	control_writeulong(curr_hash);
473 	close(fd);
474 }
475 
476 static void test_seqpacket_msg_bounds_server(const struct test_opts *opts)
477 {
478 	unsigned long sock_buf_size;
479 	unsigned long remote_hash;
480 	unsigned long curr_hash;
481 	int fd;
482 	char buf[MAX_MSG_SIZE];
483 	struct msghdr msg = {0};
484 	struct iovec iov = {0};
485 
486 	fd = vsock_seqpacket_accept(VMADDR_CID_ANY, 1234, NULL);
487 	if (fd < 0) {
488 		perror("accept");
489 		exit(EXIT_FAILURE);
490 	}
491 
492 	sock_buf_size = SOCK_BUF_SIZE;
493 
494 	if (setsockopt(fd, AF_VSOCK, SO_VM_SOCKETS_BUFFER_MAX_SIZE,
495 		       &sock_buf_size, sizeof(sock_buf_size))) {
496 		perror("setsockopt(SO_VM_SOCKETS_BUFFER_MAX_SIZE)");
497 		exit(EXIT_FAILURE);
498 	}
499 
500 	if (setsockopt(fd, AF_VSOCK, SO_VM_SOCKETS_BUFFER_SIZE,
501 		       &sock_buf_size, sizeof(sock_buf_size))) {
502 		perror("setsockopt(SO_VM_SOCKETS_BUFFER_SIZE)");
503 		exit(EXIT_FAILURE);
504 	}
505 
506 	/* Ready to receive data. */
507 	control_writeln("SRVREADY");
508 	/* Wait, until peer sends whole data. */
509 	control_expectln("SENDDONE");
510 	iov.iov_base = buf;
511 	iov.iov_len = sizeof(buf);
512 	msg.msg_iov = &iov;
513 	msg.msg_iovlen = 1;
514 
515 	curr_hash = 0;
516 
517 	while (1) {
518 		ssize_t recv_size;
519 
520 		recv_size = recvmsg(fd, &msg, 0);
521 
522 		if (!recv_size)
523 			break;
524 
525 		if (recv_size < 0) {
526 			perror("recvmsg");
527 			exit(EXIT_FAILURE);
528 		}
529 
530 		if (msg.msg_flags & MSG_EOR)
531 			curr_hash++;
532 
533 		curr_hash += hash_djb2(msg.msg_iov[0].iov_base, recv_size);
534 	}
535 
536 	close(fd);
537 	remote_hash = control_readulong();
538 
539 	if (curr_hash != remote_hash) {
540 		fprintf(stderr, "Message bounds broken\n");
541 		exit(EXIT_FAILURE);
542 	}
543 }
544 
545 #define MESSAGE_TRUNC_SZ 32
546 static void test_seqpacket_msg_trunc_client(const struct test_opts *opts)
547 {
548 	int fd;
549 	char buf[MESSAGE_TRUNC_SZ];
550 
551 	fd = vsock_seqpacket_connect(opts->peer_cid, 1234);
552 	if (fd < 0) {
553 		perror("connect");
554 		exit(EXIT_FAILURE);
555 	}
556 
557 	if (send(fd, buf, sizeof(buf), 0) != sizeof(buf)) {
558 		perror("send failed");
559 		exit(EXIT_FAILURE);
560 	}
561 
562 	control_writeln("SENDDONE");
563 	close(fd);
564 }
565 
566 static void test_seqpacket_msg_trunc_server(const struct test_opts *opts)
567 {
568 	int fd;
569 	char buf[MESSAGE_TRUNC_SZ / 2];
570 	struct msghdr msg = {0};
571 	struct iovec iov = {0};
572 
573 	fd = vsock_seqpacket_accept(VMADDR_CID_ANY, 1234, NULL);
574 	if (fd < 0) {
575 		perror("accept");
576 		exit(EXIT_FAILURE);
577 	}
578 
579 	control_expectln("SENDDONE");
580 	iov.iov_base = buf;
581 	iov.iov_len = sizeof(buf);
582 	msg.msg_iov = &iov;
583 	msg.msg_iovlen = 1;
584 
585 	ssize_t ret = recvmsg(fd, &msg, MSG_TRUNC);
586 
587 	if (ret != MESSAGE_TRUNC_SZ) {
588 		printf("%zi\n", ret);
589 		perror("MSG_TRUNC doesn't work");
590 		exit(EXIT_FAILURE);
591 	}
592 
593 	if (!(msg.msg_flags & MSG_TRUNC)) {
594 		fprintf(stderr, "MSG_TRUNC expected\n");
595 		exit(EXIT_FAILURE);
596 	}
597 
598 	close(fd);
599 }
600 
601 static time_t current_nsec(void)
602 {
603 	struct timespec ts;
604 
605 	if (clock_gettime(CLOCK_REALTIME, &ts)) {
606 		perror("clock_gettime(3) failed");
607 		exit(EXIT_FAILURE);
608 	}
609 
610 	return (ts.tv_sec * 1000000000ULL) + ts.tv_nsec;
611 }
612 
613 #define RCVTIMEO_TIMEOUT_SEC 1
614 #define READ_OVERHEAD_NSEC 250000000 /* 0.25 sec */
615 
616 static void test_seqpacket_timeout_client(const struct test_opts *opts)
617 {
618 	int fd;
619 	struct timeval tv;
620 	char dummy;
621 	time_t read_enter_ns;
622 	time_t read_overhead_ns;
623 
624 	fd = vsock_seqpacket_connect(opts->peer_cid, 1234);
625 	if (fd < 0) {
626 		perror("connect");
627 		exit(EXIT_FAILURE);
628 	}
629 
630 	tv.tv_sec = RCVTIMEO_TIMEOUT_SEC;
631 	tv.tv_usec = 0;
632 
633 	if (setsockopt(fd, SOL_SOCKET, SO_RCVTIMEO, (void *)&tv, sizeof(tv)) == -1) {
634 		perror("setsockopt(SO_RCVTIMEO)");
635 		exit(EXIT_FAILURE);
636 	}
637 
638 	read_enter_ns = current_nsec();
639 
640 	if (read(fd, &dummy, sizeof(dummy)) != -1) {
641 		fprintf(stderr,
642 			"expected 'dummy' read(2) failure\n");
643 		exit(EXIT_FAILURE);
644 	}
645 
646 	if (errno != EAGAIN) {
647 		perror("EAGAIN expected");
648 		exit(EXIT_FAILURE);
649 	}
650 
651 	read_overhead_ns = current_nsec() - read_enter_ns -
652 			1000000000ULL * RCVTIMEO_TIMEOUT_SEC;
653 
654 	if (read_overhead_ns > READ_OVERHEAD_NSEC) {
655 		fprintf(stderr,
656 			"too much time in read(2), %lu > %i ns\n",
657 			read_overhead_ns, READ_OVERHEAD_NSEC);
658 		exit(EXIT_FAILURE);
659 	}
660 
661 	control_writeln("WAITDONE");
662 	close(fd);
663 }
664 
665 static void test_seqpacket_timeout_server(const struct test_opts *opts)
666 {
667 	int fd;
668 
669 	fd = vsock_seqpacket_accept(VMADDR_CID_ANY, 1234, NULL);
670 	if (fd < 0) {
671 		perror("accept");
672 		exit(EXIT_FAILURE);
673 	}
674 
675 	control_expectln("WAITDONE");
676 	close(fd);
677 }
678 
679 static void test_seqpacket_bigmsg_client(const struct test_opts *opts)
680 {
681 	unsigned long sock_buf_size;
682 	ssize_t send_size;
683 	socklen_t len;
684 	void *data;
685 	int fd;
686 
687 	len = sizeof(sock_buf_size);
688 
689 	fd = vsock_seqpacket_connect(opts->peer_cid, 1234);
690 	if (fd < 0) {
691 		perror("connect");
692 		exit(EXIT_FAILURE);
693 	}
694 
695 	if (getsockopt(fd, AF_VSOCK, SO_VM_SOCKETS_BUFFER_SIZE,
696 		       &sock_buf_size, &len)) {
697 		perror("getsockopt");
698 		exit(EXIT_FAILURE);
699 	}
700 
701 	sock_buf_size++;
702 
703 	data = malloc(sock_buf_size);
704 	if (!data) {
705 		perror("malloc");
706 		exit(EXIT_FAILURE);
707 	}
708 
709 	send_size = send(fd, data, sock_buf_size, 0);
710 	if (send_size != -1) {
711 		fprintf(stderr, "expected 'send(2)' failure, got %zi\n",
712 			send_size);
713 		exit(EXIT_FAILURE);
714 	}
715 
716 	if (errno != EMSGSIZE) {
717 		fprintf(stderr, "expected EMSGSIZE in 'errno', got %i\n",
718 			errno);
719 		exit(EXIT_FAILURE);
720 	}
721 
722 	control_writeln("CLISENT");
723 
724 	free(data);
725 	close(fd);
726 }
727 
728 static void test_seqpacket_bigmsg_server(const struct test_opts *opts)
729 {
730 	int fd;
731 
732 	fd = vsock_seqpacket_accept(VMADDR_CID_ANY, 1234, NULL);
733 	if (fd < 0) {
734 		perror("accept");
735 		exit(EXIT_FAILURE);
736 	}
737 
738 	control_expectln("CLISENT");
739 
740 	close(fd);
741 }
742 
743 #define BUF_PATTERN_1 'a'
744 #define BUF_PATTERN_2 'b'
745 
746 static void test_seqpacket_invalid_rec_buffer_client(const struct test_opts *opts)
747 {
748 	int fd;
749 	unsigned char *buf1;
750 	unsigned char *buf2;
751 	int buf_size = getpagesize() * 3;
752 
753 	fd = vsock_seqpacket_connect(opts->peer_cid, 1234);
754 	if (fd < 0) {
755 		perror("connect");
756 		exit(EXIT_FAILURE);
757 	}
758 
759 	buf1 = malloc(buf_size);
760 	if (!buf1) {
761 		perror("'malloc()' for 'buf1'");
762 		exit(EXIT_FAILURE);
763 	}
764 
765 	buf2 = malloc(buf_size);
766 	if (!buf2) {
767 		perror("'malloc()' for 'buf2'");
768 		exit(EXIT_FAILURE);
769 	}
770 
771 	memset(buf1, BUF_PATTERN_1, buf_size);
772 	memset(buf2, BUF_PATTERN_2, buf_size);
773 
774 	if (send(fd, buf1, buf_size, 0) != buf_size) {
775 		perror("send failed");
776 		exit(EXIT_FAILURE);
777 	}
778 
779 	if (send(fd, buf2, buf_size, 0) != buf_size) {
780 		perror("send failed");
781 		exit(EXIT_FAILURE);
782 	}
783 
784 	close(fd);
785 }
786 
787 static void test_seqpacket_invalid_rec_buffer_server(const struct test_opts *opts)
788 {
789 	int fd;
790 	unsigned char *broken_buf;
791 	unsigned char *valid_buf;
792 	int page_size = getpagesize();
793 	int buf_size = page_size * 3;
794 	ssize_t res;
795 	int prot = PROT_READ | PROT_WRITE;
796 	int flags = MAP_PRIVATE | MAP_ANONYMOUS;
797 	int i;
798 
799 	fd = vsock_seqpacket_accept(VMADDR_CID_ANY, 1234, NULL);
800 	if (fd < 0) {
801 		perror("accept");
802 		exit(EXIT_FAILURE);
803 	}
804 
805 	/* Setup first buffer. */
806 	broken_buf = mmap(NULL, buf_size, prot, flags, -1, 0);
807 	if (broken_buf == MAP_FAILED) {
808 		perror("mmap for 'broken_buf'");
809 		exit(EXIT_FAILURE);
810 	}
811 
812 	/* Unmap "hole" in buffer. */
813 	if (munmap(broken_buf + page_size, page_size)) {
814 		perror("'broken_buf' setup");
815 		exit(EXIT_FAILURE);
816 	}
817 
818 	valid_buf = mmap(NULL, buf_size, prot, flags, -1, 0);
819 	if (valid_buf == MAP_FAILED) {
820 		perror("mmap for 'valid_buf'");
821 		exit(EXIT_FAILURE);
822 	}
823 
824 	/* Try to fill buffer with unmapped middle. */
825 	res = read(fd, broken_buf, buf_size);
826 	if (res != -1) {
827 		fprintf(stderr,
828 			"expected 'broken_buf' read(2) failure, got %zi\n",
829 			res);
830 		exit(EXIT_FAILURE);
831 	}
832 
833 	if (errno != EFAULT) {
834 		perror("unexpected errno of 'broken_buf'");
835 		exit(EXIT_FAILURE);
836 	}
837 
838 	/* Try to fill valid buffer. */
839 	res = read(fd, valid_buf, buf_size);
840 	if (res < 0) {
841 		perror("unexpected 'valid_buf' read(2) failure");
842 		exit(EXIT_FAILURE);
843 	}
844 
845 	if (res != buf_size) {
846 		fprintf(stderr,
847 			"invalid 'valid_buf' read(2), expected %i, got %zi\n",
848 			buf_size, res);
849 		exit(EXIT_FAILURE);
850 	}
851 
852 	for (i = 0; i < buf_size; i++) {
853 		if (valid_buf[i] != BUF_PATTERN_2) {
854 			fprintf(stderr,
855 				"invalid pattern for 'valid_buf' at %i, expected %hhX, got %hhX\n",
856 				i, BUF_PATTERN_2, valid_buf[i]);
857 			exit(EXIT_FAILURE);
858 		}
859 	}
860 
861 	/* Unmap buffers. */
862 	munmap(broken_buf, page_size);
863 	munmap(broken_buf + page_size * 2, page_size);
864 	munmap(valid_buf, buf_size);
865 	close(fd);
866 }
867 
868 #define RCVLOWAT_BUF_SIZE 128
869 
870 static void test_stream_poll_rcvlowat_server(const struct test_opts *opts)
871 {
872 	int fd;
873 	int i;
874 
875 	fd = vsock_stream_accept(VMADDR_CID_ANY, 1234, NULL);
876 	if (fd < 0) {
877 		perror("accept");
878 		exit(EXIT_FAILURE);
879 	}
880 
881 	/* Send 1 byte. */
882 	send_byte(fd, 1, 0);
883 
884 	control_writeln("SRVSENT");
885 
886 	/* Wait until client is ready to receive rest of data. */
887 	control_expectln("CLNSENT");
888 
889 	for (i = 0; i < RCVLOWAT_BUF_SIZE - 1; i++)
890 		send_byte(fd, 1, 0);
891 
892 	/* Keep socket in active state. */
893 	control_expectln("POLLDONE");
894 
895 	close(fd);
896 }
897 
898 static void test_stream_poll_rcvlowat_client(const struct test_opts *opts)
899 {
900 	unsigned long lowat_val = RCVLOWAT_BUF_SIZE;
901 	char buf[RCVLOWAT_BUF_SIZE];
902 	struct pollfd fds;
903 	ssize_t read_res;
904 	short poll_flags;
905 	int fd;
906 
907 	fd = vsock_stream_connect(opts->peer_cid, 1234);
908 	if (fd < 0) {
909 		perror("connect");
910 		exit(EXIT_FAILURE);
911 	}
912 
913 	if (setsockopt(fd, SOL_SOCKET, SO_RCVLOWAT,
914 		       &lowat_val, sizeof(lowat_val))) {
915 		perror("setsockopt(SO_RCVLOWAT)");
916 		exit(EXIT_FAILURE);
917 	}
918 
919 	control_expectln("SRVSENT");
920 
921 	/* At this point, server sent 1 byte. */
922 	fds.fd = fd;
923 	poll_flags = POLLIN | POLLRDNORM;
924 	fds.events = poll_flags;
925 
926 	/* Try to wait for 1 sec. */
927 	if (poll(&fds, 1, 1000) < 0) {
928 		perror("poll");
929 		exit(EXIT_FAILURE);
930 	}
931 
932 	/* poll() must return nothing. */
933 	if (fds.revents) {
934 		fprintf(stderr, "Unexpected poll result %hx\n",
935 			fds.revents);
936 		exit(EXIT_FAILURE);
937 	}
938 
939 	/* Tell server to send rest of data. */
940 	control_writeln("CLNSENT");
941 
942 	/* Poll for data. */
943 	if (poll(&fds, 1, 10000) < 0) {
944 		perror("poll");
945 		exit(EXIT_FAILURE);
946 	}
947 
948 	/* Only these two bits are expected. */
949 	if (fds.revents != poll_flags) {
950 		fprintf(stderr, "Unexpected poll result %hx\n",
951 			fds.revents);
952 		exit(EXIT_FAILURE);
953 	}
954 
955 	/* Use MSG_DONTWAIT, if call is going to wait, EAGAIN
956 	 * will be returned.
957 	 */
958 	read_res = recv(fd, buf, sizeof(buf), MSG_DONTWAIT);
959 	if (read_res != RCVLOWAT_BUF_SIZE) {
960 		fprintf(stderr, "Unexpected recv result %zi\n",
961 			read_res);
962 		exit(EXIT_FAILURE);
963 	}
964 
965 	control_writeln("POLLDONE");
966 
967 	close(fd);
968 }
969 
970 #define INV_BUF_TEST_DATA_LEN 512
971 
972 static void test_inv_buf_client(const struct test_opts *opts, bool stream)
973 {
974 	unsigned char data[INV_BUF_TEST_DATA_LEN] = {0};
975 	ssize_t ret;
976 	int fd;
977 
978 	if (stream)
979 		fd = vsock_stream_connect(opts->peer_cid, 1234);
980 	else
981 		fd = vsock_seqpacket_connect(opts->peer_cid, 1234);
982 
983 	if (fd < 0) {
984 		perror("connect");
985 		exit(EXIT_FAILURE);
986 	}
987 
988 	control_expectln("SENDDONE");
989 
990 	/* Use invalid buffer here. */
991 	ret = recv(fd, NULL, sizeof(data), 0);
992 	if (ret != -1) {
993 		fprintf(stderr, "expected recv(2) failure, got %zi\n", ret);
994 		exit(EXIT_FAILURE);
995 	}
996 
997 	if (errno != EFAULT) {
998 		fprintf(stderr, "unexpected recv(2) errno %d\n", errno);
999 		exit(EXIT_FAILURE);
1000 	}
1001 
1002 	ret = recv(fd, data, sizeof(data), MSG_DONTWAIT);
1003 
1004 	if (stream) {
1005 		/* For SOCK_STREAM we must continue reading. */
1006 		if (ret != sizeof(data)) {
1007 			fprintf(stderr, "expected recv(2) success, got %zi\n", ret);
1008 			exit(EXIT_FAILURE);
1009 		}
1010 		/* Don't check errno in case of success. */
1011 	} else {
1012 		/* For SOCK_SEQPACKET socket's queue must be empty. */
1013 		if (ret != -1) {
1014 			fprintf(stderr, "expected recv(2) failure, got %zi\n", ret);
1015 			exit(EXIT_FAILURE);
1016 		}
1017 
1018 		if (errno != EAGAIN) {
1019 			fprintf(stderr, "unexpected recv(2) errno %d\n", errno);
1020 			exit(EXIT_FAILURE);
1021 		}
1022 	}
1023 
1024 	control_writeln("DONE");
1025 
1026 	close(fd);
1027 }
1028 
1029 static void test_inv_buf_server(const struct test_opts *opts, bool stream)
1030 {
1031 	unsigned char data[INV_BUF_TEST_DATA_LEN] = {0};
1032 	ssize_t res;
1033 	int fd;
1034 
1035 	if (stream)
1036 		fd = vsock_stream_accept(VMADDR_CID_ANY, 1234, NULL);
1037 	else
1038 		fd = vsock_seqpacket_accept(VMADDR_CID_ANY, 1234, NULL);
1039 
1040 	if (fd < 0) {
1041 		perror("accept");
1042 		exit(EXIT_FAILURE);
1043 	}
1044 
1045 	res = send(fd, data, sizeof(data), 0);
1046 	if (res != sizeof(data)) {
1047 		fprintf(stderr, "unexpected send(2) result %zi\n", res);
1048 		exit(EXIT_FAILURE);
1049 	}
1050 
1051 	control_writeln("SENDDONE");
1052 
1053 	control_expectln("DONE");
1054 
1055 	close(fd);
1056 }
1057 
1058 static void test_stream_inv_buf_client(const struct test_opts *opts)
1059 {
1060 	test_inv_buf_client(opts, true);
1061 }
1062 
1063 static void test_stream_inv_buf_server(const struct test_opts *opts)
1064 {
1065 	test_inv_buf_server(opts, true);
1066 }
1067 
1068 static void test_seqpacket_inv_buf_client(const struct test_opts *opts)
1069 {
1070 	test_inv_buf_client(opts, false);
1071 }
1072 
1073 static void test_seqpacket_inv_buf_server(const struct test_opts *opts)
1074 {
1075 	test_inv_buf_server(opts, false);
1076 }
1077 
1078 #define HELLO_STR "HELLO"
1079 #define WORLD_STR "WORLD"
1080 
1081 static void test_stream_virtio_skb_merge_client(const struct test_opts *opts)
1082 {
1083 	ssize_t res;
1084 	int fd;
1085 
1086 	fd = vsock_stream_connect(opts->peer_cid, 1234);
1087 	if (fd < 0) {
1088 		perror("connect");
1089 		exit(EXIT_FAILURE);
1090 	}
1091 
1092 	/* Send first skbuff. */
1093 	res = send(fd, HELLO_STR, strlen(HELLO_STR), 0);
1094 	if (res != strlen(HELLO_STR)) {
1095 		fprintf(stderr, "unexpected send(2) result %zi\n", res);
1096 		exit(EXIT_FAILURE);
1097 	}
1098 
1099 	control_writeln("SEND0");
1100 	/* Peer reads part of first skbuff. */
1101 	control_expectln("REPLY0");
1102 
1103 	/* Send second skbuff, it will be appended to the first. */
1104 	res = send(fd, WORLD_STR, strlen(WORLD_STR), 0);
1105 	if (res != strlen(WORLD_STR)) {
1106 		fprintf(stderr, "unexpected send(2) result %zi\n", res);
1107 		exit(EXIT_FAILURE);
1108 	}
1109 
1110 	control_writeln("SEND1");
1111 	/* Peer reads merged skbuff packet. */
1112 	control_expectln("REPLY1");
1113 
1114 	close(fd);
1115 }
1116 
1117 static void test_stream_virtio_skb_merge_server(const struct test_opts *opts)
1118 {
1119 	unsigned char buf[64];
1120 	ssize_t res;
1121 	int fd;
1122 
1123 	fd = vsock_stream_accept(VMADDR_CID_ANY, 1234, NULL);
1124 	if (fd < 0) {
1125 		perror("accept");
1126 		exit(EXIT_FAILURE);
1127 	}
1128 
1129 	control_expectln("SEND0");
1130 
1131 	/* Read skbuff partially. */
1132 	res = recv(fd, buf, 2, 0);
1133 	if (res != 2) {
1134 		fprintf(stderr, "expected recv(2) returns 2 bytes, got %zi\n", res);
1135 		exit(EXIT_FAILURE);
1136 	}
1137 
1138 	control_writeln("REPLY0");
1139 	control_expectln("SEND1");
1140 
1141 	res = recv(fd, buf + 2, sizeof(buf) - 2, 0);
1142 	if (res != 8) {
1143 		fprintf(stderr, "expected recv(2) returns 8 bytes, got %zi\n", res);
1144 		exit(EXIT_FAILURE);
1145 	}
1146 
1147 	res = recv(fd, buf, sizeof(buf) - 8 - 2, MSG_DONTWAIT);
1148 	if (res != -1) {
1149 		fprintf(stderr, "expected recv(2) failure, got %zi\n", res);
1150 		exit(EXIT_FAILURE);
1151 	}
1152 
1153 	if (memcmp(buf, HELLO_STR WORLD_STR, strlen(HELLO_STR WORLD_STR))) {
1154 		fprintf(stderr, "pattern mismatch\n");
1155 		exit(EXIT_FAILURE);
1156 	}
1157 
1158 	control_writeln("REPLY1");
1159 
1160 	close(fd);
1161 }
1162 
1163 static void test_seqpacket_msg_peek_client(const struct test_opts *opts)
1164 {
1165 	return test_msg_peek_client(opts, true);
1166 }
1167 
1168 static void test_seqpacket_msg_peek_server(const struct test_opts *opts)
1169 {
1170 	return test_msg_peek_server(opts, true);
1171 }
1172 
1173 static struct test_case test_cases[] = {
1174 	{
1175 		.name = "SOCK_STREAM connection reset",
1176 		.run_client = test_stream_connection_reset,
1177 	},
1178 	{
1179 		.name = "SOCK_STREAM bind only",
1180 		.run_client = test_stream_bind_only_client,
1181 		.run_server = test_stream_bind_only_server,
1182 	},
1183 	{
1184 		.name = "SOCK_STREAM client close",
1185 		.run_client = test_stream_client_close_client,
1186 		.run_server = test_stream_client_close_server,
1187 	},
1188 	{
1189 		.name = "SOCK_STREAM server close",
1190 		.run_client = test_stream_server_close_client,
1191 		.run_server = test_stream_server_close_server,
1192 	},
1193 	{
1194 		.name = "SOCK_STREAM multiple connections",
1195 		.run_client = test_stream_multiconn_client,
1196 		.run_server = test_stream_multiconn_server,
1197 	},
1198 	{
1199 		.name = "SOCK_STREAM MSG_PEEK",
1200 		.run_client = test_stream_msg_peek_client,
1201 		.run_server = test_stream_msg_peek_server,
1202 	},
1203 	{
1204 		.name = "SOCK_SEQPACKET msg bounds",
1205 		.run_client = test_seqpacket_msg_bounds_client,
1206 		.run_server = test_seqpacket_msg_bounds_server,
1207 	},
1208 	{
1209 		.name = "SOCK_SEQPACKET MSG_TRUNC flag",
1210 		.run_client = test_seqpacket_msg_trunc_client,
1211 		.run_server = test_seqpacket_msg_trunc_server,
1212 	},
1213 	{
1214 		.name = "SOCK_SEQPACKET timeout",
1215 		.run_client = test_seqpacket_timeout_client,
1216 		.run_server = test_seqpacket_timeout_server,
1217 	},
1218 	{
1219 		.name = "SOCK_SEQPACKET invalid receive buffer",
1220 		.run_client = test_seqpacket_invalid_rec_buffer_client,
1221 		.run_server = test_seqpacket_invalid_rec_buffer_server,
1222 	},
1223 	{
1224 		.name = "SOCK_STREAM poll() + SO_RCVLOWAT",
1225 		.run_client = test_stream_poll_rcvlowat_client,
1226 		.run_server = test_stream_poll_rcvlowat_server,
1227 	},
1228 	{
1229 		.name = "SOCK_SEQPACKET big message",
1230 		.run_client = test_seqpacket_bigmsg_client,
1231 		.run_server = test_seqpacket_bigmsg_server,
1232 	},
1233 	{
1234 		.name = "SOCK_STREAM test invalid buffer",
1235 		.run_client = test_stream_inv_buf_client,
1236 		.run_server = test_stream_inv_buf_server,
1237 	},
1238 	{
1239 		.name = "SOCK_SEQPACKET test invalid buffer",
1240 		.run_client = test_seqpacket_inv_buf_client,
1241 		.run_server = test_seqpacket_inv_buf_server,
1242 	},
1243 	{
1244 		.name = "SOCK_STREAM virtio skb merge",
1245 		.run_client = test_stream_virtio_skb_merge_client,
1246 		.run_server = test_stream_virtio_skb_merge_server,
1247 	},
1248 	{
1249 		.name = "SOCK_SEQPACKET MSG_PEEK",
1250 		.run_client = test_seqpacket_msg_peek_client,
1251 		.run_server = test_seqpacket_msg_peek_server,
1252 	},
1253 	{},
1254 };
1255 
1256 static const char optstring[] = "";
1257 static const struct option longopts[] = {
1258 	{
1259 		.name = "control-host",
1260 		.has_arg = required_argument,
1261 		.val = 'H',
1262 	},
1263 	{
1264 		.name = "control-port",
1265 		.has_arg = required_argument,
1266 		.val = 'P',
1267 	},
1268 	{
1269 		.name = "mode",
1270 		.has_arg = required_argument,
1271 		.val = 'm',
1272 	},
1273 	{
1274 		.name = "peer-cid",
1275 		.has_arg = required_argument,
1276 		.val = 'p',
1277 	},
1278 	{
1279 		.name = "list",
1280 		.has_arg = no_argument,
1281 		.val = 'l',
1282 	},
1283 	{
1284 		.name = "skip",
1285 		.has_arg = required_argument,
1286 		.val = 's',
1287 	},
1288 	{
1289 		.name = "help",
1290 		.has_arg = no_argument,
1291 		.val = '?',
1292 	},
1293 	{},
1294 };
1295 
1296 static void usage(void)
1297 {
1298 	fprintf(stderr, "Usage: vsock_test [--help] [--control-host=<host>] --control-port=<port> --mode=client|server --peer-cid=<cid> [--list] [--skip=<test_id>]\n"
1299 		"\n"
1300 		"  Server: vsock_test --control-port=1234 --mode=server --peer-cid=3\n"
1301 		"  Client: vsock_test --control-host=192.168.0.1 --control-port=1234 --mode=client --peer-cid=2\n"
1302 		"\n"
1303 		"Run vsock.ko tests.  Must be launched in both guest\n"
1304 		"and host.  One side must use --mode=client and\n"
1305 		"the other side must use --mode=server.\n"
1306 		"\n"
1307 		"A TCP control socket connection is used to coordinate tests\n"
1308 		"between the client and the server.  The server requires a\n"
1309 		"listen address and the client requires an address to\n"
1310 		"connect to.\n"
1311 		"\n"
1312 		"The CID of the other side must be given with --peer-cid=<cid>.\n"
1313 		"\n"
1314 		"Options:\n"
1315 		"  --help                 This help message\n"
1316 		"  --control-host <host>  Server IP address to connect to\n"
1317 		"  --control-port <port>  Server port to listen on/connect to\n"
1318 		"  --mode client|server   Server or client mode\n"
1319 		"  --peer-cid <cid>       CID of the other side\n"
1320 		"  --list                 List of tests that will be executed\n"
1321 		"  --skip <test_id>       Test ID to skip;\n"
1322 		"                         use multiple --skip options to skip more tests\n"
1323 		);
1324 	exit(EXIT_FAILURE);
1325 }
1326 
1327 int main(int argc, char **argv)
1328 {
1329 	const char *control_host = NULL;
1330 	const char *control_port = NULL;
1331 	struct test_opts opts = {
1332 		.mode = TEST_MODE_UNSET,
1333 		.peer_cid = VMADDR_CID_ANY,
1334 	};
1335 
1336 	srand(time(NULL));
1337 	init_signals();
1338 
1339 	for (;;) {
1340 		int opt = getopt_long(argc, argv, optstring, longopts, NULL);
1341 
1342 		if (opt == -1)
1343 			break;
1344 
1345 		switch (opt) {
1346 		case 'H':
1347 			control_host = optarg;
1348 			break;
1349 		case 'm':
1350 			if (strcmp(optarg, "client") == 0)
1351 				opts.mode = TEST_MODE_CLIENT;
1352 			else if (strcmp(optarg, "server") == 0)
1353 				opts.mode = TEST_MODE_SERVER;
1354 			else {
1355 				fprintf(stderr, "--mode must be \"client\" or \"server\"\n");
1356 				return EXIT_FAILURE;
1357 			}
1358 			break;
1359 		case 'p':
1360 			opts.peer_cid = parse_cid(optarg);
1361 			break;
1362 		case 'P':
1363 			control_port = optarg;
1364 			break;
1365 		case 'l':
1366 			list_tests(test_cases);
1367 			break;
1368 		case 's':
1369 			skip_test(test_cases, ARRAY_SIZE(test_cases) - 1,
1370 				  optarg);
1371 			break;
1372 		case '?':
1373 		default:
1374 			usage();
1375 		}
1376 	}
1377 
1378 	if (!control_port)
1379 		usage();
1380 	if (opts.mode == TEST_MODE_UNSET)
1381 		usage();
1382 	if (opts.peer_cid == VMADDR_CID_ANY)
1383 		usage();
1384 
1385 	if (!control_host) {
1386 		if (opts.mode != TEST_MODE_SERVER)
1387 			usage();
1388 		control_host = "0.0.0.0";
1389 	}
1390 
1391 	control_init(control_host, control_port,
1392 		     opts.mode == TEST_MODE_SERVER);
1393 
1394 	run_tests(test_cases, &opts);
1395 
1396 	control_cleanup();
1397 	return EXIT_SUCCESS;
1398 }
1399