1 // SPDX-License-Identifier: GPL-2.0-only 2 /* 3 * vsock_diag_test - vsock_diag.ko test suite 4 * 5 * Copyright (C) 2017 Red Hat, Inc. 6 * 7 * Author: Stefan Hajnoczi <stefanha@redhat.com> 8 */ 9 10 #include <getopt.h> 11 #include <stdio.h> 12 #include <stdbool.h> 13 #include <stdlib.h> 14 #include <string.h> 15 #include <errno.h> 16 #include <unistd.h> 17 #include <signal.h> 18 #include <sys/socket.h> 19 #include <sys/stat.h> 20 #include <sys/types.h> 21 #include <linux/list.h> 22 #include <linux/net.h> 23 #include <linux/netlink.h> 24 #include <linux/sock_diag.h> 25 #include <netinet/tcp.h> 26 27 #include "../../../include/uapi/linux/vm_sockets.h" 28 #include "../../../include/uapi/linux/vm_sockets_diag.h" 29 30 #include "timeout.h" 31 #include "control.h" 32 33 enum test_mode { 34 TEST_MODE_UNSET, 35 TEST_MODE_CLIENT, 36 TEST_MODE_SERVER 37 }; 38 39 /* Per-socket status */ 40 struct vsock_stat { 41 struct list_head list; 42 struct vsock_diag_msg msg; 43 }; 44 45 static const char *sock_type_str(int type) 46 { 47 switch (type) { 48 case SOCK_DGRAM: 49 return "DGRAM"; 50 case SOCK_STREAM: 51 return "STREAM"; 52 default: 53 return "INVALID TYPE"; 54 } 55 } 56 57 static const char *sock_state_str(int state) 58 { 59 switch (state) { 60 case TCP_CLOSE: 61 return "UNCONNECTED"; 62 case TCP_SYN_SENT: 63 return "CONNECTING"; 64 case TCP_ESTABLISHED: 65 return "CONNECTED"; 66 case TCP_CLOSING: 67 return "DISCONNECTING"; 68 case TCP_LISTEN: 69 return "LISTEN"; 70 default: 71 return "INVALID STATE"; 72 } 73 } 74 75 static const char *sock_shutdown_str(int shutdown) 76 { 77 switch (shutdown) { 78 case 1: 79 return "RCV_SHUTDOWN"; 80 case 2: 81 return "SEND_SHUTDOWN"; 82 case 3: 83 return "RCV_SHUTDOWN | SEND_SHUTDOWN"; 84 default: 85 return "0"; 86 } 87 } 88 89 static void print_vsock_addr(FILE *fp, unsigned int cid, unsigned int port) 90 { 91 if (cid == VMADDR_CID_ANY) 92 fprintf(fp, "*:"); 93 else 94 fprintf(fp, "%u:", cid); 95 96 if (port == VMADDR_PORT_ANY) 97 fprintf(fp, "*"); 98 else 99 fprintf(fp, "%u", port); 100 } 101 102 static void print_vsock_stat(FILE *fp, struct vsock_stat *st) 103 { 104 print_vsock_addr(fp, st->msg.vdiag_src_cid, st->msg.vdiag_src_port); 105 fprintf(fp, " "); 106 print_vsock_addr(fp, st->msg.vdiag_dst_cid, st->msg.vdiag_dst_port); 107 fprintf(fp, " %s %s %s %u\n", 108 sock_type_str(st->msg.vdiag_type), 109 sock_state_str(st->msg.vdiag_state), 110 sock_shutdown_str(st->msg.vdiag_shutdown), 111 st->msg.vdiag_ino); 112 } 113 114 static void print_vsock_stats(FILE *fp, struct list_head *head) 115 { 116 struct vsock_stat *st; 117 118 list_for_each_entry(st, head, list) 119 print_vsock_stat(fp, st); 120 } 121 122 static struct vsock_stat *find_vsock_stat(struct list_head *head, int fd) 123 { 124 struct vsock_stat *st; 125 struct stat stat; 126 127 if (fstat(fd, &stat) < 0) { 128 perror("fstat"); 129 exit(EXIT_FAILURE); 130 } 131 132 list_for_each_entry(st, head, list) 133 if (st->msg.vdiag_ino == stat.st_ino) 134 return st; 135 136 fprintf(stderr, "cannot find fd %d\n", fd); 137 exit(EXIT_FAILURE); 138 } 139 140 static void check_no_sockets(struct list_head *head) 141 { 142 if (!list_empty(head)) { 143 fprintf(stderr, "expected no sockets\n"); 144 print_vsock_stats(stderr, head); 145 exit(1); 146 } 147 } 148 149 static void check_num_sockets(struct list_head *head, int expected) 150 { 151 struct list_head *node; 152 int n = 0; 153 154 list_for_each(node, head) 155 n++; 156 157 if (n != expected) { 158 fprintf(stderr, "expected %d sockets, found %d\n", 159 expected, n); 160 print_vsock_stats(stderr, head); 161 exit(EXIT_FAILURE); 162 } 163 } 164 165 static void check_socket_state(struct vsock_stat *st, __u8 state) 166 { 167 if (st->msg.vdiag_state != state) { 168 fprintf(stderr, "expected socket state %#x, got %#x\n", 169 state, st->msg.vdiag_state); 170 exit(EXIT_FAILURE); 171 } 172 } 173 174 static void send_req(int fd) 175 { 176 struct sockaddr_nl nladdr = { 177 .nl_family = AF_NETLINK, 178 }; 179 struct { 180 struct nlmsghdr nlh; 181 struct vsock_diag_req vreq; 182 } req = { 183 .nlh = { 184 .nlmsg_len = sizeof(req), 185 .nlmsg_type = SOCK_DIAG_BY_FAMILY, 186 .nlmsg_flags = NLM_F_REQUEST | NLM_F_DUMP, 187 }, 188 .vreq = { 189 .sdiag_family = AF_VSOCK, 190 .vdiag_states = ~(__u32)0, 191 }, 192 }; 193 struct iovec iov = { 194 .iov_base = &req, 195 .iov_len = sizeof(req), 196 }; 197 struct msghdr msg = { 198 .msg_name = &nladdr, 199 .msg_namelen = sizeof(nladdr), 200 .msg_iov = &iov, 201 .msg_iovlen = 1, 202 }; 203 204 for (;;) { 205 if (sendmsg(fd, &msg, 0) < 0) { 206 if (errno == EINTR) 207 continue; 208 209 perror("sendmsg"); 210 exit(EXIT_FAILURE); 211 } 212 213 return; 214 } 215 } 216 217 static ssize_t recv_resp(int fd, void *buf, size_t len) 218 { 219 struct sockaddr_nl nladdr = { 220 .nl_family = AF_NETLINK, 221 }; 222 struct iovec iov = { 223 .iov_base = buf, 224 .iov_len = len, 225 }; 226 struct msghdr msg = { 227 .msg_name = &nladdr, 228 .msg_namelen = sizeof(nladdr), 229 .msg_iov = &iov, 230 .msg_iovlen = 1, 231 }; 232 ssize_t ret; 233 234 do { 235 ret = recvmsg(fd, &msg, 0); 236 } while (ret < 0 && errno == EINTR); 237 238 if (ret < 0) { 239 perror("recvmsg"); 240 exit(EXIT_FAILURE); 241 } 242 243 return ret; 244 } 245 246 static void add_vsock_stat(struct list_head *sockets, 247 const struct vsock_diag_msg *resp) 248 { 249 struct vsock_stat *st; 250 251 st = malloc(sizeof(*st)); 252 if (!st) { 253 perror("malloc"); 254 exit(EXIT_FAILURE); 255 } 256 257 st->msg = *resp; 258 list_add_tail(&st->list, sockets); 259 } 260 261 /* 262 * Read vsock stats into a list. 263 */ 264 static void read_vsock_stat(struct list_head *sockets) 265 { 266 long buf[8192 / sizeof(long)]; 267 int fd; 268 269 fd = socket(AF_NETLINK, SOCK_RAW, NETLINK_SOCK_DIAG); 270 if (fd < 0) { 271 perror("socket"); 272 exit(EXIT_FAILURE); 273 } 274 275 send_req(fd); 276 277 for (;;) { 278 const struct nlmsghdr *h; 279 ssize_t ret; 280 281 ret = recv_resp(fd, buf, sizeof(buf)); 282 if (ret == 0) 283 goto done; 284 if (ret < sizeof(*h)) { 285 fprintf(stderr, "short read of %zd bytes\n", ret); 286 exit(EXIT_FAILURE); 287 } 288 289 h = (struct nlmsghdr *)buf; 290 291 while (NLMSG_OK(h, ret)) { 292 if (h->nlmsg_type == NLMSG_DONE) 293 goto done; 294 295 if (h->nlmsg_type == NLMSG_ERROR) { 296 const struct nlmsgerr *err = NLMSG_DATA(h); 297 298 if (h->nlmsg_len < NLMSG_LENGTH(sizeof(*err))) 299 fprintf(stderr, "NLMSG_ERROR\n"); 300 else { 301 errno = -err->error; 302 perror("NLMSG_ERROR"); 303 } 304 305 exit(EXIT_FAILURE); 306 } 307 308 if (h->nlmsg_type != SOCK_DIAG_BY_FAMILY) { 309 fprintf(stderr, "unexpected nlmsg_type %#x\n", 310 h->nlmsg_type); 311 exit(EXIT_FAILURE); 312 } 313 if (h->nlmsg_len < 314 NLMSG_LENGTH(sizeof(struct vsock_diag_msg))) { 315 fprintf(stderr, "short vsock_diag_msg\n"); 316 exit(EXIT_FAILURE); 317 } 318 319 add_vsock_stat(sockets, NLMSG_DATA(h)); 320 321 h = NLMSG_NEXT(h, ret); 322 } 323 } 324 325 done: 326 close(fd); 327 } 328 329 static void free_sock_stat(struct list_head *sockets) 330 { 331 struct vsock_stat *st; 332 struct vsock_stat *next; 333 334 list_for_each_entry_safe(st, next, sockets, list) 335 free(st); 336 } 337 338 static void test_no_sockets(unsigned int peer_cid) 339 { 340 LIST_HEAD(sockets); 341 342 read_vsock_stat(&sockets); 343 344 check_no_sockets(&sockets); 345 346 free_sock_stat(&sockets); 347 } 348 349 static void test_listen_socket_server(unsigned int peer_cid) 350 { 351 union { 352 struct sockaddr sa; 353 struct sockaddr_vm svm; 354 } addr = { 355 .svm = { 356 .svm_family = AF_VSOCK, 357 .svm_port = 1234, 358 .svm_cid = VMADDR_CID_ANY, 359 }, 360 }; 361 LIST_HEAD(sockets); 362 struct vsock_stat *st; 363 int fd; 364 365 fd = socket(AF_VSOCK, SOCK_STREAM, 0); 366 367 if (bind(fd, &addr.sa, sizeof(addr.svm)) < 0) { 368 perror("bind"); 369 exit(EXIT_FAILURE); 370 } 371 372 if (listen(fd, 1) < 0) { 373 perror("listen"); 374 exit(EXIT_FAILURE); 375 } 376 377 read_vsock_stat(&sockets); 378 379 check_num_sockets(&sockets, 1); 380 st = find_vsock_stat(&sockets, fd); 381 check_socket_state(st, TCP_LISTEN); 382 383 close(fd); 384 free_sock_stat(&sockets); 385 } 386 387 static void test_connect_client(unsigned int peer_cid) 388 { 389 union { 390 struct sockaddr sa; 391 struct sockaddr_vm svm; 392 } addr = { 393 .svm = { 394 .svm_family = AF_VSOCK, 395 .svm_port = 1234, 396 .svm_cid = peer_cid, 397 }, 398 }; 399 int fd; 400 int ret; 401 LIST_HEAD(sockets); 402 struct vsock_stat *st; 403 404 control_expectln("LISTENING"); 405 406 fd = socket(AF_VSOCK, SOCK_STREAM, 0); 407 408 timeout_begin(TIMEOUT); 409 do { 410 ret = connect(fd, &addr.sa, sizeof(addr.svm)); 411 timeout_check("connect"); 412 } while (ret < 0 && errno == EINTR); 413 timeout_end(); 414 415 if (ret < 0) { 416 perror("connect"); 417 exit(EXIT_FAILURE); 418 } 419 420 read_vsock_stat(&sockets); 421 422 check_num_sockets(&sockets, 1); 423 st = find_vsock_stat(&sockets, fd); 424 check_socket_state(st, TCP_ESTABLISHED); 425 426 control_expectln("DONE"); 427 control_writeln("DONE"); 428 429 close(fd); 430 free_sock_stat(&sockets); 431 } 432 433 static void test_connect_server(unsigned int peer_cid) 434 { 435 union { 436 struct sockaddr sa; 437 struct sockaddr_vm svm; 438 } addr = { 439 .svm = { 440 .svm_family = AF_VSOCK, 441 .svm_port = 1234, 442 .svm_cid = VMADDR_CID_ANY, 443 }, 444 }; 445 union { 446 struct sockaddr sa; 447 struct sockaddr_vm svm; 448 } clientaddr; 449 socklen_t clientaddr_len = sizeof(clientaddr.svm); 450 LIST_HEAD(sockets); 451 struct vsock_stat *st; 452 int fd; 453 int client_fd; 454 455 fd = socket(AF_VSOCK, SOCK_STREAM, 0); 456 457 if (bind(fd, &addr.sa, sizeof(addr.svm)) < 0) { 458 perror("bind"); 459 exit(EXIT_FAILURE); 460 } 461 462 if (listen(fd, 1) < 0) { 463 perror("listen"); 464 exit(EXIT_FAILURE); 465 } 466 467 control_writeln("LISTENING"); 468 469 timeout_begin(TIMEOUT); 470 do { 471 client_fd = accept(fd, &clientaddr.sa, &clientaddr_len); 472 timeout_check("accept"); 473 } while (client_fd < 0 && errno == EINTR); 474 timeout_end(); 475 476 if (client_fd < 0) { 477 perror("accept"); 478 exit(EXIT_FAILURE); 479 } 480 if (clientaddr.sa.sa_family != AF_VSOCK) { 481 fprintf(stderr, "expected AF_VSOCK from accept(2), got %d\n", 482 clientaddr.sa.sa_family); 483 exit(EXIT_FAILURE); 484 } 485 if (clientaddr.svm.svm_cid != peer_cid) { 486 fprintf(stderr, "expected peer CID %u from accept(2), got %u\n", 487 peer_cid, clientaddr.svm.svm_cid); 488 exit(EXIT_FAILURE); 489 } 490 491 read_vsock_stat(&sockets); 492 493 check_num_sockets(&sockets, 2); 494 find_vsock_stat(&sockets, fd); 495 st = find_vsock_stat(&sockets, client_fd); 496 check_socket_state(st, TCP_ESTABLISHED); 497 498 control_writeln("DONE"); 499 control_expectln("DONE"); 500 501 close(client_fd); 502 close(fd); 503 free_sock_stat(&sockets); 504 } 505 506 static struct { 507 const char *name; 508 void (*run_client)(unsigned int peer_cid); 509 void (*run_server)(unsigned int peer_cid); 510 } test_cases[] = { 511 { 512 .name = "No sockets", 513 .run_server = test_no_sockets, 514 }, 515 { 516 .name = "Listen socket", 517 .run_server = test_listen_socket_server, 518 }, 519 { 520 .name = "Connect", 521 .run_client = test_connect_client, 522 .run_server = test_connect_server, 523 }, 524 {}, 525 }; 526 527 static void init_signals(void) 528 { 529 struct sigaction act = { 530 .sa_handler = sigalrm, 531 }; 532 533 sigaction(SIGALRM, &act, NULL); 534 signal(SIGPIPE, SIG_IGN); 535 } 536 537 static unsigned int parse_cid(const char *str) 538 { 539 char *endptr = NULL; 540 unsigned long int n; 541 542 errno = 0; 543 n = strtoul(str, &endptr, 10); 544 if (errno || *endptr != '\0') { 545 fprintf(stderr, "malformed CID \"%s\"\n", str); 546 exit(EXIT_FAILURE); 547 } 548 return n; 549 } 550 551 static const char optstring[] = ""; 552 static const struct option longopts[] = { 553 { 554 .name = "control-host", 555 .has_arg = required_argument, 556 .val = 'H', 557 }, 558 { 559 .name = "control-port", 560 .has_arg = required_argument, 561 .val = 'P', 562 }, 563 { 564 .name = "mode", 565 .has_arg = required_argument, 566 .val = 'm', 567 }, 568 { 569 .name = "peer-cid", 570 .has_arg = required_argument, 571 .val = 'p', 572 }, 573 { 574 .name = "help", 575 .has_arg = no_argument, 576 .val = '?', 577 }, 578 {}, 579 }; 580 581 static void usage(void) 582 { 583 fprintf(stderr, "Usage: vsock_diag_test [--help] [--control-host=<host>] --control-port=<port> --mode=client|server --peer-cid=<cid>\n" 584 "\n" 585 " Server: vsock_diag_test --control-port=1234 --mode=server --peer-cid=3\n" 586 " Client: vsock_diag_test --control-host=192.168.0.1 --control-port=1234 --mode=client --peer-cid=2\n" 587 "\n" 588 "Run vsock_diag.ko tests. Must be launched in both\n" 589 "guest and host. One side must use --mode=client and\n" 590 "the other side must use --mode=server.\n" 591 "\n" 592 "A TCP control socket connection is used to coordinate tests\n" 593 "between the client and the server. The server requires a\n" 594 "listen address and the client requires an address to\n" 595 "connect to.\n" 596 "\n" 597 "The CID of the other side must be given with --peer-cid=<cid>.\n"); 598 exit(EXIT_FAILURE); 599 } 600 601 int main(int argc, char **argv) 602 { 603 const char *control_host = NULL; 604 const char *control_port = NULL; 605 int mode = TEST_MODE_UNSET; 606 unsigned int peer_cid = VMADDR_CID_ANY; 607 int i; 608 609 init_signals(); 610 611 for (;;) { 612 int opt = getopt_long(argc, argv, optstring, longopts, NULL); 613 614 if (opt == -1) 615 break; 616 617 switch (opt) { 618 case 'H': 619 control_host = optarg; 620 break; 621 case 'm': 622 if (strcmp(optarg, "client") == 0) 623 mode = TEST_MODE_CLIENT; 624 else if (strcmp(optarg, "server") == 0) 625 mode = TEST_MODE_SERVER; 626 else { 627 fprintf(stderr, "--mode must be \"client\" or \"server\"\n"); 628 return EXIT_FAILURE; 629 } 630 break; 631 case 'p': 632 peer_cid = parse_cid(optarg); 633 break; 634 case 'P': 635 control_port = optarg; 636 break; 637 case '?': 638 default: 639 usage(); 640 } 641 } 642 643 if (!control_port) 644 usage(); 645 if (mode == TEST_MODE_UNSET) 646 usage(); 647 if (peer_cid == VMADDR_CID_ANY) 648 usage(); 649 650 if (!control_host) { 651 if (mode != TEST_MODE_SERVER) 652 usage(); 653 control_host = "0.0.0.0"; 654 } 655 656 control_init(control_host, control_port, mode == TEST_MODE_SERVER); 657 658 for (i = 0; test_cases[i].name; i++) { 659 void (*run)(unsigned int peer_cid); 660 661 printf("%s...", test_cases[i].name); 662 fflush(stdout); 663 664 if (mode == TEST_MODE_CLIENT) 665 run = test_cases[i].run_client; 666 else 667 run = test_cases[i].run_server; 668 669 if (run) 670 run(peer_cid); 671 672 printf("ok\n"); 673 } 674 675 control_cleanup(); 676 return EXIT_SUCCESS; 677 } 678