xref: /openbmc/qemu/subprojects/libvhost-user/libvhost-user.c (revision 1406b7fc4bbaab42133b1ef03270179746e91723)
1  /*
2   * Vhost User library
3   *
4   * Copyright IBM, Corp. 2007
5   * Copyright (c) 2016 Red Hat, Inc.
6   *
7   * Authors:
8   *  Anthony Liguori <aliguori@us.ibm.com>
9   *  Marc-AndrĂ© Lureau <mlureau@redhat.com>
10   *  Victor Kaplansky <victork@redhat.com>
11   *
12   * This work is licensed under the terms of the GNU GPL, version 2 or
13   * later.  See the COPYING file in the top-level directory.
14   */
15  
16  #ifndef _GNU_SOURCE
17  #define _GNU_SOURCE
18  #endif
19  
20  /* this code avoids GLib dependency */
21  #include <stdlib.h>
22  #include <stdio.h>
23  #include <unistd.h>
24  #include <stdarg.h>
25  #include <errno.h>
26  #include <string.h>
27  #include <assert.h>
28  #include <inttypes.h>
29  #include <sys/types.h>
30  #include <sys/socket.h>
31  #include <sys/eventfd.h>
32  #include <sys/mman.h>
33  #include <endian.h>
34  
35  /* Necessary to provide VIRTIO_F_VERSION_1 on system
36   * with older linux headers. Must appear before
37   * <linux/vhost.h> below.
38   */
39  #include "standard-headers/linux/virtio_config.h"
40  
41  #if defined(__linux__)
42  #include <sys/syscall.h>
43  #include <fcntl.h>
44  #include <sys/ioctl.h>
45  #include <linux/vhost.h>
46  #include <sys/vfs.h>
47  #include <linux/magic.h>
48  
49  #ifdef __NR_userfaultfd
50  #include <linux/userfaultfd.h>
51  #endif
52  
53  #endif
54  
55  #include "include/atomic.h"
56  
57  #include "libvhost-user.h"
58  
59  /* usually provided by GLib */
60  #if     __GNUC__ > 2 || (__GNUC__ == 2 && __GNUC_MINOR__ > 4)
61  #if !defined(__clang__) && (__GNUC__ == 4 && __GNUC_MINOR__ == 4)
62  #define G_GNUC_PRINTF(format_idx, arg_idx) \
63    __attribute__((__format__(gnu_printf, format_idx, arg_idx)))
64  #else
65  #define G_GNUC_PRINTF(format_idx, arg_idx) \
66    __attribute__((__format__(__printf__, format_idx, arg_idx)))
67  #endif
68  #else   /* !__GNUC__ */
69  #define G_GNUC_PRINTF(format_idx, arg_idx)
70  #endif  /* !__GNUC__ */
71  #ifndef MIN
72  #define MIN(x, y) ({                            \
73              __typeof__(x) _min1 = (x);          \
74              __typeof__(y) _min2 = (y);          \
75              (void) (&_min1 == &_min2);          \
76              _min1 < _min2 ? _min1 : _min2; })
77  #endif
78  
79  /* Round number down to multiple */
80  #define ALIGN_DOWN(n, m) ((n) / (m) * (m))
81  
82  /* Round number up to multiple */
83  #define ALIGN_UP(n, m) ALIGN_DOWN((n) + (m) - 1, (m))
84  
85  #ifndef unlikely
86  #define unlikely(x)   __builtin_expect(!!(x), 0)
87  #endif
88  
89  /* Align each region to cache line size in inflight buffer */
90  #define INFLIGHT_ALIGNMENT 64
91  
92  /* The version of inflight buffer */
93  #define INFLIGHT_VERSION 1
94  
95  /* The version of the protocol we support */
96  #define VHOST_USER_VERSION 1
97  #define LIBVHOST_USER_DEBUG 0
98  
99  #define DPRINT(...)                             \
100      do {                                        \
101          if (LIBVHOST_USER_DEBUG) {              \
102              fprintf(stderr, __VA_ARGS__);        \
103          }                                       \
104      } while (0)
105  
106  static inline
has_feature(uint64_t features,unsigned int fbit)107  bool has_feature(uint64_t features, unsigned int fbit)
108  {
109      assert(fbit < 64);
110      return !!(features & (1ULL << fbit));
111  }
112  
113  static inline
vu_has_feature(VuDev * dev,unsigned int fbit)114  bool vu_has_feature(VuDev *dev,
115                      unsigned int fbit)
116  {
117      return has_feature(dev->features, fbit);
118  }
119  
vu_has_protocol_feature(VuDev * dev,unsigned int fbit)120  static inline bool vu_has_protocol_feature(VuDev *dev, unsigned int fbit)
121  {
122      return has_feature(dev->protocol_features, fbit);
123  }
124  
125  const char *
vu_request_to_string(unsigned int req)126  vu_request_to_string(unsigned int req)
127  {
128  #define REQ(req) [req] = #req
129      static const char *vu_request_str[] = {
130          REQ(VHOST_USER_NONE),
131          REQ(VHOST_USER_GET_FEATURES),
132          REQ(VHOST_USER_SET_FEATURES),
133          REQ(VHOST_USER_SET_OWNER),
134          REQ(VHOST_USER_RESET_OWNER),
135          REQ(VHOST_USER_SET_MEM_TABLE),
136          REQ(VHOST_USER_SET_LOG_BASE),
137          REQ(VHOST_USER_SET_LOG_FD),
138          REQ(VHOST_USER_SET_VRING_NUM),
139          REQ(VHOST_USER_SET_VRING_ADDR),
140          REQ(VHOST_USER_SET_VRING_BASE),
141          REQ(VHOST_USER_GET_VRING_BASE),
142          REQ(VHOST_USER_SET_VRING_KICK),
143          REQ(VHOST_USER_SET_VRING_CALL),
144          REQ(VHOST_USER_SET_VRING_ERR),
145          REQ(VHOST_USER_GET_PROTOCOL_FEATURES),
146          REQ(VHOST_USER_SET_PROTOCOL_FEATURES),
147          REQ(VHOST_USER_GET_QUEUE_NUM),
148          REQ(VHOST_USER_SET_VRING_ENABLE),
149          REQ(VHOST_USER_SEND_RARP),
150          REQ(VHOST_USER_NET_SET_MTU),
151          REQ(VHOST_USER_SET_BACKEND_REQ_FD),
152          REQ(VHOST_USER_IOTLB_MSG),
153          REQ(VHOST_USER_SET_VRING_ENDIAN),
154          REQ(VHOST_USER_GET_CONFIG),
155          REQ(VHOST_USER_SET_CONFIG),
156          REQ(VHOST_USER_POSTCOPY_ADVISE),
157          REQ(VHOST_USER_POSTCOPY_LISTEN),
158          REQ(VHOST_USER_POSTCOPY_END),
159          REQ(VHOST_USER_GET_INFLIGHT_FD),
160          REQ(VHOST_USER_SET_INFLIGHT_FD),
161          REQ(VHOST_USER_GPU_SET_SOCKET),
162          REQ(VHOST_USER_VRING_KICK),
163          REQ(VHOST_USER_GET_MAX_MEM_SLOTS),
164          REQ(VHOST_USER_ADD_MEM_REG),
165          REQ(VHOST_USER_REM_MEM_REG),
166          REQ(VHOST_USER_GET_SHARED_OBJECT),
167          REQ(VHOST_USER_MAX),
168      };
169  #undef REQ
170  
171      if (req < VHOST_USER_MAX) {
172          return vu_request_str[req];
173      } else {
174          return "unknown";
175      }
176  }
177  
178  static void G_GNUC_PRINTF(2, 3)
vu_panic(VuDev * dev,const char * msg,...)179  vu_panic(VuDev *dev, const char *msg, ...)
180  {
181      char *buf = NULL;
182      va_list ap;
183  
184      va_start(ap, msg);
185      if (vasprintf(&buf, msg, ap) < 0) {
186          buf = NULL;
187      }
188      va_end(ap);
189  
190      dev->broken = true;
191      dev->panic(dev, buf);
192      free(buf);
193  
194      /*
195       * FIXME:
196       * find a way to call virtio_error, or perhaps close the connection?
197       */
198  }
199  
200  /* Search for a memory region that covers this guest physical address. */
201  static VuDevRegion *
vu_gpa_to_mem_region(VuDev * dev,uint64_t guest_addr)202  vu_gpa_to_mem_region(VuDev *dev, uint64_t guest_addr)
203  {
204      int low = 0;
205      int high = dev->nregions - 1;
206  
207      /*
208       * Memory regions cannot overlap in guest physical address space. Each
209       * GPA belongs to exactly one memory region, so there can only be one
210       * match.
211       *
212       * We store our memory regions ordered by GPA and can simply perform a
213       * binary search.
214       */
215      while (low <= high) {
216          unsigned int mid = low + (high - low) / 2;
217          VuDevRegion *cur = &dev->regions[mid];
218  
219          if (guest_addr >= cur->gpa && guest_addr < cur->gpa + cur->size) {
220              return cur;
221          }
222          if (guest_addr >= cur->gpa + cur->size) {
223              low = mid + 1;
224          }
225          if (guest_addr < cur->gpa) {
226              high = mid - 1;
227          }
228      }
229      return NULL;
230  }
231  
232  /* Translate guest physical address to our virtual address.  */
233  void *
vu_gpa_to_va(VuDev * dev,uint64_t * plen,uint64_t guest_addr)234  vu_gpa_to_va(VuDev *dev, uint64_t *plen, uint64_t guest_addr)
235  {
236      VuDevRegion *r;
237  
238      if (*plen == 0) {
239          return NULL;
240      }
241  
242      r = vu_gpa_to_mem_region(dev, guest_addr);
243      if (!r) {
244          return NULL;
245      }
246  
247      if ((guest_addr + *plen) > (r->gpa + r->size)) {
248          *plen = r->gpa + r->size - guest_addr;
249      }
250      return (void *)(uintptr_t)guest_addr - r->gpa + r->mmap_addr +
251             r->mmap_offset;
252  }
253  
254  /* Translate qemu virtual address to our virtual address.  */
255  static void *
qva_to_va(VuDev * dev,uint64_t qemu_addr)256  qva_to_va(VuDev *dev, uint64_t qemu_addr)
257  {
258      unsigned int i;
259  
260      /* Find matching memory region.  */
261      for (i = 0; i < dev->nregions; i++) {
262          VuDevRegion *r = &dev->regions[i];
263  
264          if ((qemu_addr >= r->qva) && (qemu_addr < (r->qva + r->size))) {
265              return (void *)(uintptr_t)
266                  qemu_addr - r->qva + r->mmap_addr + r->mmap_offset;
267          }
268      }
269  
270      return NULL;
271  }
272  
273  static void
vu_remove_all_mem_regs(VuDev * dev)274  vu_remove_all_mem_regs(VuDev *dev)
275  {
276      unsigned int i;
277  
278      for (i = 0; i < dev->nregions; i++) {
279          VuDevRegion *r = &dev->regions[i];
280  
281          munmap((void *)(uintptr_t)r->mmap_addr, r->size + r->mmap_offset);
282      }
283      dev->nregions = 0;
284  }
285  
286  static bool
map_ring(VuDev * dev,VuVirtq * vq)287  map_ring(VuDev *dev, VuVirtq *vq)
288  {
289      vq->vring.desc = qva_to_va(dev, vq->vra.desc_user_addr);
290      vq->vring.used = qva_to_va(dev, vq->vra.used_user_addr);
291      vq->vring.avail = qva_to_va(dev, vq->vra.avail_user_addr);
292  
293      DPRINT("Setting virtq addresses:\n");
294      DPRINT("    vring_desc  at %p\n", vq->vring.desc);
295      DPRINT("    vring_used  at %p\n", vq->vring.used);
296      DPRINT("    vring_avail at %p\n", vq->vring.avail);
297  
298      return !(vq->vring.desc && vq->vring.used && vq->vring.avail);
299  }
300  
301  static bool
vu_is_vq_usable(VuDev * dev,VuVirtq * vq)302  vu_is_vq_usable(VuDev *dev, VuVirtq *vq)
303  {
304      if (unlikely(dev->broken)) {
305          return false;
306      }
307  
308      if (likely(vq->vring.avail)) {
309          return true;
310      }
311  
312      /*
313       * In corner cases, we might temporarily remove a memory region that
314       * mapped a ring. When removing a memory region we make sure to
315       * unmap any rings that would be impacted. Let's try to remap if we
316       * already succeeded mapping this ring once.
317       */
318      if (!vq->vra.desc_user_addr || !vq->vra.used_user_addr ||
319          !vq->vra.avail_user_addr) {
320          return false;
321      }
322      if (map_ring(dev, vq)) {
323          vu_panic(dev, "remapping queue on access");
324          return false;
325      }
326      return true;
327  }
328  
329  static void
unmap_rings(VuDev * dev,VuDevRegion * r)330  unmap_rings(VuDev *dev, VuDevRegion *r)
331  {
332      int i;
333  
334      for (i = 0; i < dev->max_queues; i++) {
335          VuVirtq *vq = &dev->vq[i];
336          const uintptr_t desc = (uintptr_t)vq->vring.desc;
337          const uintptr_t used = (uintptr_t)vq->vring.used;
338          const uintptr_t avail = (uintptr_t)vq->vring.avail;
339  
340          if (desc < r->mmap_addr || desc >= r->mmap_addr + r->size) {
341              continue;
342          }
343          if (used < r->mmap_addr || used >= r->mmap_addr + r->size) {
344              continue;
345          }
346          if (avail < r->mmap_addr || avail >= r->mmap_addr + r->size) {
347              continue;
348          }
349  
350          DPRINT("Unmapping rings of queue %d\n", i);
351          vq->vring.desc = NULL;
352          vq->vring.used = NULL;
353          vq->vring.avail = NULL;
354      }
355  }
356  
357  static size_t
get_fd_hugepagesize(int fd)358  get_fd_hugepagesize(int fd)
359  {
360  #if defined(__linux__)
361      struct statfs fs;
362      int ret;
363  
364      do {
365          ret = fstatfs(fd, &fs);
366      } while (ret != 0 && errno == EINTR);
367  
368      if (!ret && (unsigned int)fs.f_type == HUGETLBFS_MAGIC) {
369          return fs.f_bsize;
370      }
371  #endif
372      return 0;
373  }
374  
375  static void
_vu_add_mem_reg(VuDev * dev,VhostUserMemoryRegion * msg_region,int fd)376  _vu_add_mem_reg(VuDev *dev, VhostUserMemoryRegion *msg_region, int fd)
377  {
378      const uint64_t start_gpa = msg_region->guest_phys_addr;
379      const uint64_t end_gpa = start_gpa + msg_region->memory_size;
380      int prot = PROT_READ | PROT_WRITE;
381      uint64_t mmap_offset, fd_offset;
382      size_t hugepagesize;
383      VuDevRegion *r;
384      void *mmap_addr;
385      int low = 0;
386      int high = dev->nregions - 1;
387      unsigned int idx;
388  
389      DPRINT("Adding region %d\n", dev->nregions);
390      DPRINT("    guest_phys_addr: 0x%016"PRIx64"\n",
391             msg_region->guest_phys_addr);
392      DPRINT("    memory_size:     0x%016"PRIx64"\n",
393             msg_region->memory_size);
394      DPRINT("    userspace_addr:  0x%016"PRIx64"\n",
395             msg_region->userspace_addr);
396      DPRINT("    old mmap_offset: 0x%016"PRIx64"\n",
397             msg_region->mmap_offset);
398  
399      if (dev->postcopy_listening) {
400          /*
401           * In postcopy we're using PROT_NONE here to catch anyone
402           * accessing it before we userfault
403           */
404          prot = PROT_NONE;
405      }
406  
407      /*
408       * We will add memory regions into the array sorted by GPA. Perform a
409       * binary search to locate the insertion point: it will be at the low
410       * index.
411       */
412      while (low <= high) {
413          unsigned int mid = low + (high - low)  / 2;
414          VuDevRegion *cur = &dev->regions[mid];
415  
416          /* Overlap of GPA addresses. */
417          if (start_gpa < cur->gpa + cur->size && cur->gpa < end_gpa) {
418              vu_panic(dev, "regions with overlapping guest physical addresses");
419              return;
420          }
421          if (start_gpa >= cur->gpa + cur->size) {
422              low = mid + 1;
423          }
424          if (start_gpa < cur->gpa) {
425              high = mid - 1;
426          }
427      }
428      idx = low;
429  
430      /*
431       * Convert most of msg_region->mmap_offset to fd_offset. In almost all
432       * cases, this will leave us with mmap_offset == 0, mmap()'ing only
433       * what we really need. Only if a memory region would partially cover
434       * hugetlb pages, we'd get mmap_offset != 0, which usually doesn't happen
435       * anymore (i.e., modern QEMU).
436       *
437       * Note that mmap() with hugetlb would fail if the offset into the file
438       * is not aligned to the huge page size.
439       */
440      hugepagesize = get_fd_hugepagesize(fd);
441      if (hugepagesize) {
442          fd_offset = ALIGN_DOWN(msg_region->mmap_offset, hugepagesize);
443          mmap_offset = msg_region->mmap_offset - fd_offset;
444      } else {
445          fd_offset = msg_region->mmap_offset;
446          mmap_offset = 0;
447      }
448  
449      DPRINT("    fd_offset:       0x%016"PRIx64"\n",
450             fd_offset);
451      DPRINT("    new mmap_offset: 0x%016"PRIx64"\n",
452             mmap_offset);
453  
454      mmap_addr = mmap(0, msg_region->memory_size + mmap_offset,
455                       prot, MAP_SHARED | MAP_NORESERVE, fd, fd_offset);
456      if (mmap_addr == MAP_FAILED) {
457          vu_panic(dev, "region mmap error: %s", strerror(errno));
458          return;
459      }
460      DPRINT("    mmap_addr:       0x%016"PRIx64"\n",
461             (uint64_t)(uintptr_t)mmap_addr);
462  
463  #if defined(__linux__)
464      /* Don't include all guest memory in a coredump. */
465      madvise(mmap_addr, msg_region->memory_size + mmap_offset,
466              MADV_DONTDUMP);
467  #endif
468  
469      /* Shift all affected entries by 1 to open a hole at idx. */
470      r = &dev->regions[idx];
471      memmove(r + 1, r, sizeof(VuDevRegion) * (dev->nregions - idx));
472      r->gpa = msg_region->guest_phys_addr;
473      r->size = msg_region->memory_size;
474      r->qva = msg_region->userspace_addr;
475      r->mmap_addr = (uint64_t)(uintptr_t)mmap_addr;
476      r->mmap_offset = mmap_offset;
477      dev->nregions++;
478  
479      if (dev->postcopy_listening) {
480          /*
481           * Return the address to QEMU so that it can translate the ufd
482           * fault addresses back.
483           */
484          msg_region->userspace_addr = r->mmap_addr + r->mmap_offset;
485      }
486  }
487  
488  static void
vmsg_close_fds(VhostUserMsg * vmsg)489  vmsg_close_fds(VhostUserMsg *vmsg)
490  {
491      int i;
492  
493      for (i = 0; i < vmsg->fd_num; i++) {
494          close(vmsg->fds[i]);
495      }
496  }
497  
498  /* Set reply payload.u64 and clear request flags and fd_num */
vmsg_set_reply_u64(VhostUserMsg * vmsg,uint64_t val)499  static void vmsg_set_reply_u64(VhostUserMsg *vmsg, uint64_t val)
500  {
501      vmsg->flags = 0; /* defaults will be set by vu_send_reply() */
502      vmsg->size = sizeof(vmsg->payload.u64);
503      vmsg->payload.u64 = val;
504      vmsg->fd_num = 0;
505  }
506  
507  /* A test to see if we have userfault available */
508  static bool
have_userfault(void)509  have_userfault(void)
510  {
511  #if defined(__linux__) && defined(__NR_userfaultfd) &&\
512          defined(UFFD_FEATURE_MISSING_SHMEM) &&\
513          defined(UFFD_FEATURE_MISSING_HUGETLBFS)
514      /* Now test the kernel we're running on really has the features */
515      int ufd = syscall(__NR_userfaultfd, O_CLOEXEC | O_NONBLOCK);
516      struct uffdio_api api_struct;
517      if (ufd < 0) {
518          return false;
519      }
520  
521      api_struct.api = UFFD_API;
522      api_struct.features = UFFD_FEATURE_MISSING_SHMEM |
523                            UFFD_FEATURE_MISSING_HUGETLBFS;
524      if (ioctl(ufd, UFFDIO_API, &api_struct)) {
525          close(ufd);
526          return false;
527      }
528      close(ufd);
529      return true;
530  
531  #else
532      return false;
533  #endif
534  }
535  
536  static bool
vu_message_read_default(VuDev * dev,int conn_fd,VhostUserMsg * vmsg)537  vu_message_read_default(VuDev *dev, int conn_fd, VhostUserMsg *vmsg)
538  {
539      char control[CMSG_SPACE(VHOST_MEMORY_BASELINE_NREGIONS * sizeof(int))] = {};
540      struct iovec iov = {
541          .iov_base = (char *)vmsg,
542          .iov_len = VHOST_USER_HDR_SIZE,
543      };
544      struct msghdr msg = {
545          .msg_iov = &iov,
546          .msg_iovlen = 1,
547          .msg_control = control,
548          .msg_controllen = sizeof(control),
549      };
550      size_t fd_size;
551      struct cmsghdr *cmsg;
552      int rc;
553  
554      do {
555          rc = recvmsg(conn_fd, &msg, 0);
556      } while (rc < 0 && (errno == EINTR || errno == EAGAIN));
557  
558      if (rc < 0) {
559          vu_panic(dev, "Error while recvmsg: %s", strerror(errno));
560          return false;
561      }
562  
563      vmsg->fd_num = 0;
564      for (cmsg = CMSG_FIRSTHDR(&msg);
565           cmsg != NULL;
566           cmsg = CMSG_NXTHDR(&msg, cmsg))
567      {
568          if (cmsg->cmsg_level == SOL_SOCKET && cmsg->cmsg_type == SCM_RIGHTS) {
569              fd_size = cmsg->cmsg_len - CMSG_LEN(0);
570              vmsg->fd_num = fd_size / sizeof(int);
571              assert(vmsg->fd_num <= VHOST_MEMORY_BASELINE_NREGIONS);
572              memcpy(vmsg->fds, CMSG_DATA(cmsg), fd_size);
573              break;
574          }
575      }
576  
577      if (vmsg->size > sizeof(vmsg->payload)) {
578          vu_panic(dev,
579                   "Error: too big message request: %d, size: vmsg->size: %u, "
580                   "while sizeof(vmsg->payload) = %zu\n",
581                   vmsg->request, vmsg->size, sizeof(vmsg->payload));
582          goto fail;
583      }
584  
585      if (vmsg->size) {
586          do {
587              rc = read(conn_fd, &vmsg->payload, vmsg->size);
588          } while (rc < 0 && (errno == EINTR || errno == EAGAIN));
589  
590          if (rc <= 0) {
591              vu_panic(dev, "Error while reading: %s", strerror(errno));
592              goto fail;
593          }
594  
595          assert((uint32_t)rc == vmsg->size);
596      }
597  
598      return true;
599  
600  fail:
601      vmsg_close_fds(vmsg);
602  
603      return false;
604  }
605  
606  static bool
vu_message_write(VuDev * dev,int conn_fd,VhostUserMsg * vmsg)607  vu_message_write(VuDev *dev, int conn_fd, VhostUserMsg *vmsg)
608  {
609      int rc;
610      uint8_t *p = (uint8_t *)vmsg;
611      char control[CMSG_SPACE(VHOST_MEMORY_BASELINE_NREGIONS * sizeof(int))] = {};
612      struct iovec iov = {
613          .iov_base = (char *)vmsg,
614          .iov_len = VHOST_USER_HDR_SIZE,
615      };
616      struct msghdr msg = {
617          .msg_iov = &iov,
618          .msg_iovlen = 1,
619          .msg_control = control,
620      };
621      struct cmsghdr *cmsg;
622  
623      memset(control, 0, sizeof(control));
624      assert(vmsg->fd_num <= VHOST_MEMORY_BASELINE_NREGIONS);
625      if (vmsg->fd_num > 0) {
626          size_t fdsize = vmsg->fd_num * sizeof(int);
627          msg.msg_controllen = CMSG_SPACE(fdsize);
628          cmsg = CMSG_FIRSTHDR(&msg);
629          cmsg->cmsg_len = CMSG_LEN(fdsize);
630          cmsg->cmsg_level = SOL_SOCKET;
631          cmsg->cmsg_type = SCM_RIGHTS;
632          memcpy(CMSG_DATA(cmsg), vmsg->fds, fdsize);
633      } else {
634          msg.msg_controllen = 0;
635          msg.msg_control = NULL;
636      }
637  
638      do {
639          rc = sendmsg(conn_fd, &msg, 0);
640      } while (rc < 0 && (errno == EINTR || errno == EAGAIN));
641  
642      if (rc <= 0) {
643          vu_panic(dev, "Error while writing: %s", strerror(errno));
644          return false;
645      }
646  
647      if (vmsg->size) {
648          do {
649              if (vmsg->data) {
650                  rc = write(conn_fd, vmsg->data, vmsg->size);
651              } else {
652                  rc = write(conn_fd, p + VHOST_USER_HDR_SIZE, vmsg->size);
653              }
654          } while (rc < 0 && (errno == EINTR || errno == EAGAIN));
655      }
656  
657      if (rc <= 0) {
658          vu_panic(dev, "Error while writing: %s", strerror(errno));
659          return false;
660      }
661  
662      return true;
663  }
664  
665  static bool
vu_send_reply(VuDev * dev,int conn_fd,VhostUserMsg * vmsg)666  vu_send_reply(VuDev *dev, int conn_fd, VhostUserMsg *vmsg)
667  {
668      /* Set the version in the flags when sending the reply */
669      vmsg->flags &= ~VHOST_USER_VERSION_MASK;
670      vmsg->flags |= VHOST_USER_VERSION;
671      vmsg->flags |= VHOST_USER_REPLY_MASK;
672  
673      return vu_message_write(dev, conn_fd, vmsg);
674  }
675  
676  /*
677   * Processes a reply on the backend channel.
678   * Entered with backend_mutex held and releases it before exit.
679   * Returns true on success.
680   */
681  static bool
vu_process_message_reply(VuDev * dev,const VhostUserMsg * vmsg)682  vu_process_message_reply(VuDev *dev, const VhostUserMsg *vmsg)
683  {
684      VhostUserMsg msg_reply;
685      bool result = false;
686  
687      if ((vmsg->flags & VHOST_USER_NEED_REPLY_MASK) == 0) {
688          result = true;
689          goto out;
690      }
691  
692      if (!vu_message_read_default(dev, dev->backend_fd, &msg_reply)) {
693          goto out;
694      }
695  
696      if (msg_reply.request != vmsg->request) {
697          DPRINT("Received unexpected msg type. Expected %d received %d",
698                 vmsg->request, msg_reply.request);
699          goto out;
700      }
701  
702      result = msg_reply.payload.u64 == 0;
703  
704  out:
705      pthread_mutex_unlock(&dev->backend_mutex);
706      return result;
707  }
708  
709  /* Kick the log_call_fd if required. */
710  static void
vu_log_kick(VuDev * dev)711  vu_log_kick(VuDev *dev)
712  {
713      if (dev->log_call_fd != -1) {
714          DPRINT("Kicking the QEMU's log...\n");
715          if (eventfd_write(dev->log_call_fd, 1) < 0) {
716              vu_panic(dev, "Error writing eventfd: %s", strerror(errno));
717          }
718      }
719  }
720  
721  static void
vu_log_page(uint8_t * log_table,uint64_t page)722  vu_log_page(uint8_t *log_table, uint64_t page)
723  {
724      DPRINT("Logged dirty guest page: %"PRId64"\n", page);
725      qatomic_or(&log_table[page / 8], 1 << (page % 8));
726  }
727  
728  static void
vu_log_write(VuDev * dev,uint64_t address,uint64_t length)729  vu_log_write(VuDev *dev, uint64_t address, uint64_t length)
730  {
731      uint64_t page;
732  
733      if (!(dev->features & (1ULL << VHOST_F_LOG_ALL)) ||
734          !dev->log_table || !length) {
735          return;
736      }
737  
738      assert(dev->log_size > ((address + length - 1) / VHOST_LOG_PAGE / 8));
739  
740      page = address / VHOST_LOG_PAGE;
741      while (page * VHOST_LOG_PAGE < address + length) {
742          vu_log_page(dev->log_table, page);
743          page += 1;
744      }
745  
746      vu_log_kick(dev);
747  }
748  
749  static void
vu_kick_cb(VuDev * dev,int condition,void * data)750  vu_kick_cb(VuDev *dev, int condition, void *data)
751  {
752      int index = (intptr_t)data;
753      VuVirtq *vq = &dev->vq[index];
754      int sock = vq->kick_fd;
755      eventfd_t kick_data;
756      ssize_t rc;
757  
758      rc = eventfd_read(sock, &kick_data);
759      if (rc == -1) {
760          vu_panic(dev, "kick eventfd_read(): %s", strerror(errno));
761          dev->remove_watch(dev, dev->vq[index].kick_fd);
762      } else {
763          DPRINT("Got kick_data: %016"PRIx64" handler:%p idx:%d\n",
764                 kick_data, vq->handler, index);
765          if (vq->handler) {
766              vq->handler(dev, index);
767          }
768      }
769  }
770  
771  static bool
vu_get_features_exec(VuDev * dev,VhostUserMsg * vmsg)772  vu_get_features_exec(VuDev *dev, VhostUserMsg *vmsg)
773  {
774      vmsg->payload.u64 =
775          /*
776           * The following VIRTIO feature bits are supported by our virtqueue
777           * implementation:
778           */
779          1ULL << VIRTIO_F_NOTIFY_ON_EMPTY |
780          1ULL << VIRTIO_RING_F_INDIRECT_DESC |
781          1ULL << VIRTIO_RING_F_EVENT_IDX |
782          1ULL << VIRTIO_F_VERSION_1 |
783  
784          /* vhost-user feature bits */
785          1ULL << VHOST_F_LOG_ALL |
786          1ULL << VHOST_USER_F_PROTOCOL_FEATURES;
787  
788      if (dev->iface->get_features) {
789          vmsg->payload.u64 |= dev->iface->get_features(dev);
790      }
791  
792      vmsg->size = sizeof(vmsg->payload.u64);
793      vmsg->fd_num = 0;
794  
795      DPRINT("Sending back to guest u64: 0x%016"PRIx64"\n", vmsg->payload.u64);
796  
797      return true;
798  }
799  
800  static void
vu_set_enable_all_rings(VuDev * dev,bool enabled)801  vu_set_enable_all_rings(VuDev *dev, bool enabled)
802  {
803      uint16_t i;
804  
805      for (i = 0; i < dev->max_queues; i++) {
806          dev->vq[i].enable = enabled;
807      }
808  }
809  
810  static bool
vu_set_features_exec(VuDev * dev,VhostUserMsg * vmsg)811  vu_set_features_exec(VuDev *dev, VhostUserMsg *vmsg)
812  {
813      DPRINT("u64: 0x%016"PRIx64"\n", vmsg->payload.u64);
814  
815      dev->features = vmsg->payload.u64;
816      if (!vu_has_feature(dev, VIRTIO_F_VERSION_1)) {
817          /*
818           * We only support devices conforming to VIRTIO 1.0 or
819           * later
820           */
821          vu_panic(dev, "virtio legacy devices aren't supported by libvhost-user");
822          return false;
823      }
824  
825      if (!(dev->features & VHOST_USER_F_PROTOCOL_FEATURES)) {
826          vu_set_enable_all_rings(dev, true);
827      }
828  
829      if (dev->iface->set_features) {
830          dev->iface->set_features(dev, dev->features);
831      }
832  
833      return false;
834  }
835  
836  static bool
vu_set_owner_exec(VuDev * dev,VhostUserMsg * vmsg)837  vu_set_owner_exec(VuDev *dev, VhostUserMsg *vmsg)
838  {
839      return false;
840  }
841  
842  static void
vu_close_log(VuDev * dev)843  vu_close_log(VuDev *dev)
844  {
845      if (dev->log_table) {
846          if (munmap(dev->log_table, dev->log_size) != 0) {
847              perror("close log munmap() error");
848          }
849  
850          dev->log_table = NULL;
851      }
852      if (dev->log_call_fd != -1) {
853          close(dev->log_call_fd);
854          dev->log_call_fd = -1;
855      }
856  }
857  
858  static bool
vu_reset_device_exec(VuDev * dev,VhostUserMsg * vmsg)859  vu_reset_device_exec(VuDev *dev, VhostUserMsg *vmsg)
860  {
861      vu_set_enable_all_rings(dev, false);
862  
863      return false;
864  }
865  
866  static bool
generate_faults(VuDev * dev)867  generate_faults(VuDev *dev) {
868      unsigned int i;
869      for (i = 0; i < dev->nregions; i++) {
870  #ifdef UFFDIO_REGISTER
871          VuDevRegion *dev_region = &dev->regions[i];
872          int ret;
873          struct uffdio_register reg_struct;
874  
875          /*
876           * We should already have an open ufd. Mark each memory
877           * range as ufd.
878           * Discard any mapping we have here; note I can't use MADV_REMOVE
879           * or fallocate to make the hole since I don't want to lose
880           * data that's already arrived in the shared process.
881           * TODO: How to do hugepage
882           */
883          ret = madvise((void *)(uintptr_t)dev_region->mmap_addr,
884                        dev_region->size + dev_region->mmap_offset,
885                        MADV_DONTNEED);
886          if (ret) {
887              fprintf(stderr,
888                      "%s: Failed to madvise(DONTNEED) region %d: %s\n",
889                      __func__, i, strerror(errno));
890          }
891          /*
892           * Turn off transparent hugepages so we dont get lose wakeups
893           * in neighbouring pages.
894           * TODO: Turn this backon later.
895           */
896          ret = madvise((void *)(uintptr_t)dev_region->mmap_addr,
897                        dev_region->size + dev_region->mmap_offset,
898                        MADV_NOHUGEPAGE);
899          if (ret) {
900              /*
901               * Note: This can happen legally on kernels that are configured
902               * without madvise'able hugepages
903               */
904              fprintf(stderr,
905                      "%s: Failed to madvise(NOHUGEPAGE) region %d: %s\n",
906                      __func__, i, strerror(errno));
907          }
908  
909          reg_struct.range.start = (uintptr_t)dev_region->mmap_addr;
910          reg_struct.range.len = dev_region->size + dev_region->mmap_offset;
911          reg_struct.mode = UFFDIO_REGISTER_MODE_MISSING;
912  
913          if (ioctl(dev->postcopy_ufd, UFFDIO_REGISTER, &reg_struct)) {
914              vu_panic(dev, "%s: Failed to userfault region %d "
915                            "@%" PRIx64 " + size:%" PRIx64 " offset: %" PRIx64
916                            ": (ufd=%d)%s\n",
917                       __func__, i,
918                       dev_region->mmap_addr,
919                       dev_region->size, dev_region->mmap_offset,
920                       dev->postcopy_ufd, strerror(errno));
921              return false;
922          }
923          if (!(reg_struct.ioctls & (1ULL << _UFFDIO_COPY))) {
924              vu_panic(dev, "%s Region (%d) doesn't support COPY",
925                       __func__, i);
926              return false;
927          }
928          DPRINT("%s: region %d: Registered userfault for %"
929                 PRIx64 " + %" PRIx64 "\n", __func__, i,
930                 (uint64_t)reg_struct.range.start,
931                 (uint64_t)reg_struct.range.len);
932          /* Now it's registered we can let the client at it */
933          if (mprotect((void *)(uintptr_t)dev_region->mmap_addr,
934                       dev_region->size + dev_region->mmap_offset,
935                       PROT_READ | PROT_WRITE)) {
936              vu_panic(dev, "failed to mprotect region %d for postcopy (%s)",
937                       i, strerror(errno));
938              return false;
939          }
940          /* TODO: Stash 'zero' support flags somewhere */
941  #endif
942      }
943  
944      return true;
945  }
946  
947  static bool
vu_add_mem_reg(VuDev * dev,VhostUserMsg * vmsg)948  vu_add_mem_reg(VuDev *dev, VhostUserMsg *vmsg) {
949      VhostUserMemoryRegion m = vmsg->payload.memreg.region, *msg_region = &m;
950  
951      if (vmsg->fd_num != 1) {
952          vmsg_close_fds(vmsg);
953          vu_panic(dev, "VHOST_USER_ADD_MEM_REG received %d fds - only 1 fd "
954                        "should be sent for this message type", vmsg->fd_num);
955          return false;
956      }
957  
958      if (vmsg->size < VHOST_USER_MEM_REG_SIZE) {
959          close(vmsg->fds[0]);
960          vu_panic(dev, "VHOST_USER_ADD_MEM_REG requires a message size of at "
961                        "least %zu bytes and only %d bytes were received",
962                        VHOST_USER_MEM_REG_SIZE, vmsg->size);
963          return false;
964      }
965  
966      if (dev->nregions == VHOST_USER_MAX_RAM_SLOTS) {
967          close(vmsg->fds[0]);
968          vu_panic(dev, "failing attempt to hot add memory via "
969                        "VHOST_USER_ADD_MEM_REG message because the backend has "
970                        "no free ram slots available");
971          return false;
972      }
973  
974      /*
975       * If we are in postcopy mode and we receive a u64 payload with a 0 value
976       * we know all the postcopy client bases have been received, and we
977       * should start generating faults.
978       */
979      if (dev->postcopy_listening &&
980          vmsg->size == sizeof(vmsg->payload.u64) &&
981          vmsg->payload.u64 == 0) {
982          (void)generate_faults(dev);
983          return false;
984      }
985  
986      _vu_add_mem_reg(dev, msg_region, vmsg->fds[0]);
987      close(vmsg->fds[0]);
988  
989      if (dev->postcopy_listening) {
990          /* Send the message back to qemu with the addresses filled in. */
991          vmsg->fd_num = 0;
992          DPRINT("Successfully added new region in postcopy\n");
993          return true;
994      }
995      DPRINT("Successfully added new region\n");
996      return false;
997  }
998  
reg_equal(VuDevRegion * vudev_reg,VhostUserMemoryRegion * msg_reg)999  static inline bool reg_equal(VuDevRegion *vudev_reg,
1000                               VhostUserMemoryRegion *msg_reg)
1001  {
1002      if (vudev_reg->gpa == msg_reg->guest_phys_addr &&
1003          vudev_reg->qva == msg_reg->userspace_addr &&
1004          vudev_reg->size == msg_reg->memory_size) {
1005          return true;
1006      }
1007  
1008      return false;
1009  }
1010  
1011  static bool
vu_rem_mem_reg(VuDev * dev,VhostUserMsg * vmsg)1012  vu_rem_mem_reg(VuDev *dev, VhostUserMsg *vmsg) {
1013      VhostUserMemoryRegion m = vmsg->payload.memreg.region, *msg_region = &m;
1014      unsigned int idx;
1015      VuDevRegion *r;
1016  
1017      if (vmsg->fd_num > 1) {
1018          vmsg_close_fds(vmsg);
1019          vu_panic(dev, "VHOST_USER_REM_MEM_REG received %d fds - at most 1 fd "
1020                        "should be sent for this message type", vmsg->fd_num);
1021          return false;
1022      }
1023  
1024      if (vmsg->size < VHOST_USER_MEM_REG_SIZE) {
1025          vmsg_close_fds(vmsg);
1026          vu_panic(dev, "VHOST_USER_REM_MEM_REG requires a message size of at "
1027                        "least %zu bytes and only %d bytes were received",
1028                        VHOST_USER_MEM_REG_SIZE, vmsg->size);
1029          return false;
1030      }
1031  
1032      DPRINT("Removing region:\n");
1033      DPRINT("    guest_phys_addr: 0x%016"PRIx64"\n",
1034             msg_region->guest_phys_addr);
1035      DPRINT("    memory_size:     0x%016"PRIx64"\n",
1036             msg_region->memory_size);
1037      DPRINT("    userspace_addr   0x%016"PRIx64"\n",
1038             msg_region->userspace_addr);
1039      DPRINT("    mmap_offset      0x%016"PRIx64"\n",
1040             msg_region->mmap_offset);
1041  
1042      r = vu_gpa_to_mem_region(dev, msg_region->guest_phys_addr);
1043      if (!r || !reg_equal(r, msg_region)) {
1044          vmsg_close_fds(vmsg);
1045          vu_panic(dev, "Specified region not found\n");
1046          return false;
1047      }
1048  
1049      /*
1050       * There might be valid cases where we temporarily remove memory regions
1051       * to readd them again, or remove memory regions and don't use the rings
1052       * anymore before we set the ring addresses and restart the device.
1053       *
1054       * Unmap all affected rings, remapping them on demand later. This should
1055       * be a corner case.
1056       */
1057      unmap_rings(dev, r);
1058  
1059      munmap((void *)(uintptr_t)r->mmap_addr, r->size + r->mmap_offset);
1060  
1061      idx = r - dev->regions;
1062      assert(idx < dev->nregions);
1063      /* Shift all affected entries by 1 to close the hole. */
1064      memmove(r, r + 1, sizeof(VuDevRegion) * (dev->nregions - idx - 1));
1065      DPRINT("Successfully removed a region\n");
1066      dev->nregions--;
1067  
1068      vmsg_close_fds(vmsg);
1069  
1070      return false;
1071  }
1072  
1073  static bool
vu_get_shared_object(VuDev * dev,VhostUserMsg * vmsg)1074  vu_get_shared_object(VuDev *dev, VhostUserMsg *vmsg)
1075  {
1076      int fd_num = 0;
1077      int dmabuf_fd = -1;
1078      if (dev->iface->get_shared_object) {
1079          dmabuf_fd = dev->iface->get_shared_object(
1080              dev, &vmsg->payload.object.uuid[0]);
1081      }
1082      if (dmabuf_fd != -1) {
1083          DPRINT("dmabuf_fd found for requested UUID\n");
1084          vmsg->fds[fd_num++] = dmabuf_fd;
1085      }
1086      vmsg->fd_num = fd_num;
1087  
1088      return true;
1089  }
1090  
1091  static bool
vu_set_mem_table_exec(VuDev * dev,VhostUserMsg * vmsg)1092  vu_set_mem_table_exec(VuDev *dev, VhostUserMsg *vmsg)
1093  {
1094      VhostUserMemory m = vmsg->payload.memory, *memory = &m;
1095      unsigned int i;
1096  
1097      vu_remove_all_mem_regs(dev);
1098  
1099      DPRINT("Nregions: %u\n", memory->nregions);
1100      for (i = 0; i < memory->nregions; i++) {
1101          _vu_add_mem_reg(dev, &memory->regions[i], vmsg->fds[i]);
1102          close(vmsg->fds[i]);
1103      }
1104  
1105      if (dev->postcopy_listening) {
1106          /* Send the message back to qemu with the addresses filled in */
1107          vmsg->fd_num = 0;
1108          if (!vu_send_reply(dev, dev->sock, vmsg)) {
1109              vu_panic(dev, "failed to respond to set-mem-table for postcopy");
1110              return false;
1111          }
1112  
1113          /*
1114           * Wait for QEMU to confirm that it's registered the handler for the
1115           * faults.
1116           */
1117          if (!dev->read_msg(dev, dev->sock, vmsg) ||
1118              vmsg->size != sizeof(vmsg->payload.u64) ||
1119              vmsg->payload.u64 != 0) {
1120              vu_panic(dev, "failed to receive valid ack for postcopy set-mem-table");
1121              return false;
1122          }
1123  
1124          /* OK, now we can go and register the memory and generate faults */
1125          (void)generate_faults(dev);
1126          return false;
1127      }
1128  
1129      for (i = 0; i < dev->max_queues; i++) {
1130          if (dev->vq[i].vring.desc) {
1131              if (map_ring(dev, &dev->vq[i])) {
1132                  vu_panic(dev, "remapping queue %d during setmemtable", i);
1133              }
1134          }
1135      }
1136  
1137      return false;
1138  }
1139  
1140  static bool
vu_set_log_base_exec(VuDev * dev,VhostUserMsg * vmsg)1141  vu_set_log_base_exec(VuDev *dev, VhostUserMsg *vmsg)
1142  {
1143      int fd;
1144      uint64_t log_mmap_size, log_mmap_offset;
1145      void *rc;
1146  
1147      if (vmsg->fd_num != 1 ||
1148          vmsg->size != sizeof(vmsg->payload.log)) {
1149          vu_panic(dev, "Invalid log_base message");
1150          return true;
1151      }
1152  
1153      fd = vmsg->fds[0];
1154      log_mmap_offset = vmsg->payload.log.mmap_offset;
1155      log_mmap_size = vmsg->payload.log.mmap_size;
1156      DPRINT("Log mmap_offset: %"PRId64"\n", log_mmap_offset);
1157      DPRINT("Log mmap_size:   %"PRId64"\n", log_mmap_size);
1158  
1159      rc = mmap(0, log_mmap_size, PROT_READ | PROT_WRITE, MAP_SHARED, fd,
1160                log_mmap_offset);
1161      close(fd);
1162      if (rc == MAP_FAILED) {
1163          perror("log mmap error");
1164      }
1165  
1166      if (dev->log_table) {
1167          munmap(dev->log_table, dev->log_size);
1168      }
1169      dev->log_table = rc;
1170      dev->log_size = log_mmap_size;
1171  
1172      vmsg->size = sizeof(vmsg->payload.u64);
1173      vmsg->fd_num = 0;
1174  
1175      return true;
1176  }
1177  
1178  static bool
vu_set_log_fd_exec(VuDev * dev,VhostUserMsg * vmsg)1179  vu_set_log_fd_exec(VuDev *dev, VhostUserMsg *vmsg)
1180  {
1181      if (vmsg->fd_num != 1) {
1182          vu_panic(dev, "Invalid log_fd message");
1183          return false;
1184      }
1185  
1186      if (dev->log_call_fd != -1) {
1187          close(dev->log_call_fd);
1188      }
1189      dev->log_call_fd = vmsg->fds[0];
1190      DPRINT("Got log_call_fd: %d\n", vmsg->fds[0]);
1191  
1192      return false;
1193  }
1194  
1195  static bool
vu_set_vring_num_exec(VuDev * dev,VhostUserMsg * vmsg)1196  vu_set_vring_num_exec(VuDev *dev, VhostUserMsg *vmsg)
1197  {
1198      unsigned int index = vmsg->payload.state.index;
1199      unsigned int num = vmsg->payload.state.num;
1200  
1201      DPRINT("State.index: %u\n", index);
1202      DPRINT("State.num:   %u\n", num);
1203      dev->vq[index].vring.num = num;
1204  
1205      return false;
1206  }
1207  
1208  static bool
vu_set_vring_addr_exec(VuDev * dev,VhostUserMsg * vmsg)1209  vu_set_vring_addr_exec(VuDev *dev, VhostUserMsg *vmsg)
1210  {
1211      struct vhost_vring_addr addr = vmsg->payload.addr, *vra = &addr;
1212      unsigned int index = vra->index;
1213      VuVirtq *vq = &dev->vq[index];
1214  
1215      DPRINT("vhost_vring_addr:\n");
1216      DPRINT("    index:  %d\n", vra->index);
1217      DPRINT("    flags:  %d\n", vra->flags);
1218      DPRINT("    desc_user_addr:   0x%016" PRIx64 "\n", (uint64_t)vra->desc_user_addr);
1219      DPRINT("    used_user_addr:   0x%016" PRIx64 "\n", (uint64_t)vra->used_user_addr);
1220      DPRINT("    avail_user_addr:  0x%016" PRIx64 "\n", (uint64_t)vra->avail_user_addr);
1221      DPRINT("    log_guest_addr:   0x%016" PRIx64 "\n", (uint64_t)vra->log_guest_addr);
1222  
1223      vq->vra = *vra;
1224      vq->vring.flags = vra->flags;
1225      vq->vring.log_guest_addr = vra->log_guest_addr;
1226  
1227  
1228      if (map_ring(dev, vq)) {
1229          vu_panic(dev, "Invalid vring_addr message");
1230          return false;
1231      }
1232  
1233      vq->used_idx = le16toh(vq->vring.used->idx);
1234  
1235      if (vq->last_avail_idx != vq->used_idx) {
1236          bool resume = dev->iface->queue_is_processed_in_order &&
1237              dev->iface->queue_is_processed_in_order(dev, index);
1238  
1239          DPRINT("Last avail index != used index: %u != %u%s\n",
1240                 vq->last_avail_idx, vq->used_idx,
1241                 resume ? ", resuming" : "");
1242  
1243          if (resume) {
1244              vq->shadow_avail_idx = vq->last_avail_idx = vq->used_idx;
1245          }
1246      }
1247  
1248      return false;
1249  }
1250  
1251  static bool
vu_set_vring_base_exec(VuDev * dev,VhostUserMsg * vmsg)1252  vu_set_vring_base_exec(VuDev *dev, VhostUserMsg *vmsg)
1253  {
1254      unsigned int index = vmsg->payload.state.index;
1255      unsigned int num = vmsg->payload.state.num;
1256  
1257      DPRINT("State.index: %u\n", index);
1258      DPRINT("State.num:   %u\n", num);
1259      dev->vq[index].shadow_avail_idx = dev->vq[index].last_avail_idx = num;
1260  
1261      return false;
1262  }
1263  
1264  static bool
vu_get_vring_base_exec(VuDev * dev,VhostUserMsg * vmsg)1265  vu_get_vring_base_exec(VuDev *dev, VhostUserMsg *vmsg)
1266  {
1267      unsigned int index = vmsg->payload.state.index;
1268  
1269      DPRINT("State.index: %u\n", index);
1270      vmsg->payload.state.num = dev->vq[index].last_avail_idx;
1271      vmsg->size = sizeof(vmsg->payload.state);
1272  
1273      dev->vq[index].started = false;
1274      if (dev->iface->queue_set_started) {
1275          dev->iface->queue_set_started(dev, index, false);
1276      }
1277  
1278      if (dev->vq[index].call_fd != -1) {
1279          close(dev->vq[index].call_fd);
1280          dev->vq[index].call_fd = -1;
1281      }
1282      if (dev->vq[index].kick_fd != -1) {
1283          dev->remove_watch(dev, dev->vq[index].kick_fd);
1284          close(dev->vq[index].kick_fd);
1285          dev->vq[index].kick_fd = -1;
1286      }
1287  
1288      return true;
1289  }
1290  
1291  static bool
vu_check_queue_msg_file(VuDev * dev,VhostUserMsg * vmsg)1292  vu_check_queue_msg_file(VuDev *dev, VhostUserMsg *vmsg)
1293  {
1294      int index = vmsg->payload.u64 & VHOST_USER_VRING_IDX_MASK;
1295      bool nofd = vmsg->payload.u64 & VHOST_USER_VRING_NOFD_MASK;
1296  
1297      if (index >= dev->max_queues) {
1298          vmsg_close_fds(vmsg);
1299          vu_panic(dev, "Invalid queue index: %u", index);
1300          return false;
1301      }
1302  
1303      if (nofd) {
1304          vmsg_close_fds(vmsg);
1305          return true;
1306      }
1307  
1308      if (vmsg->fd_num != 1) {
1309          vmsg_close_fds(vmsg);
1310          vu_panic(dev, "Invalid fds in request: %d", vmsg->request);
1311          return false;
1312      }
1313  
1314      return true;
1315  }
1316  
1317  static int
inflight_desc_compare(const void * a,const void * b)1318  inflight_desc_compare(const void *a, const void *b)
1319  {
1320      VuVirtqInflightDesc *desc0 = (VuVirtqInflightDesc *)a,
1321                          *desc1 = (VuVirtqInflightDesc *)b;
1322  
1323      if (desc1->counter > desc0->counter &&
1324          (desc1->counter - desc0->counter) < VIRTQUEUE_MAX_SIZE * 2) {
1325          return 1;
1326      }
1327  
1328      return -1;
1329  }
1330  
1331  static int
vu_check_queue_inflights(VuDev * dev,VuVirtq * vq)1332  vu_check_queue_inflights(VuDev *dev, VuVirtq *vq)
1333  {
1334      int i = 0;
1335  
1336      if (!vu_has_protocol_feature(dev, VHOST_USER_PROTOCOL_F_INFLIGHT_SHMFD)) {
1337          return 0;
1338      }
1339  
1340      if (unlikely(!vq->inflight)) {
1341          return -1;
1342      }
1343  
1344      if (unlikely(!vq->inflight->version)) {
1345          /* initialize the buffer */
1346          vq->inflight->version = INFLIGHT_VERSION;
1347          return 0;
1348      }
1349  
1350      vq->used_idx = le16toh(vq->vring.used->idx);
1351      vq->resubmit_num = 0;
1352      vq->resubmit_list = NULL;
1353      vq->counter = 0;
1354  
1355      if (unlikely(vq->inflight->used_idx != vq->used_idx)) {
1356          vq->inflight->desc[vq->inflight->last_batch_head].inflight = 0;
1357  
1358          barrier();
1359  
1360          vq->inflight->used_idx = vq->used_idx;
1361      }
1362  
1363      for (i = 0; i < vq->inflight->desc_num; i++) {
1364          if (vq->inflight->desc[i].inflight == 1) {
1365              vq->inuse++;
1366          }
1367      }
1368  
1369      vq->shadow_avail_idx = vq->last_avail_idx = vq->inuse + vq->used_idx;
1370  
1371      if (vq->inuse) {
1372          vq->resubmit_list = calloc(vq->inuse, sizeof(VuVirtqInflightDesc));
1373          if (!vq->resubmit_list) {
1374              return -1;
1375          }
1376  
1377          for (i = 0; i < vq->inflight->desc_num; i++) {
1378              if (vq->inflight->desc[i].inflight) {
1379                  vq->resubmit_list[vq->resubmit_num].index = i;
1380                  vq->resubmit_list[vq->resubmit_num].counter =
1381                                          vq->inflight->desc[i].counter;
1382                  vq->resubmit_num++;
1383              }
1384          }
1385  
1386          if (vq->resubmit_num > 1) {
1387              qsort(vq->resubmit_list, vq->resubmit_num,
1388                    sizeof(VuVirtqInflightDesc), inflight_desc_compare);
1389          }
1390          vq->counter = vq->resubmit_list[0].counter + 1;
1391      }
1392  
1393      /* in case of I/O hang after reconnecting */
1394      if (eventfd_write(vq->kick_fd, 1)) {
1395          return -1;
1396      }
1397  
1398      return 0;
1399  }
1400  
1401  static bool
vu_set_vring_kick_exec(VuDev * dev,VhostUserMsg * vmsg)1402  vu_set_vring_kick_exec(VuDev *dev, VhostUserMsg *vmsg)
1403  {
1404      int index = vmsg->payload.u64 & VHOST_USER_VRING_IDX_MASK;
1405      bool nofd = vmsg->payload.u64 & VHOST_USER_VRING_NOFD_MASK;
1406  
1407      DPRINT("u64: 0x%016"PRIx64"\n", vmsg->payload.u64);
1408  
1409      if (!vu_check_queue_msg_file(dev, vmsg)) {
1410          return false;
1411      }
1412  
1413      if (dev->vq[index].kick_fd != -1) {
1414          dev->remove_watch(dev, dev->vq[index].kick_fd);
1415          close(dev->vq[index].kick_fd);
1416          dev->vq[index].kick_fd = -1;
1417      }
1418  
1419      dev->vq[index].kick_fd = nofd ? -1 : vmsg->fds[0];
1420      DPRINT("Got kick_fd: %d for vq: %d\n", dev->vq[index].kick_fd, index);
1421  
1422      dev->vq[index].started = true;
1423      if (dev->iface->queue_set_started) {
1424          dev->iface->queue_set_started(dev, index, true);
1425      }
1426  
1427      if (dev->vq[index].kick_fd != -1 && dev->vq[index].handler) {
1428          dev->set_watch(dev, dev->vq[index].kick_fd, VU_WATCH_IN,
1429                         vu_kick_cb, (void *)(long)index);
1430  
1431          DPRINT("Waiting for kicks on fd: %d for vq: %d\n",
1432                 dev->vq[index].kick_fd, index);
1433      }
1434  
1435      if (vu_check_queue_inflights(dev, &dev->vq[index])) {
1436          vu_panic(dev, "Failed to check inflights for vq: %d\n", index);
1437      }
1438  
1439      return false;
1440  }
1441  
vu_set_queue_handler(VuDev * dev,VuVirtq * vq,vu_queue_handler_cb handler)1442  void vu_set_queue_handler(VuDev *dev, VuVirtq *vq,
1443                            vu_queue_handler_cb handler)
1444  {
1445      int qidx = vq - dev->vq;
1446  
1447      vq->handler = handler;
1448      if (vq->kick_fd >= 0) {
1449          if (handler) {
1450              dev->set_watch(dev, vq->kick_fd, VU_WATCH_IN,
1451                             vu_kick_cb, (void *)(long)qidx);
1452          } else {
1453              dev->remove_watch(dev, vq->kick_fd);
1454          }
1455      }
1456  }
1457  
vu_set_queue_host_notifier(VuDev * dev,VuVirtq * vq,int fd,int size,int offset)1458  bool vu_set_queue_host_notifier(VuDev *dev, VuVirtq *vq, int fd,
1459                                  int size, int offset)
1460  {
1461      int qidx = vq - dev->vq;
1462      int fd_num = 0;
1463      VhostUserMsg vmsg = {
1464          .request = VHOST_USER_BACKEND_VRING_HOST_NOTIFIER_MSG,
1465          .flags = VHOST_USER_VERSION | VHOST_USER_NEED_REPLY_MASK,
1466          .size = sizeof(vmsg.payload.area),
1467          .payload.area = {
1468              .u64 = qidx & VHOST_USER_VRING_IDX_MASK,
1469              .size = size,
1470              .offset = offset,
1471          },
1472      };
1473  
1474      if (fd == -1) {
1475          vmsg.payload.area.u64 |= VHOST_USER_VRING_NOFD_MASK;
1476      } else {
1477          vmsg.fds[fd_num++] = fd;
1478      }
1479  
1480      vmsg.fd_num = fd_num;
1481  
1482      if (!vu_has_protocol_feature(dev, VHOST_USER_PROTOCOL_F_BACKEND_SEND_FD)) {
1483          return false;
1484      }
1485  
1486      pthread_mutex_lock(&dev->backend_mutex);
1487      if (!vu_message_write(dev, dev->backend_fd, &vmsg)) {
1488          pthread_mutex_unlock(&dev->backend_mutex);
1489          return false;
1490      }
1491  
1492      /* Also unlocks the backend_mutex */
1493      return vu_process_message_reply(dev, &vmsg);
1494  }
1495  
1496  bool
vu_lookup_shared_object(VuDev * dev,unsigned char uuid[UUID_LEN],int * dmabuf_fd)1497  vu_lookup_shared_object(VuDev *dev, unsigned char uuid[UUID_LEN],
1498                          int *dmabuf_fd)
1499  {
1500      bool result = false;
1501      VhostUserMsg msg_reply;
1502      VhostUserMsg msg = {
1503          .request = VHOST_USER_BACKEND_SHARED_OBJECT_LOOKUP,
1504          .size = sizeof(msg.payload.object),
1505          .flags = VHOST_USER_VERSION | VHOST_USER_NEED_REPLY_MASK,
1506      };
1507  
1508      memcpy(msg.payload.object.uuid, uuid, sizeof(uuid[0]) * UUID_LEN);
1509  
1510      if (!vu_has_protocol_feature(dev, VHOST_USER_PROTOCOL_F_SHARED_OBJECT)) {
1511          return false;
1512      }
1513  
1514      pthread_mutex_lock(&dev->backend_mutex);
1515      if (!vu_message_write(dev, dev->backend_fd, &msg)) {
1516          goto out;
1517      }
1518  
1519      if (!vu_message_read_default(dev, dev->backend_fd, &msg_reply)) {
1520          goto out;
1521      }
1522  
1523      if (msg_reply.request != msg.request) {
1524          DPRINT("Received unexpected msg type. Expected %d, received %d",
1525                 msg.request, msg_reply.request);
1526          goto out;
1527      }
1528  
1529      if (msg_reply.fd_num != 1) {
1530          DPRINT("Received unexpected number of fds. Expected 1, received %d",
1531                 msg_reply.fd_num);
1532          goto out;
1533      }
1534  
1535      *dmabuf_fd = msg_reply.fds[0];
1536      result = *dmabuf_fd > 0 && msg_reply.payload.u64 == 0;
1537  out:
1538      pthread_mutex_unlock(&dev->backend_mutex);
1539  
1540      return result;
1541  }
1542  
1543  static bool
vu_send_message(VuDev * dev,VhostUserMsg * vmsg)1544  vu_send_message(VuDev *dev, VhostUserMsg *vmsg)
1545  {
1546      bool result = false;
1547      pthread_mutex_lock(&dev->backend_mutex);
1548      if (!vu_message_write(dev, dev->backend_fd, vmsg)) {
1549          goto out;
1550      }
1551  
1552      result = true;
1553  out:
1554      pthread_mutex_unlock(&dev->backend_mutex);
1555  
1556      return result;
1557  }
1558  
1559  bool
vu_add_shared_object(VuDev * dev,unsigned char uuid[UUID_LEN])1560  vu_add_shared_object(VuDev *dev, unsigned char uuid[UUID_LEN])
1561  {
1562      VhostUserMsg msg = {
1563          .request = VHOST_USER_BACKEND_SHARED_OBJECT_ADD,
1564          .size = sizeof(msg.payload.object),
1565          .flags = VHOST_USER_VERSION,
1566      };
1567  
1568      memcpy(msg.payload.object.uuid, uuid, sizeof(uuid[0]) * UUID_LEN);
1569  
1570      if (!vu_has_protocol_feature(dev, VHOST_USER_PROTOCOL_F_SHARED_OBJECT)) {
1571          return false;
1572      }
1573  
1574      return vu_send_message(dev, &msg);
1575  }
1576  
1577  bool
vu_rm_shared_object(VuDev * dev,unsigned char uuid[UUID_LEN])1578  vu_rm_shared_object(VuDev *dev, unsigned char uuid[UUID_LEN])
1579  {
1580      VhostUserMsg msg = {
1581          .request = VHOST_USER_BACKEND_SHARED_OBJECT_REMOVE,
1582          .size = sizeof(msg.payload.object),
1583          .flags = VHOST_USER_VERSION,
1584      };
1585  
1586      memcpy(msg.payload.object.uuid, uuid, sizeof(uuid[0]) * UUID_LEN);
1587  
1588      if (!vu_has_protocol_feature(dev, VHOST_USER_PROTOCOL_F_SHARED_OBJECT)) {
1589          return false;
1590      }
1591  
1592      return vu_send_message(dev, &msg);
1593  }
1594  
1595  static bool
vu_set_vring_call_exec(VuDev * dev,VhostUserMsg * vmsg)1596  vu_set_vring_call_exec(VuDev *dev, VhostUserMsg *vmsg)
1597  {
1598      int index = vmsg->payload.u64 & VHOST_USER_VRING_IDX_MASK;
1599      bool nofd = vmsg->payload.u64 & VHOST_USER_VRING_NOFD_MASK;
1600  
1601      DPRINT("u64: 0x%016"PRIx64"\n", vmsg->payload.u64);
1602  
1603      if (!vu_check_queue_msg_file(dev, vmsg)) {
1604          return false;
1605      }
1606  
1607      if (dev->vq[index].call_fd != -1) {
1608          close(dev->vq[index].call_fd);
1609          dev->vq[index].call_fd = -1;
1610      }
1611  
1612      dev->vq[index].call_fd = nofd ? -1 : vmsg->fds[0];
1613  
1614      /* in case of I/O hang after reconnecting */
1615      if (dev->vq[index].call_fd != -1 && eventfd_write(vmsg->fds[0], 1)) {
1616          return -1;
1617      }
1618  
1619      DPRINT("Got call_fd: %d for vq: %d\n", dev->vq[index].call_fd, index);
1620  
1621      return false;
1622  }
1623  
1624  static bool
vu_set_vring_err_exec(VuDev * dev,VhostUserMsg * vmsg)1625  vu_set_vring_err_exec(VuDev *dev, VhostUserMsg *vmsg)
1626  {
1627      int index = vmsg->payload.u64 & VHOST_USER_VRING_IDX_MASK;
1628      bool nofd = vmsg->payload.u64 & VHOST_USER_VRING_NOFD_MASK;
1629  
1630      DPRINT("u64: 0x%016"PRIx64"\n", vmsg->payload.u64);
1631  
1632      if (!vu_check_queue_msg_file(dev, vmsg)) {
1633          return false;
1634      }
1635  
1636      if (dev->vq[index].err_fd != -1) {
1637          close(dev->vq[index].err_fd);
1638          dev->vq[index].err_fd = -1;
1639      }
1640  
1641      dev->vq[index].err_fd = nofd ? -1 : vmsg->fds[0];
1642  
1643      return false;
1644  }
1645  
1646  static bool
vu_get_protocol_features_exec(VuDev * dev,VhostUserMsg * vmsg)1647  vu_get_protocol_features_exec(VuDev *dev, VhostUserMsg *vmsg)
1648  {
1649      /*
1650       * Note that we support, but intentionally do not set,
1651       * VHOST_USER_PROTOCOL_F_INBAND_NOTIFICATIONS. This means that
1652       * a device implementation can return it in its callback
1653       * (get_protocol_features) if it wants to use this for
1654       * simulation, but it is otherwise not desirable (if even
1655       * implemented by the frontend.)
1656       */
1657      uint64_t features = 1ULL << VHOST_USER_PROTOCOL_F_MQ |
1658                          1ULL << VHOST_USER_PROTOCOL_F_LOG_SHMFD |
1659                          1ULL << VHOST_USER_PROTOCOL_F_BACKEND_REQ |
1660                          1ULL << VHOST_USER_PROTOCOL_F_HOST_NOTIFIER |
1661                          1ULL << VHOST_USER_PROTOCOL_F_BACKEND_SEND_FD |
1662                          1ULL << VHOST_USER_PROTOCOL_F_REPLY_ACK |
1663                          1ULL << VHOST_USER_PROTOCOL_F_CONFIGURE_MEM_SLOTS;
1664  
1665      if (have_userfault()) {
1666          features |= 1ULL << VHOST_USER_PROTOCOL_F_PAGEFAULT;
1667      }
1668  
1669      if (dev->iface->get_config && dev->iface->set_config) {
1670          features |= 1ULL << VHOST_USER_PROTOCOL_F_CONFIG;
1671      }
1672  
1673      if (dev->iface->get_protocol_features) {
1674          features |= dev->iface->get_protocol_features(dev);
1675      }
1676  
1677  #ifndef MFD_ALLOW_SEALING
1678      /*
1679       * If MFD_ALLOW_SEALING is not defined, we are not able to handle
1680       * VHOST_USER_GET_INFLIGHT_FD messages, since we can't create a memfd.
1681       * Those messages are used only if VHOST_USER_PROTOCOL_F_INFLIGHT_SHMFD
1682       * is negotiated. A device implementation can enable it, so let's mask
1683       * it to avoid a runtime panic.
1684       */
1685      features &= ~(1ULL << VHOST_USER_PROTOCOL_F_INFLIGHT_SHMFD);
1686  #endif
1687  
1688      vmsg_set_reply_u64(vmsg, features);
1689      return true;
1690  }
1691  
1692  static bool
vu_set_protocol_features_exec(VuDev * dev,VhostUserMsg * vmsg)1693  vu_set_protocol_features_exec(VuDev *dev, VhostUserMsg *vmsg)
1694  {
1695      uint64_t features = vmsg->payload.u64;
1696  
1697      DPRINT("u64: 0x%016"PRIx64"\n", features);
1698  
1699      dev->protocol_features = vmsg->payload.u64;
1700  
1701      if (vu_has_protocol_feature(dev,
1702                                  VHOST_USER_PROTOCOL_F_INBAND_NOTIFICATIONS) &&
1703          (!vu_has_protocol_feature(dev, VHOST_USER_PROTOCOL_F_BACKEND_REQ) ||
1704           !vu_has_protocol_feature(dev, VHOST_USER_PROTOCOL_F_REPLY_ACK))) {
1705          /*
1706           * The use case for using messages for kick/call is simulation, to make
1707           * the kick and call synchronous. To actually get that behaviour, both
1708           * of the other features are required.
1709           * Theoretically, one could use only kick messages, or do them without
1710           * having F_REPLY_ACK, but too many (possibly pending) messages on the
1711           * socket will eventually cause the frontend to hang, to avoid this in
1712           * scenarios where not desired enforce that the settings are in a way
1713           * that actually enables the simulation case.
1714           */
1715          vu_panic(dev,
1716                   "F_IN_BAND_NOTIFICATIONS requires F_BACKEND_REQ && F_REPLY_ACK");
1717          return false;
1718      }
1719  
1720      if (dev->iface->set_protocol_features) {
1721          dev->iface->set_protocol_features(dev, features);
1722      }
1723  
1724      return false;
1725  }
1726  
1727  static bool
vu_get_queue_num_exec(VuDev * dev,VhostUserMsg * vmsg)1728  vu_get_queue_num_exec(VuDev *dev, VhostUserMsg *vmsg)
1729  {
1730      vmsg_set_reply_u64(vmsg, dev->max_queues);
1731      return true;
1732  }
1733  
1734  static bool
vu_set_vring_enable_exec(VuDev * dev,VhostUserMsg * vmsg)1735  vu_set_vring_enable_exec(VuDev *dev, VhostUserMsg *vmsg)
1736  {
1737      unsigned int index = vmsg->payload.state.index;
1738      unsigned int enable = vmsg->payload.state.num;
1739  
1740      DPRINT("State.index: %u\n", index);
1741      DPRINT("State.enable:   %u\n", enable);
1742  
1743      if (index >= dev->max_queues) {
1744          vu_panic(dev, "Invalid vring_enable index: %u", index);
1745          return false;
1746      }
1747  
1748      dev->vq[index].enable = enable;
1749      return false;
1750  }
1751  
1752  static bool
vu_set_backend_req_fd(VuDev * dev,VhostUserMsg * vmsg)1753  vu_set_backend_req_fd(VuDev *dev, VhostUserMsg *vmsg)
1754  {
1755      if (vmsg->fd_num != 1) {
1756          vu_panic(dev, "Invalid backend_req_fd message (%d fd's)", vmsg->fd_num);
1757          return false;
1758      }
1759  
1760      if (dev->backend_fd != -1) {
1761          close(dev->backend_fd);
1762      }
1763      dev->backend_fd = vmsg->fds[0];
1764      DPRINT("Got backend_fd: %d\n", vmsg->fds[0]);
1765  
1766      return false;
1767  }
1768  
1769  static bool
vu_get_config(VuDev * dev,VhostUserMsg * vmsg)1770  vu_get_config(VuDev *dev, VhostUserMsg *vmsg)
1771  {
1772      int ret = -1;
1773  
1774      if (dev->iface->get_config) {
1775          ret = dev->iface->get_config(dev, vmsg->payload.config.region,
1776                                       vmsg->payload.config.size);
1777      }
1778  
1779      if (ret) {
1780          /* resize to zero to indicate an error to frontend */
1781          vmsg->size = 0;
1782      }
1783  
1784      return true;
1785  }
1786  
1787  static bool
vu_set_config(VuDev * dev,VhostUserMsg * vmsg)1788  vu_set_config(VuDev *dev, VhostUserMsg *vmsg)
1789  {
1790      int ret = -1;
1791  
1792      if (dev->iface->set_config) {
1793          ret = dev->iface->set_config(dev, vmsg->payload.config.region,
1794                                       vmsg->payload.config.offset,
1795                                       vmsg->payload.config.size,
1796                                       vmsg->payload.config.flags);
1797          if (ret) {
1798              vu_panic(dev, "Set virtio configuration space failed");
1799          }
1800      }
1801  
1802      return false;
1803  }
1804  
1805  static bool
vu_set_postcopy_advise(VuDev * dev,VhostUserMsg * vmsg)1806  vu_set_postcopy_advise(VuDev *dev, VhostUserMsg *vmsg)
1807  {
1808  #ifdef UFFDIO_API
1809      struct uffdio_api api_struct;
1810  
1811      dev->postcopy_ufd = syscall(__NR_userfaultfd, O_CLOEXEC | O_NONBLOCK);
1812      vmsg->size = 0;
1813  #else
1814      dev->postcopy_ufd = -1;
1815  #endif
1816  
1817      if (dev->postcopy_ufd == -1) {
1818          vu_panic(dev, "Userfaultfd not available: %s", strerror(errno));
1819          goto out;
1820      }
1821  
1822  #ifdef UFFDIO_API
1823      api_struct.api = UFFD_API;
1824      api_struct.features = 0;
1825      if (ioctl(dev->postcopy_ufd, UFFDIO_API, &api_struct)) {
1826          vu_panic(dev, "Failed UFFDIO_API: %s", strerror(errno));
1827          close(dev->postcopy_ufd);
1828          dev->postcopy_ufd = -1;
1829          goto out;
1830      }
1831      /* TODO: Stash feature flags somewhere */
1832  #endif
1833  
1834  out:
1835      /* Return a ufd to the QEMU */
1836      vmsg->fd_num = 1;
1837      vmsg->fds[0] = dev->postcopy_ufd;
1838      return true; /* = send a reply */
1839  }
1840  
1841  static bool
vu_set_postcopy_listen(VuDev * dev,VhostUserMsg * vmsg)1842  vu_set_postcopy_listen(VuDev *dev, VhostUserMsg *vmsg)
1843  {
1844      if (dev->nregions) {
1845          vu_panic(dev, "Regions already registered at postcopy-listen");
1846          vmsg_set_reply_u64(vmsg, -1);
1847          return true;
1848      }
1849      dev->postcopy_listening = true;
1850  
1851      vmsg_set_reply_u64(vmsg, 0);
1852      return true;
1853  }
1854  
1855  static bool
vu_set_postcopy_end(VuDev * dev,VhostUserMsg * vmsg)1856  vu_set_postcopy_end(VuDev *dev, VhostUserMsg *vmsg)
1857  {
1858      DPRINT("%s: Entry\n", __func__);
1859      dev->postcopy_listening = false;
1860      if (dev->postcopy_ufd > 0) {
1861          close(dev->postcopy_ufd);
1862          dev->postcopy_ufd = -1;
1863          DPRINT("%s: Done close\n", __func__);
1864      }
1865  
1866      vmsg_set_reply_u64(vmsg, 0);
1867      DPRINT("%s: exit\n", __func__);
1868      return true;
1869  }
1870  
1871  static inline uint64_t
vu_inflight_queue_size(uint16_t queue_size)1872  vu_inflight_queue_size(uint16_t queue_size)
1873  {
1874      return ALIGN_UP(sizeof(VuDescStateSplit) * queue_size +
1875             sizeof(uint16_t), INFLIGHT_ALIGNMENT);
1876  }
1877  
1878  #ifdef MFD_ALLOW_SEALING
1879  static void *
memfd_alloc(const char * name,size_t size,unsigned int flags,int * fd)1880  memfd_alloc(const char *name, size_t size, unsigned int flags, int *fd)
1881  {
1882      void *ptr;
1883      int ret;
1884  
1885      *fd = memfd_create(name, MFD_ALLOW_SEALING);
1886      if (*fd < 0) {
1887          return NULL;
1888      }
1889  
1890      ret = ftruncate(*fd, size);
1891      if (ret < 0) {
1892          close(*fd);
1893          return NULL;
1894      }
1895  
1896      ret = fcntl(*fd, F_ADD_SEALS, flags);
1897      if (ret < 0) {
1898          close(*fd);
1899          return NULL;
1900      }
1901  
1902      ptr = mmap(0, size, PROT_READ | PROT_WRITE, MAP_SHARED, *fd, 0);
1903      if (ptr == MAP_FAILED) {
1904          close(*fd);
1905          return NULL;
1906      }
1907  
1908      return ptr;
1909  }
1910  #endif
1911  
1912  static bool
vu_get_inflight_fd(VuDev * dev,VhostUserMsg * vmsg)1913  vu_get_inflight_fd(VuDev *dev, VhostUserMsg *vmsg)
1914  {
1915      int fd = -1;
1916      void *addr = NULL;
1917      uint64_t mmap_size;
1918      uint16_t num_queues, queue_size;
1919  
1920      if (vmsg->size != sizeof(vmsg->payload.inflight)) {
1921          vu_panic(dev, "Invalid get_inflight_fd message:%d", vmsg->size);
1922          vmsg->payload.inflight.mmap_size = 0;
1923          return true;
1924      }
1925  
1926      num_queues = vmsg->payload.inflight.num_queues;
1927      queue_size = vmsg->payload.inflight.queue_size;
1928  
1929      DPRINT("set_inflight_fd num_queues: %"PRId16"\n", num_queues);
1930      DPRINT("set_inflight_fd queue_size: %"PRId16"\n", queue_size);
1931  
1932      mmap_size = vu_inflight_queue_size(queue_size) * num_queues;
1933  
1934  #ifdef MFD_ALLOW_SEALING
1935      addr = memfd_alloc("vhost-inflight", mmap_size,
1936                         F_SEAL_GROW | F_SEAL_SHRINK | F_SEAL_SEAL,
1937                         &fd);
1938  #else
1939      vu_panic(dev, "Not implemented: memfd support is missing");
1940  #endif
1941  
1942      if (!addr) {
1943          vu_panic(dev, "Failed to alloc vhost inflight area");
1944          vmsg->payload.inflight.mmap_size = 0;
1945          return true;
1946      }
1947  
1948      memset(addr, 0, mmap_size);
1949  
1950      dev->inflight_info.addr = addr;
1951      dev->inflight_info.size = vmsg->payload.inflight.mmap_size = mmap_size;
1952      dev->inflight_info.fd = vmsg->fds[0] = fd;
1953      vmsg->fd_num = 1;
1954      vmsg->payload.inflight.mmap_offset = 0;
1955  
1956      DPRINT("send inflight mmap_size: %"PRId64"\n",
1957             vmsg->payload.inflight.mmap_size);
1958      DPRINT("send inflight mmap offset: %"PRId64"\n",
1959             vmsg->payload.inflight.mmap_offset);
1960  
1961      return true;
1962  }
1963  
1964  static bool
vu_set_inflight_fd(VuDev * dev,VhostUserMsg * vmsg)1965  vu_set_inflight_fd(VuDev *dev, VhostUserMsg *vmsg)
1966  {
1967      int fd, i;
1968      uint64_t mmap_size, mmap_offset;
1969      uint16_t num_queues, queue_size;
1970      void *rc;
1971  
1972      if (vmsg->fd_num != 1 ||
1973          vmsg->size != sizeof(vmsg->payload.inflight)) {
1974          vu_panic(dev, "Invalid set_inflight_fd message size:%d fds:%d",
1975                   vmsg->size, vmsg->fd_num);
1976          return false;
1977      }
1978  
1979      fd = vmsg->fds[0];
1980      mmap_size = vmsg->payload.inflight.mmap_size;
1981      mmap_offset = vmsg->payload.inflight.mmap_offset;
1982      num_queues = vmsg->payload.inflight.num_queues;
1983      queue_size = vmsg->payload.inflight.queue_size;
1984  
1985      DPRINT("set_inflight_fd mmap_size: %"PRId64"\n", mmap_size);
1986      DPRINT("set_inflight_fd mmap_offset: %"PRId64"\n", mmap_offset);
1987      DPRINT("set_inflight_fd num_queues: %"PRId16"\n", num_queues);
1988      DPRINT("set_inflight_fd queue_size: %"PRId16"\n", queue_size);
1989  
1990      rc = mmap(0, mmap_size, PROT_READ | PROT_WRITE, MAP_SHARED,
1991                fd, mmap_offset);
1992  
1993      if (rc == MAP_FAILED) {
1994          vu_panic(dev, "set_inflight_fd mmap error: %s", strerror(errno));
1995          return false;
1996      }
1997  
1998      if (dev->inflight_info.fd) {
1999          close(dev->inflight_info.fd);
2000      }
2001  
2002      if (dev->inflight_info.addr) {
2003          munmap(dev->inflight_info.addr, dev->inflight_info.size);
2004      }
2005  
2006      dev->inflight_info.fd = fd;
2007      dev->inflight_info.addr = rc;
2008      dev->inflight_info.size = mmap_size;
2009  
2010      for (i = 0; i < num_queues; i++) {
2011          dev->vq[i].inflight = (VuVirtqInflight *)rc;
2012          dev->vq[i].inflight->desc_num = queue_size;
2013          rc = (void *)((char *)rc + vu_inflight_queue_size(queue_size));
2014      }
2015  
2016      return false;
2017  }
2018  
2019  static bool
vu_handle_vring_kick(VuDev * dev,VhostUserMsg * vmsg)2020  vu_handle_vring_kick(VuDev *dev, VhostUserMsg *vmsg)
2021  {
2022      unsigned int index = vmsg->payload.state.index;
2023  
2024      if (index >= dev->max_queues) {
2025          vu_panic(dev, "Invalid queue index: %u", index);
2026          return false;
2027      }
2028  
2029      DPRINT("Got kick message: handler:%p idx:%u\n",
2030             dev->vq[index].handler, index);
2031  
2032      if (!dev->vq[index].started) {
2033          dev->vq[index].started = true;
2034  
2035          if (dev->iface->queue_set_started) {
2036              dev->iface->queue_set_started(dev, index, true);
2037          }
2038      }
2039  
2040      if (dev->vq[index].handler) {
2041          dev->vq[index].handler(dev, index);
2042      }
2043  
2044      return false;
2045  }
2046  
vu_handle_get_max_memslots(VuDev * dev,VhostUserMsg * vmsg)2047  static bool vu_handle_get_max_memslots(VuDev *dev, VhostUserMsg *vmsg)
2048  {
2049      vmsg_set_reply_u64(vmsg, VHOST_USER_MAX_RAM_SLOTS);
2050  
2051      DPRINT("u64: 0x%016"PRIx64"\n", (uint64_t) VHOST_USER_MAX_RAM_SLOTS);
2052  
2053      return true;
2054  }
2055  
2056  static bool
vu_process_message(VuDev * dev,VhostUserMsg * vmsg)2057  vu_process_message(VuDev *dev, VhostUserMsg *vmsg)
2058  {
2059      int do_reply = 0;
2060  
2061      /* Print out generic part of the request. */
2062      DPRINT("================ Vhost user message ================\n");
2063      DPRINT("Request: %s (%d)\n", vu_request_to_string(vmsg->request),
2064             vmsg->request);
2065      DPRINT("Flags:   0x%x\n", vmsg->flags);
2066      DPRINT("Size:    %u\n", vmsg->size);
2067  
2068      if (vmsg->fd_num) {
2069          int i;
2070          DPRINT("Fds:");
2071          for (i = 0; i < vmsg->fd_num; i++) {
2072              DPRINT(" %d", vmsg->fds[i]);
2073          }
2074          DPRINT("\n");
2075      }
2076  
2077      if (dev->iface->process_msg &&
2078          dev->iface->process_msg(dev, vmsg, &do_reply)) {
2079          return do_reply;
2080      }
2081  
2082      switch (vmsg->request) {
2083      case VHOST_USER_GET_FEATURES:
2084          return vu_get_features_exec(dev, vmsg);
2085      case VHOST_USER_SET_FEATURES:
2086          return vu_set_features_exec(dev, vmsg);
2087      case VHOST_USER_GET_PROTOCOL_FEATURES:
2088          return vu_get_protocol_features_exec(dev, vmsg);
2089      case VHOST_USER_SET_PROTOCOL_FEATURES:
2090          return vu_set_protocol_features_exec(dev, vmsg);
2091      case VHOST_USER_SET_OWNER:
2092          return vu_set_owner_exec(dev, vmsg);
2093      case VHOST_USER_RESET_OWNER:
2094          return vu_reset_device_exec(dev, vmsg);
2095      case VHOST_USER_SET_MEM_TABLE:
2096          return vu_set_mem_table_exec(dev, vmsg);
2097      case VHOST_USER_SET_LOG_BASE:
2098          return vu_set_log_base_exec(dev, vmsg);
2099      case VHOST_USER_SET_LOG_FD:
2100          return vu_set_log_fd_exec(dev, vmsg);
2101      case VHOST_USER_SET_VRING_NUM:
2102          return vu_set_vring_num_exec(dev, vmsg);
2103      case VHOST_USER_SET_VRING_ADDR:
2104          return vu_set_vring_addr_exec(dev, vmsg);
2105      case VHOST_USER_SET_VRING_BASE:
2106          return vu_set_vring_base_exec(dev, vmsg);
2107      case VHOST_USER_GET_VRING_BASE:
2108          return vu_get_vring_base_exec(dev, vmsg);
2109      case VHOST_USER_SET_VRING_KICK:
2110          return vu_set_vring_kick_exec(dev, vmsg);
2111      case VHOST_USER_SET_VRING_CALL:
2112          return vu_set_vring_call_exec(dev, vmsg);
2113      case VHOST_USER_SET_VRING_ERR:
2114          return vu_set_vring_err_exec(dev, vmsg);
2115      case VHOST_USER_GET_QUEUE_NUM:
2116          return vu_get_queue_num_exec(dev, vmsg);
2117      case VHOST_USER_SET_VRING_ENABLE:
2118          return vu_set_vring_enable_exec(dev, vmsg);
2119      case VHOST_USER_SET_BACKEND_REQ_FD:
2120          return vu_set_backend_req_fd(dev, vmsg);
2121      case VHOST_USER_GET_CONFIG:
2122          return vu_get_config(dev, vmsg);
2123      case VHOST_USER_SET_CONFIG:
2124          return vu_set_config(dev, vmsg);
2125      case VHOST_USER_NONE:
2126          /* if you need processing before exit, override iface->process_msg */
2127          exit(0);
2128      case VHOST_USER_POSTCOPY_ADVISE:
2129          return vu_set_postcopy_advise(dev, vmsg);
2130      case VHOST_USER_POSTCOPY_LISTEN:
2131          return vu_set_postcopy_listen(dev, vmsg);
2132      case VHOST_USER_POSTCOPY_END:
2133          return vu_set_postcopy_end(dev, vmsg);
2134      case VHOST_USER_GET_INFLIGHT_FD:
2135          return vu_get_inflight_fd(dev, vmsg);
2136      case VHOST_USER_SET_INFLIGHT_FD:
2137          return vu_set_inflight_fd(dev, vmsg);
2138      case VHOST_USER_VRING_KICK:
2139          return vu_handle_vring_kick(dev, vmsg);
2140      case VHOST_USER_GET_MAX_MEM_SLOTS:
2141          return vu_handle_get_max_memslots(dev, vmsg);
2142      case VHOST_USER_ADD_MEM_REG:
2143          return vu_add_mem_reg(dev, vmsg);
2144      case VHOST_USER_REM_MEM_REG:
2145          return vu_rem_mem_reg(dev, vmsg);
2146      case VHOST_USER_GET_SHARED_OBJECT:
2147          return vu_get_shared_object(dev, vmsg);
2148      default:
2149          vmsg_close_fds(vmsg);
2150          vu_panic(dev, "Unhandled request: %d", vmsg->request);
2151      }
2152  
2153      return false;
2154  }
2155  
2156  bool
vu_dispatch(VuDev * dev)2157  vu_dispatch(VuDev *dev)
2158  {
2159      VhostUserMsg vmsg = { 0, };
2160      int reply_requested;
2161      bool need_reply, success = false;
2162  
2163      if (!dev->read_msg(dev, dev->sock, &vmsg)) {
2164          goto end;
2165      }
2166  
2167      need_reply = vmsg.flags & VHOST_USER_NEED_REPLY_MASK;
2168  
2169      reply_requested = vu_process_message(dev, &vmsg);
2170      if (!reply_requested && need_reply) {
2171          vmsg_set_reply_u64(&vmsg, 0);
2172          reply_requested = 1;
2173      }
2174  
2175      if (!reply_requested) {
2176          success = true;
2177          goto end;
2178      }
2179  
2180      if (!vu_send_reply(dev, dev->sock, &vmsg)) {
2181          goto end;
2182      }
2183  
2184      success = true;
2185  
2186  end:
2187      free(vmsg.data);
2188      return success;
2189  }
2190  
2191  void
vu_deinit(VuDev * dev)2192  vu_deinit(VuDev *dev)
2193  {
2194      unsigned int i;
2195  
2196      vu_remove_all_mem_regs(dev);
2197  
2198      for (i = 0; i < dev->max_queues; i++) {
2199          VuVirtq *vq = &dev->vq[i];
2200  
2201          if (vq->call_fd != -1) {
2202              close(vq->call_fd);
2203              vq->call_fd = -1;
2204          }
2205  
2206          if (vq->kick_fd != -1) {
2207              dev->remove_watch(dev, vq->kick_fd);
2208              close(vq->kick_fd);
2209              vq->kick_fd = -1;
2210          }
2211  
2212          if (vq->err_fd != -1) {
2213              close(vq->err_fd);
2214              vq->err_fd = -1;
2215          }
2216  
2217          if (vq->resubmit_list) {
2218              free(vq->resubmit_list);
2219              vq->resubmit_list = NULL;
2220          }
2221  
2222          vq->inflight = NULL;
2223      }
2224  
2225      if (dev->inflight_info.addr) {
2226          munmap(dev->inflight_info.addr, dev->inflight_info.size);
2227          dev->inflight_info.addr = NULL;
2228      }
2229  
2230      if (dev->inflight_info.fd > 0) {
2231          close(dev->inflight_info.fd);
2232          dev->inflight_info.fd = -1;
2233      }
2234  
2235      vu_close_log(dev);
2236      if (dev->backend_fd != -1) {
2237          close(dev->backend_fd);
2238          dev->backend_fd = -1;
2239      }
2240      pthread_mutex_destroy(&dev->backend_mutex);
2241  
2242      if (dev->sock != -1) {
2243          close(dev->sock);
2244      }
2245  
2246      free(dev->vq);
2247      dev->vq = NULL;
2248      free(dev->regions);
2249      dev->regions = NULL;
2250  }
2251  
2252  bool
vu_init(VuDev * dev,uint16_t max_queues,int socket,vu_panic_cb panic,vu_read_msg_cb read_msg,vu_set_watch_cb set_watch,vu_remove_watch_cb remove_watch,const VuDevIface * iface)2253  vu_init(VuDev *dev,
2254          uint16_t max_queues,
2255          int socket,
2256          vu_panic_cb panic,
2257          vu_read_msg_cb read_msg,
2258          vu_set_watch_cb set_watch,
2259          vu_remove_watch_cb remove_watch,
2260          const VuDevIface *iface)
2261  {
2262      uint16_t i;
2263  
2264      assert(max_queues > 0);
2265      assert(socket >= 0);
2266      assert(set_watch);
2267      assert(remove_watch);
2268      assert(iface);
2269      assert(panic);
2270  
2271      memset(dev, 0, sizeof(*dev));
2272  
2273      dev->sock = socket;
2274      dev->panic = panic;
2275      dev->read_msg = read_msg ? read_msg : vu_message_read_default;
2276      dev->set_watch = set_watch;
2277      dev->remove_watch = remove_watch;
2278      dev->iface = iface;
2279      dev->log_call_fd = -1;
2280      pthread_mutex_init(&dev->backend_mutex, NULL);
2281      dev->backend_fd = -1;
2282      dev->max_queues = max_queues;
2283  
2284      dev->regions = malloc(VHOST_USER_MAX_RAM_SLOTS * sizeof(dev->regions[0]));
2285      if (!dev->regions) {
2286          DPRINT("%s: failed to malloc mem regions\n", __func__);
2287          return false;
2288      }
2289  
2290      dev->vq = malloc(max_queues * sizeof(dev->vq[0]));
2291      if (!dev->vq) {
2292          DPRINT("%s: failed to malloc virtqueues\n", __func__);
2293          free(dev->regions);
2294          dev->regions = NULL;
2295          return false;
2296      }
2297  
2298      for (i = 0; i < max_queues; i++) {
2299          dev->vq[i] = (VuVirtq) {
2300              .call_fd = -1, .kick_fd = -1, .err_fd = -1,
2301              .notification = true,
2302          };
2303      }
2304  
2305      return true;
2306  }
2307  
2308  VuVirtq *
vu_get_queue(VuDev * dev,int qidx)2309  vu_get_queue(VuDev *dev, int qidx)
2310  {
2311      assert(qidx < dev->max_queues);
2312      return &dev->vq[qidx];
2313  }
2314  
2315  bool
vu_queue_enabled(VuDev * dev,VuVirtq * vq)2316  vu_queue_enabled(VuDev *dev, VuVirtq *vq)
2317  {
2318      return vq->enable;
2319  }
2320  
2321  bool
vu_queue_started(const VuDev * dev,const VuVirtq * vq)2322  vu_queue_started(const VuDev *dev, const VuVirtq *vq)
2323  {
2324      return vq->started;
2325  }
2326  
2327  static inline uint16_t
vring_avail_flags(VuVirtq * vq)2328  vring_avail_flags(VuVirtq *vq)
2329  {
2330      return le16toh(vq->vring.avail->flags);
2331  }
2332  
2333  static inline uint16_t
vring_avail_idx(VuVirtq * vq)2334  vring_avail_idx(VuVirtq *vq)
2335  {
2336      vq->shadow_avail_idx = le16toh(vq->vring.avail->idx);
2337  
2338      return vq->shadow_avail_idx;
2339  }
2340  
2341  static inline uint16_t
vring_avail_ring(VuVirtq * vq,int i)2342  vring_avail_ring(VuVirtq *vq, int i)
2343  {
2344      return le16toh(vq->vring.avail->ring[i]);
2345  }
2346  
2347  static inline uint16_t
vring_get_used_event(VuVirtq * vq)2348  vring_get_used_event(VuVirtq *vq)
2349  {
2350      return vring_avail_ring(vq, vq->vring.num);
2351  }
2352  
2353  static int
virtqueue_num_heads(VuDev * dev,VuVirtq * vq,unsigned int idx)2354  virtqueue_num_heads(VuDev *dev, VuVirtq *vq, unsigned int idx)
2355  {
2356      uint16_t num_heads = vring_avail_idx(vq) - idx;
2357  
2358      /* Check it isn't doing very strange things with descriptor numbers. */
2359      if (num_heads > vq->vring.num) {
2360          vu_panic(dev, "Guest moved used index from %u to %u",
2361                   idx, vq->shadow_avail_idx);
2362          return -1;
2363      }
2364      if (num_heads) {
2365          /* On success, callers read a descriptor at vq->last_avail_idx.
2366           * Make sure descriptor read does not bypass avail index read. */
2367          smp_rmb();
2368      }
2369  
2370      return num_heads;
2371  }
2372  
2373  static bool
virtqueue_get_head(VuDev * dev,VuVirtq * vq,unsigned int idx,unsigned int * head)2374  virtqueue_get_head(VuDev *dev, VuVirtq *vq,
2375                     unsigned int idx, unsigned int *head)
2376  {
2377      /* Grab the next descriptor number they're advertising, and increment
2378       * the index we've seen. */
2379      *head = vring_avail_ring(vq, idx % vq->vring.num);
2380  
2381      /* If their number is silly, that's a fatal mistake. */
2382      if (*head >= vq->vring.num) {
2383          vu_panic(dev, "Guest says index %u is available", *head);
2384          return false;
2385      }
2386  
2387      return true;
2388  }
2389  
2390  static int
virtqueue_read_indirect_desc(VuDev * dev,struct vring_desc * desc,uint64_t addr,size_t len)2391  virtqueue_read_indirect_desc(VuDev *dev, struct vring_desc *desc,
2392                               uint64_t addr, size_t len)
2393  {
2394      struct vring_desc *ori_desc;
2395      uint64_t read_len;
2396  
2397      if (len > (VIRTQUEUE_MAX_SIZE * sizeof(struct vring_desc))) {
2398          return -1;
2399      }
2400  
2401      if (len == 0) {
2402          return -1;
2403      }
2404  
2405      while (len) {
2406          read_len = len;
2407          ori_desc = vu_gpa_to_va(dev, &read_len, addr);
2408          if (!ori_desc) {
2409              return -1;
2410          }
2411  
2412          memcpy(desc, ori_desc, read_len);
2413          len -= read_len;
2414          addr += read_len;
2415          desc += read_len;
2416      }
2417  
2418      return 0;
2419  }
2420  
2421  enum {
2422      VIRTQUEUE_READ_DESC_ERROR = -1,
2423      VIRTQUEUE_READ_DESC_DONE = 0,   /* end of chain */
2424      VIRTQUEUE_READ_DESC_MORE = 1,   /* more buffers in chain */
2425  };
2426  
2427  static int
virtqueue_read_next_desc(VuDev * dev,struct vring_desc * desc,int i,unsigned int max,unsigned int * next)2428  virtqueue_read_next_desc(VuDev *dev, struct vring_desc *desc,
2429                           int i, unsigned int max, unsigned int *next)
2430  {
2431      /* If this descriptor says it doesn't chain, we're done. */
2432      if (!(le16toh(desc[i].flags) & VRING_DESC_F_NEXT)) {
2433          return VIRTQUEUE_READ_DESC_DONE;
2434      }
2435  
2436      /* Check they're not leading us off end of descriptors. */
2437      *next = le16toh(desc[i].next);
2438      /* Make sure compiler knows to grab that: we don't want it changing! */
2439      smp_wmb();
2440  
2441      if (*next >= max) {
2442          vu_panic(dev, "Desc next is %u", *next);
2443          return VIRTQUEUE_READ_DESC_ERROR;
2444      }
2445  
2446      return VIRTQUEUE_READ_DESC_MORE;
2447  }
2448  
2449  void
vu_queue_get_avail_bytes(VuDev * dev,VuVirtq * vq,unsigned int * in_bytes,unsigned int * out_bytes,unsigned max_in_bytes,unsigned max_out_bytes)2450  vu_queue_get_avail_bytes(VuDev *dev, VuVirtq *vq, unsigned int *in_bytes,
2451                           unsigned int *out_bytes,
2452                           unsigned max_in_bytes, unsigned max_out_bytes)
2453  {
2454      unsigned int idx;
2455      unsigned int total_bufs, in_total, out_total;
2456      int rc;
2457  
2458      idx = vq->last_avail_idx;
2459  
2460      total_bufs = in_total = out_total = 0;
2461      if (!vu_is_vq_usable(dev, vq)) {
2462          goto done;
2463      }
2464  
2465      while ((rc = virtqueue_num_heads(dev, vq, idx)) > 0) {
2466          unsigned int max, desc_len, num_bufs, indirect = 0;
2467          uint64_t desc_addr, read_len;
2468          struct vring_desc *desc;
2469          struct vring_desc desc_buf[VIRTQUEUE_MAX_SIZE];
2470          unsigned int i;
2471  
2472          max = vq->vring.num;
2473          num_bufs = total_bufs;
2474          if (!virtqueue_get_head(dev, vq, idx++, &i)) {
2475              goto err;
2476          }
2477          desc = vq->vring.desc;
2478  
2479          if (le16toh(desc[i].flags) & VRING_DESC_F_INDIRECT) {
2480              if (le32toh(desc[i].len) % sizeof(struct vring_desc)) {
2481                  vu_panic(dev, "Invalid size for indirect buffer table");
2482                  goto err;
2483              }
2484  
2485              /* If we've got too many, that implies a descriptor loop. */
2486              if (num_bufs >= max) {
2487                  vu_panic(dev, "Looped descriptor");
2488                  goto err;
2489              }
2490  
2491              /* loop over the indirect descriptor table */
2492              indirect = 1;
2493              desc_addr = le64toh(desc[i].addr);
2494              desc_len = le32toh(desc[i].len);
2495              max = desc_len / sizeof(struct vring_desc);
2496              read_len = desc_len;
2497              desc = vu_gpa_to_va(dev, &read_len, desc_addr);
2498              if (unlikely(desc && read_len != desc_len)) {
2499                  /* Failed to use zero copy */
2500                  desc = NULL;
2501                  if (!virtqueue_read_indirect_desc(dev, desc_buf,
2502                                                    desc_addr,
2503                                                    desc_len)) {
2504                      desc = desc_buf;
2505                  }
2506              }
2507              if (!desc) {
2508                  vu_panic(dev, "Invalid indirect buffer table");
2509                  goto err;
2510              }
2511              num_bufs = i = 0;
2512          }
2513  
2514          do {
2515              /* If we've got too many, that implies a descriptor loop. */
2516              if (++num_bufs > max) {
2517                  vu_panic(dev, "Looped descriptor");
2518                  goto err;
2519              }
2520  
2521              if (le16toh(desc[i].flags) & VRING_DESC_F_WRITE) {
2522                  in_total += le32toh(desc[i].len);
2523              } else {
2524                  out_total += le32toh(desc[i].len);
2525              }
2526              if (in_total >= max_in_bytes && out_total >= max_out_bytes) {
2527                  goto done;
2528              }
2529              rc = virtqueue_read_next_desc(dev, desc, i, max, &i);
2530          } while (rc == VIRTQUEUE_READ_DESC_MORE);
2531  
2532          if (rc == VIRTQUEUE_READ_DESC_ERROR) {
2533              goto err;
2534          }
2535  
2536          if (!indirect) {
2537              total_bufs = num_bufs;
2538          } else {
2539              total_bufs++;
2540          }
2541      }
2542      if (rc < 0) {
2543          goto err;
2544      }
2545  done:
2546      if (in_bytes) {
2547          *in_bytes = in_total;
2548      }
2549      if (out_bytes) {
2550          *out_bytes = out_total;
2551      }
2552      return;
2553  
2554  err:
2555      in_total = out_total = 0;
2556      goto done;
2557  }
2558  
2559  bool
vu_queue_avail_bytes(VuDev * dev,VuVirtq * vq,unsigned int in_bytes,unsigned int out_bytes)2560  vu_queue_avail_bytes(VuDev *dev, VuVirtq *vq, unsigned int in_bytes,
2561                       unsigned int out_bytes)
2562  {
2563      unsigned int in_total, out_total;
2564  
2565      vu_queue_get_avail_bytes(dev, vq, &in_total, &out_total,
2566                               in_bytes, out_bytes);
2567  
2568      return in_bytes <= in_total && out_bytes <= out_total;
2569  }
2570  
2571  /* Fetch avail_idx from VQ memory only when we really need to know if
2572   * guest has added some buffers. */
2573  bool
vu_queue_empty(VuDev * dev,VuVirtq * vq)2574  vu_queue_empty(VuDev *dev, VuVirtq *vq)
2575  {
2576      if (!vu_is_vq_usable(dev, vq)) {
2577          return true;
2578      }
2579  
2580      if (vq->shadow_avail_idx != vq->last_avail_idx) {
2581          return false;
2582      }
2583  
2584      return vring_avail_idx(vq) == vq->last_avail_idx;
2585  }
2586  
2587  static bool
vring_notify(VuDev * dev,VuVirtq * vq)2588  vring_notify(VuDev *dev, VuVirtq *vq)
2589  {
2590      uint16_t old, new;
2591      bool v;
2592  
2593      /* We need to expose used array entries before checking used event. */
2594      smp_mb();
2595  
2596      /* Always notify when queue is empty (when feature acknowledge) */
2597      if (vu_has_feature(dev, VIRTIO_F_NOTIFY_ON_EMPTY) &&
2598          !vq->inuse && vu_queue_empty(dev, vq)) {
2599          return true;
2600      }
2601  
2602      if (!vu_has_feature(dev, VIRTIO_RING_F_EVENT_IDX)) {
2603          return !(vring_avail_flags(vq) & VRING_AVAIL_F_NO_INTERRUPT);
2604      }
2605  
2606      v = vq->signalled_used_valid;
2607      vq->signalled_used_valid = true;
2608      old = vq->signalled_used;
2609      new = vq->signalled_used = vq->used_idx;
2610      return !v || vring_need_event(vring_get_used_event(vq), new, old);
2611  }
2612  
_vu_queue_notify(VuDev * dev,VuVirtq * vq,bool sync)2613  static void _vu_queue_notify(VuDev *dev, VuVirtq *vq, bool sync)
2614  {
2615      if (!vu_is_vq_usable(dev, vq)) {
2616          return;
2617      }
2618  
2619      if (!vring_notify(dev, vq)) {
2620          DPRINT("skipped notify...\n");
2621          return;
2622      }
2623  
2624      if (vq->call_fd < 0 &&
2625          vu_has_protocol_feature(dev,
2626                                  VHOST_USER_PROTOCOL_F_INBAND_NOTIFICATIONS) &&
2627          vu_has_protocol_feature(dev, VHOST_USER_PROTOCOL_F_BACKEND_REQ)) {
2628          VhostUserMsg vmsg = {
2629              .request = VHOST_USER_BACKEND_VRING_CALL,
2630              .flags = VHOST_USER_VERSION,
2631              .size = sizeof(vmsg.payload.state),
2632              .payload.state = {
2633                  .index = vq - dev->vq,
2634              },
2635          };
2636          bool ack = sync &&
2637                     vu_has_protocol_feature(dev,
2638                                             VHOST_USER_PROTOCOL_F_REPLY_ACK);
2639  
2640          if (ack) {
2641              vmsg.flags |= VHOST_USER_NEED_REPLY_MASK;
2642          }
2643  
2644          vu_message_write(dev, dev->backend_fd, &vmsg);
2645          if (ack) {
2646              vu_message_read_default(dev, dev->backend_fd, &vmsg);
2647          }
2648          return;
2649      }
2650  
2651      if (eventfd_write(vq->call_fd, 1) < 0) {
2652          vu_panic(dev, "Error writing eventfd: %s", strerror(errno));
2653      }
2654  }
2655  
vu_queue_notify(VuDev * dev,VuVirtq * vq)2656  void vu_queue_notify(VuDev *dev, VuVirtq *vq)
2657  {
2658      _vu_queue_notify(dev, vq, false);
2659  }
2660  
vu_queue_notify_sync(VuDev * dev,VuVirtq * vq)2661  void vu_queue_notify_sync(VuDev *dev, VuVirtq *vq)
2662  {
2663      _vu_queue_notify(dev, vq, true);
2664  }
2665  
vu_config_change_msg(VuDev * dev)2666  void vu_config_change_msg(VuDev *dev)
2667  {
2668      VhostUserMsg vmsg = {
2669          .request = VHOST_USER_BACKEND_CONFIG_CHANGE_MSG,
2670          .flags = VHOST_USER_VERSION,
2671      };
2672  
2673      vu_message_write(dev, dev->backend_fd, &vmsg);
2674  }
2675  
2676  static inline void
vring_used_flags_set_bit(VuVirtq * vq,int mask)2677  vring_used_flags_set_bit(VuVirtq *vq, int mask)
2678  {
2679      uint16_t *flags;
2680  
2681      flags = (uint16_t *)((char*)vq->vring.used +
2682                           offsetof(struct vring_used, flags));
2683      *flags = htole16(le16toh(*flags) | mask);
2684  }
2685  
2686  static inline void
vring_used_flags_unset_bit(VuVirtq * vq,int mask)2687  vring_used_flags_unset_bit(VuVirtq *vq, int mask)
2688  {
2689      uint16_t *flags;
2690  
2691      flags = (uint16_t *)((char*)vq->vring.used +
2692                           offsetof(struct vring_used, flags));
2693      *flags = htole16(le16toh(*flags) & ~mask);
2694  }
2695  
2696  static inline void
vring_set_avail_event(VuVirtq * vq,uint16_t val)2697  vring_set_avail_event(VuVirtq *vq, uint16_t val)
2698  {
2699      uint16_t val_le = htole16(val);
2700  
2701      if (!vq->notification) {
2702          return;
2703      }
2704  
2705      memcpy(&vq->vring.used->ring[vq->vring.num], &val_le, sizeof(uint16_t));
2706  }
2707  
2708  void
vu_queue_set_notification(VuDev * dev,VuVirtq * vq,int enable)2709  vu_queue_set_notification(VuDev *dev, VuVirtq *vq, int enable)
2710  {
2711      vq->notification = enable;
2712      if (vu_has_feature(dev, VIRTIO_RING_F_EVENT_IDX)) {
2713          vring_set_avail_event(vq, vring_avail_idx(vq));
2714      } else if (enable) {
2715          vring_used_flags_unset_bit(vq, VRING_USED_F_NO_NOTIFY);
2716      } else {
2717          vring_used_flags_set_bit(vq, VRING_USED_F_NO_NOTIFY);
2718      }
2719      if (enable) {
2720          /* Expose avail event/used flags before caller checks the avail idx. */
2721          smp_mb();
2722      }
2723  }
2724  
2725  static bool
virtqueue_map_desc(VuDev * dev,unsigned int * p_num_sg,struct iovec * iov,unsigned int max_num_sg,bool is_write,uint64_t pa,size_t sz)2726  virtqueue_map_desc(VuDev *dev,
2727                     unsigned int *p_num_sg, struct iovec *iov,
2728                     unsigned int max_num_sg, bool is_write,
2729                     uint64_t pa, size_t sz)
2730  {
2731      unsigned num_sg = *p_num_sg;
2732  
2733      assert(num_sg <= max_num_sg);
2734  
2735      if (!sz) {
2736          vu_panic(dev, "virtio: zero sized buffers are not allowed");
2737          return false;
2738      }
2739  
2740      while (sz) {
2741          uint64_t len = sz;
2742  
2743          if (num_sg == max_num_sg) {
2744              vu_panic(dev, "virtio: too many descriptors in indirect table");
2745              return false;
2746          }
2747  
2748          iov[num_sg].iov_base = vu_gpa_to_va(dev, &len, pa);
2749          if (iov[num_sg].iov_base == NULL) {
2750              vu_panic(dev, "virtio: invalid address for buffers");
2751              return false;
2752          }
2753          iov[num_sg].iov_len = len;
2754          num_sg++;
2755          sz -= len;
2756          pa += len;
2757      }
2758  
2759      *p_num_sg = num_sg;
2760      return true;
2761  }
2762  
2763  static void *
virtqueue_alloc_element(size_t sz,unsigned out_num,unsigned in_num)2764  virtqueue_alloc_element(size_t sz,
2765                                       unsigned out_num, unsigned in_num)
2766  {
2767      VuVirtqElement *elem;
2768      size_t in_sg_ofs = ALIGN_UP(sz, __alignof__(elem->in_sg[0]));
2769      size_t out_sg_ofs = in_sg_ofs + in_num * sizeof(elem->in_sg[0]);
2770      size_t out_sg_end = out_sg_ofs + out_num * sizeof(elem->out_sg[0]);
2771  
2772      assert(sz >= sizeof(VuVirtqElement));
2773      elem = malloc(out_sg_end);
2774      if (!elem) {
2775          DPRINT("%s: failed to malloc virtqueue element\n", __func__);
2776          return NULL;
2777      }
2778      elem->out_num = out_num;
2779      elem->in_num = in_num;
2780      elem->in_sg = (void *)elem + in_sg_ofs;
2781      elem->out_sg = (void *)elem + out_sg_ofs;
2782      return elem;
2783  }
2784  
2785  static void *
vu_queue_map_desc(VuDev * dev,VuVirtq * vq,unsigned int idx,size_t sz)2786  vu_queue_map_desc(VuDev *dev, VuVirtq *vq, unsigned int idx, size_t sz)
2787  {
2788      struct vring_desc *desc = vq->vring.desc;
2789      uint64_t desc_addr, read_len;
2790      unsigned int desc_len;
2791      unsigned int max = vq->vring.num;
2792      unsigned int i = idx;
2793      VuVirtqElement *elem;
2794      unsigned int out_num = 0, in_num = 0;
2795      struct iovec iov[VIRTQUEUE_MAX_SIZE];
2796      struct vring_desc desc_buf[VIRTQUEUE_MAX_SIZE];
2797      int rc;
2798  
2799      if (le16toh(desc[i].flags) & VRING_DESC_F_INDIRECT) {
2800          if (le32toh(desc[i].len) % sizeof(struct vring_desc)) {
2801              vu_panic(dev, "Invalid size for indirect buffer table");
2802              return NULL;
2803          }
2804  
2805          /* loop over the indirect descriptor table */
2806          desc_addr = le64toh(desc[i].addr);
2807          desc_len = le32toh(desc[i].len);
2808          max = desc_len / sizeof(struct vring_desc);
2809          read_len = desc_len;
2810          desc = vu_gpa_to_va(dev, &read_len, desc_addr);
2811          if (unlikely(desc && read_len != desc_len)) {
2812              /* Failed to use zero copy */
2813              desc = NULL;
2814              if (!virtqueue_read_indirect_desc(dev, desc_buf,
2815                                                desc_addr,
2816                                                desc_len)) {
2817                  desc = desc_buf;
2818              }
2819          }
2820          if (!desc) {
2821              vu_panic(dev, "Invalid indirect buffer table");
2822              return NULL;
2823          }
2824          i = 0;
2825      }
2826  
2827      /* Collect all the descriptors */
2828      do {
2829          if (le16toh(desc[i].flags) & VRING_DESC_F_WRITE) {
2830              if (!virtqueue_map_desc(dev, &in_num, iov + out_num,
2831                                 VIRTQUEUE_MAX_SIZE - out_num, true,
2832                                 le64toh(desc[i].addr),
2833                                 le32toh(desc[i].len))) {
2834                  return NULL;
2835              }
2836          } else {
2837              if (in_num) {
2838                  vu_panic(dev, "Incorrect order for descriptors");
2839                  return NULL;
2840              }
2841              if (!virtqueue_map_desc(dev, &out_num, iov,
2842                                 VIRTQUEUE_MAX_SIZE, false,
2843                                 le64toh(desc[i].addr),
2844                                 le32toh(desc[i].len))) {
2845                  return NULL;
2846              }
2847          }
2848  
2849          /* If we've got too many, that implies a descriptor loop. */
2850          if ((in_num + out_num) > max) {
2851              vu_panic(dev, "Looped descriptor");
2852              return NULL;
2853          }
2854          rc = virtqueue_read_next_desc(dev, desc, i, max, &i);
2855      } while (rc == VIRTQUEUE_READ_DESC_MORE);
2856  
2857      if (rc == VIRTQUEUE_READ_DESC_ERROR) {
2858          vu_panic(dev, "read descriptor error");
2859          return NULL;
2860      }
2861  
2862      /* Now copy what we have collected and mapped */
2863      elem = virtqueue_alloc_element(sz, out_num, in_num);
2864      if (!elem) {
2865          return NULL;
2866      }
2867      elem->index = idx;
2868      for (i = 0; i < out_num; i++) {
2869          elem->out_sg[i] = iov[i];
2870      }
2871      for (i = 0; i < in_num; i++) {
2872          elem->in_sg[i] = iov[out_num + i];
2873      }
2874  
2875      return elem;
2876  }
2877  
2878  static int
vu_queue_inflight_get(VuDev * dev,VuVirtq * vq,int desc_idx)2879  vu_queue_inflight_get(VuDev *dev, VuVirtq *vq, int desc_idx)
2880  {
2881      if (!vu_has_protocol_feature(dev, VHOST_USER_PROTOCOL_F_INFLIGHT_SHMFD)) {
2882          return 0;
2883      }
2884  
2885      if (unlikely(!vq->inflight)) {
2886          return -1;
2887      }
2888  
2889      vq->inflight->desc[desc_idx].counter = vq->counter++;
2890      vq->inflight->desc[desc_idx].inflight = 1;
2891  
2892      return 0;
2893  }
2894  
2895  static int
vu_queue_inflight_pre_put(VuDev * dev,VuVirtq * vq,int desc_idx)2896  vu_queue_inflight_pre_put(VuDev *dev, VuVirtq *vq, int desc_idx)
2897  {
2898      if (!vu_has_protocol_feature(dev, VHOST_USER_PROTOCOL_F_INFLIGHT_SHMFD)) {
2899          return 0;
2900      }
2901  
2902      if (unlikely(!vq->inflight)) {
2903          return -1;
2904      }
2905  
2906      vq->inflight->last_batch_head = desc_idx;
2907  
2908      return 0;
2909  }
2910  
2911  static int
vu_queue_inflight_post_put(VuDev * dev,VuVirtq * vq,int desc_idx)2912  vu_queue_inflight_post_put(VuDev *dev, VuVirtq *vq, int desc_idx)
2913  {
2914      if (!vu_has_protocol_feature(dev, VHOST_USER_PROTOCOL_F_INFLIGHT_SHMFD)) {
2915          return 0;
2916      }
2917  
2918      if (unlikely(!vq->inflight)) {
2919          return -1;
2920      }
2921  
2922      barrier();
2923  
2924      vq->inflight->desc[desc_idx].inflight = 0;
2925  
2926      barrier();
2927  
2928      vq->inflight->used_idx = vq->used_idx;
2929  
2930      return 0;
2931  }
2932  
2933  void *
vu_queue_pop(VuDev * dev,VuVirtq * vq,size_t sz)2934  vu_queue_pop(VuDev *dev, VuVirtq *vq, size_t sz)
2935  {
2936      int i;
2937      unsigned int head;
2938      VuVirtqElement *elem;
2939  
2940      if (!vu_is_vq_usable(dev, vq)) {
2941          return NULL;
2942      }
2943  
2944      if (unlikely(vq->resubmit_list && vq->resubmit_num > 0)) {
2945          i = (--vq->resubmit_num);
2946          elem = vu_queue_map_desc(dev, vq, vq->resubmit_list[i].index, sz);
2947  
2948          if (!vq->resubmit_num) {
2949              free(vq->resubmit_list);
2950              vq->resubmit_list = NULL;
2951          }
2952  
2953          return elem;
2954      }
2955  
2956      if (vu_queue_empty(dev, vq)) {
2957          return NULL;
2958      }
2959      /*
2960       * Needed after virtio_queue_empty(), see comment in
2961       * virtqueue_num_heads().
2962       */
2963      smp_rmb();
2964  
2965      if (vq->inuse >= vq->vring.num) {
2966          vu_panic(dev, "Virtqueue size exceeded");
2967          return NULL;
2968      }
2969  
2970      if (!virtqueue_get_head(dev, vq, vq->last_avail_idx++, &head)) {
2971          return NULL;
2972      }
2973  
2974      if (vu_has_feature(dev, VIRTIO_RING_F_EVENT_IDX)) {
2975          vring_set_avail_event(vq, vq->last_avail_idx);
2976      }
2977  
2978      elem = vu_queue_map_desc(dev, vq, head, sz);
2979  
2980      if (!elem) {
2981          return NULL;
2982      }
2983  
2984      vq->inuse++;
2985  
2986      vu_queue_inflight_get(dev, vq, head);
2987  
2988      return elem;
2989  }
2990  
2991  static void
vu_queue_detach_element(VuDev * dev,VuVirtq * vq,VuVirtqElement * elem,size_t len)2992  vu_queue_detach_element(VuDev *dev, VuVirtq *vq, VuVirtqElement *elem,
2993                          size_t len)
2994  {
2995      vq->inuse--;
2996      /* unmap, when DMA support is added */
2997  }
2998  
2999  void
vu_queue_unpop(VuDev * dev,VuVirtq * vq,VuVirtqElement * elem,size_t len)3000  vu_queue_unpop(VuDev *dev, VuVirtq *vq, VuVirtqElement *elem,
3001                 size_t len)
3002  {
3003      vq->last_avail_idx--;
3004      vu_queue_detach_element(dev, vq, elem, len);
3005  }
3006  
3007  bool
vu_queue_rewind(VuDev * dev,VuVirtq * vq,unsigned int num)3008  vu_queue_rewind(VuDev *dev, VuVirtq *vq, unsigned int num)
3009  {
3010      if (num > vq->inuse) {
3011          return false;
3012      }
3013      vq->last_avail_idx -= num;
3014      vq->inuse -= num;
3015      return true;
3016  }
3017  
3018  static inline
vring_used_write(VuDev * dev,VuVirtq * vq,struct vring_used_elem * uelem,int i)3019  void vring_used_write(VuDev *dev, VuVirtq *vq,
3020                        struct vring_used_elem *uelem, int i)
3021  {
3022      struct vring_used *used = vq->vring.used;
3023  
3024      used->ring[i] = *uelem;
3025      vu_log_write(dev, vq->vring.log_guest_addr +
3026                   offsetof(struct vring_used, ring[i]),
3027                   sizeof(used->ring[i]));
3028  }
3029  
3030  
3031  static void
vu_log_queue_fill(VuDev * dev,VuVirtq * vq,const VuVirtqElement * elem,unsigned int len)3032  vu_log_queue_fill(VuDev *dev, VuVirtq *vq,
3033                    const VuVirtqElement *elem,
3034                    unsigned int len)
3035  {
3036      struct vring_desc *desc = vq->vring.desc;
3037      unsigned int i, max, min, desc_len;
3038      uint64_t desc_addr, read_len;
3039      struct vring_desc desc_buf[VIRTQUEUE_MAX_SIZE];
3040      unsigned num_bufs = 0;
3041  
3042      max = vq->vring.num;
3043      i = elem->index;
3044  
3045      if (le16toh(desc[i].flags) & VRING_DESC_F_INDIRECT) {
3046          if (le32toh(desc[i].len) % sizeof(struct vring_desc)) {
3047              vu_panic(dev, "Invalid size for indirect buffer table");
3048              return;
3049          }
3050  
3051          /* loop over the indirect descriptor table */
3052          desc_addr = le64toh(desc[i].addr);
3053          desc_len = le32toh(desc[i].len);
3054          max = desc_len / sizeof(struct vring_desc);
3055          read_len = desc_len;
3056          desc = vu_gpa_to_va(dev, &read_len, desc_addr);
3057          if (unlikely(desc && read_len != desc_len)) {
3058              /* Failed to use zero copy */
3059              desc = NULL;
3060              if (!virtqueue_read_indirect_desc(dev, desc_buf,
3061                                                desc_addr,
3062                                                desc_len)) {
3063                  desc = desc_buf;
3064              }
3065          }
3066          if (!desc) {
3067              vu_panic(dev, "Invalid indirect buffer table");
3068              return;
3069          }
3070          i = 0;
3071      }
3072  
3073      do {
3074          if (++num_bufs > max) {
3075              vu_panic(dev, "Looped descriptor");
3076              return;
3077          }
3078  
3079          if (le16toh(desc[i].flags) & VRING_DESC_F_WRITE) {
3080              min = MIN(le32toh(desc[i].len), len);
3081              vu_log_write(dev, le64toh(desc[i].addr), min);
3082              len -= min;
3083          }
3084  
3085      } while (len > 0 &&
3086               (virtqueue_read_next_desc(dev, desc, i, max, &i)
3087                == VIRTQUEUE_READ_DESC_MORE));
3088  }
3089  
3090  void
vu_queue_fill(VuDev * dev,VuVirtq * vq,const VuVirtqElement * elem,unsigned int len,unsigned int idx)3091  vu_queue_fill(VuDev *dev, VuVirtq *vq,
3092                const VuVirtqElement *elem,
3093                unsigned int len, unsigned int idx)
3094  {
3095      struct vring_used_elem uelem;
3096  
3097      if (!vu_is_vq_usable(dev, vq)) {
3098          return;
3099      }
3100  
3101      vu_log_queue_fill(dev, vq, elem, len);
3102  
3103      idx = (idx + vq->used_idx) % vq->vring.num;
3104  
3105      uelem.id = htole32(elem->index);
3106      uelem.len = htole32(len);
3107      vring_used_write(dev, vq, &uelem, idx);
3108  }
3109  
3110  static inline
vring_used_idx_set(VuDev * dev,VuVirtq * vq,uint16_t val)3111  void vring_used_idx_set(VuDev *dev, VuVirtq *vq, uint16_t val)
3112  {
3113      vq->vring.used->idx = htole16(val);
3114      vu_log_write(dev,
3115                   vq->vring.log_guest_addr + offsetof(struct vring_used, idx),
3116                   sizeof(vq->vring.used->idx));
3117  
3118      vq->used_idx = val;
3119  }
3120  
3121  void
vu_queue_flush(VuDev * dev,VuVirtq * vq,unsigned int count)3122  vu_queue_flush(VuDev *dev, VuVirtq *vq, unsigned int count)
3123  {
3124      uint16_t old, new;
3125  
3126      if (!vu_is_vq_usable(dev, vq)) {
3127          return;
3128      }
3129  
3130      /* Make sure buffer is written before we update index. */
3131      smp_wmb();
3132  
3133      old = vq->used_idx;
3134      new = old + count;
3135      vring_used_idx_set(dev, vq, new);
3136      vq->inuse -= count;
3137      if (unlikely((int16_t)(new - vq->signalled_used) < (uint16_t)(new - old))) {
3138          vq->signalled_used_valid = false;
3139      }
3140  }
3141  
3142  void
vu_queue_push(VuDev * dev,VuVirtq * vq,const VuVirtqElement * elem,unsigned int len)3143  vu_queue_push(VuDev *dev, VuVirtq *vq,
3144                const VuVirtqElement *elem, unsigned int len)
3145  {
3146      vu_queue_fill(dev, vq, elem, len, 0);
3147      vu_queue_inflight_pre_put(dev, vq, elem->index);
3148      vu_queue_flush(dev, vq, 1);
3149      vu_queue_inflight_post_put(dev, vq, elem->index);
3150  }
3151