1 /* 2 * Vhost User library 3 * 4 * Copyright IBM, Corp. 2007 5 * Copyright (c) 2016 Red Hat, Inc. 6 * 7 * Authors: 8 * Anthony Liguori <aliguori@us.ibm.com> 9 * Marc-André Lureau <mlureau@redhat.com> 10 * Victor Kaplansky <victork@redhat.com> 11 * 12 * This work is licensed under the terms of the GNU GPL, version 2 or 13 * later. See the COPYING file in the top-level directory. 14 */ 15 16 /* this code avoids GLib dependency */ 17 #include <stdlib.h> 18 #include <stdio.h> 19 #include <unistd.h> 20 #include <stdarg.h> 21 #include <errno.h> 22 #include <string.h> 23 #include <assert.h> 24 #include <inttypes.h> 25 #include <sys/types.h> 26 #include <sys/socket.h> 27 #include <sys/eventfd.h> 28 #include <sys/mman.h> 29 #include <endian.h> 30 31 #if defined(__linux__) 32 #include <sys/syscall.h> 33 #include <fcntl.h> 34 #include <sys/ioctl.h> 35 #include <linux/vhost.h> 36 37 #ifdef __NR_userfaultfd 38 #include <linux/userfaultfd.h> 39 #endif 40 41 #endif 42 43 #include "include/atomic.h" 44 45 #include "libvhost-user.h" 46 47 /* usually provided by GLib */ 48 #if __GNUC__ > 2 || (__GNUC__ == 2 && __GNUC_MINOR__ > 4) 49 #if !defined(__clang__) && (__GNUC__ == 4 && __GNUC_MINOR__ == 4) 50 #define G_GNUC_PRINTF(format_idx, arg_idx) \ 51 __attribute__((__format__(gnu_printf, format_idx, arg_idx))) 52 #else 53 #define G_GNUC_PRINTF(format_idx, arg_idx) \ 54 __attribute__((__format__(__printf__, format_idx, arg_idx))) 55 #endif 56 #else /* !__GNUC__ */ 57 #define G_GNUC_PRINTF(format_idx, arg_idx) 58 #endif /* !__GNUC__ */ 59 #ifndef MIN 60 #define MIN(x, y) ({ \ 61 typeof(x) _min1 = (x); \ 62 typeof(y) _min2 = (y); \ 63 (void) (&_min1 == &_min2); \ 64 _min1 < _min2 ? _min1 : _min2; }) 65 #endif 66 67 /* Round number down to multiple */ 68 #define ALIGN_DOWN(n, m) ((n) / (m) * (m)) 69 70 /* Round number up to multiple */ 71 #define ALIGN_UP(n, m) ALIGN_DOWN((n) + (m) - 1, (m)) 72 73 #ifndef unlikely 74 #define unlikely(x) __builtin_expect(!!(x), 0) 75 #endif 76 77 /* Align each region to cache line size in inflight buffer */ 78 #define INFLIGHT_ALIGNMENT 64 79 80 /* The version of inflight buffer */ 81 #define INFLIGHT_VERSION 1 82 83 /* The version of the protocol we support */ 84 #define VHOST_USER_VERSION 1 85 #define LIBVHOST_USER_DEBUG 0 86 87 #define DPRINT(...) \ 88 do { \ 89 if (LIBVHOST_USER_DEBUG) { \ 90 fprintf(stderr, __VA_ARGS__); \ 91 } \ 92 } while (0) 93 94 static inline 95 bool has_feature(uint64_t features, unsigned int fbit) 96 { 97 assert(fbit < 64); 98 return !!(features & (1ULL << fbit)); 99 } 100 101 static inline 102 bool vu_has_feature(VuDev *dev, 103 unsigned int fbit) 104 { 105 return has_feature(dev->features, fbit); 106 } 107 108 static inline bool vu_has_protocol_feature(VuDev *dev, unsigned int fbit) 109 { 110 return has_feature(dev->protocol_features, fbit); 111 } 112 113 const char * 114 vu_request_to_string(unsigned int req) 115 { 116 #define REQ(req) [req] = #req 117 static const char *vu_request_str[] = { 118 REQ(VHOST_USER_NONE), 119 REQ(VHOST_USER_GET_FEATURES), 120 REQ(VHOST_USER_SET_FEATURES), 121 REQ(VHOST_USER_SET_OWNER), 122 REQ(VHOST_USER_RESET_OWNER), 123 REQ(VHOST_USER_SET_MEM_TABLE), 124 REQ(VHOST_USER_SET_LOG_BASE), 125 REQ(VHOST_USER_SET_LOG_FD), 126 REQ(VHOST_USER_SET_VRING_NUM), 127 REQ(VHOST_USER_SET_VRING_ADDR), 128 REQ(VHOST_USER_SET_VRING_BASE), 129 REQ(VHOST_USER_GET_VRING_BASE), 130 REQ(VHOST_USER_SET_VRING_KICK), 131 REQ(VHOST_USER_SET_VRING_CALL), 132 REQ(VHOST_USER_SET_VRING_ERR), 133 REQ(VHOST_USER_GET_PROTOCOL_FEATURES), 134 REQ(VHOST_USER_SET_PROTOCOL_FEATURES), 135 REQ(VHOST_USER_GET_QUEUE_NUM), 136 REQ(VHOST_USER_SET_VRING_ENABLE), 137 REQ(VHOST_USER_SEND_RARP), 138 REQ(VHOST_USER_NET_SET_MTU), 139 REQ(VHOST_USER_SET_SLAVE_REQ_FD), 140 REQ(VHOST_USER_IOTLB_MSG), 141 REQ(VHOST_USER_SET_VRING_ENDIAN), 142 REQ(VHOST_USER_GET_CONFIG), 143 REQ(VHOST_USER_SET_CONFIG), 144 REQ(VHOST_USER_POSTCOPY_ADVISE), 145 REQ(VHOST_USER_POSTCOPY_LISTEN), 146 REQ(VHOST_USER_POSTCOPY_END), 147 REQ(VHOST_USER_GET_INFLIGHT_FD), 148 REQ(VHOST_USER_SET_INFLIGHT_FD), 149 REQ(VHOST_USER_GPU_SET_SOCKET), 150 REQ(VHOST_USER_VRING_KICK), 151 REQ(VHOST_USER_GET_MAX_MEM_SLOTS), 152 REQ(VHOST_USER_ADD_MEM_REG), 153 REQ(VHOST_USER_REM_MEM_REG), 154 REQ(VHOST_USER_MAX), 155 }; 156 #undef REQ 157 158 if (req < VHOST_USER_MAX) { 159 return vu_request_str[req]; 160 } else { 161 return "unknown"; 162 } 163 } 164 165 static void G_GNUC_PRINTF(2, 3) 166 vu_panic(VuDev *dev, const char *msg, ...) 167 { 168 char *buf = NULL; 169 va_list ap; 170 171 va_start(ap, msg); 172 if (vasprintf(&buf, msg, ap) < 0) { 173 buf = NULL; 174 } 175 va_end(ap); 176 177 dev->broken = true; 178 dev->panic(dev, buf); 179 free(buf); 180 181 /* 182 * FIXME: 183 * find a way to call virtio_error, or perhaps close the connection? 184 */ 185 } 186 187 /* Translate guest physical address to our virtual address. */ 188 void * 189 vu_gpa_to_va(VuDev *dev, uint64_t *plen, uint64_t guest_addr) 190 { 191 int i; 192 193 if (*plen == 0) { 194 return NULL; 195 } 196 197 /* Find matching memory region. */ 198 for (i = 0; i < dev->nregions; i++) { 199 VuDevRegion *r = &dev->regions[i]; 200 201 if ((guest_addr >= r->gpa) && (guest_addr < (r->gpa + r->size))) { 202 if ((guest_addr + *plen) > (r->gpa + r->size)) { 203 *plen = r->gpa + r->size - guest_addr; 204 } 205 return (void *)(uintptr_t) 206 guest_addr - r->gpa + r->mmap_addr + r->mmap_offset; 207 } 208 } 209 210 return NULL; 211 } 212 213 /* Translate qemu virtual address to our virtual address. */ 214 static void * 215 qva_to_va(VuDev *dev, uint64_t qemu_addr) 216 { 217 int i; 218 219 /* Find matching memory region. */ 220 for (i = 0; i < dev->nregions; i++) { 221 VuDevRegion *r = &dev->regions[i]; 222 223 if ((qemu_addr >= r->qva) && (qemu_addr < (r->qva + r->size))) { 224 return (void *)(uintptr_t) 225 qemu_addr - r->qva + r->mmap_addr + r->mmap_offset; 226 } 227 } 228 229 return NULL; 230 } 231 232 static void 233 vmsg_close_fds(VhostUserMsg *vmsg) 234 { 235 int i; 236 237 for (i = 0; i < vmsg->fd_num; i++) { 238 close(vmsg->fds[i]); 239 } 240 } 241 242 /* Set reply payload.u64 and clear request flags and fd_num */ 243 static void vmsg_set_reply_u64(VhostUserMsg *vmsg, uint64_t val) 244 { 245 vmsg->flags = 0; /* defaults will be set by vu_send_reply() */ 246 vmsg->size = sizeof(vmsg->payload.u64); 247 vmsg->payload.u64 = val; 248 vmsg->fd_num = 0; 249 } 250 251 /* A test to see if we have userfault available */ 252 static bool 253 have_userfault(void) 254 { 255 #if defined(__linux__) && defined(__NR_userfaultfd) &&\ 256 defined(UFFD_FEATURE_MISSING_SHMEM) &&\ 257 defined(UFFD_FEATURE_MISSING_HUGETLBFS) 258 /* Now test the kernel we're running on really has the features */ 259 int ufd = syscall(__NR_userfaultfd, O_CLOEXEC | O_NONBLOCK); 260 struct uffdio_api api_struct; 261 if (ufd < 0) { 262 return false; 263 } 264 265 api_struct.api = UFFD_API; 266 api_struct.features = UFFD_FEATURE_MISSING_SHMEM | 267 UFFD_FEATURE_MISSING_HUGETLBFS; 268 if (ioctl(ufd, UFFDIO_API, &api_struct)) { 269 close(ufd); 270 return false; 271 } 272 close(ufd); 273 return true; 274 275 #else 276 return false; 277 #endif 278 } 279 280 static bool 281 vu_message_read_default(VuDev *dev, int conn_fd, VhostUserMsg *vmsg) 282 { 283 char control[CMSG_SPACE(VHOST_MEMORY_BASELINE_NREGIONS * sizeof(int))] = {}; 284 struct iovec iov = { 285 .iov_base = (char *)vmsg, 286 .iov_len = VHOST_USER_HDR_SIZE, 287 }; 288 struct msghdr msg = { 289 .msg_iov = &iov, 290 .msg_iovlen = 1, 291 .msg_control = control, 292 .msg_controllen = sizeof(control), 293 }; 294 size_t fd_size; 295 struct cmsghdr *cmsg; 296 int rc; 297 298 do { 299 rc = recvmsg(conn_fd, &msg, 0); 300 } while (rc < 0 && (errno == EINTR || errno == EAGAIN)); 301 302 if (rc < 0) { 303 vu_panic(dev, "Error while recvmsg: %s", strerror(errno)); 304 return false; 305 } 306 307 vmsg->fd_num = 0; 308 for (cmsg = CMSG_FIRSTHDR(&msg); 309 cmsg != NULL; 310 cmsg = CMSG_NXTHDR(&msg, cmsg)) 311 { 312 if (cmsg->cmsg_level == SOL_SOCKET && cmsg->cmsg_type == SCM_RIGHTS) { 313 fd_size = cmsg->cmsg_len - CMSG_LEN(0); 314 vmsg->fd_num = fd_size / sizeof(int); 315 memcpy(vmsg->fds, CMSG_DATA(cmsg), fd_size); 316 break; 317 } 318 } 319 320 if (vmsg->size > sizeof(vmsg->payload)) { 321 vu_panic(dev, 322 "Error: too big message request: %d, size: vmsg->size: %u, " 323 "while sizeof(vmsg->payload) = %zu\n", 324 vmsg->request, vmsg->size, sizeof(vmsg->payload)); 325 goto fail; 326 } 327 328 if (vmsg->size) { 329 do { 330 rc = read(conn_fd, &vmsg->payload, vmsg->size); 331 } while (rc < 0 && (errno == EINTR || errno == EAGAIN)); 332 333 if (rc <= 0) { 334 vu_panic(dev, "Error while reading: %s", strerror(errno)); 335 goto fail; 336 } 337 338 assert(rc == vmsg->size); 339 } 340 341 return true; 342 343 fail: 344 vmsg_close_fds(vmsg); 345 346 return false; 347 } 348 349 static bool 350 vu_message_write(VuDev *dev, int conn_fd, VhostUserMsg *vmsg) 351 { 352 int rc; 353 uint8_t *p = (uint8_t *)vmsg; 354 char control[CMSG_SPACE(VHOST_MEMORY_BASELINE_NREGIONS * sizeof(int))] = {}; 355 struct iovec iov = { 356 .iov_base = (char *)vmsg, 357 .iov_len = VHOST_USER_HDR_SIZE, 358 }; 359 struct msghdr msg = { 360 .msg_iov = &iov, 361 .msg_iovlen = 1, 362 .msg_control = control, 363 }; 364 struct cmsghdr *cmsg; 365 366 memset(control, 0, sizeof(control)); 367 assert(vmsg->fd_num <= VHOST_MEMORY_BASELINE_NREGIONS); 368 if (vmsg->fd_num > 0) { 369 size_t fdsize = vmsg->fd_num * sizeof(int); 370 msg.msg_controllen = CMSG_SPACE(fdsize); 371 cmsg = CMSG_FIRSTHDR(&msg); 372 cmsg->cmsg_len = CMSG_LEN(fdsize); 373 cmsg->cmsg_level = SOL_SOCKET; 374 cmsg->cmsg_type = SCM_RIGHTS; 375 memcpy(CMSG_DATA(cmsg), vmsg->fds, fdsize); 376 } else { 377 msg.msg_controllen = 0; 378 } 379 380 do { 381 rc = sendmsg(conn_fd, &msg, 0); 382 } while (rc < 0 && (errno == EINTR || errno == EAGAIN)); 383 384 if (vmsg->size) { 385 do { 386 if (vmsg->data) { 387 rc = write(conn_fd, vmsg->data, vmsg->size); 388 } else { 389 rc = write(conn_fd, p + VHOST_USER_HDR_SIZE, vmsg->size); 390 } 391 } while (rc < 0 && (errno == EINTR || errno == EAGAIN)); 392 } 393 394 if (rc <= 0) { 395 vu_panic(dev, "Error while writing: %s", strerror(errno)); 396 return false; 397 } 398 399 return true; 400 } 401 402 static bool 403 vu_send_reply(VuDev *dev, int conn_fd, VhostUserMsg *vmsg) 404 { 405 /* Set the version in the flags when sending the reply */ 406 vmsg->flags &= ~VHOST_USER_VERSION_MASK; 407 vmsg->flags |= VHOST_USER_VERSION; 408 vmsg->flags |= VHOST_USER_REPLY_MASK; 409 410 return vu_message_write(dev, conn_fd, vmsg); 411 } 412 413 /* 414 * Processes a reply on the slave channel. 415 * Entered with slave_mutex held and releases it before exit. 416 * Returns true on success. 417 */ 418 static bool 419 vu_process_message_reply(VuDev *dev, const VhostUserMsg *vmsg) 420 { 421 VhostUserMsg msg_reply; 422 bool result = false; 423 424 if ((vmsg->flags & VHOST_USER_NEED_REPLY_MASK) == 0) { 425 result = true; 426 goto out; 427 } 428 429 if (!vu_message_read_default(dev, dev->slave_fd, &msg_reply)) { 430 goto out; 431 } 432 433 if (msg_reply.request != vmsg->request) { 434 DPRINT("Received unexpected msg type. Expected %d received %d", 435 vmsg->request, msg_reply.request); 436 goto out; 437 } 438 439 result = msg_reply.payload.u64 == 0; 440 441 out: 442 pthread_mutex_unlock(&dev->slave_mutex); 443 return result; 444 } 445 446 /* Kick the log_call_fd if required. */ 447 static void 448 vu_log_kick(VuDev *dev) 449 { 450 if (dev->log_call_fd != -1) { 451 DPRINT("Kicking the QEMU's log...\n"); 452 if (eventfd_write(dev->log_call_fd, 1) < 0) { 453 vu_panic(dev, "Error writing eventfd: %s", strerror(errno)); 454 } 455 } 456 } 457 458 static void 459 vu_log_page(uint8_t *log_table, uint64_t page) 460 { 461 DPRINT("Logged dirty guest page: %"PRId64"\n", page); 462 qatomic_or(&log_table[page / 8], 1 << (page % 8)); 463 } 464 465 static void 466 vu_log_write(VuDev *dev, uint64_t address, uint64_t length) 467 { 468 uint64_t page; 469 470 if (!(dev->features & (1ULL << VHOST_F_LOG_ALL)) || 471 !dev->log_table || !length) { 472 return; 473 } 474 475 assert(dev->log_size > ((address + length - 1) / VHOST_LOG_PAGE / 8)); 476 477 page = address / VHOST_LOG_PAGE; 478 while (page * VHOST_LOG_PAGE < address + length) { 479 vu_log_page(dev->log_table, page); 480 page += 1; 481 } 482 483 vu_log_kick(dev); 484 } 485 486 static void 487 vu_kick_cb(VuDev *dev, int condition, void *data) 488 { 489 int index = (intptr_t)data; 490 VuVirtq *vq = &dev->vq[index]; 491 int sock = vq->kick_fd; 492 eventfd_t kick_data; 493 ssize_t rc; 494 495 rc = eventfd_read(sock, &kick_data); 496 if (rc == -1) { 497 vu_panic(dev, "kick eventfd_read(): %s", strerror(errno)); 498 dev->remove_watch(dev, dev->vq[index].kick_fd); 499 } else { 500 DPRINT("Got kick_data: %016"PRIx64" handler:%p idx:%d\n", 501 kick_data, vq->handler, index); 502 if (vq->handler) { 503 vq->handler(dev, index); 504 } 505 } 506 } 507 508 static bool 509 vu_get_features_exec(VuDev *dev, VhostUserMsg *vmsg) 510 { 511 vmsg->payload.u64 = 512 /* 513 * The following VIRTIO feature bits are supported by our virtqueue 514 * implementation: 515 */ 516 1ULL << VIRTIO_F_NOTIFY_ON_EMPTY | 517 1ULL << VIRTIO_RING_F_INDIRECT_DESC | 518 1ULL << VIRTIO_RING_F_EVENT_IDX | 519 1ULL << VIRTIO_F_VERSION_1 | 520 521 /* vhost-user feature bits */ 522 1ULL << VHOST_F_LOG_ALL | 523 1ULL << VHOST_USER_F_PROTOCOL_FEATURES; 524 525 if (dev->iface->get_features) { 526 vmsg->payload.u64 |= dev->iface->get_features(dev); 527 } 528 529 vmsg->size = sizeof(vmsg->payload.u64); 530 vmsg->fd_num = 0; 531 532 DPRINT("Sending back to guest u64: 0x%016"PRIx64"\n", vmsg->payload.u64); 533 534 return true; 535 } 536 537 static void 538 vu_set_enable_all_rings(VuDev *dev, bool enabled) 539 { 540 uint16_t i; 541 542 for (i = 0; i < dev->max_queues; i++) { 543 dev->vq[i].enable = enabled; 544 } 545 } 546 547 static bool 548 vu_set_features_exec(VuDev *dev, VhostUserMsg *vmsg) 549 { 550 DPRINT("u64: 0x%016"PRIx64"\n", vmsg->payload.u64); 551 552 dev->features = vmsg->payload.u64; 553 if (!vu_has_feature(dev, VIRTIO_F_VERSION_1)) { 554 /* 555 * We only support devices conforming to VIRTIO 1.0 or 556 * later 557 */ 558 vu_panic(dev, "virtio legacy devices aren't supported by libvhost-user"); 559 return false; 560 } 561 562 if (!(dev->features & VHOST_USER_F_PROTOCOL_FEATURES)) { 563 vu_set_enable_all_rings(dev, true); 564 } 565 566 if (dev->iface->set_features) { 567 dev->iface->set_features(dev, dev->features); 568 } 569 570 return false; 571 } 572 573 static bool 574 vu_set_owner_exec(VuDev *dev, VhostUserMsg *vmsg) 575 { 576 return false; 577 } 578 579 static void 580 vu_close_log(VuDev *dev) 581 { 582 if (dev->log_table) { 583 if (munmap(dev->log_table, dev->log_size) != 0) { 584 perror("close log munmap() error"); 585 } 586 587 dev->log_table = NULL; 588 } 589 if (dev->log_call_fd != -1) { 590 close(dev->log_call_fd); 591 dev->log_call_fd = -1; 592 } 593 } 594 595 static bool 596 vu_reset_device_exec(VuDev *dev, VhostUserMsg *vmsg) 597 { 598 vu_set_enable_all_rings(dev, false); 599 600 return false; 601 } 602 603 static bool 604 map_ring(VuDev *dev, VuVirtq *vq) 605 { 606 vq->vring.desc = qva_to_va(dev, vq->vra.desc_user_addr); 607 vq->vring.used = qva_to_va(dev, vq->vra.used_user_addr); 608 vq->vring.avail = qva_to_va(dev, vq->vra.avail_user_addr); 609 610 DPRINT("Setting virtq addresses:\n"); 611 DPRINT(" vring_desc at %p\n", vq->vring.desc); 612 DPRINT(" vring_used at %p\n", vq->vring.used); 613 DPRINT(" vring_avail at %p\n", vq->vring.avail); 614 615 return !(vq->vring.desc && vq->vring.used && vq->vring.avail); 616 } 617 618 static bool 619 generate_faults(VuDev *dev) { 620 int i; 621 for (i = 0; i < dev->nregions; i++) { 622 VuDevRegion *dev_region = &dev->regions[i]; 623 int ret; 624 #ifdef UFFDIO_REGISTER 625 /* 626 * We should already have an open ufd. Mark each memory 627 * range as ufd. 628 * Discard any mapping we have here; note I can't use MADV_REMOVE 629 * or fallocate to make the hole since I don't want to lose 630 * data that's already arrived in the shared process. 631 * TODO: How to do hugepage 632 */ 633 ret = madvise((void *)(uintptr_t)dev_region->mmap_addr, 634 dev_region->size + dev_region->mmap_offset, 635 MADV_DONTNEED); 636 if (ret) { 637 fprintf(stderr, 638 "%s: Failed to madvise(DONTNEED) region %d: %s\n", 639 __func__, i, strerror(errno)); 640 } 641 /* 642 * Turn off transparent hugepages so we dont get lose wakeups 643 * in neighbouring pages. 644 * TODO: Turn this backon later. 645 */ 646 ret = madvise((void *)(uintptr_t)dev_region->mmap_addr, 647 dev_region->size + dev_region->mmap_offset, 648 MADV_NOHUGEPAGE); 649 if (ret) { 650 /* 651 * Note: This can happen legally on kernels that are configured 652 * without madvise'able hugepages 653 */ 654 fprintf(stderr, 655 "%s: Failed to madvise(NOHUGEPAGE) region %d: %s\n", 656 __func__, i, strerror(errno)); 657 } 658 struct uffdio_register reg_struct; 659 reg_struct.range.start = (uintptr_t)dev_region->mmap_addr; 660 reg_struct.range.len = dev_region->size + dev_region->mmap_offset; 661 reg_struct.mode = UFFDIO_REGISTER_MODE_MISSING; 662 663 if (ioctl(dev->postcopy_ufd, UFFDIO_REGISTER, ®_struct)) { 664 vu_panic(dev, "%s: Failed to userfault region %d " 665 "@%" PRIx64 " + size:%" PRIx64 " offset: %" PRIx64 666 ": (ufd=%d)%s\n", 667 __func__, i, 668 dev_region->mmap_addr, 669 dev_region->size, dev_region->mmap_offset, 670 dev->postcopy_ufd, strerror(errno)); 671 return false; 672 } 673 if (!(reg_struct.ioctls & ((__u64)1 << _UFFDIO_COPY))) { 674 vu_panic(dev, "%s Region (%d) doesn't support COPY", 675 __func__, i); 676 return false; 677 } 678 DPRINT("%s: region %d: Registered userfault for %" 679 PRIx64 " + %" PRIx64 "\n", __func__, i, 680 (uint64_t)reg_struct.range.start, 681 (uint64_t)reg_struct.range.len); 682 /* Now it's registered we can let the client at it */ 683 if (mprotect((void *)(uintptr_t)dev_region->mmap_addr, 684 dev_region->size + dev_region->mmap_offset, 685 PROT_READ | PROT_WRITE)) { 686 vu_panic(dev, "failed to mprotect region %d for postcopy (%s)", 687 i, strerror(errno)); 688 return false; 689 } 690 /* TODO: Stash 'zero' support flags somewhere */ 691 #endif 692 } 693 694 return true; 695 } 696 697 static bool 698 vu_add_mem_reg(VuDev *dev, VhostUserMsg *vmsg) { 699 int i; 700 bool track_ramblocks = dev->postcopy_listening; 701 VhostUserMemoryRegion m = vmsg->payload.memreg.region, *msg_region = &m; 702 VuDevRegion *dev_region = &dev->regions[dev->nregions]; 703 void *mmap_addr; 704 705 if (vmsg->fd_num != 1) { 706 vmsg_close_fds(vmsg); 707 vu_panic(dev, "VHOST_USER_ADD_MEM_REG received %d fds - only 1 fd " 708 "should be sent for this message type", vmsg->fd_num); 709 return false; 710 } 711 712 if (vmsg->size < VHOST_USER_MEM_REG_SIZE) { 713 close(vmsg->fds[0]); 714 vu_panic(dev, "VHOST_USER_ADD_MEM_REG requires a message size of at " 715 "least %zu bytes and only %d bytes were received", 716 VHOST_USER_MEM_REG_SIZE, vmsg->size); 717 return false; 718 } 719 720 if (dev->nregions == VHOST_USER_MAX_RAM_SLOTS) { 721 close(vmsg->fds[0]); 722 vu_panic(dev, "failing attempt to hot add memory via " 723 "VHOST_USER_ADD_MEM_REG message because the backend has " 724 "no free ram slots available"); 725 return false; 726 } 727 728 /* 729 * If we are in postcopy mode and we receive a u64 payload with a 0 value 730 * we know all the postcopy client bases have been received, and we 731 * should start generating faults. 732 */ 733 if (track_ramblocks && 734 vmsg->size == sizeof(vmsg->payload.u64) && 735 vmsg->payload.u64 == 0) { 736 (void)generate_faults(dev); 737 return false; 738 } 739 740 DPRINT("Adding region: %u\n", dev->nregions); 741 DPRINT(" guest_phys_addr: 0x%016"PRIx64"\n", 742 msg_region->guest_phys_addr); 743 DPRINT(" memory_size: 0x%016"PRIx64"\n", 744 msg_region->memory_size); 745 DPRINT(" userspace_addr 0x%016"PRIx64"\n", 746 msg_region->userspace_addr); 747 DPRINT(" mmap_offset 0x%016"PRIx64"\n", 748 msg_region->mmap_offset); 749 750 dev_region->gpa = msg_region->guest_phys_addr; 751 dev_region->size = msg_region->memory_size; 752 dev_region->qva = msg_region->userspace_addr; 753 dev_region->mmap_offset = msg_region->mmap_offset; 754 755 /* 756 * We don't use offset argument of mmap() since the 757 * mapped address has to be page aligned, and we use huge 758 * pages. 759 */ 760 if (track_ramblocks) { 761 /* 762 * In postcopy we're using PROT_NONE here to catch anyone 763 * accessing it before we userfault. 764 */ 765 mmap_addr = mmap(0, dev_region->size + dev_region->mmap_offset, 766 PROT_NONE, MAP_SHARED | MAP_NORESERVE, 767 vmsg->fds[0], 0); 768 } else { 769 mmap_addr = mmap(0, dev_region->size + dev_region->mmap_offset, 770 PROT_READ | PROT_WRITE, MAP_SHARED | MAP_NORESERVE, 771 vmsg->fds[0], 0); 772 } 773 774 if (mmap_addr == MAP_FAILED) { 775 vu_panic(dev, "region mmap error: %s", strerror(errno)); 776 } else { 777 dev_region->mmap_addr = (uint64_t)(uintptr_t)mmap_addr; 778 DPRINT(" mmap_addr: 0x%016"PRIx64"\n", 779 dev_region->mmap_addr); 780 } 781 782 close(vmsg->fds[0]); 783 784 if (track_ramblocks) { 785 /* 786 * Return the address to QEMU so that it can translate the ufd 787 * fault addresses back. 788 */ 789 msg_region->userspace_addr = (uintptr_t)(mmap_addr + 790 dev_region->mmap_offset); 791 792 /* Send the message back to qemu with the addresses filled in. */ 793 vmsg->fd_num = 0; 794 DPRINT("Successfully added new region in postcopy\n"); 795 dev->nregions++; 796 return true; 797 } else { 798 for (i = 0; i < dev->max_queues; i++) { 799 if (dev->vq[i].vring.desc) { 800 if (map_ring(dev, &dev->vq[i])) { 801 vu_panic(dev, "remapping queue %d for new memory region", 802 i); 803 } 804 } 805 } 806 807 DPRINT("Successfully added new region\n"); 808 dev->nregions++; 809 return false; 810 } 811 } 812 813 static inline bool reg_equal(VuDevRegion *vudev_reg, 814 VhostUserMemoryRegion *msg_reg) 815 { 816 if (vudev_reg->gpa == msg_reg->guest_phys_addr && 817 vudev_reg->qva == msg_reg->userspace_addr && 818 vudev_reg->size == msg_reg->memory_size) { 819 return true; 820 } 821 822 return false; 823 } 824 825 static bool 826 vu_rem_mem_reg(VuDev *dev, VhostUserMsg *vmsg) { 827 VhostUserMemoryRegion m = vmsg->payload.memreg.region, *msg_region = &m; 828 int i; 829 bool found = false; 830 831 if (vmsg->fd_num > 1) { 832 vmsg_close_fds(vmsg); 833 vu_panic(dev, "VHOST_USER_REM_MEM_REG received %d fds - at most 1 fd " 834 "should be sent for this message type", vmsg->fd_num); 835 return false; 836 } 837 838 if (vmsg->size < VHOST_USER_MEM_REG_SIZE) { 839 vmsg_close_fds(vmsg); 840 vu_panic(dev, "VHOST_USER_REM_MEM_REG requires a message size of at " 841 "least %zu bytes and only %d bytes were received", 842 VHOST_USER_MEM_REG_SIZE, vmsg->size); 843 return false; 844 } 845 846 DPRINT("Removing region:\n"); 847 DPRINT(" guest_phys_addr: 0x%016"PRIx64"\n", 848 msg_region->guest_phys_addr); 849 DPRINT(" memory_size: 0x%016"PRIx64"\n", 850 msg_region->memory_size); 851 DPRINT(" userspace_addr 0x%016"PRIx64"\n", 852 msg_region->userspace_addr); 853 DPRINT(" mmap_offset 0x%016"PRIx64"\n", 854 msg_region->mmap_offset); 855 856 for (i = 0; i < dev->nregions; i++) { 857 if (reg_equal(&dev->regions[i], msg_region)) { 858 VuDevRegion *r = &dev->regions[i]; 859 void *m = (void *) (uintptr_t) r->mmap_addr; 860 861 if (m) { 862 munmap(m, r->size + r->mmap_offset); 863 } 864 865 /* 866 * Shift all affected entries by 1 to close the hole at index i and 867 * zero out the last entry. 868 */ 869 memmove(dev->regions + i, dev->regions + i + 1, 870 sizeof(VuDevRegion) * (dev->nregions - i - 1)); 871 memset(dev->regions + dev->nregions - 1, 0, sizeof(VuDevRegion)); 872 DPRINT("Successfully removed a region\n"); 873 dev->nregions--; 874 i--; 875 876 found = true; 877 878 /* Continue the search for eventual duplicates. */ 879 } 880 } 881 882 if (!found) { 883 vu_panic(dev, "Specified region not found\n"); 884 } 885 886 vmsg_close_fds(vmsg); 887 888 return false; 889 } 890 891 static bool 892 vu_set_mem_table_exec_postcopy(VuDev *dev, VhostUserMsg *vmsg) 893 { 894 int i; 895 VhostUserMemory m = vmsg->payload.memory, *memory = &m; 896 dev->nregions = memory->nregions; 897 898 DPRINT("Nregions: %u\n", memory->nregions); 899 for (i = 0; i < dev->nregions; i++) { 900 void *mmap_addr; 901 VhostUserMemoryRegion *msg_region = &memory->regions[i]; 902 VuDevRegion *dev_region = &dev->regions[i]; 903 904 DPRINT("Region %d\n", i); 905 DPRINT(" guest_phys_addr: 0x%016"PRIx64"\n", 906 msg_region->guest_phys_addr); 907 DPRINT(" memory_size: 0x%016"PRIx64"\n", 908 msg_region->memory_size); 909 DPRINT(" userspace_addr 0x%016"PRIx64"\n", 910 msg_region->userspace_addr); 911 DPRINT(" mmap_offset 0x%016"PRIx64"\n", 912 msg_region->mmap_offset); 913 914 dev_region->gpa = msg_region->guest_phys_addr; 915 dev_region->size = msg_region->memory_size; 916 dev_region->qva = msg_region->userspace_addr; 917 dev_region->mmap_offset = msg_region->mmap_offset; 918 919 /* We don't use offset argument of mmap() since the 920 * mapped address has to be page aligned, and we use huge 921 * pages. 922 * In postcopy we're using PROT_NONE here to catch anyone 923 * accessing it before we userfault 924 */ 925 mmap_addr = mmap(0, dev_region->size + dev_region->mmap_offset, 926 PROT_NONE, MAP_SHARED | MAP_NORESERVE, 927 vmsg->fds[i], 0); 928 929 if (mmap_addr == MAP_FAILED) { 930 vu_panic(dev, "region mmap error: %s", strerror(errno)); 931 } else { 932 dev_region->mmap_addr = (uint64_t)(uintptr_t)mmap_addr; 933 DPRINT(" mmap_addr: 0x%016"PRIx64"\n", 934 dev_region->mmap_addr); 935 } 936 937 /* Return the address to QEMU so that it can translate the ufd 938 * fault addresses back. 939 */ 940 msg_region->userspace_addr = (uintptr_t)(mmap_addr + 941 dev_region->mmap_offset); 942 close(vmsg->fds[i]); 943 } 944 945 /* Send the message back to qemu with the addresses filled in */ 946 vmsg->fd_num = 0; 947 if (!vu_send_reply(dev, dev->sock, vmsg)) { 948 vu_panic(dev, "failed to respond to set-mem-table for postcopy"); 949 return false; 950 } 951 952 /* Wait for QEMU to confirm that it's registered the handler for the 953 * faults. 954 */ 955 if (!dev->read_msg(dev, dev->sock, vmsg) || 956 vmsg->size != sizeof(vmsg->payload.u64) || 957 vmsg->payload.u64 != 0) { 958 vu_panic(dev, "failed to receive valid ack for postcopy set-mem-table"); 959 return false; 960 } 961 962 /* OK, now we can go and register the memory and generate faults */ 963 (void)generate_faults(dev); 964 965 return false; 966 } 967 968 static bool 969 vu_set_mem_table_exec(VuDev *dev, VhostUserMsg *vmsg) 970 { 971 int i; 972 VhostUserMemory m = vmsg->payload.memory, *memory = &m; 973 974 for (i = 0; i < dev->nregions; i++) { 975 VuDevRegion *r = &dev->regions[i]; 976 void *m = (void *) (uintptr_t) r->mmap_addr; 977 978 if (m) { 979 munmap(m, r->size + r->mmap_offset); 980 } 981 } 982 dev->nregions = memory->nregions; 983 984 if (dev->postcopy_listening) { 985 return vu_set_mem_table_exec_postcopy(dev, vmsg); 986 } 987 988 DPRINT("Nregions: %u\n", memory->nregions); 989 for (i = 0; i < dev->nregions; i++) { 990 void *mmap_addr; 991 VhostUserMemoryRegion *msg_region = &memory->regions[i]; 992 VuDevRegion *dev_region = &dev->regions[i]; 993 994 DPRINT("Region %d\n", i); 995 DPRINT(" guest_phys_addr: 0x%016"PRIx64"\n", 996 msg_region->guest_phys_addr); 997 DPRINT(" memory_size: 0x%016"PRIx64"\n", 998 msg_region->memory_size); 999 DPRINT(" userspace_addr 0x%016"PRIx64"\n", 1000 msg_region->userspace_addr); 1001 DPRINT(" mmap_offset 0x%016"PRIx64"\n", 1002 msg_region->mmap_offset); 1003 1004 dev_region->gpa = msg_region->guest_phys_addr; 1005 dev_region->size = msg_region->memory_size; 1006 dev_region->qva = msg_region->userspace_addr; 1007 dev_region->mmap_offset = msg_region->mmap_offset; 1008 1009 /* We don't use offset argument of mmap() since the 1010 * mapped address has to be page aligned, and we use huge 1011 * pages. */ 1012 mmap_addr = mmap(0, dev_region->size + dev_region->mmap_offset, 1013 PROT_READ | PROT_WRITE, MAP_SHARED | MAP_NORESERVE, 1014 vmsg->fds[i], 0); 1015 1016 if (mmap_addr == MAP_FAILED) { 1017 vu_panic(dev, "region mmap error: %s", strerror(errno)); 1018 } else { 1019 dev_region->mmap_addr = (uint64_t)(uintptr_t)mmap_addr; 1020 DPRINT(" mmap_addr: 0x%016"PRIx64"\n", 1021 dev_region->mmap_addr); 1022 } 1023 1024 close(vmsg->fds[i]); 1025 } 1026 1027 for (i = 0; i < dev->max_queues; i++) { 1028 if (dev->vq[i].vring.desc) { 1029 if (map_ring(dev, &dev->vq[i])) { 1030 vu_panic(dev, "remapping queue %d during setmemtable", i); 1031 } 1032 } 1033 } 1034 1035 return false; 1036 } 1037 1038 static bool 1039 vu_set_log_base_exec(VuDev *dev, VhostUserMsg *vmsg) 1040 { 1041 int fd; 1042 uint64_t log_mmap_size, log_mmap_offset; 1043 void *rc; 1044 1045 if (vmsg->fd_num != 1 || 1046 vmsg->size != sizeof(vmsg->payload.log)) { 1047 vu_panic(dev, "Invalid log_base message"); 1048 return true; 1049 } 1050 1051 fd = vmsg->fds[0]; 1052 log_mmap_offset = vmsg->payload.log.mmap_offset; 1053 log_mmap_size = vmsg->payload.log.mmap_size; 1054 DPRINT("Log mmap_offset: %"PRId64"\n", log_mmap_offset); 1055 DPRINT("Log mmap_size: %"PRId64"\n", log_mmap_size); 1056 1057 rc = mmap(0, log_mmap_size, PROT_READ | PROT_WRITE, MAP_SHARED, fd, 1058 log_mmap_offset); 1059 close(fd); 1060 if (rc == MAP_FAILED) { 1061 perror("log mmap error"); 1062 } 1063 1064 if (dev->log_table) { 1065 munmap(dev->log_table, dev->log_size); 1066 } 1067 dev->log_table = rc; 1068 dev->log_size = log_mmap_size; 1069 1070 vmsg->size = sizeof(vmsg->payload.u64); 1071 vmsg->fd_num = 0; 1072 1073 return true; 1074 } 1075 1076 static bool 1077 vu_set_log_fd_exec(VuDev *dev, VhostUserMsg *vmsg) 1078 { 1079 if (vmsg->fd_num != 1) { 1080 vu_panic(dev, "Invalid log_fd message"); 1081 return false; 1082 } 1083 1084 if (dev->log_call_fd != -1) { 1085 close(dev->log_call_fd); 1086 } 1087 dev->log_call_fd = vmsg->fds[0]; 1088 DPRINT("Got log_call_fd: %d\n", vmsg->fds[0]); 1089 1090 return false; 1091 } 1092 1093 static bool 1094 vu_set_vring_num_exec(VuDev *dev, VhostUserMsg *vmsg) 1095 { 1096 unsigned int index = vmsg->payload.state.index; 1097 unsigned int num = vmsg->payload.state.num; 1098 1099 DPRINT("State.index: %u\n", index); 1100 DPRINT("State.num: %u\n", num); 1101 dev->vq[index].vring.num = num; 1102 1103 return false; 1104 } 1105 1106 static bool 1107 vu_set_vring_addr_exec(VuDev *dev, VhostUserMsg *vmsg) 1108 { 1109 struct vhost_vring_addr addr = vmsg->payload.addr, *vra = &addr; 1110 unsigned int index = vra->index; 1111 VuVirtq *vq = &dev->vq[index]; 1112 1113 DPRINT("vhost_vring_addr:\n"); 1114 DPRINT(" index: %d\n", vra->index); 1115 DPRINT(" flags: %d\n", vra->flags); 1116 DPRINT(" desc_user_addr: 0x%016" PRIx64 "\n", (uint64_t)vra->desc_user_addr); 1117 DPRINT(" used_user_addr: 0x%016" PRIx64 "\n", (uint64_t)vra->used_user_addr); 1118 DPRINT(" avail_user_addr: 0x%016" PRIx64 "\n", (uint64_t)vra->avail_user_addr); 1119 DPRINT(" log_guest_addr: 0x%016" PRIx64 "\n", (uint64_t)vra->log_guest_addr); 1120 1121 vq->vra = *vra; 1122 vq->vring.flags = vra->flags; 1123 vq->vring.log_guest_addr = vra->log_guest_addr; 1124 1125 1126 if (map_ring(dev, vq)) { 1127 vu_panic(dev, "Invalid vring_addr message"); 1128 return false; 1129 } 1130 1131 vq->used_idx = le16toh(vq->vring.used->idx); 1132 1133 if (vq->last_avail_idx != vq->used_idx) { 1134 bool resume = dev->iface->queue_is_processed_in_order && 1135 dev->iface->queue_is_processed_in_order(dev, index); 1136 1137 DPRINT("Last avail index != used index: %u != %u%s\n", 1138 vq->last_avail_idx, vq->used_idx, 1139 resume ? ", resuming" : ""); 1140 1141 if (resume) { 1142 vq->shadow_avail_idx = vq->last_avail_idx = vq->used_idx; 1143 } 1144 } 1145 1146 return false; 1147 } 1148 1149 static bool 1150 vu_set_vring_base_exec(VuDev *dev, VhostUserMsg *vmsg) 1151 { 1152 unsigned int index = vmsg->payload.state.index; 1153 unsigned int num = vmsg->payload.state.num; 1154 1155 DPRINT("State.index: %u\n", index); 1156 DPRINT("State.num: %u\n", num); 1157 dev->vq[index].shadow_avail_idx = dev->vq[index].last_avail_idx = num; 1158 1159 return false; 1160 } 1161 1162 static bool 1163 vu_get_vring_base_exec(VuDev *dev, VhostUserMsg *vmsg) 1164 { 1165 unsigned int index = vmsg->payload.state.index; 1166 1167 DPRINT("State.index: %u\n", index); 1168 vmsg->payload.state.num = dev->vq[index].last_avail_idx; 1169 vmsg->size = sizeof(vmsg->payload.state); 1170 1171 dev->vq[index].started = false; 1172 if (dev->iface->queue_set_started) { 1173 dev->iface->queue_set_started(dev, index, false); 1174 } 1175 1176 if (dev->vq[index].call_fd != -1) { 1177 close(dev->vq[index].call_fd); 1178 dev->vq[index].call_fd = -1; 1179 } 1180 if (dev->vq[index].kick_fd != -1) { 1181 dev->remove_watch(dev, dev->vq[index].kick_fd); 1182 close(dev->vq[index].kick_fd); 1183 dev->vq[index].kick_fd = -1; 1184 } 1185 1186 return true; 1187 } 1188 1189 static bool 1190 vu_check_queue_msg_file(VuDev *dev, VhostUserMsg *vmsg) 1191 { 1192 int index = vmsg->payload.u64 & VHOST_USER_VRING_IDX_MASK; 1193 bool nofd = vmsg->payload.u64 & VHOST_USER_VRING_NOFD_MASK; 1194 1195 if (index >= dev->max_queues) { 1196 vmsg_close_fds(vmsg); 1197 vu_panic(dev, "Invalid queue index: %u", index); 1198 return false; 1199 } 1200 1201 if (nofd) { 1202 vmsg_close_fds(vmsg); 1203 return true; 1204 } 1205 1206 if (vmsg->fd_num != 1) { 1207 vmsg_close_fds(vmsg); 1208 vu_panic(dev, "Invalid fds in request: %d", vmsg->request); 1209 return false; 1210 } 1211 1212 return true; 1213 } 1214 1215 static int 1216 inflight_desc_compare(const void *a, const void *b) 1217 { 1218 VuVirtqInflightDesc *desc0 = (VuVirtqInflightDesc *)a, 1219 *desc1 = (VuVirtqInflightDesc *)b; 1220 1221 if (desc1->counter > desc0->counter && 1222 (desc1->counter - desc0->counter) < VIRTQUEUE_MAX_SIZE * 2) { 1223 return 1; 1224 } 1225 1226 return -1; 1227 } 1228 1229 static int 1230 vu_check_queue_inflights(VuDev *dev, VuVirtq *vq) 1231 { 1232 int i = 0; 1233 1234 if (!vu_has_protocol_feature(dev, VHOST_USER_PROTOCOL_F_INFLIGHT_SHMFD)) { 1235 return 0; 1236 } 1237 1238 if (unlikely(!vq->inflight)) { 1239 return -1; 1240 } 1241 1242 if (unlikely(!vq->inflight->version)) { 1243 /* initialize the buffer */ 1244 vq->inflight->version = INFLIGHT_VERSION; 1245 return 0; 1246 } 1247 1248 vq->used_idx = le16toh(vq->vring.used->idx); 1249 vq->resubmit_num = 0; 1250 vq->resubmit_list = NULL; 1251 vq->counter = 0; 1252 1253 if (unlikely(vq->inflight->used_idx != vq->used_idx)) { 1254 vq->inflight->desc[vq->inflight->last_batch_head].inflight = 0; 1255 1256 barrier(); 1257 1258 vq->inflight->used_idx = vq->used_idx; 1259 } 1260 1261 for (i = 0; i < vq->inflight->desc_num; i++) { 1262 if (vq->inflight->desc[i].inflight == 1) { 1263 vq->inuse++; 1264 } 1265 } 1266 1267 vq->shadow_avail_idx = vq->last_avail_idx = vq->inuse + vq->used_idx; 1268 1269 if (vq->inuse) { 1270 vq->resubmit_list = calloc(vq->inuse, sizeof(VuVirtqInflightDesc)); 1271 if (!vq->resubmit_list) { 1272 return -1; 1273 } 1274 1275 for (i = 0; i < vq->inflight->desc_num; i++) { 1276 if (vq->inflight->desc[i].inflight) { 1277 vq->resubmit_list[vq->resubmit_num].index = i; 1278 vq->resubmit_list[vq->resubmit_num].counter = 1279 vq->inflight->desc[i].counter; 1280 vq->resubmit_num++; 1281 } 1282 } 1283 1284 if (vq->resubmit_num > 1) { 1285 qsort(vq->resubmit_list, vq->resubmit_num, 1286 sizeof(VuVirtqInflightDesc), inflight_desc_compare); 1287 } 1288 vq->counter = vq->resubmit_list[0].counter + 1; 1289 } 1290 1291 /* in case of I/O hang after reconnecting */ 1292 if (eventfd_write(vq->kick_fd, 1)) { 1293 return -1; 1294 } 1295 1296 return 0; 1297 } 1298 1299 static bool 1300 vu_set_vring_kick_exec(VuDev *dev, VhostUserMsg *vmsg) 1301 { 1302 int index = vmsg->payload.u64 & VHOST_USER_VRING_IDX_MASK; 1303 bool nofd = vmsg->payload.u64 & VHOST_USER_VRING_NOFD_MASK; 1304 1305 DPRINT("u64: 0x%016"PRIx64"\n", vmsg->payload.u64); 1306 1307 if (!vu_check_queue_msg_file(dev, vmsg)) { 1308 return false; 1309 } 1310 1311 if (dev->vq[index].kick_fd != -1) { 1312 dev->remove_watch(dev, dev->vq[index].kick_fd); 1313 close(dev->vq[index].kick_fd); 1314 dev->vq[index].kick_fd = -1; 1315 } 1316 1317 dev->vq[index].kick_fd = nofd ? -1 : vmsg->fds[0]; 1318 DPRINT("Got kick_fd: %d for vq: %d\n", dev->vq[index].kick_fd, index); 1319 1320 dev->vq[index].started = true; 1321 if (dev->iface->queue_set_started) { 1322 dev->iface->queue_set_started(dev, index, true); 1323 } 1324 1325 if (dev->vq[index].kick_fd != -1 && dev->vq[index].handler) { 1326 dev->set_watch(dev, dev->vq[index].kick_fd, VU_WATCH_IN, 1327 vu_kick_cb, (void *)(long)index); 1328 1329 DPRINT("Waiting for kicks on fd: %d for vq: %d\n", 1330 dev->vq[index].kick_fd, index); 1331 } 1332 1333 if (vu_check_queue_inflights(dev, &dev->vq[index])) { 1334 vu_panic(dev, "Failed to check inflights for vq: %d\n", index); 1335 } 1336 1337 return false; 1338 } 1339 1340 void vu_set_queue_handler(VuDev *dev, VuVirtq *vq, 1341 vu_queue_handler_cb handler) 1342 { 1343 int qidx = vq - dev->vq; 1344 1345 vq->handler = handler; 1346 if (vq->kick_fd >= 0) { 1347 if (handler) { 1348 dev->set_watch(dev, vq->kick_fd, VU_WATCH_IN, 1349 vu_kick_cb, (void *)(long)qidx); 1350 } else { 1351 dev->remove_watch(dev, vq->kick_fd); 1352 } 1353 } 1354 } 1355 1356 bool vu_set_queue_host_notifier(VuDev *dev, VuVirtq *vq, int fd, 1357 int size, int offset) 1358 { 1359 int qidx = vq - dev->vq; 1360 int fd_num = 0; 1361 VhostUserMsg vmsg = { 1362 .request = VHOST_USER_SLAVE_VRING_HOST_NOTIFIER_MSG, 1363 .flags = VHOST_USER_VERSION | VHOST_USER_NEED_REPLY_MASK, 1364 .size = sizeof(vmsg.payload.area), 1365 .payload.area = { 1366 .u64 = qidx & VHOST_USER_VRING_IDX_MASK, 1367 .size = size, 1368 .offset = offset, 1369 }, 1370 }; 1371 1372 if (fd == -1) { 1373 vmsg.payload.area.u64 |= VHOST_USER_VRING_NOFD_MASK; 1374 } else { 1375 vmsg.fds[fd_num++] = fd; 1376 } 1377 1378 vmsg.fd_num = fd_num; 1379 1380 if (!vu_has_protocol_feature(dev, VHOST_USER_PROTOCOL_F_SLAVE_SEND_FD)) { 1381 return false; 1382 } 1383 1384 pthread_mutex_lock(&dev->slave_mutex); 1385 if (!vu_message_write(dev, dev->slave_fd, &vmsg)) { 1386 pthread_mutex_unlock(&dev->slave_mutex); 1387 return false; 1388 } 1389 1390 /* Also unlocks the slave_mutex */ 1391 return vu_process_message_reply(dev, &vmsg); 1392 } 1393 1394 static bool 1395 vu_set_vring_call_exec(VuDev *dev, VhostUserMsg *vmsg) 1396 { 1397 int index = vmsg->payload.u64 & VHOST_USER_VRING_IDX_MASK; 1398 bool nofd = vmsg->payload.u64 & VHOST_USER_VRING_NOFD_MASK; 1399 1400 DPRINT("u64: 0x%016"PRIx64"\n", vmsg->payload.u64); 1401 1402 if (!vu_check_queue_msg_file(dev, vmsg)) { 1403 return false; 1404 } 1405 1406 if (dev->vq[index].call_fd != -1) { 1407 close(dev->vq[index].call_fd); 1408 dev->vq[index].call_fd = -1; 1409 } 1410 1411 dev->vq[index].call_fd = nofd ? -1 : vmsg->fds[0]; 1412 1413 /* in case of I/O hang after reconnecting */ 1414 if (dev->vq[index].call_fd != -1 && eventfd_write(vmsg->fds[0], 1)) { 1415 return -1; 1416 } 1417 1418 DPRINT("Got call_fd: %d for vq: %d\n", dev->vq[index].call_fd, index); 1419 1420 return false; 1421 } 1422 1423 static bool 1424 vu_set_vring_err_exec(VuDev *dev, VhostUserMsg *vmsg) 1425 { 1426 int index = vmsg->payload.u64 & VHOST_USER_VRING_IDX_MASK; 1427 bool nofd = vmsg->payload.u64 & VHOST_USER_VRING_NOFD_MASK; 1428 1429 DPRINT("u64: 0x%016"PRIx64"\n", vmsg->payload.u64); 1430 1431 if (!vu_check_queue_msg_file(dev, vmsg)) { 1432 return false; 1433 } 1434 1435 if (dev->vq[index].err_fd != -1) { 1436 close(dev->vq[index].err_fd); 1437 dev->vq[index].err_fd = -1; 1438 } 1439 1440 dev->vq[index].err_fd = nofd ? -1 : vmsg->fds[0]; 1441 1442 return false; 1443 } 1444 1445 static bool 1446 vu_get_protocol_features_exec(VuDev *dev, VhostUserMsg *vmsg) 1447 { 1448 /* 1449 * Note that we support, but intentionally do not set, 1450 * VHOST_USER_PROTOCOL_F_INBAND_NOTIFICATIONS. This means that 1451 * a device implementation can return it in its callback 1452 * (get_protocol_features) if it wants to use this for 1453 * simulation, but it is otherwise not desirable (if even 1454 * implemented by the master.) 1455 */ 1456 uint64_t features = 1ULL << VHOST_USER_PROTOCOL_F_MQ | 1457 1ULL << VHOST_USER_PROTOCOL_F_LOG_SHMFD | 1458 1ULL << VHOST_USER_PROTOCOL_F_SLAVE_REQ | 1459 1ULL << VHOST_USER_PROTOCOL_F_HOST_NOTIFIER | 1460 1ULL << VHOST_USER_PROTOCOL_F_SLAVE_SEND_FD | 1461 1ULL << VHOST_USER_PROTOCOL_F_REPLY_ACK | 1462 1ULL << VHOST_USER_PROTOCOL_F_CONFIGURE_MEM_SLOTS; 1463 1464 if (have_userfault()) { 1465 features |= 1ULL << VHOST_USER_PROTOCOL_F_PAGEFAULT; 1466 } 1467 1468 if (dev->iface->get_config && dev->iface->set_config) { 1469 features |= 1ULL << VHOST_USER_PROTOCOL_F_CONFIG; 1470 } 1471 1472 if (dev->iface->get_protocol_features) { 1473 features |= dev->iface->get_protocol_features(dev); 1474 } 1475 1476 vmsg_set_reply_u64(vmsg, features); 1477 return true; 1478 } 1479 1480 static bool 1481 vu_set_protocol_features_exec(VuDev *dev, VhostUserMsg *vmsg) 1482 { 1483 uint64_t features = vmsg->payload.u64; 1484 1485 DPRINT("u64: 0x%016"PRIx64"\n", features); 1486 1487 dev->protocol_features = vmsg->payload.u64; 1488 1489 if (vu_has_protocol_feature(dev, 1490 VHOST_USER_PROTOCOL_F_INBAND_NOTIFICATIONS) && 1491 (!vu_has_protocol_feature(dev, VHOST_USER_PROTOCOL_F_SLAVE_REQ) || 1492 !vu_has_protocol_feature(dev, VHOST_USER_PROTOCOL_F_REPLY_ACK))) { 1493 /* 1494 * The use case for using messages for kick/call is simulation, to make 1495 * the kick and call synchronous. To actually get that behaviour, both 1496 * of the other features are required. 1497 * Theoretically, one could use only kick messages, or do them without 1498 * having F_REPLY_ACK, but too many (possibly pending) messages on the 1499 * socket will eventually cause the master to hang, to avoid this in 1500 * scenarios where not desired enforce that the settings are in a way 1501 * that actually enables the simulation case. 1502 */ 1503 vu_panic(dev, 1504 "F_IN_BAND_NOTIFICATIONS requires F_SLAVE_REQ && F_REPLY_ACK"); 1505 return false; 1506 } 1507 1508 if (dev->iface->set_protocol_features) { 1509 dev->iface->set_protocol_features(dev, features); 1510 } 1511 1512 return false; 1513 } 1514 1515 static bool 1516 vu_get_queue_num_exec(VuDev *dev, VhostUserMsg *vmsg) 1517 { 1518 vmsg_set_reply_u64(vmsg, dev->max_queues); 1519 return true; 1520 } 1521 1522 static bool 1523 vu_set_vring_enable_exec(VuDev *dev, VhostUserMsg *vmsg) 1524 { 1525 unsigned int index = vmsg->payload.state.index; 1526 unsigned int enable = vmsg->payload.state.num; 1527 1528 DPRINT("State.index: %u\n", index); 1529 DPRINT("State.enable: %u\n", enable); 1530 1531 if (index >= dev->max_queues) { 1532 vu_panic(dev, "Invalid vring_enable index: %u", index); 1533 return false; 1534 } 1535 1536 dev->vq[index].enable = enable; 1537 return false; 1538 } 1539 1540 static bool 1541 vu_set_slave_req_fd(VuDev *dev, VhostUserMsg *vmsg) 1542 { 1543 if (vmsg->fd_num != 1) { 1544 vu_panic(dev, "Invalid slave_req_fd message (%d fd's)", vmsg->fd_num); 1545 return false; 1546 } 1547 1548 if (dev->slave_fd != -1) { 1549 close(dev->slave_fd); 1550 } 1551 dev->slave_fd = vmsg->fds[0]; 1552 DPRINT("Got slave_fd: %d\n", vmsg->fds[0]); 1553 1554 return false; 1555 } 1556 1557 static bool 1558 vu_get_config(VuDev *dev, VhostUserMsg *vmsg) 1559 { 1560 int ret = -1; 1561 1562 if (dev->iface->get_config) { 1563 ret = dev->iface->get_config(dev, vmsg->payload.config.region, 1564 vmsg->payload.config.size); 1565 } 1566 1567 if (ret) { 1568 /* resize to zero to indicate an error to master */ 1569 vmsg->size = 0; 1570 } 1571 1572 return true; 1573 } 1574 1575 static bool 1576 vu_set_config(VuDev *dev, VhostUserMsg *vmsg) 1577 { 1578 int ret = -1; 1579 1580 if (dev->iface->set_config) { 1581 ret = dev->iface->set_config(dev, vmsg->payload.config.region, 1582 vmsg->payload.config.offset, 1583 vmsg->payload.config.size, 1584 vmsg->payload.config.flags); 1585 if (ret) { 1586 vu_panic(dev, "Set virtio configuration space failed"); 1587 } 1588 } 1589 1590 return false; 1591 } 1592 1593 static bool 1594 vu_set_postcopy_advise(VuDev *dev, VhostUserMsg *vmsg) 1595 { 1596 dev->postcopy_ufd = -1; 1597 #ifdef UFFDIO_API 1598 struct uffdio_api api_struct; 1599 1600 dev->postcopy_ufd = syscall(__NR_userfaultfd, O_CLOEXEC | O_NONBLOCK); 1601 vmsg->size = 0; 1602 #endif 1603 1604 if (dev->postcopy_ufd == -1) { 1605 vu_panic(dev, "Userfaultfd not available: %s", strerror(errno)); 1606 goto out; 1607 } 1608 1609 #ifdef UFFDIO_API 1610 api_struct.api = UFFD_API; 1611 api_struct.features = 0; 1612 if (ioctl(dev->postcopy_ufd, UFFDIO_API, &api_struct)) { 1613 vu_panic(dev, "Failed UFFDIO_API: %s", strerror(errno)); 1614 close(dev->postcopy_ufd); 1615 dev->postcopy_ufd = -1; 1616 goto out; 1617 } 1618 /* TODO: Stash feature flags somewhere */ 1619 #endif 1620 1621 out: 1622 /* Return a ufd to the QEMU */ 1623 vmsg->fd_num = 1; 1624 vmsg->fds[0] = dev->postcopy_ufd; 1625 return true; /* = send a reply */ 1626 } 1627 1628 static bool 1629 vu_set_postcopy_listen(VuDev *dev, VhostUserMsg *vmsg) 1630 { 1631 if (dev->nregions) { 1632 vu_panic(dev, "Regions already registered at postcopy-listen"); 1633 vmsg_set_reply_u64(vmsg, -1); 1634 return true; 1635 } 1636 dev->postcopy_listening = true; 1637 1638 vmsg_set_reply_u64(vmsg, 0); 1639 return true; 1640 } 1641 1642 static bool 1643 vu_set_postcopy_end(VuDev *dev, VhostUserMsg *vmsg) 1644 { 1645 DPRINT("%s: Entry\n", __func__); 1646 dev->postcopy_listening = false; 1647 if (dev->postcopy_ufd > 0) { 1648 close(dev->postcopy_ufd); 1649 dev->postcopy_ufd = -1; 1650 DPRINT("%s: Done close\n", __func__); 1651 } 1652 1653 vmsg_set_reply_u64(vmsg, 0); 1654 DPRINT("%s: exit\n", __func__); 1655 return true; 1656 } 1657 1658 static inline uint64_t 1659 vu_inflight_queue_size(uint16_t queue_size) 1660 { 1661 return ALIGN_UP(sizeof(VuDescStateSplit) * queue_size + 1662 sizeof(uint16_t), INFLIGHT_ALIGNMENT); 1663 } 1664 1665 #ifdef MFD_ALLOW_SEALING 1666 static void * 1667 memfd_alloc(const char *name, size_t size, unsigned int flags, int *fd) 1668 { 1669 void *ptr; 1670 int ret; 1671 1672 *fd = memfd_create(name, MFD_ALLOW_SEALING); 1673 if (*fd < 0) { 1674 return NULL; 1675 } 1676 1677 ret = ftruncate(*fd, size); 1678 if (ret < 0) { 1679 close(*fd); 1680 return NULL; 1681 } 1682 1683 ret = fcntl(*fd, F_ADD_SEALS, flags); 1684 if (ret < 0) { 1685 close(*fd); 1686 return NULL; 1687 } 1688 1689 ptr = mmap(0, size, PROT_READ | PROT_WRITE, MAP_SHARED, *fd, 0); 1690 if (ptr == MAP_FAILED) { 1691 close(*fd); 1692 return NULL; 1693 } 1694 1695 return ptr; 1696 } 1697 #endif 1698 1699 static bool 1700 vu_get_inflight_fd(VuDev *dev, VhostUserMsg *vmsg) 1701 { 1702 int fd = -1; 1703 void *addr = NULL; 1704 uint64_t mmap_size; 1705 uint16_t num_queues, queue_size; 1706 1707 if (vmsg->size != sizeof(vmsg->payload.inflight)) { 1708 vu_panic(dev, "Invalid get_inflight_fd message:%d", vmsg->size); 1709 vmsg->payload.inflight.mmap_size = 0; 1710 return true; 1711 } 1712 1713 num_queues = vmsg->payload.inflight.num_queues; 1714 queue_size = vmsg->payload.inflight.queue_size; 1715 1716 DPRINT("set_inflight_fd num_queues: %"PRId16"\n", num_queues); 1717 DPRINT("set_inflight_fd queue_size: %"PRId16"\n", queue_size); 1718 1719 mmap_size = vu_inflight_queue_size(queue_size) * num_queues; 1720 1721 #ifdef MFD_ALLOW_SEALING 1722 addr = memfd_alloc("vhost-inflight", mmap_size, 1723 F_SEAL_GROW | F_SEAL_SHRINK | F_SEAL_SEAL, 1724 &fd); 1725 #else 1726 vu_panic(dev, "Not implemented: memfd support is missing"); 1727 #endif 1728 1729 if (!addr) { 1730 vu_panic(dev, "Failed to alloc vhost inflight area"); 1731 vmsg->payload.inflight.mmap_size = 0; 1732 return true; 1733 } 1734 1735 memset(addr, 0, mmap_size); 1736 1737 dev->inflight_info.addr = addr; 1738 dev->inflight_info.size = vmsg->payload.inflight.mmap_size = mmap_size; 1739 dev->inflight_info.fd = vmsg->fds[0] = fd; 1740 vmsg->fd_num = 1; 1741 vmsg->payload.inflight.mmap_offset = 0; 1742 1743 DPRINT("send inflight mmap_size: %"PRId64"\n", 1744 vmsg->payload.inflight.mmap_size); 1745 DPRINT("send inflight mmap offset: %"PRId64"\n", 1746 vmsg->payload.inflight.mmap_offset); 1747 1748 return true; 1749 } 1750 1751 static bool 1752 vu_set_inflight_fd(VuDev *dev, VhostUserMsg *vmsg) 1753 { 1754 int fd, i; 1755 uint64_t mmap_size, mmap_offset; 1756 uint16_t num_queues, queue_size; 1757 void *rc; 1758 1759 if (vmsg->fd_num != 1 || 1760 vmsg->size != sizeof(vmsg->payload.inflight)) { 1761 vu_panic(dev, "Invalid set_inflight_fd message size:%d fds:%d", 1762 vmsg->size, vmsg->fd_num); 1763 return false; 1764 } 1765 1766 fd = vmsg->fds[0]; 1767 mmap_size = vmsg->payload.inflight.mmap_size; 1768 mmap_offset = vmsg->payload.inflight.mmap_offset; 1769 num_queues = vmsg->payload.inflight.num_queues; 1770 queue_size = vmsg->payload.inflight.queue_size; 1771 1772 DPRINT("set_inflight_fd mmap_size: %"PRId64"\n", mmap_size); 1773 DPRINT("set_inflight_fd mmap_offset: %"PRId64"\n", mmap_offset); 1774 DPRINT("set_inflight_fd num_queues: %"PRId16"\n", num_queues); 1775 DPRINT("set_inflight_fd queue_size: %"PRId16"\n", queue_size); 1776 1777 rc = mmap(0, mmap_size, PROT_READ | PROT_WRITE, MAP_SHARED, 1778 fd, mmap_offset); 1779 1780 if (rc == MAP_FAILED) { 1781 vu_panic(dev, "set_inflight_fd mmap error: %s", strerror(errno)); 1782 return false; 1783 } 1784 1785 if (dev->inflight_info.fd) { 1786 close(dev->inflight_info.fd); 1787 } 1788 1789 if (dev->inflight_info.addr) { 1790 munmap(dev->inflight_info.addr, dev->inflight_info.size); 1791 } 1792 1793 dev->inflight_info.fd = fd; 1794 dev->inflight_info.addr = rc; 1795 dev->inflight_info.size = mmap_size; 1796 1797 for (i = 0; i < num_queues; i++) { 1798 dev->vq[i].inflight = (VuVirtqInflight *)rc; 1799 dev->vq[i].inflight->desc_num = queue_size; 1800 rc = (void *)((char *)rc + vu_inflight_queue_size(queue_size)); 1801 } 1802 1803 return false; 1804 } 1805 1806 static bool 1807 vu_handle_vring_kick(VuDev *dev, VhostUserMsg *vmsg) 1808 { 1809 unsigned int index = vmsg->payload.state.index; 1810 1811 if (index >= dev->max_queues) { 1812 vu_panic(dev, "Invalid queue index: %u", index); 1813 return false; 1814 } 1815 1816 DPRINT("Got kick message: handler:%p idx:%u\n", 1817 dev->vq[index].handler, index); 1818 1819 if (!dev->vq[index].started) { 1820 dev->vq[index].started = true; 1821 1822 if (dev->iface->queue_set_started) { 1823 dev->iface->queue_set_started(dev, index, true); 1824 } 1825 } 1826 1827 if (dev->vq[index].handler) { 1828 dev->vq[index].handler(dev, index); 1829 } 1830 1831 return false; 1832 } 1833 1834 static bool vu_handle_get_max_memslots(VuDev *dev, VhostUserMsg *vmsg) 1835 { 1836 vmsg_set_reply_u64(vmsg, VHOST_USER_MAX_RAM_SLOTS); 1837 1838 DPRINT("u64: 0x%016"PRIx64"\n", (uint64_t) VHOST_USER_MAX_RAM_SLOTS); 1839 1840 return true; 1841 } 1842 1843 static bool 1844 vu_process_message(VuDev *dev, VhostUserMsg *vmsg) 1845 { 1846 int do_reply = 0; 1847 1848 /* Print out generic part of the request. */ 1849 DPRINT("================ Vhost user message ================\n"); 1850 DPRINT("Request: %s (%d)\n", vu_request_to_string(vmsg->request), 1851 vmsg->request); 1852 DPRINT("Flags: 0x%x\n", vmsg->flags); 1853 DPRINT("Size: %u\n", vmsg->size); 1854 1855 if (vmsg->fd_num) { 1856 int i; 1857 DPRINT("Fds:"); 1858 for (i = 0; i < vmsg->fd_num; i++) { 1859 DPRINT(" %d", vmsg->fds[i]); 1860 } 1861 DPRINT("\n"); 1862 } 1863 1864 if (dev->iface->process_msg && 1865 dev->iface->process_msg(dev, vmsg, &do_reply)) { 1866 return do_reply; 1867 } 1868 1869 switch (vmsg->request) { 1870 case VHOST_USER_GET_FEATURES: 1871 return vu_get_features_exec(dev, vmsg); 1872 case VHOST_USER_SET_FEATURES: 1873 return vu_set_features_exec(dev, vmsg); 1874 case VHOST_USER_GET_PROTOCOL_FEATURES: 1875 return vu_get_protocol_features_exec(dev, vmsg); 1876 case VHOST_USER_SET_PROTOCOL_FEATURES: 1877 return vu_set_protocol_features_exec(dev, vmsg); 1878 case VHOST_USER_SET_OWNER: 1879 return vu_set_owner_exec(dev, vmsg); 1880 case VHOST_USER_RESET_OWNER: 1881 return vu_reset_device_exec(dev, vmsg); 1882 case VHOST_USER_SET_MEM_TABLE: 1883 return vu_set_mem_table_exec(dev, vmsg); 1884 case VHOST_USER_SET_LOG_BASE: 1885 return vu_set_log_base_exec(dev, vmsg); 1886 case VHOST_USER_SET_LOG_FD: 1887 return vu_set_log_fd_exec(dev, vmsg); 1888 case VHOST_USER_SET_VRING_NUM: 1889 return vu_set_vring_num_exec(dev, vmsg); 1890 case VHOST_USER_SET_VRING_ADDR: 1891 return vu_set_vring_addr_exec(dev, vmsg); 1892 case VHOST_USER_SET_VRING_BASE: 1893 return vu_set_vring_base_exec(dev, vmsg); 1894 case VHOST_USER_GET_VRING_BASE: 1895 return vu_get_vring_base_exec(dev, vmsg); 1896 case VHOST_USER_SET_VRING_KICK: 1897 return vu_set_vring_kick_exec(dev, vmsg); 1898 case VHOST_USER_SET_VRING_CALL: 1899 return vu_set_vring_call_exec(dev, vmsg); 1900 case VHOST_USER_SET_VRING_ERR: 1901 return vu_set_vring_err_exec(dev, vmsg); 1902 case VHOST_USER_GET_QUEUE_NUM: 1903 return vu_get_queue_num_exec(dev, vmsg); 1904 case VHOST_USER_SET_VRING_ENABLE: 1905 return vu_set_vring_enable_exec(dev, vmsg); 1906 case VHOST_USER_SET_SLAVE_REQ_FD: 1907 return vu_set_slave_req_fd(dev, vmsg); 1908 case VHOST_USER_GET_CONFIG: 1909 return vu_get_config(dev, vmsg); 1910 case VHOST_USER_SET_CONFIG: 1911 return vu_set_config(dev, vmsg); 1912 case VHOST_USER_NONE: 1913 /* if you need processing before exit, override iface->process_msg */ 1914 exit(0); 1915 case VHOST_USER_POSTCOPY_ADVISE: 1916 return vu_set_postcopy_advise(dev, vmsg); 1917 case VHOST_USER_POSTCOPY_LISTEN: 1918 return vu_set_postcopy_listen(dev, vmsg); 1919 case VHOST_USER_POSTCOPY_END: 1920 return vu_set_postcopy_end(dev, vmsg); 1921 case VHOST_USER_GET_INFLIGHT_FD: 1922 return vu_get_inflight_fd(dev, vmsg); 1923 case VHOST_USER_SET_INFLIGHT_FD: 1924 return vu_set_inflight_fd(dev, vmsg); 1925 case VHOST_USER_VRING_KICK: 1926 return vu_handle_vring_kick(dev, vmsg); 1927 case VHOST_USER_GET_MAX_MEM_SLOTS: 1928 return vu_handle_get_max_memslots(dev, vmsg); 1929 case VHOST_USER_ADD_MEM_REG: 1930 return vu_add_mem_reg(dev, vmsg); 1931 case VHOST_USER_REM_MEM_REG: 1932 return vu_rem_mem_reg(dev, vmsg); 1933 default: 1934 vmsg_close_fds(vmsg); 1935 vu_panic(dev, "Unhandled request: %d", vmsg->request); 1936 } 1937 1938 return false; 1939 } 1940 1941 bool 1942 vu_dispatch(VuDev *dev) 1943 { 1944 VhostUserMsg vmsg = { 0, }; 1945 int reply_requested; 1946 bool need_reply, success = false; 1947 1948 if (!dev->read_msg(dev, dev->sock, &vmsg)) { 1949 goto end; 1950 } 1951 1952 need_reply = vmsg.flags & VHOST_USER_NEED_REPLY_MASK; 1953 1954 reply_requested = vu_process_message(dev, &vmsg); 1955 if (!reply_requested && need_reply) { 1956 vmsg_set_reply_u64(&vmsg, 0); 1957 reply_requested = 1; 1958 } 1959 1960 if (!reply_requested) { 1961 success = true; 1962 goto end; 1963 } 1964 1965 if (!vu_send_reply(dev, dev->sock, &vmsg)) { 1966 goto end; 1967 } 1968 1969 success = true; 1970 1971 end: 1972 free(vmsg.data); 1973 return success; 1974 } 1975 1976 void 1977 vu_deinit(VuDev *dev) 1978 { 1979 int i; 1980 1981 for (i = 0; i < dev->nregions; i++) { 1982 VuDevRegion *r = &dev->regions[i]; 1983 void *m = (void *) (uintptr_t) r->mmap_addr; 1984 if (m != MAP_FAILED) { 1985 munmap(m, r->size + r->mmap_offset); 1986 } 1987 } 1988 dev->nregions = 0; 1989 1990 for (i = 0; i < dev->max_queues; i++) { 1991 VuVirtq *vq = &dev->vq[i]; 1992 1993 if (vq->call_fd != -1) { 1994 close(vq->call_fd); 1995 vq->call_fd = -1; 1996 } 1997 1998 if (vq->kick_fd != -1) { 1999 dev->remove_watch(dev, vq->kick_fd); 2000 close(vq->kick_fd); 2001 vq->kick_fd = -1; 2002 } 2003 2004 if (vq->err_fd != -1) { 2005 close(vq->err_fd); 2006 vq->err_fd = -1; 2007 } 2008 2009 if (vq->resubmit_list) { 2010 free(vq->resubmit_list); 2011 vq->resubmit_list = NULL; 2012 } 2013 2014 vq->inflight = NULL; 2015 } 2016 2017 if (dev->inflight_info.addr) { 2018 munmap(dev->inflight_info.addr, dev->inflight_info.size); 2019 dev->inflight_info.addr = NULL; 2020 } 2021 2022 if (dev->inflight_info.fd > 0) { 2023 close(dev->inflight_info.fd); 2024 dev->inflight_info.fd = -1; 2025 } 2026 2027 vu_close_log(dev); 2028 if (dev->slave_fd != -1) { 2029 close(dev->slave_fd); 2030 dev->slave_fd = -1; 2031 } 2032 pthread_mutex_destroy(&dev->slave_mutex); 2033 2034 if (dev->sock != -1) { 2035 close(dev->sock); 2036 } 2037 2038 free(dev->vq); 2039 dev->vq = NULL; 2040 } 2041 2042 bool 2043 vu_init(VuDev *dev, 2044 uint16_t max_queues, 2045 int socket, 2046 vu_panic_cb panic, 2047 vu_read_msg_cb read_msg, 2048 vu_set_watch_cb set_watch, 2049 vu_remove_watch_cb remove_watch, 2050 const VuDevIface *iface) 2051 { 2052 uint16_t i; 2053 2054 assert(max_queues > 0); 2055 assert(socket >= 0); 2056 assert(set_watch); 2057 assert(remove_watch); 2058 assert(iface); 2059 assert(panic); 2060 2061 memset(dev, 0, sizeof(*dev)); 2062 2063 dev->sock = socket; 2064 dev->panic = panic; 2065 dev->read_msg = read_msg ? read_msg : vu_message_read_default; 2066 dev->set_watch = set_watch; 2067 dev->remove_watch = remove_watch; 2068 dev->iface = iface; 2069 dev->log_call_fd = -1; 2070 pthread_mutex_init(&dev->slave_mutex, NULL); 2071 dev->slave_fd = -1; 2072 dev->max_queues = max_queues; 2073 2074 dev->vq = malloc(max_queues * sizeof(dev->vq[0])); 2075 if (!dev->vq) { 2076 DPRINT("%s: failed to malloc virtqueues\n", __func__); 2077 return false; 2078 } 2079 2080 for (i = 0; i < max_queues; i++) { 2081 dev->vq[i] = (VuVirtq) { 2082 .call_fd = -1, .kick_fd = -1, .err_fd = -1, 2083 .notification = true, 2084 }; 2085 } 2086 2087 return true; 2088 } 2089 2090 VuVirtq * 2091 vu_get_queue(VuDev *dev, int qidx) 2092 { 2093 assert(qidx < dev->max_queues); 2094 return &dev->vq[qidx]; 2095 } 2096 2097 bool 2098 vu_queue_enabled(VuDev *dev, VuVirtq *vq) 2099 { 2100 return vq->enable; 2101 } 2102 2103 bool 2104 vu_queue_started(const VuDev *dev, const VuVirtq *vq) 2105 { 2106 return vq->started; 2107 } 2108 2109 static inline uint16_t 2110 vring_avail_flags(VuVirtq *vq) 2111 { 2112 return le16toh(vq->vring.avail->flags); 2113 } 2114 2115 static inline uint16_t 2116 vring_avail_idx(VuVirtq *vq) 2117 { 2118 vq->shadow_avail_idx = le16toh(vq->vring.avail->idx); 2119 2120 return vq->shadow_avail_idx; 2121 } 2122 2123 static inline uint16_t 2124 vring_avail_ring(VuVirtq *vq, int i) 2125 { 2126 return le16toh(vq->vring.avail->ring[i]); 2127 } 2128 2129 static inline uint16_t 2130 vring_get_used_event(VuVirtq *vq) 2131 { 2132 return vring_avail_ring(vq, vq->vring.num); 2133 } 2134 2135 static int 2136 virtqueue_num_heads(VuDev *dev, VuVirtq *vq, unsigned int idx) 2137 { 2138 uint16_t num_heads = vring_avail_idx(vq) - idx; 2139 2140 /* Check it isn't doing very strange things with descriptor numbers. */ 2141 if (num_heads > vq->vring.num) { 2142 vu_panic(dev, "Guest moved used index from %u to %u", 2143 idx, vq->shadow_avail_idx); 2144 return -1; 2145 } 2146 if (num_heads) { 2147 /* On success, callers read a descriptor at vq->last_avail_idx. 2148 * Make sure descriptor read does not bypass avail index read. */ 2149 smp_rmb(); 2150 } 2151 2152 return num_heads; 2153 } 2154 2155 static bool 2156 virtqueue_get_head(VuDev *dev, VuVirtq *vq, 2157 unsigned int idx, unsigned int *head) 2158 { 2159 /* Grab the next descriptor number they're advertising, and increment 2160 * the index we've seen. */ 2161 *head = vring_avail_ring(vq, idx % vq->vring.num); 2162 2163 /* If their number is silly, that's a fatal mistake. */ 2164 if (*head >= vq->vring.num) { 2165 vu_panic(dev, "Guest says index %u is available", *head); 2166 return false; 2167 } 2168 2169 return true; 2170 } 2171 2172 static int 2173 virtqueue_read_indirect_desc(VuDev *dev, struct vring_desc *desc, 2174 uint64_t addr, size_t len) 2175 { 2176 struct vring_desc *ori_desc; 2177 uint64_t read_len; 2178 2179 if (len > (VIRTQUEUE_MAX_SIZE * sizeof(struct vring_desc))) { 2180 return -1; 2181 } 2182 2183 if (len == 0) { 2184 return -1; 2185 } 2186 2187 while (len) { 2188 read_len = len; 2189 ori_desc = vu_gpa_to_va(dev, &read_len, addr); 2190 if (!ori_desc) { 2191 return -1; 2192 } 2193 2194 memcpy(desc, ori_desc, read_len); 2195 len -= read_len; 2196 addr += read_len; 2197 desc += read_len; 2198 } 2199 2200 return 0; 2201 } 2202 2203 enum { 2204 VIRTQUEUE_READ_DESC_ERROR = -1, 2205 VIRTQUEUE_READ_DESC_DONE = 0, /* end of chain */ 2206 VIRTQUEUE_READ_DESC_MORE = 1, /* more buffers in chain */ 2207 }; 2208 2209 static int 2210 virtqueue_read_next_desc(VuDev *dev, struct vring_desc *desc, 2211 int i, unsigned int max, unsigned int *next) 2212 { 2213 /* If this descriptor says it doesn't chain, we're done. */ 2214 if (!(le16toh(desc[i].flags) & VRING_DESC_F_NEXT)) { 2215 return VIRTQUEUE_READ_DESC_DONE; 2216 } 2217 2218 /* Check they're not leading us off end of descriptors. */ 2219 *next = le16toh(desc[i].next); 2220 /* Make sure compiler knows to grab that: we don't want it changing! */ 2221 smp_wmb(); 2222 2223 if (*next >= max) { 2224 vu_panic(dev, "Desc next is %u", *next); 2225 return VIRTQUEUE_READ_DESC_ERROR; 2226 } 2227 2228 return VIRTQUEUE_READ_DESC_MORE; 2229 } 2230 2231 void 2232 vu_queue_get_avail_bytes(VuDev *dev, VuVirtq *vq, unsigned int *in_bytes, 2233 unsigned int *out_bytes, 2234 unsigned max_in_bytes, unsigned max_out_bytes) 2235 { 2236 unsigned int idx; 2237 unsigned int total_bufs, in_total, out_total; 2238 int rc; 2239 2240 idx = vq->last_avail_idx; 2241 2242 total_bufs = in_total = out_total = 0; 2243 if (unlikely(dev->broken) || 2244 unlikely(!vq->vring.avail)) { 2245 goto done; 2246 } 2247 2248 while ((rc = virtqueue_num_heads(dev, vq, idx)) > 0) { 2249 unsigned int max, desc_len, num_bufs, indirect = 0; 2250 uint64_t desc_addr, read_len; 2251 struct vring_desc *desc; 2252 struct vring_desc desc_buf[VIRTQUEUE_MAX_SIZE]; 2253 unsigned int i; 2254 2255 max = vq->vring.num; 2256 num_bufs = total_bufs; 2257 if (!virtqueue_get_head(dev, vq, idx++, &i)) { 2258 goto err; 2259 } 2260 desc = vq->vring.desc; 2261 2262 if (le16toh(desc[i].flags) & VRING_DESC_F_INDIRECT) { 2263 if (le32toh(desc[i].len) % sizeof(struct vring_desc)) { 2264 vu_panic(dev, "Invalid size for indirect buffer table"); 2265 goto err; 2266 } 2267 2268 /* If we've got too many, that implies a descriptor loop. */ 2269 if (num_bufs >= max) { 2270 vu_panic(dev, "Looped descriptor"); 2271 goto err; 2272 } 2273 2274 /* loop over the indirect descriptor table */ 2275 indirect = 1; 2276 desc_addr = le64toh(desc[i].addr); 2277 desc_len = le32toh(desc[i].len); 2278 max = desc_len / sizeof(struct vring_desc); 2279 read_len = desc_len; 2280 desc = vu_gpa_to_va(dev, &read_len, desc_addr); 2281 if (unlikely(desc && read_len != desc_len)) { 2282 /* Failed to use zero copy */ 2283 desc = NULL; 2284 if (!virtqueue_read_indirect_desc(dev, desc_buf, 2285 desc_addr, 2286 desc_len)) { 2287 desc = desc_buf; 2288 } 2289 } 2290 if (!desc) { 2291 vu_panic(dev, "Invalid indirect buffer table"); 2292 goto err; 2293 } 2294 num_bufs = i = 0; 2295 } 2296 2297 do { 2298 /* If we've got too many, that implies a descriptor loop. */ 2299 if (++num_bufs > max) { 2300 vu_panic(dev, "Looped descriptor"); 2301 goto err; 2302 } 2303 2304 if (le16toh(desc[i].flags) & VRING_DESC_F_WRITE) { 2305 in_total += le32toh(desc[i].len); 2306 } else { 2307 out_total += le32toh(desc[i].len); 2308 } 2309 if (in_total >= max_in_bytes && out_total >= max_out_bytes) { 2310 goto done; 2311 } 2312 rc = virtqueue_read_next_desc(dev, desc, i, max, &i); 2313 } while (rc == VIRTQUEUE_READ_DESC_MORE); 2314 2315 if (rc == VIRTQUEUE_READ_DESC_ERROR) { 2316 goto err; 2317 } 2318 2319 if (!indirect) { 2320 total_bufs = num_bufs; 2321 } else { 2322 total_bufs++; 2323 } 2324 } 2325 if (rc < 0) { 2326 goto err; 2327 } 2328 done: 2329 if (in_bytes) { 2330 *in_bytes = in_total; 2331 } 2332 if (out_bytes) { 2333 *out_bytes = out_total; 2334 } 2335 return; 2336 2337 err: 2338 in_total = out_total = 0; 2339 goto done; 2340 } 2341 2342 bool 2343 vu_queue_avail_bytes(VuDev *dev, VuVirtq *vq, unsigned int in_bytes, 2344 unsigned int out_bytes) 2345 { 2346 unsigned int in_total, out_total; 2347 2348 vu_queue_get_avail_bytes(dev, vq, &in_total, &out_total, 2349 in_bytes, out_bytes); 2350 2351 return in_bytes <= in_total && out_bytes <= out_total; 2352 } 2353 2354 /* Fetch avail_idx from VQ memory only when we really need to know if 2355 * guest has added some buffers. */ 2356 bool 2357 vu_queue_empty(VuDev *dev, VuVirtq *vq) 2358 { 2359 if (unlikely(dev->broken) || 2360 unlikely(!vq->vring.avail)) { 2361 return true; 2362 } 2363 2364 if (vq->shadow_avail_idx != vq->last_avail_idx) { 2365 return false; 2366 } 2367 2368 return vring_avail_idx(vq) == vq->last_avail_idx; 2369 } 2370 2371 static bool 2372 vring_notify(VuDev *dev, VuVirtq *vq) 2373 { 2374 uint16_t old, new; 2375 bool v; 2376 2377 /* We need to expose used array entries before checking used event. */ 2378 smp_mb(); 2379 2380 /* Always notify when queue is empty (when feature acknowledge) */ 2381 if (vu_has_feature(dev, VIRTIO_F_NOTIFY_ON_EMPTY) && 2382 !vq->inuse && vu_queue_empty(dev, vq)) { 2383 return true; 2384 } 2385 2386 if (!vu_has_feature(dev, VIRTIO_RING_F_EVENT_IDX)) { 2387 return !(vring_avail_flags(vq) & VRING_AVAIL_F_NO_INTERRUPT); 2388 } 2389 2390 v = vq->signalled_used_valid; 2391 vq->signalled_used_valid = true; 2392 old = vq->signalled_used; 2393 new = vq->signalled_used = vq->used_idx; 2394 return !v || vring_need_event(vring_get_used_event(vq), new, old); 2395 } 2396 2397 static void _vu_queue_notify(VuDev *dev, VuVirtq *vq, bool sync) 2398 { 2399 if (unlikely(dev->broken) || 2400 unlikely(!vq->vring.avail)) { 2401 return; 2402 } 2403 2404 if (!vring_notify(dev, vq)) { 2405 DPRINT("skipped notify...\n"); 2406 return; 2407 } 2408 2409 if (vq->call_fd < 0 && 2410 vu_has_protocol_feature(dev, 2411 VHOST_USER_PROTOCOL_F_INBAND_NOTIFICATIONS) && 2412 vu_has_protocol_feature(dev, VHOST_USER_PROTOCOL_F_SLAVE_REQ)) { 2413 VhostUserMsg vmsg = { 2414 .request = VHOST_USER_SLAVE_VRING_CALL, 2415 .flags = VHOST_USER_VERSION, 2416 .size = sizeof(vmsg.payload.state), 2417 .payload.state = { 2418 .index = vq - dev->vq, 2419 }, 2420 }; 2421 bool ack = sync && 2422 vu_has_protocol_feature(dev, 2423 VHOST_USER_PROTOCOL_F_REPLY_ACK); 2424 2425 if (ack) { 2426 vmsg.flags |= VHOST_USER_NEED_REPLY_MASK; 2427 } 2428 2429 vu_message_write(dev, dev->slave_fd, &vmsg); 2430 if (ack) { 2431 vu_message_read_default(dev, dev->slave_fd, &vmsg); 2432 } 2433 return; 2434 } 2435 2436 if (eventfd_write(vq->call_fd, 1) < 0) { 2437 vu_panic(dev, "Error writing eventfd: %s", strerror(errno)); 2438 } 2439 } 2440 2441 void vu_queue_notify(VuDev *dev, VuVirtq *vq) 2442 { 2443 _vu_queue_notify(dev, vq, false); 2444 } 2445 2446 void vu_queue_notify_sync(VuDev *dev, VuVirtq *vq) 2447 { 2448 _vu_queue_notify(dev, vq, true); 2449 } 2450 2451 static inline void 2452 vring_used_flags_set_bit(VuVirtq *vq, int mask) 2453 { 2454 uint16_t *flags; 2455 2456 flags = (uint16_t *)((char*)vq->vring.used + 2457 offsetof(struct vring_used, flags)); 2458 *flags = htole16(le16toh(*flags) | mask); 2459 } 2460 2461 static inline void 2462 vring_used_flags_unset_bit(VuVirtq *vq, int mask) 2463 { 2464 uint16_t *flags; 2465 2466 flags = (uint16_t *)((char*)vq->vring.used + 2467 offsetof(struct vring_used, flags)); 2468 *flags = htole16(le16toh(*flags) & ~mask); 2469 } 2470 2471 static inline void 2472 vring_set_avail_event(VuVirtq *vq, uint16_t val) 2473 { 2474 uint16_t *avail; 2475 2476 if (!vq->notification) { 2477 return; 2478 } 2479 2480 avail = (uint16_t *)&vq->vring.used->ring[vq->vring.num]; 2481 *avail = htole16(val); 2482 } 2483 2484 void 2485 vu_queue_set_notification(VuDev *dev, VuVirtq *vq, int enable) 2486 { 2487 vq->notification = enable; 2488 if (vu_has_feature(dev, VIRTIO_RING_F_EVENT_IDX)) { 2489 vring_set_avail_event(vq, vring_avail_idx(vq)); 2490 } else if (enable) { 2491 vring_used_flags_unset_bit(vq, VRING_USED_F_NO_NOTIFY); 2492 } else { 2493 vring_used_flags_set_bit(vq, VRING_USED_F_NO_NOTIFY); 2494 } 2495 if (enable) { 2496 /* Expose avail event/used flags before caller checks the avail idx. */ 2497 smp_mb(); 2498 } 2499 } 2500 2501 static bool 2502 virtqueue_map_desc(VuDev *dev, 2503 unsigned int *p_num_sg, struct iovec *iov, 2504 unsigned int max_num_sg, bool is_write, 2505 uint64_t pa, size_t sz) 2506 { 2507 unsigned num_sg = *p_num_sg; 2508 2509 assert(num_sg <= max_num_sg); 2510 2511 if (!sz) { 2512 vu_panic(dev, "virtio: zero sized buffers are not allowed"); 2513 return false; 2514 } 2515 2516 while (sz) { 2517 uint64_t len = sz; 2518 2519 if (num_sg == max_num_sg) { 2520 vu_panic(dev, "virtio: too many descriptors in indirect table"); 2521 return false; 2522 } 2523 2524 iov[num_sg].iov_base = vu_gpa_to_va(dev, &len, pa); 2525 if (iov[num_sg].iov_base == NULL) { 2526 vu_panic(dev, "virtio: invalid address for buffers"); 2527 return false; 2528 } 2529 iov[num_sg].iov_len = len; 2530 num_sg++; 2531 sz -= len; 2532 pa += len; 2533 } 2534 2535 *p_num_sg = num_sg; 2536 return true; 2537 } 2538 2539 static void * 2540 virtqueue_alloc_element(size_t sz, 2541 unsigned out_num, unsigned in_num) 2542 { 2543 VuVirtqElement *elem; 2544 size_t in_sg_ofs = ALIGN_UP(sz, __alignof__(elem->in_sg[0])); 2545 size_t out_sg_ofs = in_sg_ofs + in_num * sizeof(elem->in_sg[0]); 2546 size_t out_sg_end = out_sg_ofs + out_num * sizeof(elem->out_sg[0]); 2547 2548 assert(sz >= sizeof(VuVirtqElement)); 2549 elem = malloc(out_sg_end); 2550 elem->out_num = out_num; 2551 elem->in_num = in_num; 2552 elem->in_sg = (void *)elem + in_sg_ofs; 2553 elem->out_sg = (void *)elem + out_sg_ofs; 2554 return elem; 2555 } 2556 2557 static void * 2558 vu_queue_map_desc(VuDev *dev, VuVirtq *vq, unsigned int idx, size_t sz) 2559 { 2560 struct vring_desc *desc = vq->vring.desc; 2561 uint64_t desc_addr, read_len; 2562 unsigned int desc_len; 2563 unsigned int max = vq->vring.num; 2564 unsigned int i = idx; 2565 VuVirtqElement *elem; 2566 unsigned int out_num = 0, in_num = 0; 2567 struct iovec iov[VIRTQUEUE_MAX_SIZE]; 2568 struct vring_desc desc_buf[VIRTQUEUE_MAX_SIZE]; 2569 int rc; 2570 2571 if (le16toh(desc[i].flags) & VRING_DESC_F_INDIRECT) { 2572 if (le32toh(desc[i].len) % sizeof(struct vring_desc)) { 2573 vu_panic(dev, "Invalid size for indirect buffer table"); 2574 return NULL; 2575 } 2576 2577 /* loop over the indirect descriptor table */ 2578 desc_addr = le64toh(desc[i].addr); 2579 desc_len = le32toh(desc[i].len); 2580 max = desc_len / sizeof(struct vring_desc); 2581 read_len = desc_len; 2582 desc = vu_gpa_to_va(dev, &read_len, desc_addr); 2583 if (unlikely(desc && read_len != desc_len)) { 2584 /* Failed to use zero copy */ 2585 desc = NULL; 2586 if (!virtqueue_read_indirect_desc(dev, desc_buf, 2587 desc_addr, 2588 desc_len)) { 2589 desc = desc_buf; 2590 } 2591 } 2592 if (!desc) { 2593 vu_panic(dev, "Invalid indirect buffer table"); 2594 return NULL; 2595 } 2596 i = 0; 2597 } 2598 2599 /* Collect all the descriptors */ 2600 do { 2601 if (le16toh(desc[i].flags) & VRING_DESC_F_WRITE) { 2602 if (!virtqueue_map_desc(dev, &in_num, iov + out_num, 2603 VIRTQUEUE_MAX_SIZE - out_num, true, 2604 le64toh(desc[i].addr), 2605 le32toh(desc[i].len))) { 2606 return NULL; 2607 } 2608 } else { 2609 if (in_num) { 2610 vu_panic(dev, "Incorrect order for descriptors"); 2611 return NULL; 2612 } 2613 if (!virtqueue_map_desc(dev, &out_num, iov, 2614 VIRTQUEUE_MAX_SIZE, false, 2615 le64toh(desc[i].addr), 2616 le32toh(desc[i].len))) { 2617 return NULL; 2618 } 2619 } 2620 2621 /* If we've got too many, that implies a descriptor loop. */ 2622 if ((in_num + out_num) > max) { 2623 vu_panic(dev, "Looped descriptor"); 2624 return NULL; 2625 } 2626 rc = virtqueue_read_next_desc(dev, desc, i, max, &i); 2627 } while (rc == VIRTQUEUE_READ_DESC_MORE); 2628 2629 if (rc == VIRTQUEUE_READ_DESC_ERROR) { 2630 vu_panic(dev, "read descriptor error"); 2631 return NULL; 2632 } 2633 2634 /* Now copy what we have collected and mapped */ 2635 elem = virtqueue_alloc_element(sz, out_num, in_num); 2636 elem->index = idx; 2637 for (i = 0; i < out_num; i++) { 2638 elem->out_sg[i] = iov[i]; 2639 } 2640 for (i = 0; i < in_num; i++) { 2641 elem->in_sg[i] = iov[out_num + i]; 2642 } 2643 2644 return elem; 2645 } 2646 2647 static int 2648 vu_queue_inflight_get(VuDev *dev, VuVirtq *vq, int desc_idx) 2649 { 2650 if (!vu_has_protocol_feature(dev, VHOST_USER_PROTOCOL_F_INFLIGHT_SHMFD)) { 2651 return 0; 2652 } 2653 2654 if (unlikely(!vq->inflight)) { 2655 return -1; 2656 } 2657 2658 vq->inflight->desc[desc_idx].counter = vq->counter++; 2659 vq->inflight->desc[desc_idx].inflight = 1; 2660 2661 return 0; 2662 } 2663 2664 static int 2665 vu_queue_inflight_pre_put(VuDev *dev, VuVirtq *vq, int desc_idx) 2666 { 2667 if (!vu_has_protocol_feature(dev, VHOST_USER_PROTOCOL_F_INFLIGHT_SHMFD)) { 2668 return 0; 2669 } 2670 2671 if (unlikely(!vq->inflight)) { 2672 return -1; 2673 } 2674 2675 vq->inflight->last_batch_head = desc_idx; 2676 2677 return 0; 2678 } 2679 2680 static int 2681 vu_queue_inflight_post_put(VuDev *dev, VuVirtq *vq, int desc_idx) 2682 { 2683 if (!vu_has_protocol_feature(dev, VHOST_USER_PROTOCOL_F_INFLIGHT_SHMFD)) { 2684 return 0; 2685 } 2686 2687 if (unlikely(!vq->inflight)) { 2688 return -1; 2689 } 2690 2691 barrier(); 2692 2693 vq->inflight->desc[desc_idx].inflight = 0; 2694 2695 barrier(); 2696 2697 vq->inflight->used_idx = vq->used_idx; 2698 2699 return 0; 2700 } 2701 2702 void * 2703 vu_queue_pop(VuDev *dev, VuVirtq *vq, size_t sz) 2704 { 2705 int i; 2706 unsigned int head; 2707 VuVirtqElement *elem; 2708 2709 if (unlikely(dev->broken) || 2710 unlikely(!vq->vring.avail)) { 2711 return NULL; 2712 } 2713 2714 if (unlikely(vq->resubmit_list && vq->resubmit_num > 0)) { 2715 i = (--vq->resubmit_num); 2716 elem = vu_queue_map_desc(dev, vq, vq->resubmit_list[i].index, sz); 2717 2718 if (!vq->resubmit_num) { 2719 free(vq->resubmit_list); 2720 vq->resubmit_list = NULL; 2721 } 2722 2723 return elem; 2724 } 2725 2726 if (vu_queue_empty(dev, vq)) { 2727 return NULL; 2728 } 2729 /* 2730 * Needed after virtio_queue_empty(), see comment in 2731 * virtqueue_num_heads(). 2732 */ 2733 smp_rmb(); 2734 2735 if (vq->inuse >= vq->vring.num) { 2736 vu_panic(dev, "Virtqueue size exceeded"); 2737 return NULL; 2738 } 2739 2740 if (!virtqueue_get_head(dev, vq, vq->last_avail_idx++, &head)) { 2741 return NULL; 2742 } 2743 2744 if (vu_has_feature(dev, VIRTIO_RING_F_EVENT_IDX)) { 2745 vring_set_avail_event(vq, vq->last_avail_idx); 2746 } 2747 2748 elem = vu_queue_map_desc(dev, vq, head, sz); 2749 2750 if (!elem) { 2751 return NULL; 2752 } 2753 2754 vq->inuse++; 2755 2756 vu_queue_inflight_get(dev, vq, head); 2757 2758 return elem; 2759 } 2760 2761 static void 2762 vu_queue_detach_element(VuDev *dev, VuVirtq *vq, VuVirtqElement *elem, 2763 size_t len) 2764 { 2765 vq->inuse--; 2766 /* unmap, when DMA support is added */ 2767 } 2768 2769 void 2770 vu_queue_unpop(VuDev *dev, VuVirtq *vq, VuVirtqElement *elem, 2771 size_t len) 2772 { 2773 vq->last_avail_idx--; 2774 vu_queue_detach_element(dev, vq, elem, len); 2775 } 2776 2777 bool 2778 vu_queue_rewind(VuDev *dev, VuVirtq *vq, unsigned int num) 2779 { 2780 if (num > vq->inuse) { 2781 return false; 2782 } 2783 vq->last_avail_idx -= num; 2784 vq->inuse -= num; 2785 return true; 2786 } 2787 2788 static inline 2789 void vring_used_write(VuDev *dev, VuVirtq *vq, 2790 struct vring_used_elem *uelem, int i) 2791 { 2792 struct vring_used *used = vq->vring.used; 2793 2794 used->ring[i] = *uelem; 2795 vu_log_write(dev, vq->vring.log_guest_addr + 2796 offsetof(struct vring_used, ring[i]), 2797 sizeof(used->ring[i])); 2798 } 2799 2800 2801 static void 2802 vu_log_queue_fill(VuDev *dev, VuVirtq *vq, 2803 const VuVirtqElement *elem, 2804 unsigned int len) 2805 { 2806 struct vring_desc *desc = vq->vring.desc; 2807 unsigned int i, max, min, desc_len; 2808 uint64_t desc_addr, read_len; 2809 struct vring_desc desc_buf[VIRTQUEUE_MAX_SIZE]; 2810 unsigned num_bufs = 0; 2811 2812 max = vq->vring.num; 2813 i = elem->index; 2814 2815 if (le16toh(desc[i].flags) & VRING_DESC_F_INDIRECT) { 2816 if (le32toh(desc[i].len) % sizeof(struct vring_desc)) { 2817 vu_panic(dev, "Invalid size for indirect buffer table"); 2818 return; 2819 } 2820 2821 /* loop over the indirect descriptor table */ 2822 desc_addr = le64toh(desc[i].addr); 2823 desc_len = le32toh(desc[i].len); 2824 max = desc_len / sizeof(struct vring_desc); 2825 read_len = desc_len; 2826 desc = vu_gpa_to_va(dev, &read_len, desc_addr); 2827 if (unlikely(desc && read_len != desc_len)) { 2828 /* Failed to use zero copy */ 2829 desc = NULL; 2830 if (!virtqueue_read_indirect_desc(dev, desc_buf, 2831 desc_addr, 2832 desc_len)) { 2833 desc = desc_buf; 2834 } 2835 } 2836 if (!desc) { 2837 vu_panic(dev, "Invalid indirect buffer table"); 2838 return; 2839 } 2840 i = 0; 2841 } 2842 2843 do { 2844 if (++num_bufs > max) { 2845 vu_panic(dev, "Looped descriptor"); 2846 return; 2847 } 2848 2849 if (le16toh(desc[i].flags) & VRING_DESC_F_WRITE) { 2850 min = MIN(le32toh(desc[i].len), len); 2851 vu_log_write(dev, le64toh(desc[i].addr), min); 2852 len -= min; 2853 } 2854 2855 } while (len > 0 && 2856 (virtqueue_read_next_desc(dev, desc, i, max, &i) 2857 == VIRTQUEUE_READ_DESC_MORE)); 2858 } 2859 2860 void 2861 vu_queue_fill(VuDev *dev, VuVirtq *vq, 2862 const VuVirtqElement *elem, 2863 unsigned int len, unsigned int idx) 2864 { 2865 struct vring_used_elem uelem; 2866 2867 if (unlikely(dev->broken) || 2868 unlikely(!vq->vring.avail)) { 2869 return; 2870 } 2871 2872 vu_log_queue_fill(dev, vq, elem, len); 2873 2874 idx = (idx + vq->used_idx) % vq->vring.num; 2875 2876 uelem.id = htole32(elem->index); 2877 uelem.len = htole32(len); 2878 vring_used_write(dev, vq, &uelem, idx); 2879 } 2880 2881 static inline 2882 void vring_used_idx_set(VuDev *dev, VuVirtq *vq, uint16_t val) 2883 { 2884 vq->vring.used->idx = htole16(val); 2885 vu_log_write(dev, 2886 vq->vring.log_guest_addr + offsetof(struct vring_used, idx), 2887 sizeof(vq->vring.used->idx)); 2888 2889 vq->used_idx = val; 2890 } 2891 2892 void 2893 vu_queue_flush(VuDev *dev, VuVirtq *vq, unsigned int count) 2894 { 2895 uint16_t old, new; 2896 2897 if (unlikely(dev->broken) || 2898 unlikely(!vq->vring.avail)) { 2899 return; 2900 } 2901 2902 /* Make sure buffer is written before we update index. */ 2903 smp_wmb(); 2904 2905 old = vq->used_idx; 2906 new = old + count; 2907 vring_used_idx_set(dev, vq, new); 2908 vq->inuse -= count; 2909 if (unlikely((int16_t)(new - vq->signalled_used) < (uint16_t)(new - old))) { 2910 vq->signalled_used_valid = false; 2911 } 2912 } 2913 2914 void 2915 vu_queue_push(VuDev *dev, VuVirtq *vq, 2916 const VuVirtqElement *elem, unsigned int len) 2917 { 2918 vu_queue_fill(dev, vq, elem, len, 0); 2919 vu_queue_inflight_pre_put(dev, vq, elem->index); 2920 vu_queue_flush(dev, vq, 1); 2921 vu_queue_inflight_post_put(dev, vq, elem->index); 2922 } 2923