1 /* 2 * vsock_diag_test - vsock_diag.ko test suite 3 * 4 * Copyright (C) 2017 Red Hat, Inc. 5 * 6 * Author: Stefan Hajnoczi <stefanha@redhat.com> 7 * 8 * This program is free software; you can redistribute it and/or 9 * modify it under the terms of the GNU General Public License 10 * as published by the Free Software Foundation; version 2 11 * of the License. 12 */ 13 14 #include <getopt.h> 15 #include <stdio.h> 16 #include <stdbool.h> 17 #include <stdlib.h> 18 #include <string.h> 19 #include <errno.h> 20 #include <unistd.h> 21 #include <signal.h> 22 #include <sys/socket.h> 23 #include <sys/stat.h> 24 #include <sys/types.h> 25 #include <linux/list.h> 26 #include <linux/net.h> 27 #include <linux/netlink.h> 28 #include <linux/sock_diag.h> 29 #include <netinet/tcp.h> 30 31 #include "../../../include/uapi/linux/vm_sockets.h" 32 #include "../../../include/uapi/linux/vm_sockets_diag.h" 33 34 #include "timeout.h" 35 #include "control.h" 36 37 enum test_mode { 38 TEST_MODE_UNSET, 39 TEST_MODE_CLIENT, 40 TEST_MODE_SERVER 41 }; 42 43 /* Per-socket status */ 44 struct vsock_stat { 45 struct list_head list; 46 struct vsock_diag_msg msg; 47 }; 48 49 static const char *sock_type_str(int type) 50 { 51 switch (type) { 52 case SOCK_DGRAM: 53 return "DGRAM"; 54 case SOCK_STREAM: 55 return "STREAM"; 56 default: 57 return "INVALID TYPE"; 58 } 59 } 60 61 static const char *sock_state_str(int state) 62 { 63 switch (state) { 64 case TCP_CLOSE: 65 return "UNCONNECTED"; 66 case TCP_SYN_SENT: 67 return "CONNECTING"; 68 case TCP_ESTABLISHED: 69 return "CONNECTED"; 70 case TCP_CLOSING: 71 return "DISCONNECTING"; 72 case TCP_LISTEN: 73 return "LISTEN"; 74 default: 75 return "INVALID STATE"; 76 } 77 } 78 79 static const char *sock_shutdown_str(int shutdown) 80 { 81 switch (shutdown) { 82 case 1: 83 return "RCV_SHUTDOWN"; 84 case 2: 85 return "SEND_SHUTDOWN"; 86 case 3: 87 return "RCV_SHUTDOWN | SEND_SHUTDOWN"; 88 default: 89 return "0"; 90 } 91 } 92 93 static void print_vsock_addr(FILE *fp, unsigned int cid, unsigned int port) 94 { 95 if (cid == VMADDR_CID_ANY) 96 fprintf(fp, "*:"); 97 else 98 fprintf(fp, "%u:", cid); 99 100 if (port == VMADDR_PORT_ANY) 101 fprintf(fp, "*"); 102 else 103 fprintf(fp, "%u", port); 104 } 105 106 static void print_vsock_stat(FILE *fp, struct vsock_stat *st) 107 { 108 print_vsock_addr(fp, st->msg.vdiag_src_cid, st->msg.vdiag_src_port); 109 fprintf(fp, " "); 110 print_vsock_addr(fp, st->msg.vdiag_dst_cid, st->msg.vdiag_dst_port); 111 fprintf(fp, " %s %s %s %u\n", 112 sock_type_str(st->msg.vdiag_type), 113 sock_state_str(st->msg.vdiag_state), 114 sock_shutdown_str(st->msg.vdiag_shutdown), 115 st->msg.vdiag_ino); 116 } 117 118 static void print_vsock_stats(FILE *fp, struct list_head *head) 119 { 120 struct vsock_stat *st; 121 122 list_for_each_entry(st, head, list) 123 print_vsock_stat(fp, st); 124 } 125 126 static struct vsock_stat *find_vsock_stat(struct list_head *head, int fd) 127 { 128 struct vsock_stat *st; 129 struct stat stat; 130 131 if (fstat(fd, &stat) < 0) { 132 perror("fstat"); 133 exit(EXIT_FAILURE); 134 } 135 136 list_for_each_entry(st, head, list) 137 if (st->msg.vdiag_ino == stat.st_ino) 138 return st; 139 140 fprintf(stderr, "cannot find fd %d\n", fd); 141 exit(EXIT_FAILURE); 142 } 143 144 static void check_no_sockets(struct list_head *head) 145 { 146 if (!list_empty(head)) { 147 fprintf(stderr, "expected no sockets\n"); 148 print_vsock_stats(stderr, head); 149 exit(1); 150 } 151 } 152 153 static void check_num_sockets(struct list_head *head, int expected) 154 { 155 struct list_head *node; 156 int n = 0; 157 158 list_for_each(node, head) 159 n++; 160 161 if (n != expected) { 162 fprintf(stderr, "expected %d sockets, found %d\n", 163 expected, n); 164 print_vsock_stats(stderr, head); 165 exit(EXIT_FAILURE); 166 } 167 } 168 169 static void check_socket_state(struct vsock_stat *st, __u8 state) 170 { 171 if (st->msg.vdiag_state != state) { 172 fprintf(stderr, "expected socket state %#x, got %#x\n", 173 state, st->msg.vdiag_state); 174 exit(EXIT_FAILURE); 175 } 176 } 177 178 static void send_req(int fd) 179 { 180 struct sockaddr_nl nladdr = { 181 .nl_family = AF_NETLINK, 182 }; 183 struct { 184 struct nlmsghdr nlh; 185 struct vsock_diag_req vreq; 186 } req = { 187 .nlh = { 188 .nlmsg_len = sizeof(req), 189 .nlmsg_type = SOCK_DIAG_BY_FAMILY, 190 .nlmsg_flags = NLM_F_REQUEST | NLM_F_DUMP, 191 }, 192 .vreq = { 193 .sdiag_family = AF_VSOCK, 194 .vdiag_states = ~(__u32)0, 195 }, 196 }; 197 struct iovec iov = { 198 .iov_base = &req, 199 .iov_len = sizeof(req), 200 }; 201 struct msghdr msg = { 202 .msg_name = &nladdr, 203 .msg_namelen = sizeof(nladdr), 204 .msg_iov = &iov, 205 .msg_iovlen = 1, 206 }; 207 208 for (;;) { 209 if (sendmsg(fd, &msg, 0) < 0) { 210 if (errno == EINTR) 211 continue; 212 213 perror("sendmsg"); 214 exit(EXIT_FAILURE); 215 } 216 217 return; 218 } 219 } 220 221 static ssize_t recv_resp(int fd, void *buf, size_t len) 222 { 223 struct sockaddr_nl nladdr = { 224 .nl_family = AF_NETLINK, 225 }; 226 struct iovec iov = { 227 .iov_base = buf, 228 .iov_len = len, 229 }; 230 struct msghdr msg = { 231 .msg_name = &nladdr, 232 .msg_namelen = sizeof(nladdr), 233 .msg_iov = &iov, 234 .msg_iovlen = 1, 235 }; 236 ssize_t ret; 237 238 do { 239 ret = recvmsg(fd, &msg, 0); 240 } while (ret < 0 && errno == EINTR); 241 242 if (ret < 0) { 243 perror("recvmsg"); 244 exit(EXIT_FAILURE); 245 } 246 247 return ret; 248 } 249 250 static void add_vsock_stat(struct list_head *sockets, 251 const struct vsock_diag_msg *resp) 252 { 253 struct vsock_stat *st; 254 255 st = malloc(sizeof(*st)); 256 if (!st) { 257 perror("malloc"); 258 exit(EXIT_FAILURE); 259 } 260 261 st->msg = *resp; 262 list_add_tail(&st->list, sockets); 263 } 264 265 /* 266 * Read vsock stats into a list. 267 */ 268 static void read_vsock_stat(struct list_head *sockets) 269 { 270 long buf[8192 / sizeof(long)]; 271 int fd; 272 273 fd = socket(AF_NETLINK, SOCK_RAW, NETLINK_SOCK_DIAG); 274 if (fd < 0) { 275 perror("socket"); 276 exit(EXIT_FAILURE); 277 } 278 279 send_req(fd); 280 281 for (;;) { 282 const struct nlmsghdr *h; 283 ssize_t ret; 284 285 ret = recv_resp(fd, buf, sizeof(buf)); 286 if (ret == 0) 287 goto done; 288 if (ret < sizeof(*h)) { 289 fprintf(stderr, "short read of %zd bytes\n", ret); 290 exit(EXIT_FAILURE); 291 } 292 293 h = (struct nlmsghdr *)buf; 294 295 while (NLMSG_OK(h, ret)) { 296 if (h->nlmsg_type == NLMSG_DONE) 297 goto done; 298 299 if (h->nlmsg_type == NLMSG_ERROR) { 300 const struct nlmsgerr *err = NLMSG_DATA(h); 301 302 if (h->nlmsg_len < NLMSG_LENGTH(sizeof(*err))) 303 fprintf(stderr, "NLMSG_ERROR\n"); 304 else { 305 errno = -err->error; 306 perror("NLMSG_ERROR"); 307 } 308 309 exit(EXIT_FAILURE); 310 } 311 312 if (h->nlmsg_type != SOCK_DIAG_BY_FAMILY) { 313 fprintf(stderr, "unexpected nlmsg_type %#x\n", 314 h->nlmsg_type); 315 exit(EXIT_FAILURE); 316 } 317 if (h->nlmsg_len < 318 NLMSG_LENGTH(sizeof(struct vsock_diag_msg))) { 319 fprintf(stderr, "short vsock_diag_msg\n"); 320 exit(EXIT_FAILURE); 321 } 322 323 add_vsock_stat(sockets, NLMSG_DATA(h)); 324 325 h = NLMSG_NEXT(h, ret); 326 } 327 } 328 329 done: 330 close(fd); 331 } 332 333 static void free_sock_stat(struct list_head *sockets) 334 { 335 struct vsock_stat *st; 336 struct vsock_stat *next; 337 338 list_for_each_entry_safe(st, next, sockets, list) 339 free(st); 340 } 341 342 static void test_no_sockets(unsigned int peer_cid) 343 { 344 LIST_HEAD(sockets); 345 346 read_vsock_stat(&sockets); 347 348 check_no_sockets(&sockets); 349 350 free_sock_stat(&sockets); 351 } 352 353 static void test_listen_socket_server(unsigned int peer_cid) 354 { 355 union { 356 struct sockaddr sa; 357 struct sockaddr_vm svm; 358 } addr = { 359 .svm = { 360 .svm_family = AF_VSOCK, 361 .svm_port = 1234, 362 .svm_cid = VMADDR_CID_ANY, 363 }, 364 }; 365 LIST_HEAD(sockets); 366 struct vsock_stat *st; 367 int fd; 368 369 fd = socket(AF_VSOCK, SOCK_STREAM, 0); 370 371 if (bind(fd, &addr.sa, sizeof(addr.svm)) < 0) { 372 perror("bind"); 373 exit(EXIT_FAILURE); 374 } 375 376 if (listen(fd, 1) < 0) { 377 perror("listen"); 378 exit(EXIT_FAILURE); 379 } 380 381 read_vsock_stat(&sockets); 382 383 check_num_sockets(&sockets, 1); 384 st = find_vsock_stat(&sockets, fd); 385 check_socket_state(st, TCP_LISTEN); 386 387 close(fd); 388 free_sock_stat(&sockets); 389 } 390 391 static void test_connect_client(unsigned int peer_cid) 392 { 393 union { 394 struct sockaddr sa; 395 struct sockaddr_vm svm; 396 } addr = { 397 .svm = { 398 .svm_family = AF_VSOCK, 399 .svm_port = 1234, 400 .svm_cid = peer_cid, 401 }, 402 }; 403 int fd; 404 int ret; 405 LIST_HEAD(sockets); 406 struct vsock_stat *st; 407 408 control_expectln("LISTENING"); 409 410 fd = socket(AF_VSOCK, SOCK_STREAM, 0); 411 412 timeout_begin(TIMEOUT); 413 do { 414 ret = connect(fd, &addr.sa, sizeof(addr.svm)); 415 timeout_check("connect"); 416 } while (ret < 0 && errno == EINTR); 417 timeout_end(); 418 419 if (ret < 0) { 420 perror("connect"); 421 exit(EXIT_FAILURE); 422 } 423 424 read_vsock_stat(&sockets); 425 426 check_num_sockets(&sockets, 1); 427 st = find_vsock_stat(&sockets, fd); 428 check_socket_state(st, TCP_ESTABLISHED); 429 430 control_expectln("DONE"); 431 control_writeln("DONE"); 432 433 close(fd); 434 free_sock_stat(&sockets); 435 } 436 437 static void test_connect_server(unsigned int peer_cid) 438 { 439 union { 440 struct sockaddr sa; 441 struct sockaddr_vm svm; 442 } addr = { 443 .svm = { 444 .svm_family = AF_VSOCK, 445 .svm_port = 1234, 446 .svm_cid = VMADDR_CID_ANY, 447 }, 448 }; 449 union { 450 struct sockaddr sa; 451 struct sockaddr_vm svm; 452 } clientaddr; 453 socklen_t clientaddr_len = sizeof(clientaddr.svm); 454 LIST_HEAD(sockets); 455 struct vsock_stat *st; 456 int fd; 457 int client_fd; 458 459 fd = socket(AF_VSOCK, SOCK_STREAM, 0); 460 461 if (bind(fd, &addr.sa, sizeof(addr.svm)) < 0) { 462 perror("bind"); 463 exit(EXIT_FAILURE); 464 } 465 466 if (listen(fd, 1) < 0) { 467 perror("listen"); 468 exit(EXIT_FAILURE); 469 } 470 471 control_writeln("LISTENING"); 472 473 timeout_begin(TIMEOUT); 474 do { 475 client_fd = accept(fd, &clientaddr.sa, &clientaddr_len); 476 timeout_check("accept"); 477 } while (client_fd < 0 && errno == EINTR); 478 timeout_end(); 479 480 if (client_fd < 0) { 481 perror("accept"); 482 exit(EXIT_FAILURE); 483 } 484 if (clientaddr.sa.sa_family != AF_VSOCK) { 485 fprintf(stderr, "expected AF_VSOCK from accept(2), got %d\n", 486 clientaddr.sa.sa_family); 487 exit(EXIT_FAILURE); 488 } 489 if (clientaddr.svm.svm_cid != peer_cid) { 490 fprintf(stderr, "expected peer CID %u from accept(2), got %u\n", 491 peer_cid, clientaddr.svm.svm_cid); 492 exit(EXIT_FAILURE); 493 } 494 495 read_vsock_stat(&sockets); 496 497 check_num_sockets(&sockets, 2); 498 find_vsock_stat(&sockets, fd); 499 st = find_vsock_stat(&sockets, client_fd); 500 check_socket_state(st, TCP_ESTABLISHED); 501 502 control_writeln("DONE"); 503 control_expectln("DONE"); 504 505 close(client_fd); 506 close(fd); 507 free_sock_stat(&sockets); 508 } 509 510 static struct { 511 const char *name; 512 void (*run_client)(unsigned int peer_cid); 513 void (*run_server)(unsigned int peer_cid); 514 } test_cases[] = { 515 { 516 .name = "No sockets", 517 .run_server = test_no_sockets, 518 }, 519 { 520 .name = "Listen socket", 521 .run_server = test_listen_socket_server, 522 }, 523 { 524 .name = "Connect", 525 .run_client = test_connect_client, 526 .run_server = test_connect_server, 527 }, 528 {}, 529 }; 530 531 static void init_signals(void) 532 { 533 struct sigaction act = { 534 .sa_handler = sigalrm, 535 }; 536 537 sigaction(SIGALRM, &act, NULL); 538 signal(SIGPIPE, SIG_IGN); 539 } 540 541 static unsigned int parse_cid(const char *str) 542 { 543 char *endptr = NULL; 544 unsigned long int n; 545 546 errno = 0; 547 n = strtoul(str, &endptr, 10); 548 if (errno || *endptr != '\0') { 549 fprintf(stderr, "malformed CID \"%s\"\n", str); 550 exit(EXIT_FAILURE); 551 } 552 return n; 553 } 554 555 static const char optstring[] = ""; 556 static const struct option longopts[] = { 557 { 558 .name = "control-host", 559 .has_arg = required_argument, 560 .val = 'H', 561 }, 562 { 563 .name = "control-port", 564 .has_arg = required_argument, 565 .val = 'P', 566 }, 567 { 568 .name = "mode", 569 .has_arg = required_argument, 570 .val = 'm', 571 }, 572 { 573 .name = "peer-cid", 574 .has_arg = required_argument, 575 .val = 'p', 576 }, 577 { 578 .name = "help", 579 .has_arg = no_argument, 580 .val = '?', 581 }, 582 {}, 583 }; 584 585 static void usage(void) 586 { 587 fprintf(stderr, "Usage: vsock_diag_test [--help] [--control-host=<host>] --control-port=<port> --mode=client|server --peer-cid=<cid>\n" 588 "\n" 589 " Server: vsock_diag_test --control-port=1234 --mode=server --peer-cid=3\n" 590 " Client: vsock_diag_test --control-host=192.168.0.1 --control-port=1234 --mode=client --peer-cid=2\n" 591 "\n" 592 "Run vsock_diag.ko tests. Must be launched in both\n" 593 "guest and host. One side must use --mode=client and\n" 594 "the other side must use --mode=server.\n" 595 "\n" 596 "A TCP control socket connection is used to coordinate tests\n" 597 "between the client and the server. The server requires a\n" 598 "listen address and the client requires an address to\n" 599 "connect to.\n" 600 "\n" 601 "The CID of the other side must be given with --peer-cid=<cid>.\n"); 602 exit(EXIT_FAILURE); 603 } 604 605 int main(int argc, char **argv) 606 { 607 const char *control_host = NULL; 608 const char *control_port = NULL; 609 int mode = TEST_MODE_UNSET; 610 unsigned int peer_cid = VMADDR_CID_ANY; 611 int i; 612 613 init_signals(); 614 615 for (;;) { 616 int opt = getopt_long(argc, argv, optstring, longopts, NULL); 617 618 if (opt == -1) 619 break; 620 621 switch (opt) { 622 case 'H': 623 control_host = optarg; 624 break; 625 case 'm': 626 if (strcmp(optarg, "client") == 0) 627 mode = TEST_MODE_CLIENT; 628 else if (strcmp(optarg, "server") == 0) 629 mode = TEST_MODE_SERVER; 630 else { 631 fprintf(stderr, "--mode must be \"client\" or \"server\"\n"); 632 return EXIT_FAILURE; 633 } 634 break; 635 case 'p': 636 peer_cid = parse_cid(optarg); 637 break; 638 case 'P': 639 control_port = optarg; 640 break; 641 case '?': 642 default: 643 usage(); 644 } 645 } 646 647 if (!control_port) 648 usage(); 649 if (mode == TEST_MODE_UNSET) 650 usage(); 651 if (peer_cid == VMADDR_CID_ANY) 652 usage(); 653 654 if (!control_host) { 655 if (mode != TEST_MODE_SERVER) 656 usage(); 657 control_host = "0.0.0.0"; 658 } 659 660 control_init(control_host, control_port, mode == TEST_MODE_SERVER); 661 662 for (i = 0; test_cases[i].name; i++) { 663 void (*run)(unsigned int peer_cid); 664 665 printf("%s...", test_cases[i].name); 666 fflush(stdout); 667 668 if (mode == TEST_MODE_CLIENT) 669 run = test_cases[i].run_client; 670 else 671 run = test_cases[i].run_server; 672 673 if (run) 674 run(peer_cid); 675 676 printf("ok\n"); 677 } 678 679 control_cleanup(); 680 return EXIT_SUCCESS; 681 } 682