1 // SPDX-License-Identifier: GPL-2.0-only 2 /* 3 * vsock_diag_test - vsock_diag.ko test suite 4 * 5 * Copyright (C) 2017 Red Hat, Inc. 6 * 7 * Author: Stefan Hajnoczi <stefanha@redhat.com> 8 */ 9 10 #include <getopt.h> 11 #include <stdio.h> 12 #include <stdlib.h> 13 #include <string.h> 14 #include <errno.h> 15 #include <unistd.h> 16 #include <sys/stat.h> 17 #include <sys/types.h> 18 #include <linux/list.h> 19 #include <linux/net.h> 20 #include <linux/netlink.h> 21 #include <linux/sock_diag.h> 22 #include <linux/vm_sockets_diag.h> 23 #include <netinet/tcp.h> 24 25 #include "timeout.h" 26 #include "control.h" 27 #include "util.h" 28 29 /* Per-socket status */ 30 struct vsock_stat { 31 struct list_head list; 32 struct vsock_diag_msg msg; 33 }; 34 35 static const char *sock_type_str(int type) 36 { 37 switch (type) { 38 case SOCK_DGRAM: 39 return "DGRAM"; 40 case SOCK_STREAM: 41 return "STREAM"; 42 default: 43 return "INVALID TYPE"; 44 } 45 } 46 47 static const char *sock_state_str(int state) 48 { 49 switch (state) { 50 case TCP_CLOSE: 51 return "UNCONNECTED"; 52 case TCP_SYN_SENT: 53 return "CONNECTING"; 54 case TCP_ESTABLISHED: 55 return "CONNECTED"; 56 case TCP_CLOSING: 57 return "DISCONNECTING"; 58 case TCP_LISTEN: 59 return "LISTEN"; 60 default: 61 return "INVALID STATE"; 62 } 63 } 64 65 static const char *sock_shutdown_str(int shutdown) 66 { 67 switch (shutdown) { 68 case 1: 69 return "RCV_SHUTDOWN"; 70 case 2: 71 return "SEND_SHUTDOWN"; 72 case 3: 73 return "RCV_SHUTDOWN | SEND_SHUTDOWN"; 74 default: 75 return "0"; 76 } 77 } 78 79 static void print_vsock_addr(FILE *fp, unsigned int cid, unsigned int port) 80 { 81 if (cid == VMADDR_CID_ANY) 82 fprintf(fp, "*:"); 83 else 84 fprintf(fp, "%u:", cid); 85 86 if (port == VMADDR_PORT_ANY) 87 fprintf(fp, "*"); 88 else 89 fprintf(fp, "%u", port); 90 } 91 92 static void print_vsock_stat(FILE *fp, struct vsock_stat *st) 93 { 94 print_vsock_addr(fp, st->msg.vdiag_src_cid, st->msg.vdiag_src_port); 95 fprintf(fp, " "); 96 print_vsock_addr(fp, st->msg.vdiag_dst_cid, st->msg.vdiag_dst_port); 97 fprintf(fp, " %s %s %s %u\n", 98 sock_type_str(st->msg.vdiag_type), 99 sock_state_str(st->msg.vdiag_state), 100 sock_shutdown_str(st->msg.vdiag_shutdown), 101 st->msg.vdiag_ino); 102 } 103 104 static void print_vsock_stats(FILE *fp, struct list_head *head) 105 { 106 struct vsock_stat *st; 107 108 list_for_each_entry(st, head, list) 109 print_vsock_stat(fp, st); 110 } 111 112 static struct vsock_stat *find_vsock_stat(struct list_head *head, int fd) 113 { 114 struct vsock_stat *st; 115 struct stat stat; 116 117 if (fstat(fd, &stat) < 0) { 118 perror("fstat"); 119 exit(EXIT_FAILURE); 120 } 121 122 list_for_each_entry(st, head, list) 123 if (st->msg.vdiag_ino == stat.st_ino) 124 return st; 125 126 fprintf(stderr, "cannot find fd %d\n", fd); 127 exit(EXIT_FAILURE); 128 } 129 130 static void check_no_sockets(struct list_head *head) 131 { 132 if (!list_empty(head)) { 133 fprintf(stderr, "expected no sockets\n"); 134 print_vsock_stats(stderr, head); 135 exit(1); 136 } 137 } 138 139 static void check_num_sockets(struct list_head *head, int expected) 140 { 141 struct list_head *node; 142 int n = 0; 143 144 list_for_each(node, head) 145 n++; 146 147 if (n != expected) { 148 fprintf(stderr, "expected %d sockets, found %d\n", 149 expected, n); 150 print_vsock_stats(stderr, head); 151 exit(EXIT_FAILURE); 152 } 153 } 154 155 static void check_socket_state(struct vsock_stat *st, __u8 state) 156 { 157 if (st->msg.vdiag_state != state) { 158 fprintf(stderr, "expected socket state %#x, got %#x\n", 159 state, st->msg.vdiag_state); 160 exit(EXIT_FAILURE); 161 } 162 } 163 164 static void send_req(int fd) 165 { 166 struct sockaddr_nl nladdr = { 167 .nl_family = AF_NETLINK, 168 }; 169 struct { 170 struct nlmsghdr nlh; 171 struct vsock_diag_req vreq; 172 } req = { 173 .nlh = { 174 .nlmsg_len = sizeof(req), 175 .nlmsg_type = SOCK_DIAG_BY_FAMILY, 176 .nlmsg_flags = NLM_F_REQUEST | NLM_F_DUMP, 177 }, 178 .vreq = { 179 .sdiag_family = AF_VSOCK, 180 .vdiag_states = ~(__u32)0, 181 }, 182 }; 183 struct iovec iov = { 184 .iov_base = &req, 185 .iov_len = sizeof(req), 186 }; 187 struct msghdr msg = { 188 .msg_name = &nladdr, 189 .msg_namelen = sizeof(nladdr), 190 .msg_iov = &iov, 191 .msg_iovlen = 1, 192 }; 193 194 for (;;) { 195 if (sendmsg(fd, &msg, 0) < 0) { 196 if (errno == EINTR) 197 continue; 198 199 perror("sendmsg"); 200 exit(EXIT_FAILURE); 201 } 202 203 return; 204 } 205 } 206 207 static ssize_t recv_resp(int fd, void *buf, size_t len) 208 { 209 struct sockaddr_nl nladdr = { 210 .nl_family = AF_NETLINK, 211 }; 212 struct iovec iov = { 213 .iov_base = buf, 214 .iov_len = len, 215 }; 216 struct msghdr msg = { 217 .msg_name = &nladdr, 218 .msg_namelen = sizeof(nladdr), 219 .msg_iov = &iov, 220 .msg_iovlen = 1, 221 }; 222 ssize_t ret; 223 224 do { 225 ret = recvmsg(fd, &msg, 0); 226 } while (ret < 0 && errno == EINTR); 227 228 if (ret < 0) { 229 perror("recvmsg"); 230 exit(EXIT_FAILURE); 231 } 232 233 return ret; 234 } 235 236 static void add_vsock_stat(struct list_head *sockets, 237 const struct vsock_diag_msg *resp) 238 { 239 struct vsock_stat *st; 240 241 st = malloc(sizeof(*st)); 242 if (!st) { 243 perror("malloc"); 244 exit(EXIT_FAILURE); 245 } 246 247 st->msg = *resp; 248 list_add_tail(&st->list, sockets); 249 } 250 251 /* 252 * Read vsock stats into a list. 253 */ 254 static void read_vsock_stat(struct list_head *sockets) 255 { 256 long buf[8192 / sizeof(long)]; 257 int fd; 258 259 fd = socket(AF_NETLINK, SOCK_RAW, NETLINK_SOCK_DIAG); 260 if (fd < 0) { 261 perror("socket"); 262 exit(EXIT_FAILURE); 263 } 264 265 send_req(fd); 266 267 for (;;) { 268 const struct nlmsghdr *h; 269 ssize_t ret; 270 271 ret = recv_resp(fd, buf, sizeof(buf)); 272 if (ret == 0) 273 goto done; 274 if (ret < sizeof(*h)) { 275 fprintf(stderr, "short read of %zd bytes\n", ret); 276 exit(EXIT_FAILURE); 277 } 278 279 h = (struct nlmsghdr *)buf; 280 281 while (NLMSG_OK(h, ret)) { 282 if (h->nlmsg_type == NLMSG_DONE) 283 goto done; 284 285 if (h->nlmsg_type == NLMSG_ERROR) { 286 const struct nlmsgerr *err = NLMSG_DATA(h); 287 288 if (h->nlmsg_len < NLMSG_LENGTH(sizeof(*err))) 289 fprintf(stderr, "NLMSG_ERROR\n"); 290 else { 291 errno = -err->error; 292 perror("NLMSG_ERROR"); 293 } 294 295 exit(EXIT_FAILURE); 296 } 297 298 if (h->nlmsg_type != SOCK_DIAG_BY_FAMILY) { 299 fprintf(stderr, "unexpected nlmsg_type %#x\n", 300 h->nlmsg_type); 301 exit(EXIT_FAILURE); 302 } 303 if (h->nlmsg_len < 304 NLMSG_LENGTH(sizeof(struct vsock_diag_msg))) { 305 fprintf(stderr, "short vsock_diag_msg\n"); 306 exit(EXIT_FAILURE); 307 } 308 309 add_vsock_stat(sockets, NLMSG_DATA(h)); 310 311 h = NLMSG_NEXT(h, ret); 312 } 313 } 314 315 done: 316 close(fd); 317 } 318 319 static void free_sock_stat(struct list_head *sockets) 320 { 321 struct vsock_stat *st; 322 struct vsock_stat *next; 323 324 list_for_each_entry_safe(st, next, sockets, list) 325 free(st); 326 } 327 328 static void test_no_sockets(const struct test_opts *opts) 329 { 330 LIST_HEAD(sockets); 331 332 read_vsock_stat(&sockets); 333 334 check_no_sockets(&sockets); 335 } 336 337 static void test_listen_socket_server(const struct test_opts *opts) 338 { 339 union { 340 struct sockaddr sa; 341 struct sockaddr_vm svm; 342 } addr = { 343 .svm = { 344 .svm_family = AF_VSOCK, 345 .svm_port = 1234, 346 .svm_cid = VMADDR_CID_ANY, 347 }, 348 }; 349 LIST_HEAD(sockets); 350 struct vsock_stat *st; 351 int fd; 352 353 fd = socket(AF_VSOCK, SOCK_STREAM, 0); 354 355 if (bind(fd, &addr.sa, sizeof(addr.svm)) < 0) { 356 perror("bind"); 357 exit(EXIT_FAILURE); 358 } 359 360 if (listen(fd, 1) < 0) { 361 perror("listen"); 362 exit(EXIT_FAILURE); 363 } 364 365 read_vsock_stat(&sockets); 366 367 check_num_sockets(&sockets, 1); 368 st = find_vsock_stat(&sockets, fd); 369 check_socket_state(st, TCP_LISTEN); 370 371 close(fd); 372 free_sock_stat(&sockets); 373 } 374 375 static void test_connect_client(const struct test_opts *opts) 376 { 377 int fd; 378 LIST_HEAD(sockets); 379 struct vsock_stat *st; 380 381 fd = vsock_stream_connect(opts->peer_cid, 1234); 382 if (fd < 0) { 383 perror("connect"); 384 exit(EXIT_FAILURE); 385 } 386 387 read_vsock_stat(&sockets); 388 389 check_num_sockets(&sockets, 1); 390 st = find_vsock_stat(&sockets, fd); 391 check_socket_state(st, TCP_ESTABLISHED); 392 393 control_expectln("DONE"); 394 control_writeln("DONE"); 395 396 close(fd); 397 free_sock_stat(&sockets); 398 } 399 400 static void test_connect_server(const struct test_opts *opts) 401 { 402 struct vsock_stat *st; 403 LIST_HEAD(sockets); 404 int client_fd; 405 406 client_fd = vsock_stream_accept(VMADDR_CID_ANY, 1234, NULL); 407 if (client_fd < 0) { 408 perror("accept"); 409 exit(EXIT_FAILURE); 410 } 411 412 read_vsock_stat(&sockets); 413 414 check_num_sockets(&sockets, 1); 415 st = find_vsock_stat(&sockets, client_fd); 416 check_socket_state(st, TCP_ESTABLISHED); 417 418 control_writeln("DONE"); 419 control_expectln("DONE"); 420 421 close(client_fd); 422 free_sock_stat(&sockets); 423 } 424 425 static struct test_case test_cases[] = { 426 { 427 .name = "No sockets", 428 .run_server = test_no_sockets, 429 }, 430 { 431 .name = "Listen socket", 432 .run_server = test_listen_socket_server, 433 }, 434 { 435 .name = "Connect", 436 .run_client = test_connect_client, 437 .run_server = test_connect_server, 438 }, 439 {}, 440 }; 441 442 static const char optstring[] = ""; 443 static const struct option longopts[] = { 444 { 445 .name = "control-host", 446 .has_arg = required_argument, 447 .val = 'H', 448 }, 449 { 450 .name = "control-port", 451 .has_arg = required_argument, 452 .val = 'P', 453 }, 454 { 455 .name = "mode", 456 .has_arg = required_argument, 457 .val = 'm', 458 }, 459 { 460 .name = "peer-cid", 461 .has_arg = required_argument, 462 .val = 'p', 463 }, 464 { 465 .name = "list", 466 .has_arg = no_argument, 467 .val = 'l', 468 }, 469 { 470 .name = "skip", 471 .has_arg = required_argument, 472 .val = 's', 473 }, 474 { 475 .name = "help", 476 .has_arg = no_argument, 477 .val = '?', 478 }, 479 {}, 480 }; 481 482 static void usage(void) 483 { 484 fprintf(stderr, "Usage: vsock_diag_test [--help] [--control-host=<host>] --control-port=<port> --mode=client|server --peer-cid=<cid> [--list] [--skip=<test_id>]\n" 485 "\n" 486 " Server: vsock_diag_test --control-port=1234 --mode=server --peer-cid=3\n" 487 " Client: vsock_diag_test --control-host=192.168.0.1 --control-port=1234 --mode=client --peer-cid=2\n" 488 "\n" 489 "Run vsock_diag.ko tests. Must be launched in both\n" 490 "guest and host. One side must use --mode=client and\n" 491 "the other side must use --mode=server.\n" 492 "\n" 493 "A TCP control socket connection is used to coordinate tests\n" 494 "between the client and the server. The server requires a\n" 495 "listen address and the client requires an address to\n" 496 "connect to.\n" 497 "\n" 498 "The CID of the other side must be given with --peer-cid=<cid>.\n" 499 "\n" 500 "Options:\n" 501 " --help This help message\n" 502 " --control-host <host> Server IP address to connect to\n" 503 " --control-port <port> Server port to listen on/connect to\n" 504 " --mode client|server Server or client mode\n" 505 " --peer-cid <cid> CID of the other side\n" 506 " --list List of tests that will be executed\n" 507 " --skip <test_id> Test ID to skip;\n" 508 " use multiple --skip options to skip more tests\n" 509 ); 510 exit(EXIT_FAILURE); 511 } 512 513 int main(int argc, char **argv) 514 { 515 const char *control_host = NULL; 516 const char *control_port = NULL; 517 struct test_opts opts = { 518 .mode = TEST_MODE_UNSET, 519 .peer_cid = VMADDR_CID_ANY, 520 }; 521 522 init_signals(); 523 524 for (;;) { 525 int opt = getopt_long(argc, argv, optstring, longopts, NULL); 526 527 if (opt == -1) 528 break; 529 530 switch (opt) { 531 case 'H': 532 control_host = optarg; 533 break; 534 case 'm': 535 if (strcmp(optarg, "client") == 0) 536 opts.mode = TEST_MODE_CLIENT; 537 else if (strcmp(optarg, "server") == 0) 538 opts.mode = TEST_MODE_SERVER; 539 else { 540 fprintf(stderr, "--mode must be \"client\" or \"server\"\n"); 541 return EXIT_FAILURE; 542 } 543 break; 544 case 'p': 545 opts.peer_cid = parse_cid(optarg); 546 break; 547 case 'P': 548 control_port = optarg; 549 break; 550 case 'l': 551 list_tests(test_cases); 552 break; 553 case 's': 554 skip_test(test_cases, ARRAY_SIZE(test_cases) - 1, 555 optarg); 556 break; 557 case '?': 558 default: 559 usage(); 560 } 561 } 562 563 if (!control_port) 564 usage(); 565 if (opts.mode == TEST_MODE_UNSET) 566 usage(); 567 if (opts.peer_cid == VMADDR_CID_ANY) 568 usage(); 569 570 if (!control_host) { 571 if (opts.mode != TEST_MODE_SERVER) 572 usage(); 573 control_host = "0.0.0.0"; 574 } 575 576 control_init(control_host, control_port, 577 opts.mode == TEST_MODE_SERVER); 578 579 run_tests(test_cases, &opts); 580 581 control_cleanup(); 582 return EXIT_SUCCESS; 583 } 584