xref: /openbmc/linux/tools/testing/vsock/util.c (revision 092f32ae)
1 // SPDX-License-Identifier: GPL-2.0-only
2 /*
3  * vsock test utilities
4  *
5  * Copyright (C) 2017 Red Hat, Inc.
6  *
7  * Author: Stefan Hajnoczi <stefanha@redhat.com>
8  */
9 
10 #include <errno.h>
11 #include <stdio.h>
12 #include <stdint.h>
13 #include <stdlib.h>
14 #include <signal.h>
15 #include <unistd.h>
16 
17 #include "timeout.h"
18 #include "control.h"
19 #include "util.h"
20 
21 /* Install signal handlers */
22 void init_signals(void)
23 {
24 	struct sigaction act = {
25 		.sa_handler = sigalrm,
26 	};
27 
28 	sigaction(SIGALRM, &act, NULL);
29 	signal(SIGPIPE, SIG_IGN);
30 }
31 
32 /* Parse a CID in string representation */
33 unsigned int parse_cid(const char *str)
34 {
35 	char *endptr = NULL;
36 	unsigned long n;
37 
38 	errno = 0;
39 	n = strtoul(str, &endptr, 10);
40 	if (errno || *endptr != '\0') {
41 		fprintf(stderr, "malformed CID \"%s\"\n", str);
42 		exit(EXIT_FAILURE);
43 	}
44 	return n;
45 }
46 
47 /* Connect to <cid, port> and return the file descriptor. */
48 int vsock_stream_connect(unsigned int cid, unsigned int port)
49 {
50 	union {
51 		struct sockaddr sa;
52 		struct sockaddr_vm svm;
53 	} addr = {
54 		.svm = {
55 			.svm_family = AF_VSOCK,
56 			.svm_port = port,
57 			.svm_cid = cid,
58 		},
59 	};
60 	int ret;
61 	int fd;
62 
63 	control_expectln("LISTENING");
64 
65 	fd = socket(AF_VSOCK, SOCK_STREAM, 0);
66 
67 	timeout_begin(TIMEOUT);
68 	do {
69 		ret = connect(fd, &addr.sa, sizeof(addr.svm));
70 		timeout_check("connect");
71 	} while (ret < 0 && errno == EINTR);
72 	timeout_end();
73 
74 	if (ret < 0) {
75 		int old_errno = errno;
76 
77 		close(fd);
78 		fd = -1;
79 		errno = old_errno;
80 	}
81 	return fd;
82 }
83 
84 /* Listen on <cid, port> and return the first incoming connection.  The remote
85  * address is stored to clientaddrp.  clientaddrp may be NULL.
86  */
87 int vsock_stream_accept(unsigned int cid, unsigned int port,
88 			struct sockaddr_vm *clientaddrp)
89 {
90 	union {
91 		struct sockaddr sa;
92 		struct sockaddr_vm svm;
93 	} addr = {
94 		.svm = {
95 			.svm_family = AF_VSOCK,
96 			.svm_port = port,
97 			.svm_cid = cid,
98 		},
99 	};
100 	union {
101 		struct sockaddr sa;
102 		struct sockaddr_vm svm;
103 	} clientaddr;
104 	socklen_t clientaddr_len = sizeof(clientaddr.svm);
105 	int fd;
106 	int client_fd;
107 	int old_errno;
108 
109 	fd = socket(AF_VSOCK, SOCK_STREAM, 0);
110 
111 	if (bind(fd, &addr.sa, sizeof(addr.svm)) < 0) {
112 		perror("bind");
113 		exit(EXIT_FAILURE);
114 	}
115 
116 	if (listen(fd, 1) < 0) {
117 		perror("listen");
118 		exit(EXIT_FAILURE);
119 	}
120 
121 	control_writeln("LISTENING");
122 
123 	timeout_begin(TIMEOUT);
124 	do {
125 		client_fd = accept(fd, &clientaddr.sa, &clientaddr_len);
126 		timeout_check("accept");
127 	} while (client_fd < 0 && errno == EINTR);
128 	timeout_end();
129 
130 	old_errno = errno;
131 	close(fd);
132 	errno = old_errno;
133 
134 	if (client_fd < 0)
135 		return client_fd;
136 
137 	if (clientaddr_len != sizeof(clientaddr.svm)) {
138 		fprintf(stderr, "unexpected addrlen from accept(2), %zu\n",
139 			(size_t)clientaddr_len);
140 		exit(EXIT_FAILURE);
141 	}
142 	if (clientaddr.sa.sa_family != AF_VSOCK) {
143 		fprintf(stderr, "expected AF_VSOCK from accept(2), got %d\n",
144 			clientaddr.sa.sa_family);
145 		exit(EXIT_FAILURE);
146 	}
147 
148 	if (clientaddrp)
149 		*clientaddrp = clientaddr.svm;
150 	return client_fd;
151 }
152 
153 /* Transmit one byte and check the return value.
154  *
155  * expected_ret:
156  *  <0 Negative errno (for testing errors)
157  *   0 End-of-file
158  *   1 Success
159  */
160 void send_byte(int fd, int expected_ret, int flags)
161 {
162 	const uint8_t byte = 'A';
163 	ssize_t nwritten;
164 
165 	timeout_begin(TIMEOUT);
166 	do {
167 		nwritten = send(fd, &byte, sizeof(byte), flags);
168 		timeout_check("write");
169 	} while (nwritten < 0 && errno == EINTR);
170 	timeout_end();
171 
172 	if (expected_ret < 0) {
173 		if (nwritten != -1) {
174 			fprintf(stderr, "bogus send(2) return value %zd\n",
175 				nwritten);
176 			exit(EXIT_FAILURE);
177 		}
178 		if (errno != -expected_ret) {
179 			perror("write");
180 			exit(EXIT_FAILURE);
181 		}
182 		return;
183 	}
184 
185 	if (nwritten < 0) {
186 		perror("write");
187 		exit(EXIT_FAILURE);
188 	}
189 	if (nwritten == 0) {
190 		if (expected_ret == 0)
191 			return;
192 
193 		fprintf(stderr, "unexpected EOF while sending byte\n");
194 		exit(EXIT_FAILURE);
195 	}
196 	if (nwritten != sizeof(byte)) {
197 		fprintf(stderr, "bogus send(2) return value %zd\n", nwritten);
198 		exit(EXIT_FAILURE);
199 	}
200 }
201 
202 /* Receive one byte and check the return value.
203  *
204  * expected_ret:
205  *  <0 Negative errno (for testing errors)
206  *   0 End-of-file
207  *   1 Success
208  */
209 void recv_byte(int fd, int expected_ret, int flags)
210 {
211 	uint8_t byte;
212 	ssize_t nread;
213 
214 	timeout_begin(TIMEOUT);
215 	do {
216 		nread = recv(fd, &byte, sizeof(byte), flags);
217 		timeout_check("read");
218 	} while (nread < 0 && errno == EINTR);
219 	timeout_end();
220 
221 	if (expected_ret < 0) {
222 		if (nread != -1) {
223 			fprintf(stderr, "bogus recv(2) return value %zd\n",
224 				nread);
225 			exit(EXIT_FAILURE);
226 		}
227 		if (errno != -expected_ret) {
228 			perror("read");
229 			exit(EXIT_FAILURE);
230 		}
231 		return;
232 	}
233 
234 	if (nread < 0) {
235 		perror("read");
236 		exit(EXIT_FAILURE);
237 	}
238 	if (nread == 0) {
239 		if (expected_ret == 0)
240 			return;
241 
242 		fprintf(stderr, "unexpected EOF while receiving byte\n");
243 		exit(EXIT_FAILURE);
244 	}
245 	if (nread != sizeof(byte)) {
246 		fprintf(stderr, "bogus recv(2) return value %zd\n", nread);
247 		exit(EXIT_FAILURE);
248 	}
249 	if (byte != 'A') {
250 		fprintf(stderr, "unexpected byte read %c\n", byte);
251 		exit(EXIT_FAILURE);
252 	}
253 }
254 
255 /* Run test cases.  The program terminates if a failure occurs. */
256 void run_tests(const struct test_case *test_cases,
257 	       const struct test_opts *opts)
258 {
259 	int i;
260 
261 	for (i = 0; test_cases[i].name; i++) {
262 		void (*run)(const struct test_opts *opts);
263 
264 		printf("%s...", test_cases[i].name);
265 		fflush(stdout);
266 
267 		if (opts->mode == TEST_MODE_CLIENT) {
268 			/* Full barrier before executing the next test.  This
269 			 * ensures that client and server are executing the
270 			 * same test case.  In particular, it means whoever is
271 			 * faster will not see the peer still executing the
272 			 * last test.  This is important because port numbers
273 			 * can be used by multiple test cases.
274 			 */
275 			control_expectln("NEXT");
276 			control_writeln("NEXT");
277 
278 			run = test_cases[i].run_client;
279 		} else {
280 			control_writeln("NEXT");
281 			control_expectln("NEXT");
282 
283 			run = test_cases[i].run_server;
284 		}
285 
286 		if (run)
287 			run(opts);
288 
289 		printf("ok\n");
290 	}
291 }
292