1 // SPDX-License-Identifier: GPL-2.0-only 2 /* 3 * vsock test utilities 4 * 5 * Copyright (C) 2017 Red Hat, Inc. 6 * 7 * Author: Stefan Hajnoczi <stefanha@redhat.com> 8 */ 9 10 #include <errno.h> 11 #include <stdio.h> 12 #include <stdint.h> 13 #include <stdlib.h> 14 #include <signal.h> 15 #include <unistd.h> 16 17 #include "timeout.h" 18 #include "control.h" 19 #include "util.h" 20 21 /* Install signal handlers */ 22 void init_signals(void) 23 { 24 struct sigaction act = { 25 .sa_handler = sigalrm, 26 }; 27 28 sigaction(SIGALRM, &act, NULL); 29 signal(SIGPIPE, SIG_IGN); 30 } 31 32 /* Parse a CID in string representation */ 33 unsigned int parse_cid(const char *str) 34 { 35 char *endptr = NULL; 36 unsigned long n; 37 38 errno = 0; 39 n = strtoul(str, &endptr, 10); 40 if (errno || *endptr != '\0') { 41 fprintf(stderr, "malformed CID \"%s\"\n", str); 42 exit(EXIT_FAILURE); 43 } 44 return n; 45 } 46 47 /* Connect to <cid, port> and return the file descriptor. */ 48 int vsock_stream_connect(unsigned int cid, unsigned int port) 49 { 50 union { 51 struct sockaddr sa; 52 struct sockaddr_vm svm; 53 } addr = { 54 .svm = { 55 .svm_family = AF_VSOCK, 56 .svm_port = port, 57 .svm_cid = cid, 58 }, 59 }; 60 int ret; 61 int fd; 62 63 control_expectln("LISTENING"); 64 65 fd = socket(AF_VSOCK, SOCK_STREAM, 0); 66 67 timeout_begin(TIMEOUT); 68 do { 69 ret = connect(fd, &addr.sa, sizeof(addr.svm)); 70 timeout_check("connect"); 71 } while (ret < 0 && errno == EINTR); 72 timeout_end(); 73 74 if (ret < 0) { 75 int old_errno = errno; 76 77 close(fd); 78 fd = -1; 79 errno = old_errno; 80 } 81 return fd; 82 } 83 84 /* Listen on <cid, port> and return the first incoming connection. The remote 85 * address is stored to clientaddrp. clientaddrp may be NULL. 86 */ 87 int vsock_stream_accept(unsigned int cid, unsigned int port, 88 struct sockaddr_vm *clientaddrp) 89 { 90 union { 91 struct sockaddr sa; 92 struct sockaddr_vm svm; 93 } addr = { 94 .svm = { 95 .svm_family = AF_VSOCK, 96 .svm_port = port, 97 .svm_cid = cid, 98 }, 99 }; 100 union { 101 struct sockaddr sa; 102 struct sockaddr_vm svm; 103 } clientaddr; 104 socklen_t clientaddr_len = sizeof(clientaddr.svm); 105 int fd; 106 int client_fd; 107 int old_errno; 108 109 fd = socket(AF_VSOCK, SOCK_STREAM, 0); 110 111 if (bind(fd, &addr.sa, sizeof(addr.svm)) < 0) { 112 perror("bind"); 113 exit(EXIT_FAILURE); 114 } 115 116 if (listen(fd, 1) < 0) { 117 perror("listen"); 118 exit(EXIT_FAILURE); 119 } 120 121 control_writeln("LISTENING"); 122 123 timeout_begin(TIMEOUT); 124 do { 125 client_fd = accept(fd, &clientaddr.sa, &clientaddr_len); 126 timeout_check("accept"); 127 } while (client_fd < 0 && errno == EINTR); 128 timeout_end(); 129 130 old_errno = errno; 131 close(fd); 132 errno = old_errno; 133 134 if (client_fd < 0) 135 return client_fd; 136 137 if (clientaddr_len != sizeof(clientaddr.svm)) { 138 fprintf(stderr, "unexpected addrlen from accept(2), %zu\n", 139 (size_t)clientaddr_len); 140 exit(EXIT_FAILURE); 141 } 142 if (clientaddr.sa.sa_family != AF_VSOCK) { 143 fprintf(stderr, "expected AF_VSOCK from accept(2), got %d\n", 144 clientaddr.sa.sa_family); 145 exit(EXIT_FAILURE); 146 } 147 148 if (clientaddrp) 149 *clientaddrp = clientaddr.svm; 150 return client_fd; 151 } 152 153 /* Transmit one byte and check the return value. 154 * 155 * expected_ret: 156 * <0 Negative errno (for testing errors) 157 * 0 End-of-file 158 * 1 Success 159 */ 160 void send_byte(int fd, int expected_ret, int flags) 161 { 162 const uint8_t byte = 'A'; 163 ssize_t nwritten; 164 165 timeout_begin(TIMEOUT); 166 do { 167 nwritten = send(fd, &byte, sizeof(byte), flags); 168 timeout_check("write"); 169 } while (nwritten < 0 && errno == EINTR); 170 timeout_end(); 171 172 if (expected_ret < 0) { 173 if (nwritten != -1) { 174 fprintf(stderr, "bogus send(2) return value %zd\n", 175 nwritten); 176 exit(EXIT_FAILURE); 177 } 178 if (errno != -expected_ret) { 179 perror("write"); 180 exit(EXIT_FAILURE); 181 } 182 return; 183 } 184 185 if (nwritten < 0) { 186 perror("write"); 187 exit(EXIT_FAILURE); 188 } 189 if (nwritten == 0) { 190 if (expected_ret == 0) 191 return; 192 193 fprintf(stderr, "unexpected EOF while sending byte\n"); 194 exit(EXIT_FAILURE); 195 } 196 if (nwritten != sizeof(byte)) { 197 fprintf(stderr, "bogus send(2) return value %zd\n", nwritten); 198 exit(EXIT_FAILURE); 199 } 200 } 201 202 /* Receive one byte and check the return value. 203 * 204 * expected_ret: 205 * <0 Negative errno (for testing errors) 206 * 0 End-of-file 207 * 1 Success 208 */ 209 void recv_byte(int fd, int expected_ret, int flags) 210 { 211 uint8_t byte; 212 ssize_t nread; 213 214 timeout_begin(TIMEOUT); 215 do { 216 nread = recv(fd, &byte, sizeof(byte), flags); 217 timeout_check("read"); 218 } while (nread < 0 && errno == EINTR); 219 timeout_end(); 220 221 if (expected_ret < 0) { 222 if (nread != -1) { 223 fprintf(stderr, "bogus recv(2) return value %zd\n", 224 nread); 225 exit(EXIT_FAILURE); 226 } 227 if (errno != -expected_ret) { 228 perror("read"); 229 exit(EXIT_FAILURE); 230 } 231 return; 232 } 233 234 if (nread < 0) { 235 perror("read"); 236 exit(EXIT_FAILURE); 237 } 238 if (nread == 0) { 239 if (expected_ret == 0) 240 return; 241 242 fprintf(stderr, "unexpected EOF while receiving byte\n"); 243 exit(EXIT_FAILURE); 244 } 245 if (nread != sizeof(byte)) { 246 fprintf(stderr, "bogus recv(2) return value %zd\n", nread); 247 exit(EXIT_FAILURE); 248 } 249 if (byte != 'A') { 250 fprintf(stderr, "unexpected byte read %c\n", byte); 251 exit(EXIT_FAILURE); 252 } 253 } 254 255 /* Run test cases. The program terminates if a failure occurs. */ 256 void run_tests(const struct test_case *test_cases, 257 const struct test_opts *opts) 258 { 259 int i; 260 261 for (i = 0; test_cases[i].name; i++) { 262 void (*run)(const struct test_opts *opts); 263 264 printf("%s...", test_cases[i].name); 265 fflush(stdout); 266 267 if (opts->mode == TEST_MODE_CLIENT) { 268 /* Full barrier before executing the next test. This 269 * ensures that client and server are executing the 270 * same test case. In particular, it means whoever is 271 * faster will not see the peer still executing the 272 * last test. This is important because port numbers 273 * can be used by multiple test cases. 274 */ 275 control_expectln("NEXT"); 276 control_writeln("NEXT"); 277 278 run = test_cases[i].run_client; 279 } else { 280 control_writeln("NEXT"); 281 control_expectln("NEXT"); 282 283 run = test_cases[i].run_server; 284 } 285 286 if (run) 287 run(opts); 288 289 printf("ok\n"); 290 } 291 } 292