xref: /openbmc/linux/tools/testing/vsock/vsock_test.c (revision f5fb5ac7cee29cea9156e734fd652a66417d32fc)
1 // SPDX-License-Identifier: GPL-2.0-only
2 /*
3  * vsock_test - vsock.ko test suite
4  *
5  * Copyright (C) 2017 Red Hat, Inc.
6  *
7  * Author: Stefan Hajnoczi <stefanha@redhat.com>
8  */
9 
10 #include <getopt.h>
11 #include <stdio.h>
12 #include <stdlib.h>
13 #include <string.h>
14 #include <errno.h>
15 #include <unistd.h>
16 #include <linux/kernel.h>
17 #include <sys/types.h>
18 #include <sys/socket.h>
19 #include <time.h>
20 #include <sys/mman.h>
21 #include <poll.h>
22 
23 #include "timeout.h"
24 #include "control.h"
25 #include "util.h"
26 
27 static void test_stream_connection_reset(const struct test_opts *opts)
28 {
29 	union {
30 		struct sockaddr sa;
31 		struct sockaddr_vm svm;
32 	} addr = {
33 		.svm = {
34 			.svm_family = AF_VSOCK,
35 			.svm_port = 1234,
36 			.svm_cid = opts->peer_cid,
37 		},
38 	};
39 	int ret;
40 	int fd;
41 
42 	fd = socket(AF_VSOCK, SOCK_STREAM, 0);
43 
44 	timeout_begin(TIMEOUT);
45 	do {
46 		ret = connect(fd, &addr.sa, sizeof(addr.svm));
47 		timeout_check("connect");
48 	} while (ret < 0 && errno == EINTR);
49 	timeout_end();
50 
51 	if (ret != -1) {
52 		fprintf(stderr, "expected connect(2) failure, got %d\n", ret);
53 		exit(EXIT_FAILURE);
54 	}
55 	if (errno != ECONNRESET) {
56 		fprintf(stderr, "unexpected connect(2) errno %d\n", errno);
57 		exit(EXIT_FAILURE);
58 	}
59 
60 	close(fd);
61 }
62 
63 static void test_stream_bind_only_client(const struct test_opts *opts)
64 {
65 	union {
66 		struct sockaddr sa;
67 		struct sockaddr_vm svm;
68 	} addr = {
69 		.svm = {
70 			.svm_family = AF_VSOCK,
71 			.svm_port = 1234,
72 			.svm_cid = opts->peer_cid,
73 		},
74 	};
75 	int ret;
76 	int fd;
77 
78 	/* Wait for the server to be ready */
79 	control_expectln("BIND");
80 
81 	fd = socket(AF_VSOCK, SOCK_STREAM, 0);
82 
83 	timeout_begin(TIMEOUT);
84 	do {
85 		ret = connect(fd, &addr.sa, sizeof(addr.svm));
86 		timeout_check("connect");
87 	} while (ret < 0 && errno == EINTR);
88 	timeout_end();
89 
90 	if (ret != -1) {
91 		fprintf(stderr, "expected connect(2) failure, got %d\n", ret);
92 		exit(EXIT_FAILURE);
93 	}
94 	if (errno != ECONNRESET) {
95 		fprintf(stderr, "unexpected connect(2) errno %d\n", errno);
96 		exit(EXIT_FAILURE);
97 	}
98 
99 	/* Notify the server that the client has finished */
100 	control_writeln("DONE");
101 
102 	close(fd);
103 }
104 
105 static void test_stream_bind_only_server(const struct test_opts *opts)
106 {
107 	union {
108 		struct sockaddr sa;
109 		struct sockaddr_vm svm;
110 	} addr = {
111 		.svm = {
112 			.svm_family = AF_VSOCK,
113 			.svm_port = 1234,
114 			.svm_cid = VMADDR_CID_ANY,
115 		},
116 	};
117 	int fd;
118 
119 	fd = socket(AF_VSOCK, SOCK_STREAM, 0);
120 
121 	if (bind(fd, &addr.sa, sizeof(addr.svm)) < 0) {
122 		perror("bind");
123 		exit(EXIT_FAILURE);
124 	}
125 
126 	/* Notify the client that the server is ready */
127 	control_writeln("BIND");
128 
129 	/* Wait for the client to finish */
130 	control_expectln("DONE");
131 
132 	close(fd);
133 }
134 
135 static void test_stream_client_close_client(const struct test_opts *opts)
136 {
137 	int fd;
138 
139 	fd = vsock_stream_connect(opts->peer_cid, 1234);
140 	if (fd < 0) {
141 		perror("connect");
142 		exit(EXIT_FAILURE);
143 	}
144 
145 	send_byte(fd, 1, 0);
146 	close(fd);
147 }
148 
149 static void test_stream_client_close_server(const struct test_opts *opts)
150 {
151 	int fd;
152 
153 	fd = vsock_stream_accept(VMADDR_CID_ANY, 1234, NULL);
154 	if (fd < 0) {
155 		perror("accept");
156 		exit(EXIT_FAILURE);
157 	}
158 
159 	/* Wait for the remote to close the connection, before check
160 	 * -EPIPE error on send.
161 	 */
162 	vsock_wait_remote_close(fd);
163 
164 	send_byte(fd, -EPIPE, 0);
165 	recv_byte(fd, 1, 0);
166 	recv_byte(fd, 0, 0);
167 	close(fd);
168 }
169 
170 static void test_stream_server_close_client(const struct test_opts *opts)
171 {
172 	int fd;
173 
174 	fd = vsock_stream_connect(opts->peer_cid, 1234);
175 	if (fd < 0) {
176 		perror("connect");
177 		exit(EXIT_FAILURE);
178 	}
179 
180 	/* Wait for the remote to close the connection, before check
181 	 * -EPIPE error on send.
182 	 */
183 	vsock_wait_remote_close(fd);
184 
185 	send_byte(fd, -EPIPE, 0);
186 	recv_byte(fd, 1, 0);
187 	recv_byte(fd, 0, 0);
188 	close(fd);
189 }
190 
191 static void test_stream_server_close_server(const struct test_opts *opts)
192 {
193 	int fd;
194 
195 	fd = vsock_stream_accept(VMADDR_CID_ANY, 1234, NULL);
196 	if (fd < 0) {
197 		perror("accept");
198 		exit(EXIT_FAILURE);
199 	}
200 
201 	send_byte(fd, 1, 0);
202 	close(fd);
203 }
204 
205 /* With the standard socket sizes, VMCI is able to support about 100
206  * concurrent stream connections.
207  */
208 #define MULTICONN_NFDS 100
209 
210 static void test_stream_multiconn_client(const struct test_opts *opts)
211 {
212 	int fds[MULTICONN_NFDS];
213 	int i;
214 
215 	for (i = 0; i < MULTICONN_NFDS; i++) {
216 		fds[i] = vsock_stream_connect(opts->peer_cid, 1234);
217 		if (fds[i] < 0) {
218 			perror("connect");
219 			exit(EXIT_FAILURE);
220 		}
221 	}
222 
223 	for (i = 0; i < MULTICONN_NFDS; i++) {
224 		if (i % 2)
225 			recv_byte(fds[i], 1, 0);
226 		else
227 			send_byte(fds[i], 1, 0);
228 	}
229 
230 	for (i = 0; i < MULTICONN_NFDS; i++)
231 		close(fds[i]);
232 }
233 
234 static void test_stream_multiconn_server(const struct test_opts *opts)
235 {
236 	int fds[MULTICONN_NFDS];
237 	int i;
238 
239 	for (i = 0; i < MULTICONN_NFDS; i++) {
240 		fds[i] = vsock_stream_accept(VMADDR_CID_ANY, 1234, NULL);
241 		if (fds[i] < 0) {
242 			perror("accept");
243 			exit(EXIT_FAILURE);
244 		}
245 	}
246 
247 	for (i = 0; i < MULTICONN_NFDS; i++) {
248 		if (i % 2)
249 			send_byte(fds[i], 1, 0);
250 		else
251 			recv_byte(fds[i], 1, 0);
252 	}
253 
254 	for (i = 0; i < MULTICONN_NFDS; i++)
255 		close(fds[i]);
256 }
257 
258 #define MSG_PEEK_BUF_LEN 64
259 
260 static void test_msg_peek_client(const struct test_opts *opts,
261 				 bool seqpacket)
262 {
263 	unsigned char buf[MSG_PEEK_BUF_LEN];
264 	ssize_t send_size;
265 	int fd;
266 	int i;
267 
268 	if (seqpacket)
269 		fd = vsock_seqpacket_connect(opts->peer_cid, 1234);
270 	else
271 		fd = vsock_stream_connect(opts->peer_cid, 1234);
272 
273 	if (fd < 0) {
274 		perror("connect");
275 		exit(EXIT_FAILURE);
276 	}
277 
278 	for (i = 0; i < sizeof(buf); i++)
279 		buf[i] = rand() & 0xFF;
280 
281 	control_expectln("SRVREADY");
282 
283 	send_size = send(fd, buf, sizeof(buf), 0);
284 
285 	if (send_size < 0) {
286 		perror("send");
287 		exit(EXIT_FAILURE);
288 	}
289 
290 	if (send_size != sizeof(buf)) {
291 		fprintf(stderr, "Invalid send size %zi\n", send_size);
292 		exit(EXIT_FAILURE);
293 	}
294 
295 	close(fd);
296 }
297 
298 static void test_msg_peek_server(const struct test_opts *opts,
299 				 bool seqpacket)
300 {
301 	unsigned char buf_half[MSG_PEEK_BUF_LEN / 2];
302 	unsigned char buf_normal[MSG_PEEK_BUF_LEN];
303 	unsigned char buf_peek[MSG_PEEK_BUF_LEN];
304 	ssize_t res;
305 	int fd;
306 
307 	if (seqpacket)
308 		fd = vsock_seqpacket_accept(VMADDR_CID_ANY, 1234, NULL);
309 	else
310 		fd = vsock_stream_accept(VMADDR_CID_ANY, 1234, NULL);
311 
312 	if (fd < 0) {
313 		perror("accept");
314 		exit(EXIT_FAILURE);
315 	}
316 
317 	/* Peek from empty socket. */
318 	res = recv(fd, buf_peek, sizeof(buf_peek), MSG_PEEK | MSG_DONTWAIT);
319 	if (res != -1) {
320 		fprintf(stderr, "expected recv(2) failure, got %zi\n", res);
321 		exit(EXIT_FAILURE);
322 	}
323 
324 	if (errno != EAGAIN) {
325 		perror("EAGAIN expected");
326 		exit(EXIT_FAILURE);
327 	}
328 
329 	control_writeln("SRVREADY");
330 
331 	/* Peek part of data. */
332 	res = recv(fd, buf_half, sizeof(buf_half), MSG_PEEK);
333 	if (res != sizeof(buf_half)) {
334 		fprintf(stderr, "recv(2) + MSG_PEEK, expected %zu, got %zi\n",
335 			sizeof(buf_half), res);
336 		exit(EXIT_FAILURE);
337 	}
338 
339 	/* Peek whole data. */
340 	res = recv(fd, buf_peek, sizeof(buf_peek), MSG_PEEK);
341 	if (res != sizeof(buf_peek)) {
342 		fprintf(stderr, "recv(2) + MSG_PEEK, expected %zu, got %zi\n",
343 			sizeof(buf_peek), res);
344 		exit(EXIT_FAILURE);
345 	}
346 
347 	/* Compare partial and full peek. */
348 	if (memcmp(buf_half, buf_peek, sizeof(buf_half))) {
349 		fprintf(stderr, "Partial peek data mismatch\n");
350 		exit(EXIT_FAILURE);
351 	}
352 
353 	if (seqpacket) {
354 		/* This type of socket supports MSG_TRUNC flag,
355 		 * so check it with MSG_PEEK. We must get length
356 		 * of the message.
357 		 */
358 		res = recv(fd, buf_half, sizeof(buf_half), MSG_PEEK |
359 			   MSG_TRUNC);
360 		if (res != sizeof(buf_peek)) {
361 			fprintf(stderr,
362 				"recv(2) + MSG_PEEK | MSG_TRUNC, exp %zu, got %zi\n",
363 				sizeof(buf_half), res);
364 			exit(EXIT_FAILURE);
365 		}
366 	}
367 
368 	res = recv(fd, buf_normal, sizeof(buf_normal), 0);
369 	if (res != sizeof(buf_normal)) {
370 		fprintf(stderr, "recv(2), expected %zu, got %zi\n",
371 			sizeof(buf_normal), res);
372 		exit(EXIT_FAILURE);
373 	}
374 
375 	/* Compare full peek and normal read. */
376 	if (memcmp(buf_peek, buf_normal, sizeof(buf_peek))) {
377 		fprintf(stderr, "Full peek data mismatch\n");
378 		exit(EXIT_FAILURE);
379 	}
380 
381 	close(fd);
382 }
383 
384 static void test_stream_msg_peek_client(const struct test_opts *opts)
385 {
386 	return test_msg_peek_client(opts, false);
387 }
388 
389 static void test_stream_msg_peek_server(const struct test_opts *opts)
390 {
391 	return test_msg_peek_server(opts, false);
392 }
393 
394 #define SOCK_BUF_SIZE (2 * 1024 * 1024)
395 #define MAX_MSG_PAGES 4
396 
397 static void test_seqpacket_msg_bounds_client(const struct test_opts *opts)
398 {
399 	unsigned long curr_hash;
400 	size_t max_msg_size;
401 	int page_size;
402 	int msg_count;
403 	int fd;
404 
405 	fd = vsock_seqpacket_connect(opts->peer_cid, 1234);
406 	if (fd < 0) {
407 		perror("connect");
408 		exit(EXIT_FAILURE);
409 	}
410 
411 	/* Wait, until receiver sets buffer size. */
412 	control_expectln("SRVREADY");
413 
414 	curr_hash = 0;
415 	page_size = getpagesize();
416 	max_msg_size = MAX_MSG_PAGES * page_size;
417 	msg_count = SOCK_BUF_SIZE / max_msg_size;
418 
419 	for (int i = 0; i < msg_count; i++) {
420 		ssize_t send_size;
421 		size_t buf_size;
422 		int flags;
423 		void *buf;
424 
425 		/* Use "small" buffers and "big" buffers. */
426 		if (i & 1)
427 			buf_size = page_size +
428 					(rand() % (max_msg_size - page_size));
429 		else
430 			buf_size = 1 + (rand() % page_size);
431 
432 		buf = malloc(buf_size);
433 
434 		if (!buf) {
435 			perror("malloc");
436 			exit(EXIT_FAILURE);
437 		}
438 
439 		memset(buf, rand() & 0xff, buf_size);
440 		/* Set at least one MSG_EOR + some random. */
441 		if (i == (msg_count / 2) || (rand() & 1)) {
442 			flags = MSG_EOR;
443 			curr_hash++;
444 		} else {
445 			flags = 0;
446 		}
447 
448 		send_size = send(fd, buf, buf_size, flags);
449 
450 		if (send_size < 0) {
451 			perror("send");
452 			exit(EXIT_FAILURE);
453 		}
454 
455 		if (send_size != buf_size) {
456 			fprintf(stderr, "Invalid send size\n");
457 			exit(EXIT_FAILURE);
458 		}
459 
460 		/*
461 		 * Hash sum is computed at both client and server in
462 		 * the same way:
463 		 * H += hash('message data')
464 		 * Such hash "controls" both data integrity and message
465 		 * bounds. After data exchange, both sums are compared
466 		 * using control socket, and if message bounds wasn't
467 		 * broken - two values must be equal.
468 		 */
469 		curr_hash += hash_djb2(buf, buf_size);
470 		free(buf);
471 	}
472 
473 	control_writeln("SENDDONE");
474 	control_writeulong(curr_hash);
475 	close(fd);
476 }
477 
478 static void test_seqpacket_msg_bounds_server(const struct test_opts *opts)
479 {
480 	unsigned long sock_buf_size;
481 	unsigned long remote_hash;
482 	unsigned long curr_hash;
483 	int fd;
484 	struct msghdr msg = {0};
485 	struct iovec iov = {0};
486 
487 	fd = vsock_seqpacket_accept(VMADDR_CID_ANY, 1234, NULL);
488 	if (fd < 0) {
489 		perror("accept");
490 		exit(EXIT_FAILURE);
491 	}
492 
493 	sock_buf_size = SOCK_BUF_SIZE;
494 
495 	if (setsockopt(fd, AF_VSOCK, SO_VM_SOCKETS_BUFFER_MAX_SIZE,
496 		       &sock_buf_size, sizeof(sock_buf_size))) {
497 		perror("setsockopt(SO_VM_SOCKETS_BUFFER_MAX_SIZE)");
498 		exit(EXIT_FAILURE);
499 	}
500 
501 	if (setsockopt(fd, AF_VSOCK, SO_VM_SOCKETS_BUFFER_SIZE,
502 		       &sock_buf_size, sizeof(sock_buf_size))) {
503 		perror("setsockopt(SO_VM_SOCKETS_BUFFER_SIZE)");
504 		exit(EXIT_FAILURE);
505 	}
506 
507 	/* Ready to receive data. */
508 	control_writeln("SRVREADY");
509 	/* Wait, until peer sends whole data. */
510 	control_expectln("SENDDONE");
511 	iov.iov_len = MAX_MSG_PAGES * getpagesize();
512 	iov.iov_base = malloc(iov.iov_len);
513 	if (!iov.iov_base) {
514 		perror("malloc");
515 		exit(EXIT_FAILURE);
516 	}
517 
518 	msg.msg_iov = &iov;
519 	msg.msg_iovlen = 1;
520 
521 	curr_hash = 0;
522 
523 	while (1) {
524 		ssize_t recv_size;
525 
526 		recv_size = recvmsg(fd, &msg, 0);
527 
528 		if (!recv_size)
529 			break;
530 
531 		if (recv_size < 0) {
532 			perror("recvmsg");
533 			exit(EXIT_FAILURE);
534 		}
535 
536 		if (msg.msg_flags & MSG_EOR)
537 			curr_hash++;
538 
539 		curr_hash += hash_djb2(msg.msg_iov[0].iov_base, recv_size);
540 	}
541 
542 	free(iov.iov_base);
543 	close(fd);
544 	remote_hash = control_readulong();
545 
546 	if (curr_hash != remote_hash) {
547 		fprintf(stderr, "Message bounds broken\n");
548 		exit(EXIT_FAILURE);
549 	}
550 }
551 
552 #define MESSAGE_TRUNC_SZ 32
553 static void test_seqpacket_msg_trunc_client(const struct test_opts *opts)
554 {
555 	int fd;
556 	char buf[MESSAGE_TRUNC_SZ];
557 
558 	fd = vsock_seqpacket_connect(opts->peer_cid, 1234);
559 	if (fd < 0) {
560 		perror("connect");
561 		exit(EXIT_FAILURE);
562 	}
563 
564 	if (send(fd, buf, sizeof(buf), 0) != sizeof(buf)) {
565 		perror("send failed");
566 		exit(EXIT_FAILURE);
567 	}
568 
569 	control_writeln("SENDDONE");
570 	close(fd);
571 }
572 
573 static void test_seqpacket_msg_trunc_server(const struct test_opts *opts)
574 {
575 	int fd;
576 	char buf[MESSAGE_TRUNC_SZ / 2];
577 	struct msghdr msg = {0};
578 	struct iovec iov = {0};
579 
580 	fd = vsock_seqpacket_accept(VMADDR_CID_ANY, 1234, NULL);
581 	if (fd < 0) {
582 		perror("accept");
583 		exit(EXIT_FAILURE);
584 	}
585 
586 	control_expectln("SENDDONE");
587 	iov.iov_base = buf;
588 	iov.iov_len = sizeof(buf);
589 	msg.msg_iov = &iov;
590 	msg.msg_iovlen = 1;
591 
592 	ssize_t ret = recvmsg(fd, &msg, MSG_TRUNC);
593 
594 	if (ret != MESSAGE_TRUNC_SZ) {
595 		printf("%zi\n", ret);
596 		perror("MSG_TRUNC doesn't work");
597 		exit(EXIT_FAILURE);
598 	}
599 
600 	if (!(msg.msg_flags & MSG_TRUNC)) {
601 		fprintf(stderr, "MSG_TRUNC expected\n");
602 		exit(EXIT_FAILURE);
603 	}
604 
605 	close(fd);
606 }
607 
608 static time_t current_nsec(void)
609 {
610 	struct timespec ts;
611 
612 	if (clock_gettime(CLOCK_REALTIME, &ts)) {
613 		perror("clock_gettime(3) failed");
614 		exit(EXIT_FAILURE);
615 	}
616 
617 	return (ts.tv_sec * 1000000000ULL) + ts.tv_nsec;
618 }
619 
620 #define RCVTIMEO_TIMEOUT_SEC 1
621 #define READ_OVERHEAD_NSEC 250000000 /* 0.25 sec */
622 
623 static void test_seqpacket_timeout_client(const struct test_opts *opts)
624 {
625 	int fd;
626 	struct timeval tv;
627 	char dummy;
628 	time_t read_enter_ns;
629 	time_t read_overhead_ns;
630 
631 	fd = vsock_seqpacket_connect(opts->peer_cid, 1234);
632 	if (fd < 0) {
633 		perror("connect");
634 		exit(EXIT_FAILURE);
635 	}
636 
637 	tv.tv_sec = RCVTIMEO_TIMEOUT_SEC;
638 	tv.tv_usec = 0;
639 
640 	if (setsockopt(fd, SOL_SOCKET, SO_RCVTIMEO, (void *)&tv, sizeof(tv)) == -1) {
641 		perror("setsockopt(SO_RCVTIMEO)");
642 		exit(EXIT_FAILURE);
643 	}
644 
645 	read_enter_ns = current_nsec();
646 
647 	if (read(fd, &dummy, sizeof(dummy)) != -1) {
648 		fprintf(stderr,
649 			"expected 'dummy' read(2) failure\n");
650 		exit(EXIT_FAILURE);
651 	}
652 
653 	if (errno != EAGAIN) {
654 		perror("EAGAIN expected");
655 		exit(EXIT_FAILURE);
656 	}
657 
658 	read_overhead_ns = current_nsec() - read_enter_ns -
659 			1000000000ULL * RCVTIMEO_TIMEOUT_SEC;
660 
661 	if (read_overhead_ns > READ_OVERHEAD_NSEC) {
662 		fprintf(stderr,
663 			"too much time in read(2), %lu > %i ns\n",
664 			read_overhead_ns, READ_OVERHEAD_NSEC);
665 		exit(EXIT_FAILURE);
666 	}
667 
668 	control_writeln("WAITDONE");
669 	close(fd);
670 }
671 
672 static void test_seqpacket_timeout_server(const struct test_opts *opts)
673 {
674 	int fd;
675 
676 	fd = vsock_seqpacket_accept(VMADDR_CID_ANY, 1234, NULL);
677 	if (fd < 0) {
678 		perror("accept");
679 		exit(EXIT_FAILURE);
680 	}
681 
682 	control_expectln("WAITDONE");
683 	close(fd);
684 }
685 
686 static void test_seqpacket_bigmsg_client(const struct test_opts *opts)
687 {
688 	unsigned long sock_buf_size;
689 	ssize_t send_size;
690 	socklen_t len;
691 	void *data;
692 	int fd;
693 
694 	len = sizeof(sock_buf_size);
695 
696 	fd = vsock_seqpacket_connect(opts->peer_cid, 1234);
697 	if (fd < 0) {
698 		perror("connect");
699 		exit(EXIT_FAILURE);
700 	}
701 
702 	if (getsockopt(fd, AF_VSOCK, SO_VM_SOCKETS_BUFFER_SIZE,
703 		       &sock_buf_size, &len)) {
704 		perror("getsockopt");
705 		exit(EXIT_FAILURE);
706 	}
707 
708 	sock_buf_size++;
709 
710 	data = malloc(sock_buf_size);
711 	if (!data) {
712 		perror("malloc");
713 		exit(EXIT_FAILURE);
714 	}
715 
716 	send_size = send(fd, data, sock_buf_size, 0);
717 	if (send_size != -1) {
718 		fprintf(stderr, "expected 'send(2)' failure, got %zi\n",
719 			send_size);
720 		exit(EXIT_FAILURE);
721 	}
722 
723 	if (errno != EMSGSIZE) {
724 		fprintf(stderr, "expected EMSGSIZE in 'errno', got %i\n",
725 			errno);
726 		exit(EXIT_FAILURE);
727 	}
728 
729 	control_writeln("CLISENT");
730 
731 	free(data);
732 	close(fd);
733 }
734 
735 static void test_seqpacket_bigmsg_server(const struct test_opts *opts)
736 {
737 	int fd;
738 
739 	fd = vsock_seqpacket_accept(VMADDR_CID_ANY, 1234, NULL);
740 	if (fd < 0) {
741 		perror("accept");
742 		exit(EXIT_FAILURE);
743 	}
744 
745 	control_expectln("CLISENT");
746 
747 	close(fd);
748 }
749 
750 #define BUF_PATTERN_1 'a'
751 #define BUF_PATTERN_2 'b'
752 
753 static void test_seqpacket_invalid_rec_buffer_client(const struct test_opts *opts)
754 {
755 	int fd;
756 	unsigned char *buf1;
757 	unsigned char *buf2;
758 	int buf_size = getpagesize() * 3;
759 
760 	fd = vsock_seqpacket_connect(opts->peer_cid, 1234);
761 	if (fd < 0) {
762 		perror("connect");
763 		exit(EXIT_FAILURE);
764 	}
765 
766 	buf1 = malloc(buf_size);
767 	if (!buf1) {
768 		perror("'malloc()' for 'buf1'");
769 		exit(EXIT_FAILURE);
770 	}
771 
772 	buf2 = malloc(buf_size);
773 	if (!buf2) {
774 		perror("'malloc()' for 'buf2'");
775 		exit(EXIT_FAILURE);
776 	}
777 
778 	memset(buf1, BUF_PATTERN_1, buf_size);
779 	memset(buf2, BUF_PATTERN_2, buf_size);
780 
781 	if (send(fd, buf1, buf_size, 0) != buf_size) {
782 		perror("send failed");
783 		exit(EXIT_FAILURE);
784 	}
785 
786 	if (send(fd, buf2, buf_size, 0) != buf_size) {
787 		perror("send failed");
788 		exit(EXIT_FAILURE);
789 	}
790 
791 	close(fd);
792 }
793 
794 static void test_seqpacket_invalid_rec_buffer_server(const struct test_opts *opts)
795 {
796 	int fd;
797 	unsigned char *broken_buf;
798 	unsigned char *valid_buf;
799 	int page_size = getpagesize();
800 	int buf_size = page_size * 3;
801 	ssize_t res;
802 	int prot = PROT_READ | PROT_WRITE;
803 	int flags = MAP_PRIVATE | MAP_ANONYMOUS;
804 	int i;
805 
806 	fd = vsock_seqpacket_accept(VMADDR_CID_ANY, 1234, NULL);
807 	if (fd < 0) {
808 		perror("accept");
809 		exit(EXIT_FAILURE);
810 	}
811 
812 	/* Setup first buffer. */
813 	broken_buf = mmap(NULL, buf_size, prot, flags, -1, 0);
814 	if (broken_buf == MAP_FAILED) {
815 		perror("mmap for 'broken_buf'");
816 		exit(EXIT_FAILURE);
817 	}
818 
819 	/* Unmap "hole" in buffer. */
820 	if (munmap(broken_buf + page_size, page_size)) {
821 		perror("'broken_buf' setup");
822 		exit(EXIT_FAILURE);
823 	}
824 
825 	valid_buf = mmap(NULL, buf_size, prot, flags, -1, 0);
826 	if (valid_buf == MAP_FAILED) {
827 		perror("mmap for 'valid_buf'");
828 		exit(EXIT_FAILURE);
829 	}
830 
831 	/* Try to fill buffer with unmapped middle. */
832 	res = read(fd, broken_buf, buf_size);
833 	if (res != -1) {
834 		fprintf(stderr,
835 			"expected 'broken_buf' read(2) failure, got %zi\n",
836 			res);
837 		exit(EXIT_FAILURE);
838 	}
839 
840 	if (errno != EFAULT) {
841 		perror("unexpected errno of 'broken_buf'");
842 		exit(EXIT_FAILURE);
843 	}
844 
845 	/* Try to fill valid buffer. */
846 	res = read(fd, valid_buf, buf_size);
847 	if (res < 0) {
848 		perror("unexpected 'valid_buf' read(2) failure");
849 		exit(EXIT_FAILURE);
850 	}
851 
852 	if (res != buf_size) {
853 		fprintf(stderr,
854 			"invalid 'valid_buf' read(2), expected %i, got %zi\n",
855 			buf_size, res);
856 		exit(EXIT_FAILURE);
857 	}
858 
859 	for (i = 0; i < buf_size; i++) {
860 		if (valid_buf[i] != BUF_PATTERN_2) {
861 			fprintf(stderr,
862 				"invalid pattern for 'valid_buf' at %i, expected %hhX, got %hhX\n",
863 				i, BUF_PATTERN_2, valid_buf[i]);
864 			exit(EXIT_FAILURE);
865 		}
866 	}
867 
868 	/* Unmap buffers. */
869 	munmap(broken_buf, page_size);
870 	munmap(broken_buf + page_size * 2, page_size);
871 	munmap(valid_buf, buf_size);
872 	close(fd);
873 }
874 
875 #define RCVLOWAT_BUF_SIZE 128
876 
877 static void test_stream_poll_rcvlowat_server(const struct test_opts *opts)
878 {
879 	int fd;
880 	int i;
881 
882 	fd = vsock_stream_accept(VMADDR_CID_ANY, 1234, NULL);
883 	if (fd < 0) {
884 		perror("accept");
885 		exit(EXIT_FAILURE);
886 	}
887 
888 	/* Send 1 byte. */
889 	send_byte(fd, 1, 0);
890 
891 	control_writeln("SRVSENT");
892 
893 	/* Wait until client is ready to receive rest of data. */
894 	control_expectln("CLNSENT");
895 
896 	for (i = 0; i < RCVLOWAT_BUF_SIZE - 1; i++)
897 		send_byte(fd, 1, 0);
898 
899 	/* Keep socket in active state. */
900 	control_expectln("POLLDONE");
901 
902 	close(fd);
903 }
904 
905 static void test_stream_poll_rcvlowat_client(const struct test_opts *opts)
906 {
907 	unsigned long lowat_val = RCVLOWAT_BUF_SIZE;
908 	char buf[RCVLOWAT_BUF_SIZE];
909 	struct pollfd fds;
910 	ssize_t read_res;
911 	short poll_flags;
912 	int fd;
913 
914 	fd = vsock_stream_connect(opts->peer_cid, 1234);
915 	if (fd < 0) {
916 		perror("connect");
917 		exit(EXIT_FAILURE);
918 	}
919 
920 	if (setsockopt(fd, SOL_SOCKET, SO_RCVLOWAT,
921 		       &lowat_val, sizeof(lowat_val))) {
922 		perror("setsockopt(SO_RCVLOWAT)");
923 		exit(EXIT_FAILURE);
924 	}
925 
926 	control_expectln("SRVSENT");
927 
928 	/* At this point, server sent 1 byte. */
929 	fds.fd = fd;
930 	poll_flags = POLLIN | POLLRDNORM;
931 	fds.events = poll_flags;
932 
933 	/* Try to wait for 1 sec. */
934 	if (poll(&fds, 1, 1000) < 0) {
935 		perror("poll");
936 		exit(EXIT_FAILURE);
937 	}
938 
939 	/* poll() must return nothing. */
940 	if (fds.revents) {
941 		fprintf(stderr, "Unexpected poll result %hx\n",
942 			fds.revents);
943 		exit(EXIT_FAILURE);
944 	}
945 
946 	/* Tell server to send rest of data. */
947 	control_writeln("CLNSENT");
948 
949 	/* Poll for data. */
950 	if (poll(&fds, 1, 10000) < 0) {
951 		perror("poll");
952 		exit(EXIT_FAILURE);
953 	}
954 
955 	/* Only these two bits are expected. */
956 	if (fds.revents != poll_flags) {
957 		fprintf(stderr, "Unexpected poll result %hx\n",
958 			fds.revents);
959 		exit(EXIT_FAILURE);
960 	}
961 
962 	/* Use MSG_DONTWAIT, if call is going to wait, EAGAIN
963 	 * will be returned.
964 	 */
965 	read_res = recv(fd, buf, sizeof(buf), MSG_DONTWAIT);
966 	if (read_res != RCVLOWAT_BUF_SIZE) {
967 		fprintf(stderr, "Unexpected recv result %zi\n",
968 			read_res);
969 		exit(EXIT_FAILURE);
970 	}
971 
972 	control_writeln("POLLDONE");
973 
974 	close(fd);
975 }
976 
977 #define INV_BUF_TEST_DATA_LEN 512
978 
979 static void test_inv_buf_client(const struct test_opts *opts, bool stream)
980 {
981 	unsigned char data[INV_BUF_TEST_DATA_LEN] = {0};
982 	ssize_t ret;
983 	int fd;
984 
985 	if (stream)
986 		fd = vsock_stream_connect(opts->peer_cid, 1234);
987 	else
988 		fd = vsock_seqpacket_connect(opts->peer_cid, 1234);
989 
990 	if (fd < 0) {
991 		perror("connect");
992 		exit(EXIT_FAILURE);
993 	}
994 
995 	control_expectln("SENDDONE");
996 
997 	/* Use invalid buffer here. */
998 	ret = recv(fd, NULL, sizeof(data), 0);
999 	if (ret != -1) {
1000 		fprintf(stderr, "expected recv(2) failure, got %zi\n", ret);
1001 		exit(EXIT_FAILURE);
1002 	}
1003 
1004 	if (errno != EFAULT) {
1005 		fprintf(stderr, "unexpected recv(2) errno %d\n", errno);
1006 		exit(EXIT_FAILURE);
1007 	}
1008 
1009 	ret = recv(fd, data, sizeof(data), MSG_DONTWAIT);
1010 
1011 	if (stream) {
1012 		/* For SOCK_STREAM we must continue reading. */
1013 		if (ret != sizeof(data)) {
1014 			fprintf(stderr, "expected recv(2) success, got %zi\n", ret);
1015 			exit(EXIT_FAILURE);
1016 		}
1017 		/* Don't check errno in case of success. */
1018 	} else {
1019 		/* For SOCK_SEQPACKET socket's queue must be empty. */
1020 		if (ret != -1) {
1021 			fprintf(stderr, "expected recv(2) failure, got %zi\n", ret);
1022 			exit(EXIT_FAILURE);
1023 		}
1024 
1025 		if (errno != EAGAIN) {
1026 			fprintf(stderr, "unexpected recv(2) errno %d\n", errno);
1027 			exit(EXIT_FAILURE);
1028 		}
1029 	}
1030 
1031 	control_writeln("DONE");
1032 
1033 	close(fd);
1034 }
1035 
1036 static void test_inv_buf_server(const struct test_opts *opts, bool stream)
1037 {
1038 	unsigned char data[INV_BUF_TEST_DATA_LEN] = {0};
1039 	ssize_t res;
1040 	int fd;
1041 
1042 	if (stream)
1043 		fd = vsock_stream_accept(VMADDR_CID_ANY, 1234, NULL);
1044 	else
1045 		fd = vsock_seqpacket_accept(VMADDR_CID_ANY, 1234, NULL);
1046 
1047 	if (fd < 0) {
1048 		perror("accept");
1049 		exit(EXIT_FAILURE);
1050 	}
1051 
1052 	res = send(fd, data, sizeof(data), 0);
1053 	if (res != sizeof(data)) {
1054 		fprintf(stderr, "unexpected send(2) result %zi\n", res);
1055 		exit(EXIT_FAILURE);
1056 	}
1057 
1058 	control_writeln("SENDDONE");
1059 
1060 	control_expectln("DONE");
1061 
1062 	close(fd);
1063 }
1064 
1065 static void test_stream_inv_buf_client(const struct test_opts *opts)
1066 {
1067 	test_inv_buf_client(opts, true);
1068 }
1069 
1070 static void test_stream_inv_buf_server(const struct test_opts *opts)
1071 {
1072 	test_inv_buf_server(opts, true);
1073 }
1074 
1075 static void test_seqpacket_inv_buf_client(const struct test_opts *opts)
1076 {
1077 	test_inv_buf_client(opts, false);
1078 }
1079 
1080 static void test_seqpacket_inv_buf_server(const struct test_opts *opts)
1081 {
1082 	test_inv_buf_server(opts, false);
1083 }
1084 
1085 #define HELLO_STR "HELLO"
1086 #define WORLD_STR "WORLD"
1087 
1088 static void test_stream_virtio_skb_merge_client(const struct test_opts *opts)
1089 {
1090 	ssize_t res;
1091 	int fd;
1092 
1093 	fd = vsock_stream_connect(opts->peer_cid, 1234);
1094 	if (fd < 0) {
1095 		perror("connect");
1096 		exit(EXIT_FAILURE);
1097 	}
1098 
1099 	/* Send first skbuff. */
1100 	res = send(fd, HELLO_STR, strlen(HELLO_STR), 0);
1101 	if (res != strlen(HELLO_STR)) {
1102 		fprintf(stderr, "unexpected send(2) result %zi\n", res);
1103 		exit(EXIT_FAILURE);
1104 	}
1105 
1106 	control_writeln("SEND0");
1107 	/* Peer reads part of first skbuff. */
1108 	control_expectln("REPLY0");
1109 
1110 	/* Send second skbuff, it will be appended to the first. */
1111 	res = send(fd, WORLD_STR, strlen(WORLD_STR), 0);
1112 	if (res != strlen(WORLD_STR)) {
1113 		fprintf(stderr, "unexpected send(2) result %zi\n", res);
1114 		exit(EXIT_FAILURE);
1115 	}
1116 
1117 	control_writeln("SEND1");
1118 	/* Peer reads merged skbuff packet. */
1119 	control_expectln("REPLY1");
1120 
1121 	close(fd);
1122 }
1123 
1124 static void test_stream_virtio_skb_merge_server(const struct test_opts *opts)
1125 {
1126 	unsigned char buf[64];
1127 	ssize_t res;
1128 	int fd;
1129 
1130 	fd = vsock_stream_accept(VMADDR_CID_ANY, 1234, NULL);
1131 	if (fd < 0) {
1132 		perror("accept");
1133 		exit(EXIT_FAILURE);
1134 	}
1135 
1136 	control_expectln("SEND0");
1137 
1138 	/* Read skbuff partially. */
1139 	res = recv(fd, buf, 2, 0);
1140 	if (res != 2) {
1141 		fprintf(stderr, "expected recv(2) returns 2 bytes, got %zi\n", res);
1142 		exit(EXIT_FAILURE);
1143 	}
1144 
1145 	control_writeln("REPLY0");
1146 	control_expectln("SEND1");
1147 
1148 	res = recv(fd, buf + 2, sizeof(buf) - 2, 0);
1149 	if (res != 8) {
1150 		fprintf(stderr, "expected recv(2) returns 8 bytes, got %zi\n", res);
1151 		exit(EXIT_FAILURE);
1152 	}
1153 
1154 	res = recv(fd, buf, sizeof(buf) - 8 - 2, MSG_DONTWAIT);
1155 	if (res != -1) {
1156 		fprintf(stderr, "expected recv(2) failure, got %zi\n", res);
1157 		exit(EXIT_FAILURE);
1158 	}
1159 
1160 	if (memcmp(buf, HELLO_STR WORLD_STR, strlen(HELLO_STR WORLD_STR))) {
1161 		fprintf(stderr, "pattern mismatch\n");
1162 		exit(EXIT_FAILURE);
1163 	}
1164 
1165 	control_writeln("REPLY1");
1166 
1167 	close(fd);
1168 }
1169 
1170 static void test_seqpacket_msg_peek_client(const struct test_opts *opts)
1171 {
1172 	return test_msg_peek_client(opts, true);
1173 }
1174 
1175 static void test_seqpacket_msg_peek_server(const struct test_opts *opts)
1176 {
1177 	return test_msg_peek_server(opts, true);
1178 }
1179 
1180 static struct test_case test_cases[] = {
1181 	{
1182 		.name = "SOCK_STREAM connection reset",
1183 		.run_client = test_stream_connection_reset,
1184 	},
1185 	{
1186 		.name = "SOCK_STREAM bind only",
1187 		.run_client = test_stream_bind_only_client,
1188 		.run_server = test_stream_bind_only_server,
1189 	},
1190 	{
1191 		.name = "SOCK_STREAM client close",
1192 		.run_client = test_stream_client_close_client,
1193 		.run_server = test_stream_client_close_server,
1194 	},
1195 	{
1196 		.name = "SOCK_STREAM server close",
1197 		.run_client = test_stream_server_close_client,
1198 		.run_server = test_stream_server_close_server,
1199 	},
1200 	{
1201 		.name = "SOCK_STREAM multiple connections",
1202 		.run_client = test_stream_multiconn_client,
1203 		.run_server = test_stream_multiconn_server,
1204 	},
1205 	{
1206 		.name = "SOCK_STREAM MSG_PEEK",
1207 		.run_client = test_stream_msg_peek_client,
1208 		.run_server = test_stream_msg_peek_server,
1209 	},
1210 	{
1211 		.name = "SOCK_SEQPACKET msg bounds",
1212 		.run_client = test_seqpacket_msg_bounds_client,
1213 		.run_server = test_seqpacket_msg_bounds_server,
1214 	},
1215 	{
1216 		.name = "SOCK_SEQPACKET MSG_TRUNC flag",
1217 		.run_client = test_seqpacket_msg_trunc_client,
1218 		.run_server = test_seqpacket_msg_trunc_server,
1219 	},
1220 	{
1221 		.name = "SOCK_SEQPACKET timeout",
1222 		.run_client = test_seqpacket_timeout_client,
1223 		.run_server = test_seqpacket_timeout_server,
1224 	},
1225 	{
1226 		.name = "SOCK_SEQPACKET invalid receive buffer",
1227 		.run_client = test_seqpacket_invalid_rec_buffer_client,
1228 		.run_server = test_seqpacket_invalid_rec_buffer_server,
1229 	},
1230 	{
1231 		.name = "SOCK_STREAM poll() + SO_RCVLOWAT",
1232 		.run_client = test_stream_poll_rcvlowat_client,
1233 		.run_server = test_stream_poll_rcvlowat_server,
1234 	},
1235 	{
1236 		.name = "SOCK_SEQPACKET big message",
1237 		.run_client = test_seqpacket_bigmsg_client,
1238 		.run_server = test_seqpacket_bigmsg_server,
1239 	},
1240 	{
1241 		.name = "SOCK_STREAM test invalid buffer",
1242 		.run_client = test_stream_inv_buf_client,
1243 		.run_server = test_stream_inv_buf_server,
1244 	},
1245 	{
1246 		.name = "SOCK_SEQPACKET test invalid buffer",
1247 		.run_client = test_seqpacket_inv_buf_client,
1248 		.run_server = test_seqpacket_inv_buf_server,
1249 	},
1250 	{
1251 		.name = "SOCK_STREAM virtio skb merge",
1252 		.run_client = test_stream_virtio_skb_merge_client,
1253 		.run_server = test_stream_virtio_skb_merge_server,
1254 	},
1255 	{
1256 		.name = "SOCK_SEQPACKET MSG_PEEK",
1257 		.run_client = test_seqpacket_msg_peek_client,
1258 		.run_server = test_seqpacket_msg_peek_server,
1259 	},
1260 	{},
1261 };
1262 
1263 static const char optstring[] = "";
1264 static const struct option longopts[] = {
1265 	{
1266 		.name = "control-host",
1267 		.has_arg = required_argument,
1268 		.val = 'H',
1269 	},
1270 	{
1271 		.name = "control-port",
1272 		.has_arg = required_argument,
1273 		.val = 'P',
1274 	},
1275 	{
1276 		.name = "mode",
1277 		.has_arg = required_argument,
1278 		.val = 'm',
1279 	},
1280 	{
1281 		.name = "peer-cid",
1282 		.has_arg = required_argument,
1283 		.val = 'p',
1284 	},
1285 	{
1286 		.name = "list",
1287 		.has_arg = no_argument,
1288 		.val = 'l',
1289 	},
1290 	{
1291 		.name = "skip",
1292 		.has_arg = required_argument,
1293 		.val = 's',
1294 	},
1295 	{
1296 		.name = "help",
1297 		.has_arg = no_argument,
1298 		.val = '?',
1299 	},
1300 	{},
1301 };
1302 
1303 static void usage(void)
1304 {
1305 	fprintf(stderr, "Usage: vsock_test [--help] [--control-host=<host>] --control-port=<port> --mode=client|server --peer-cid=<cid> [--list] [--skip=<test_id>]\n"
1306 		"\n"
1307 		"  Server: vsock_test --control-port=1234 --mode=server --peer-cid=3\n"
1308 		"  Client: vsock_test --control-host=192.168.0.1 --control-port=1234 --mode=client --peer-cid=2\n"
1309 		"\n"
1310 		"Run vsock.ko tests.  Must be launched in both guest\n"
1311 		"and host.  One side must use --mode=client and\n"
1312 		"the other side must use --mode=server.\n"
1313 		"\n"
1314 		"A TCP control socket connection is used to coordinate tests\n"
1315 		"between the client and the server.  The server requires a\n"
1316 		"listen address and the client requires an address to\n"
1317 		"connect to.\n"
1318 		"\n"
1319 		"The CID of the other side must be given with --peer-cid=<cid>.\n"
1320 		"\n"
1321 		"Options:\n"
1322 		"  --help                 This help message\n"
1323 		"  --control-host <host>  Server IP address to connect to\n"
1324 		"  --control-port <port>  Server port to listen on/connect to\n"
1325 		"  --mode client|server   Server or client mode\n"
1326 		"  --peer-cid <cid>       CID of the other side\n"
1327 		"  --list                 List of tests that will be executed\n"
1328 		"  --skip <test_id>       Test ID to skip;\n"
1329 		"                         use multiple --skip options to skip more tests\n"
1330 		);
1331 	exit(EXIT_FAILURE);
1332 }
1333 
1334 int main(int argc, char **argv)
1335 {
1336 	const char *control_host = NULL;
1337 	const char *control_port = NULL;
1338 	struct test_opts opts = {
1339 		.mode = TEST_MODE_UNSET,
1340 		.peer_cid = VMADDR_CID_ANY,
1341 	};
1342 
1343 	srand(time(NULL));
1344 	init_signals();
1345 
1346 	for (;;) {
1347 		int opt = getopt_long(argc, argv, optstring, longopts, NULL);
1348 
1349 		if (opt == -1)
1350 			break;
1351 
1352 		switch (opt) {
1353 		case 'H':
1354 			control_host = optarg;
1355 			break;
1356 		case 'm':
1357 			if (strcmp(optarg, "client") == 0)
1358 				opts.mode = TEST_MODE_CLIENT;
1359 			else if (strcmp(optarg, "server") == 0)
1360 				opts.mode = TEST_MODE_SERVER;
1361 			else {
1362 				fprintf(stderr, "--mode must be \"client\" or \"server\"\n");
1363 				return EXIT_FAILURE;
1364 			}
1365 			break;
1366 		case 'p':
1367 			opts.peer_cid = parse_cid(optarg);
1368 			break;
1369 		case 'P':
1370 			control_port = optarg;
1371 			break;
1372 		case 'l':
1373 			list_tests(test_cases);
1374 			break;
1375 		case 's':
1376 			skip_test(test_cases, ARRAY_SIZE(test_cases) - 1,
1377 				  optarg);
1378 			break;
1379 		case '?':
1380 		default:
1381 			usage();
1382 		}
1383 	}
1384 
1385 	if (!control_port)
1386 		usage();
1387 	if (opts.mode == TEST_MODE_UNSET)
1388 		usage();
1389 	if (opts.peer_cid == VMADDR_CID_ANY)
1390 		usage();
1391 
1392 	if (!control_host) {
1393 		if (opts.mode != TEST_MODE_SERVER)
1394 			usage();
1395 		control_host = "0.0.0.0";
1396 	}
1397 
1398 	control_init(control_host, control_port,
1399 		     opts.mode == TEST_MODE_SERVER);
1400 
1401 	run_tests(test_cases, &opts);
1402 
1403 	control_cleanup();
1404 	return EXIT_SUCCESS;
1405 }
1406