xref: /openbmc/qemu/subprojects/libvhost-user/libvhost-user.c (revision 732d548732edda90f3cfe0f7dceee64861f07b25)
1 /*
2  * Vhost User library
3  *
4  * Copyright IBM, Corp. 2007
5  * Copyright (c) 2016 Red Hat, Inc.
6  *
7  * Authors:
8  *  Anthony Liguori <aliguori@us.ibm.com>
9  *  Marc-André Lureau <mlureau@redhat.com>
10  *  Victor Kaplansky <victork@redhat.com>
11  *
12  * This work is licensed under the terms of the GNU GPL, version 2 or
13  * later.  See the COPYING file in the top-level directory.
14  */
15 
16 #ifndef _GNU_SOURCE
17 #define _GNU_SOURCE
18 #endif
19 
20 /* this code avoids GLib dependency */
21 #include <stdlib.h>
22 #include <stdio.h>
23 #include <unistd.h>
24 #include <stdarg.h>
25 #include <errno.h>
26 #include <string.h>
27 #include <assert.h>
28 #include <inttypes.h>
29 #include <sys/types.h>
30 #include <sys/socket.h>
31 #include <sys/eventfd.h>
32 #include <sys/mman.h>
33 #include <endian.h>
34 
35 /* Necessary to provide VIRTIO_F_VERSION_1 on system
36  * with older linux headers. Must appear before
37  * <linux/vhost.h> below.
38  */
39 #include "standard-headers/linux/virtio_config.h"
40 
41 #if defined(__linux__)
42 #include <sys/syscall.h>
43 #include <fcntl.h>
44 #include <sys/ioctl.h>
45 #include <linux/vhost.h>
46 
47 #ifdef __NR_userfaultfd
48 #include <linux/userfaultfd.h>
49 #endif
50 
51 #endif
52 
53 #include "include/atomic.h"
54 
55 #include "libvhost-user.h"
56 
57 /* usually provided by GLib */
58 #if     __GNUC__ > 2 || (__GNUC__ == 2 && __GNUC_MINOR__ > 4)
59 #if !defined(__clang__) && (__GNUC__ == 4 && __GNUC_MINOR__ == 4)
60 #define G_GNUC_PRINTF(format_idx, arg_idx) \
61   __attribute__((__format__(gnu_printf, format_idx, arg_idx)))
62 #else
63 #define G_GNUC_PRINTF(format_idx, arg_idx) \
64   __attribute__((__format__(__printf__, format_idx, arg_idx)))
65 #endif
66 #else   /* !__GNUC__ */
67 #define G_GNUC_PRINTF(format_idx, arg_idx)
68 #endif  /* !__GNUC__ */
69 #ifndef MIN
70 #define MIN(x, y) ({                            \
71             __typeof__(x) _min1 = (x);          \
72             __typeof__(y) _min2 = (y);          \
73             (void) (&_min1 == &_min2);          \
74             _min1 < _min2 ? _min1 : _min2; })
75 #endif
76 
77 /* Round number down to multiple */
78 #define ALIGN_DOWN(n, m) ((n) / (m) * (m))
79 
80 /* Round number up to multiple */
81 #define ALIGN_UP(n, m) ALIGN_DOWN((n) + (m) - 1, (m))
82 
83 #ifndef unlikely
84 #define unlikely(x)   __builtin_expect(!!(x), 0)
85 #endif
86 
87 /* Align each region to cache line size in inflight buffer */
88 #define INFLIGHT_ALIGNMENT 64
89 
90 /* The version of inflight buffer */
91 #define INFLIGHT_VERSION 1
92 
93 /* The version of the protocol we support */
94 #define VHOST_USER_VERSION 1
95 #define LIBVHOST_USER_DEBUG 0
96 
97 #define DPRINT(...)                             \
98     do {                                        \
99         if (LIBVHOST_USER_DEBUG) {              \
100             fprintf(stderr, __VA_ARGS__);        \
101         }                                       \
102     } while (0)
103 
104 static inline
105 bool has_feature(uint64_t features, unsigned int fbit)
106 {
107     assert(fbit < 64);
108     return !!(features & (1ULL << fbit));
109 }
110 
111 static inline
112 bool vu_has_feature(VuDev *dev,
113                     unsigned int fbit)
114 {
115     return has_feature(dev->features, fbit);
116 }
117 
118 static inline bool vu_has_protocol_feature(VuDev *dev, unsigned int fbit)
119 {
120     return has_feature(dev->protocol_features, fbit);
121 }
122 
123 const char *
124 vu_request_to_string(unsigned int req)
125 {
126 #define REQ(req) [req] = #req
127     static const char *vu_request_str[] = {
128         REQ(VHOST_USER_NONE),
129         REQ(VHOST_USER_GET_FEATURES),
130         REQ(VHOST_USER_SET_FEATURES),
131         REQ(VHOST_USER_SET_OWNER),
132         REQ(VHOST_USER_RESET_OWNER),
133         REQ(VHOST_USER_SET_MEM_TABLE),
134         REQ(VHOST_USER_SET_LOG_BASE),
135         REQ(VHOST_USER_SET_LOG_FD),
136         REQ(VHOST_USER_SET_VRING_NUM),
137         REQ(VHOST_USER_SET_VRING_ADDR),
138         REQ(VHOST_USER_SET_VRING_BASE),
139         REQ(VHOST_USER_GET_VRING_BASE),
140         REQ(VHOST_USER_SET_VRING_KICK),
141         REQ(VHOST_USER_SET_VRING_CALL),
142         REQ(VHOST_USER_SET_VRING_ERR),
143         REQ(VHOST_USER_GET_PROTOCOL_FEATURES),
144         REQ(VHOST_USER_SET_PROTOCOL_FEATURES),
145         REQ(VHOST_USER_GET_QUEUE_NUM),
146         REQ(VHOST_USER_SET_VRING_ENABLE),
147         REQ(VHOST_USER_SEND_RARP),
148         REQ(VHOST_USER_NET_SET_MTU),
149         REQ(VHOST_USER_SET_BACKEND_REQ_FD),
150         REQ(VHOST_USER_IOTLB_MSG),
151         REQ(VHOST_USER_SET_VRING_ENDIAN),
152         REQ(VHOST_USER_GET_CONFIG),
153         REQ(VHOST_USER_SET_CONFIG),
154         REQ(VHOST_USER_POSTCOPY_ADVISE),
155         REQ(VHOST_USER_POSTCOPY_LISTEN),
156         REQ(VHOST_USER_POSTCOPY_END),
157         REQ(VHOST_USER_GET_INFLIGHT_FD),
158         REQ(VHOST_USER_SET_INFLIGHT_FD),
159         REQ(VHOST_USER_GPU_SET_SOCKET),
160         REQ(VHOST_USER_VRING_KICK),
161         REQ(VHOST_USER_GET_MAX_MEM_SLOTS),
162         REQ(VHOST_USER_ADD_MEM_REG),
163         REQ(VHOST_USER_REM_MEM_REG),
164         REQ(VHOST_USER_MAX),
165     };
166 #undef REQ
167 
168     if (req < VHOST_USER_MAX) {
169         return vu_request_str[req];
170     } else {
171         return "unknown";
172     }
173 }
174 
175 static void G_GNUC_PRINTF(2, 3)
176 vu_panic(VuDev *dev, const char *msg, ...)
177 {
178     char *buf = NULL;
179     va_list ap;
180 
181     va_start(ap, msg);
182     if (vasprintf(&buf, msg, ap) < 0) {
183         buf = NULL;
184     }
185     va_end(ap);
186 
187     dev->broken = true;
188     dev->panic(dev, buf);
189     free(buf);
190 
191     /*
192      * FIXME:
193      * find a way to call virtio_error, or perhaps close the connection?
194      */
195 }
196 
197 /* Translate guest physical address to our virtual address.  */
198 void *
199 vu_gpa_to_va(VuDev *dev, uint64_t *plen, uint64_t guest_addr)
200 {
201     unsigned int i;
202 
203     if (*plen == 0) {
204         return NULL;
205     }
206 
207     /* Find matching memory region.  */
208     for (i = 0; i < dev->nregions; i++) {
209         VuDevRegion *r = &dev->regions[i];
210 
211         if ((guest_addr >= r->gpa) && (guest_addr < (r->gpa + r->size))) {
212             if ((guest_addr + *plen) > (r->gpa + r->size)) {
213                 *plen = r->gpa + r->size - guest_addr;
214             }
215             return (void *)(uintptr_t)
216                 guest_addr - r->gpa + r->mmap_addr + r->mmap_offset;
217         }
218     }
219 
220     return NULL;
221 }
222 
223 /* Translate qemu virtual address to our virtual address.  */
224 static void *
225 qva_to_va(VuDev *dev, uint64_t qemu_addr)
226 {
227     unsigned int i;
228 
229     /* Find matching memory region.  */
230     for (i = 0; i < dev->nregions; i++) {
231         VuDevRegion *r = &dev->regions[i];
232 
233         if ((qemu_addr >= r->qva) && (qemu_addr < (r->qva + r->size))) {
234             return (void *)(uintptr_t)
235                 qemu_addr - r->qva + r->mmap_addr + r->mmap_offset;
236         }
237     }
238 
239     return NULL;
240 }
241 
242 static void
243 vmsg_close_fds(VhostUserMsg *vmsg)
244 {
245     int i;
246 
247     for (i = 0; i < vmsg->fd_num; i++) {
248         close(vmsg->fds[i]);
249     }
250 }
251 
252 /* Set reply payload.u64 and clear request flags and fd_num */
253 static void vmsg_set_reply_u64(VhostUserMsg *vmsg, uint64_t val)
254 {
255     vmsg->flags = 0; /* defaults will be set by vu_send_reply() */
256     vmsg->size = sizeof(vmsg->payload.u64);
257     vmsg->payload.u64 = val;
258     vmsg->fd_num = 0;
259 }
260 
261 /* A test to see if we have userfault available */
262 static bool
263 have_userfault(void)
264 {
265 #if defined(__linux__) && defined(__NR_userfaultfd) &&\
266         defined(UFFD_FEATURE_MISSING_SHMEM) &&\
267         defined(UFFD_FEATURE_MISSING_HUGETLBFS)
268     /* Now test the kernel we're running on really has the features */
269     int ufd = syscall(__NR_userfaultfd, O_CLOEXEC | O_NONBLOCK);
270     struct uffdio_api api_struct;
271     if (ufd < 0) {
272         return false;
273     }
274 
275     api_struct.api = UFFD_API;
276     api_struct.features = UFFD_FEATURE_MISSING_SHMEM |
277                           UFFD_FEATURE_MISSING_HUGETLBFS;
278     if (ioctl(ufd, UFFDIO_API, &api_struct)) {
279         close(ufd);
280         return false;
281     }
282     close(ufd);
283     return true;
284 
285 #else
286     return false;
287 #endif
288 }
289 
290 static bool
291 vu_message_read_default(VuDev *dev, int conn_fd, VhostUserMsg *vmsg)
292 {
293     char control[CMSG_SPACE(VHOST_MEMORY_BASELINE_NREGIONS * sizeof(int))] = {};
294     struct iovec iov = {
295         .iov_base = (char *)vmsg,
296         .iov_len = VHOST_USER_HDR_SIZE,
297     };
298     struct msghdr msg = {
299         .msg_iov = &iov,
300         .msg_iovlen = 1,
301         .msg_control = control,
302         .msg_controllen = sizeof(control),
303     };
304     size_t fd_size;
305     struct cmsghdr *cmsg;
306     int rc;
307 
308     do {
309         rc = recvmsg(conn_fd, &msg, 0);
310     } while (rc < 0 && (errno == EINTR || errno == EAGAIN));
311 
312     if (rc < 0) {
313         vu_panic(dev, "Error while recvmsg: %s", strerror(errno));
314         return false;
315     }
316 
317     vmsg->fd_num = 0;
318     for (cmsg = CMSG_FIRSTHDR(&msg);
319          cmsg != NULL;
320          cmsg = CMSG_NXTHDR(&msg, cmsg))
321     {
322         if (cmsg->cmsg_level == SOL_SOCKET && cmsg->cmsg_type == SCM_RIGHTS) {
323             fd_size = cmsg->cmsg_len - CMSG_LEN(0);
324             vmsg->fd_num = fd_size / sizeof(int);
325             memcpy(vmsg->fds, CMSG_DATA(cmsg), fd_size);
326             break;
327         }
328     }
329 
330     if (vmsg->size > sizeof(vmsg->payload)) {
331         vu_panic(dev,
332                  "Error: too big message request: %d, size: vmsg->size: %u, "
333                  "while sizeof(vmsg->payload) = %zu\n",
334                  vmsg->request, vmsg->size, sizeof(vmsg->payload));
335         goto fail;
336     }
337 
338     if (vmsg->size) {
339         do {
340             rc = read(conn_fd, &vmsg->payload, vmsg->size);
341         } while (rc < 0 && (errno == EINTR || errno == EAGAIN));
342 
343         if (rc <= 0) {
344             vu_panic(dev, "Error while reading: %s", strerror(errno));
345             goto fail;
346         }
347 
348         assert((uint32_t)rc == vmsg->size);
349     }
350 
351     return true;
352 
353 fail:
354     vmsg_close_fds(vmsg);
355 
356     return false;
357 }
358 
359 static bool
360 vu_message_write(VuDev *dev, int conn_fd, VhostUserMsg *vmsg)
361 {
362     int rc;
363     uint8_t *p = (uint8_t *)vmsg;
364     char control[CMSG_SPACE(VHOST_MEMORY_BASELINE_NREGIONS * sizeof(int))] = {};
365     struct iovec iov = {
366         .iov_base = (char *)vmsg,
367         .iov_len = VHOST_USER_HDR_SIZE,
368     };
369     struct msghdr msg = {
370         .msg_iov = &iov,
371         .msg_iovlen = 1,
372         .msg_control = control,
373     };
374     struct cmsghdr *cmsg;
375 
376     memset(control, 0, sizeof(control));
377     assert(vmsg->fd_num <= VHOST_MEMORY_BASELINE_NREGIONS);
378     if (vmsg->fd_num > 0) {
379         size_t fdsize = vmsg->fd_num * sizeof(int);
380         msg.msg_controllen = CMSG_SPACE(fdsize);
381         cmsg = CMSG_FIRSTHDR(&msg);
382         cmsg->cmsg_len = CMSG_LEN(fdsize);
383         cmsg->cmsg_level = SOL_SOCKET;
384         cmsg->cmsg_type = SCM_RIGHTS;
385         memcpy(CMSG_DATA(cmsg), vmsg->fds, fdsize);
386     } else {
387         msg.msg_controllen = 0;
388     }
389 
390     do {
391         rc = sendmsg(conn_fd, &msg, 0);
392     } while (rc < 0 && (errno == EINTR || errno == EAGAIN));
393 
394     if (vmsg->size) {
395         do {
396             if (vmsg->data) {
397                 rc = write(conn_fd, vmsg->data, vmsg->size);
398             } else {
399                 rc = write(conn_fd, p + VHOST_USER_HDR_SIZE, vmsg->size);
400             }
401         } while (rc < 0 && (errno == EINTR || errno == EAGAIN));
402     }
403 
404     if (rc <= 0) {
405         vu_panic(dev, "Error while writing: %s", strerror(errno));
406         return false;
407     }
408 
409     return true;
410 }
411 
412 static bool
413 vu_send_reply(VuDev *dev, int conn_fd, VhostUserMsg *vmsg)
414 {
415     /* Set the version in the flags when sending the reply */
416     vmsg->flags &= ~VHOST_USER_VERSION_MASK;
417     vmsg->flags |= VHOST_USER_VERSION;
418     vmsg->flags |= VHOST_USER_REPLY_MASK;
419 
420     return vu_message_write(dev, conn_fd, vmsg);
421 }
422 
423 /*
424  * Processes a reply on the backend channel.
425  * Entered with backend_mutex held and releases it before exit.
426  * Returns true on success.
427  */
428 static bool
429 vu_process_message_reply(VuDev *dev, const VhostUserMsg *vmsg)
430 {
431     VhostUserMsg msg_reply;
432     bool result = false;
433 
434     if ((vmsg->flags & VHOST_USER_NEED_REPLY_MASK) == 0) {
435         result = true;
436         goto out;
437     }
438 
439     if (!vu_message_read_default(dev, dev->backend_fd, &msg_reply)) {
440         goto out;
441     }
442 
443     if (msg_reply.request != vmsg->request) {
444         DPRINT("Received unexpected msg type. Expected %d received %d",
445                vmsg->request, msg_reply.request);
446         goto out;
447     }
448 
449     result = msg_reply.payload.u64 == 0;
450 
451 out:
452     pthread_mutex_unlock(&dev->backend_mutex);
453     return result;
454 }
455 
456 /* Kick the log_call_fd if required. */
457 static void
458 vu_log_kick(VuDev *dev)
459 {
460     if (dev->log_call_fd != -1) {
461         DPRINT("Kicking the QEMU's log...\n");
462         if (eventfd_write(dev->log_call_fd, 1) < 0) {
463             vu_panic(dev, "Error writing eventfd: %s", strerror(errno));
464         }
465     }
466 }
467 
468 static void
469 vu_log_page(uint8_t *log_table, uint64_t page)
470 {
471     DPRINT("Logged dirty guest page: %"PRId64"\n", page);
472     qatomic_or(&log_table[page / 8], 1 << (page % 8));
473 }
474 
475 static void
476 vu_log_write(VuDev *dev, uint64_t address, uint64_t length)
477 {
478     uint64_t page;
479 
480     if (!(dev->features & (1ULL << VHOST_F_LOG_ALL)) ||
481         !dev->log_table || !length) {
482         return;
483     }
484 
485     assert(dev->log_size > ((address + length - 1) / VHOST_LOG_PAGE / 8));
486 
487     page = address / VHOST_LOG_PAGE;
488     while (page * VHOST_LOG_PAGE < address + length) {
489         vu_log_page(dev->log_table, page);
490         page += 1;
491     }
492 
493     vu_log_kick(dev);
494 }
495 
496 static void
497 vu_kick_cb(VuDev *dev, int condition, void *data)
498 {
499     int index = (intptr_t)data;
500     VuVirtq *vq = &dev->vq[index];
501     int sock = vq->kick_fd;
502     eventfd_t kick_data;
503     ssize_t rc;
504 
505     rc = eventfd_read(sock, &kick_data);
506     if (rc == -1) {
507         vu_panic(dev, "kick eventfd_read(): %s", strerror(errno));
508         dev->remove_watch(dev, dev->vq[index].kick_fd);
509     } else {
510         DPRINT("Got kick_data: %016"PRIx64" handler:%p idx:%d\n",
511                kick_data, vq->handler, index);
512         if (vq->handler) {
513             vq->handler(dev, index);
514         }
515     }
516 }
517 
518 static bool
519 vu_get_features_exec(VuDev *dev, VhostUserMsg *vmsg)
520 {
521     vmsg->payload.u64 =
522         /*
523          * The following VIRTIO feature bits are supported by our virtqueue
524          * implementation:
525          */
526         1ULL << VIRTIO_F_NOTIFY_ON_EMPTY |
527         1ULL << VIRTIO_RING_F_INDIRECT_DESC |
528         1ULL << VIRTIO_RING_F_EVENT_IDX |
529         1ULL << VIRTIO_F_VERSION_1 |
530 
531         /* vhost-user feature bits */
532         1ULL << VHOST_F_LOG_ALL |
533         1ULL << VHOST_USER_F_PROTOCOL_FEATURES;
534 
535     if (dev->iface->get_features) {
536         vmsg->payload.u64 |= dev->iface->get_features(dev);
537     }
538 
539     vmsg->size = sizeof(vmsg->payload.u64);
540     vmsg->fd_num = 0;
541 
542     DPRINT("Sending back to guest u64: 0x%016"PRIx64"\n", vmsg->payload.u64);
543 
544     return true;
545 }
546 
547 static void
548 vu_set_enable_all_rings(VuDev *dev, bool enabled)
549 {
550     uint16_t i;
551 
552     for (i = 0; i < dev->max_queues; i++) {
553         dev->vq[i].enable = enabled;
554     }
555 }
556 
557 static bool
558 vu_set_features_exec(VuDev *dev, VhostUserMsg *vmsg)
559 {
560     DPRINT("u64: 0x%016"PRIx64"\n", vmsg->payload.u64);
561 
562     dev->features = vmsg->payload.u64;
563     if (!vu_has_feature(dev, VIRTIO_F_VERSION_1)) {
564         /*
565          * We only support devices conforming to VIRTIO 1.0 or
566          * later
567          */
568         vu_panic(dev, "virtio legacy devices aren't supported by libvhost-user");
569         return false;
570     }
571 
572     if (!(dev->features & VHOST_USER_F_PROTOCOL_FEATURES)) {
573         vu_set_enable_all_rings(dev, true);
574     }
575 
576     if (dev->iface->set_features) {
577         dev->iface->set_features(dev, dev->features);
578     }
579 
580     return false;
581 }
582 
583 static bool
584 vu_set_owner_exec(VuDev *dev, VhostUserMsg *vmsg)
585 {
586     return false;
587 }
588 
589 static void
590 vu_close_log(VuDev *dev)
591 {
592     if (dev->log_table) {
593         if (munmap(dev->log_table, dev->log_size) != 0) {
594             perror("close log munmap() error");
595         }
596 
597         dev->log_table = NULL;
598     }
599     if (dev->log_call_fd != -1) {
600         close(dev->log_call_fd);
601         dev->log_call_fd = -1;
602     }
603 }
604 
605 static bool
606 vu_reset_device_exec(VuDev *dev, VhostUserMsg *vmsg)
607 {
608     vu_set_enable_all_rings(dev, false);
609 
610     return false;
611 }
612 
613 static bool
614 map_ring(VuDev *dev, VuVirtq *vq)
615 {
616     vq->vring.desc = qva_to_va(dev, vq->vra.desc_user_addr);
617     vq->vring.used = qva_to_va(dev, vq->vra.used_user_addr);
618     vq->vring.avail = qva_to_va(dev, vq->vra.avail_user_addr);
619 
620     DPRINT("Setting virtq addresses:\n");
621     DPRINT("    vring_desc  at %p\n", vq->vring.desc);
622     DPRINT("    vring_used  at %p\n", vq->vring.used);
623     DPRINT("    vring_avail at %p\n", vq->vring.avail);
624 
625     return !(vq->vring.desc && vq->vring.used && vq->vring.avail);
626 }
627 
628 static bool
629 generate_faults(VuDev *dev) {
630     unsigned int i;
631     for (i = 0; i < dev->nregions; i++) {
632         VuDevRegion *dev_region = &dev->regions[i];
633         int ret;
634 #ifdef UFFDIO_REGISTER
635         struct uffdio_register reg_struct;
636 
637         /*
638          * We should already have an open ufd. Mark each memory
639          * range as ufd.
640          * Discard any mapping we have here; note I can't use MADV_REMOVE
641          * or fallocate to make the hole since I don't want to lose
642          * data that's already arrived in the shared process.
643          * TODO: How to do hugepage
644          */
645         ret = madvise((void *)(uintptr_t)dev_region->mmap_addr,
646                       dev_region->size + dev_region->mmap_offset,
647                       MADV_DONTNEED);
648         if (ret) {
649             fprintf(stderr,
650                     "%s: Failed to madvise(DONTNEED) region %d: %s\n",
651                     __func__, i, strerror(errno));
652         }
653         /*
654          * Turn off transparent hugepages so we dont get lose wakeups
655          * in neighbouring pages.
656          * TODO: Turn this backon later.
657          */
658         ret = madvise((void *)(uintptr_t)dev_region->mmap_addr,
659                       dev_region->size + dev_region->mmap_offset,
660                       MADV_NOHUGEPAGE);
661         if (ret) {
662             /*
663              * Note: This can happen legally on kernels that are configured
664              * without madvise'able hugepages
665              */
666             fprintf(stderr,
667                     "%s: Failed to madvise(NOHUGEPAGE) region %d: %s\n",
668                     __func__, i, strerror(errno));
669         }
670 
671         reg_struct.range.start = (uintptr_t)dev_region->mmap_addr;
672         reg_struct.range.len = dev_region->size + dev_region->mmap_offset;
673         reg_struct.mode = UFFDIO_REGISTER_MODE_MISSING;
674 
675         if (ioctl(dev->postcopy_ufd, UFFDIO_REGISTER, &reg_struct)) {
676             vu_panic(dev, "%s: Failed to userfault region %d "
677                           "@%" PRIx64 " + size:%" PRIx64 " offset: %" PRIx64
678                           ": (ufd=%d)%s\n",
679                      __func__, i,
680                      dev_region->mmap_addr,
681                      dev_region->size, dev_region->mmap_offset,
682                      dev->postcopy_ufd, strerror(errno));
683             return false;
684         }
685         if (!(reg_struct.ioctls & ((__u64)1 << _UFFDIO_COPY))) {
686             vu_panic(dev, "%s Region (%d) doesn't support COPY",
687                      __func__, i);
688             return false;
689         }
690         DPRINT("%s: region %d: Registered userfault for %"
691                PRIx64 " + %" PRIx64 "\n", __func__, i,
692                (uint64_t)reg_struct.range.start,
693                (uint64_t)reg_struct.range.len);
694         /* Now it's registered we can let the client at it */
695         if (mprotect((void *)(uintptr_t)dev_region->mmap_addr,
696                      dev_region->size + dev_region->mmap_offset,
697                      PROT_READ | PROT_WRITE)) {
698             vu_panic(dev, "failed to mprotect region %d for postcopy (%s)",
699                      i, strerror(errno));
700             return false;
701         }
702         /* TODO: Stash 'zero' support flags somewhere */
703 #endif
704     }
705 
706     return true;
707 }
708 
709 static bool
710 vu_add_mem_reg(VuDev *dev, VhostUserMsg *vmsg) {
711     int i;
712     bool track_ramblocks = dev->postcopy_listening;
713     VhostUserMemoryRegion m = vmsg->payload.memreg.region, *msg_region = &m;
714     VuDevRegion *dev_region = &dev->regions[dev->nregions];
715     void *mmap_addr;
716 
717     if (vmsg->fd_num != 1) {
718         vmsg_close_fds(vmsg);
719         vu_panic(dev, "VHOST_USER_ADD_MEM_REG received %d fds - only 1 fd "
720                       "should be sent for this message type", vmsg->fd_num);
721         return false;
722     }
723 
724     if (vmsg->size < VHOST_USER_MEM_REG_SIZE) {
725         close(vmsg->fds[0]);
726         vu_panic(dev, "VHOST_USER_ADD_MEM_REG requires a message size of at "
727                       "least %zu bytes and only %d bytes were received",
728                       VHOST_USER_MEM_REG_SIZE, vmsg->size);
729         return false;
730     }
731 
732     if (dev->nregions == VHOST_USER_MAX_RAM_SLOTS) {
733         close(vmsg->fds[0]);
734         vu_panic(dev, "failing attempt to hot add memory via "
735                       "VHOST_USER_ADD_MEM_REG message because the backend has "
736                       "no free ram slots available");
737         return false;
738     }
739 
740     /*
741      * If we are in postcopy mode and we receive a u64 payload with a 0 value
742      * we know all the postcopy client bases have been received, and we
743      * should start generating faults.
744      */
745     if (track_ramblocks &&
746         vmsg->size == sizeof(vmsg->payload.u64) &&
747         vmsg->payload.u64 == 0) {
748         (void)generate_faults(dev);
749         return false;
750     }
751 
752     DPRINT("Adding region: %u\n", dev->nregions);
753     DPRINT("    guest_phys_addr: 0x%016"PRIx64"\n",
754            msg_region->guest_phys_addr);
755     DPRINT("    memory_size:     0x%016"PRIx64"\n",
756            msg_region->memory_size);
757     DPRINT("    userspace_addr   0x%016"PRIx64"\n",
758            msg_region->userspace_addr);
759     DPRINT("    mmap_offset      0x%016"PRIx64"\n",
760            msg_region->mmap_offset);
761 
762     dev_region->gpa = msg_region->guest_phys_addr;
763     dev_region->size = msg_region->memory_size;
764     dev_region->qva = msg_region->userspace_addr;
765     dev_region->mmap_offset = msg_region->mmap_offset;
766 
767     /*
768      * We don't use offset argument of mmap() since the
769      * mapped address has to be page aligned, and we use huge
770      * pages.
771      */
772     if (track_ramblocks) {
773         /*
774          * In postcopy we're using PROT_NONE here to catch anyone
775          * accessing it before we userfault.
776          */
777         mmap_addr = mmap(0, dev_region->size + dev_region->mmap_offset,
778                          PROT_NONE, MAP_SHARED | MAP_NORESERVE,
779                          vmsg->fds[0], 0);
780     } else {
781         mmap_addr = mmap(0, dev_region->size + dev_region->mmap_offset,
782                          PROT_READ | PROT_WRITE, MAP_SHARED | MAP_NORESERVE,
783                          vmsg->fds[0], 0);
784     }
785 
786     if (mmap_addr == MAP_FAILED) {
787         vu_panic(dev, "region mmap error: %s", strerror(errno));
788     } else {
789         dev_region->mmap_addr = (uint64_t)(uintptr_t)mmap_addr;
790         DPRINT("    mmap_addr:       0x%016"PRIx64"\n",
791                dev_region->mmap_addr);
792     }
793 
794     close(vmsg->fds[0]);
795 
796     if (track_ramblocks) {
797         /*
798          * Return the address to QEMU so that it can translate the ufd
799          * fault addresses back.
800          */
801         msg_region->userspace_addr = (uintptr_t)(mmap_addr +
802                                                  dev_region->mmap_offset);
803 
804         /* Send the message back to qemu with the addresses filled in. */
805         vmsg->fd_num = 0;
806         DPRINT("Successfully added new region in postcopy\n");
807         dev->nregions++;
808         return true;
809     } else {
810         for (i = 0; i < dev->max_queues; i++) {
811             if (dev->vq[i].vring.desc) {
812                 if (map_ring(dev, &dev->vq[i])) {
813                     vu_panic(dev, "remapping queue %d for new memory region",
814                              i);
815                 }
816             }
817         }
818 
819         DPRINT("Successfully added new region\n");
820         dev->nregions++;
821         return false;
822     }
823 }
824 
825 static inline bool reg_equal(VuDevRegion *vudev_reg,
826                              VhostUserMemoryRegion *msg_reg)
827 {
828     if (vudev_reg->gpa == msg_reg->guest_phys_addr &&
829         vudev_reg->qva == msg_reg->userspace_addr &&
830         vudev_reg->size == msg_reg->memory_size) {
831         return true;
832     }
833 
834     return false;
835 }
836 
837 static bool
838 vu_rem_mem_reg(VuDev *dev, VhostUserMsg *vmsg) {
839     VhostUserMemoryRegion m = vmsg->payload.memreg.region, *msg_region = &m;
840     unsigned int i;
841     bool found = false;
842 
843     if (vmsg->fd_num > 1) {
844         vmsg_close_fds(vmsg);
845         vu_panic(dev, "VHOST_USER_REM_MEM_REG received %d fds - at most 1 fd "
846                       "should be sent for this message type", vmsg->fd_num);
847         return false;
848     }
849 
850     if (vmsg->size < VHOST_USER_MEM_REG_SIZE) {
851         vmsg_close_fds(vmsg);
852         vu_panic(dev, "VHOST_USER_REM_MEM_REG requires a message size of at "
853                       "least %zu bytes and only %d bytes were received",
854                       VHOST_USER_MEM_REG_SIZE, vmsg->size);
855         return false;
856     }
857 
858     DPRINT("Removing region:\n");
859     DPRINT("    guest_phys_addr: 0x%016"PRIx64"\n",
860            msg_region->guest_phys_addr);
861     DPRINT("    memory_size:     0x%016"PRIx64"\n",
862            msg_region->memory_size);
863     DPRINT("    userspace_addr   0x%016"PRIx64"\n",
864            msg_region->userspace_addr);
865     DPRINT("    mmap_offset      0x%016"PRIx64"\n",
866            msg_region->mmap_offset);
867 
868     for (i = 0; i < dev->nregions; i++) {
869         if (reg_equal(&dev->regions[i], msg_region)) {
870             VuDevRegion *r = &dev->regions[i];
871             void *m = (void *) (uintptr_t) r->mmap_addr;
872 
873             if (m) {
874                 munmap(m, r->size + r->mmap_offset);
875             }
876 
877             /*
878              * Shift all affected entries by 1 to close the hole at index i and
879              * zero out the last entry.
880              */
881             memmove(dev->regions + i, dev->regions + i + 1,
882                     sizeof(VuDevRegion) * (dev->nregions - i - 1));
883             memset(dev->regions + dev->nregions - 1, 0, sizeof(VuDevRegion));
884             DPRINT("Successfully removed a region\n");
885             dev->nregions--;
886             i--;
887 
888             found = true;
889 
890             /* Continue the search for eventual duplicates. */
891         }
892     }
893 
894     if (!found) {
895         vu_panic(dev, "Specified region not found\n");
896     }
897 
898     vmsg_close_fds(vmsg);
899 
900     return false;
901 }
902 
903 static bool
904 vu_set_mem_table_exec_postcopy(VuDev *dev, VhostUserMsg *vmsg)
905 {
906     unsigned int i;
907     VhostUserMemory m = vmsg->payload.memory, *memory = &m;
908     dev->nregions = memory->nregions;
909 
910     DPRINT("Nregions: %u\n", memory->nregions);
911     for (i = 0; i < dev->nregions; i++) {
912         void *mmap_addr;
913         VhostUserMemoryRegion *msg_region = &memory->regions[i];
914         VuDevRegion *dev_region = &dev->regions[i];
915 
916         DPRINT("Region %d\n", i);
917         DPRINT("    guest_phys_addr: 0x%016"PRIx64"\n",
918                msg_region->guest_phys_addr);
919         DPRINT("    memory_size:     0x%016"PRIx64"\n",
920                msg_region->memory_size);
921         DPRINT("    userspace_addr   0x%016"PRIx64"\n",
922                msg_region->userspace_addr);
923         DPRINT("    mmap_offset      0x%016"PRIx64"\n",
924                msg_region->mmap_offset);
925 
926         dev_region->gpa = msg_region->guest_phys_addr;
927         dev_region->size = msg_region->memory_size;
928         dev_region->qva = msg_region->userspace_addr;
929         dev_region->mmap_offset = msg_region->mmap_offset;
930 
931         /* We don't use offset argument of mmap() since the
932          * mapped address has to be page aligned, and we use huge
933          * pages.
934          * In postcopy we're using PROT_NONE here to catch anyone
935          * accessing it before we userfault
936          */
937         mmap_addr = mmap(0, dev_region->size + dev_region->mmap_offset,
938                          PROT_NONE, MAP_SHARED | MAP_NORESERVE,
939                          vmsg->fds[i], 0);
940 
941         if (mmap_addr == MAP_FAILED) {
942             vu_panic(dev, "region mmap error: %s", strerror(errno));
943         } else {
944             dev_region->mmap_addr = (uint64_t)(uintptr_t)mmap_addr;
945             DPRINT("    mmap_addr:       0x%016"PRIx64"\n",
946                    dev_region->mmap_addr);
947         }
948 
949         /* Return the address to QEMU so that it can translate the ufd
950          * fault addresses back.
951          */
952         msg_region->userspace_addr = (uintptr_t)(mmap_addr +
953                                                  dev_region->mmap_offset);
954         close(vmsg->fds[i]);
955     }
956 
957     /* Send the message back to qemu with the addresses filled in */
958     vmsg->fd_num = 0;
959     if (!vu_send_reply(dev, dev->sock, vmsg)) {
960         vu_panic(dev, "failed to respond to set-mem-table for postcopy");
961         return false;
962     }
963 
964     /* Wait for QEMU to confirm that it's registered the handler for the
965      * faults.
966      */
967     if (!dev->read_msg(dev, dev->sock, vmsg) ||
968         vmsg->size != sizeof(vmsg->payload.u64) ||
969         vmsg->payload.u64 != 0) {
970         vu_panic(dev, "failed to receive valid ack for postcopy set-mem-table");
971         return false;
972     }
973 
974     /* OK, now we can go and register the memory and generate faults */
975     (void)generate_faults(dev);
976 
977     return false;
978 }
979 
980 static bool
981 vu_set_mem_table_exec(VuDev *dev, VhostUserMsg *vmsg)
982 {
983     unsigned int i;
984     VhostUserMemory m = vmsg->payload.memory, *memory = &m;
985 
986     for (i = 0; i < dev->nregions; i++) {
987         VuDevRegion *r = &dev->regions[i];
988         void *m = (void *) (uintptr_t) r->mmap_addr;
989 
990         if (m) {
991             munmap(m, r->size + r->mmap_offset);
992         }
993     }
994     dev->nregions = memory->nregions;
995 
996     if (dev->postcopy_listening) {
997         return vu_set_mem_table_exec_postcopy(dev, vmsg);
998     }
999 
1000     DPRINT("Nregions: %u\n", memory->nregions);
1001     for (i = 0; i < dev->nregions; i++) {
1002         void *mmap_addr;
1003         VhostUserMemoryRegion *msg_region = &memory->regions[i];
1004         VuDevRegion *dev_region = &dev->regions[i];
1005 
1006         DPRINT("Region %d\n", i);
1007         DPRINT("    guest_phys_addr: 0x%016"PRIx64"\n",
1008                msg_region->guest_phys_addr);
1009         DPRINT("    memory_size:     0x%016"PRIx64"\n",
1010                msg_region->memory_size);
1011         DPRINT("    userspace_addr   0x%016"PRIx64"\n",
1012                msg_region->userspace_addr);
1013         DPRINT("    mmap_offset      0x%016"PRIx64"\n",
1014                msg_region->mmap_offset);
1015 
1016         dev_region->gpa = msg_region->guest_phys_addr;
1017         dev_region->size = msg_region->memory_size;
1018         dev_region->qva = msg_region->userspace_addr;
1019         dev_region->mmap_offset = msg_region->mmap_offset;
1020 
1021         /* We don't use offset argument of mmap() since the
1022          * mapped address has to be page aligned, and we use huge
1023          * pages.  */
1024         mmap_addr = mmap(0, dev_region->size + dev_region->mmap_offset,
1025                          PROT_READ | PROT_WRITE, MAP_SHARED | MAP_NORESERVE,
1026                          vmsg->fds[i], 0);
1027 
1028         if (mmap_addr == MAP_FAILED) {
1029             vu_panic(dev, "region mmap error: %s", strerror(errno));
1030         } else {
1031             dev_region->mmap_addr = (uint64_t)(uintptr_t)mmap_addr;
1032             DPRINT("    mmap_addr:       0x%016"PRIx64"\n",
1033                    dev_region->mmap_addr);
1034         }
1035 
1036         close(vmsg->fds[i]);
1037     }
1038 
1039     for (i = 0; i < dev->max_queues; i++) {
1040         if (dev->vq[i].vring.desc) {
1041             if (map_ring(dev, &dev->vq[i])) {
1042                 vu_panic(dev, "remapping queue %d during setmemtable", i);
1043             }
1044         }
1045     }
1046 
1047     return false;
1048 }
1049 
1050 static bool
1051 vu_set_log_base_exec(VuDev *dev, VhostUserMsg *vmsg)
1052 {
1053     int fd;
1054     uint64_t log_mmap_size, log_mmap_offset;
1055     void *rc;
1056 
1057     if (vmsg->fd_num != 1 ||
1058         vmsg->size != sizeof(vmsg->payload.log)) {
1059         vu_panic(dev, "Invalid log_base message");
1060         return true;
1061     }
1062 
1063     fd = vmsg->fds[0];
1064     log_mmap_offset = vmsg->payload.log.mmap_offset;
1065     log_mmap_size = vmsg->payload.log.mmap_size;
1066     DPRINT("Log mmap_offset: %"PRId64"\n", log_mmap_offset);
1067     DPRINT("Log mmap_size:   %"PRId64"\n", log_mmap_size);
1068 
1069     rc = mmap(0, log_mmap_size, PROT_READ | PROT_WRITE, MAP_SHARED, fd,
1070               log_mmap_offset);
1071     close(fd);
1072     if (rc == MAP_FAILED) {
1073         perror("log mmap error");
1074     }
1075 
1076     if (dev->log_table) {
1077         munmap(dev->log_table, dev->log_size);
1078     }
1079     dev->log_table = rc;
1080     dev->log_size = log_mmap_size;
1081 
1082     vmsg->size = sizeof(vmsg->payload.u64);
1083     vmsg->fd_num = 0;
1084 
1085     return true;
1086 }
1087 
1088 static bool
1089 vu_set_log_fd_exec(VuDev *dev, VhostUserMsg *vmsg)
1090 {
1091     if (vmsg->fd_num != 1) {
1092         vu_panic(dev, "Invalid log_fd message");
1093         return false;
1094     }
1095 
1096     if (dev->log_call_fd != -1) {
1097         close(dev->log_call_fd);
1098     }
1099     dev->log_call_fd = vmsg->fds[0];
1100     DPRINT("Got log_call_fd: %d\n", vmsg->fds[0]);
1101 
1102     return false;
1103 }
1104 
1105 static bool
1106 vu_set_vring_num_exec(VuDev *dev, VhostUserMsg *vmsg)
1107 {
1108     unsigned int index = vmsg->payload.state.index;
1109     unsigned int num = vmsg->payload.state.num;
1110 
1111     DPRINT("State.index: %u\n", index);
1112     DPRINT("State.num:   %u\n", num);
1113     dev->vq[index].vring.num = num;
1114 
1115     return false;
1116 }
1117 
1118 static bool
1119 vu_set_vring_addr_exec(VuDev *dev, VhostUserMsg *vmsg)
1120 {
1121     struct vhost_vring_addr addr = vmsg->payload.addr, *vra = &addr;
1122     unsigned int index = vra->index;
1123     VuVirtq *vq = &dev->vq[index];
1124 
1125     DPRINT("vhost_vring_addr:\n");
1126     DPRINT("    index:  %d\n", vra->index);
1127     DPRINT("    flags:  %d\n", vra->flags);
1128     DPRINT("    desc_user_addr:   0x%016" PRIx64 "\n", (uint64_t)vra->desc_user_addr);
1129     DPRINT("    used_user_addr:   0x%016" PRIx64 "\n", (uint64_t)vra->used_user_addr);
1130     DPRINT("    avail_user_addr:  0x%016" PRIx64 "\n", (uint64_t)vra->avail_user_addr);
1131     DPRINT("    log_guest_addr:   0x%016" PRIx64 "\n", (uint64_t)vra->log_guest_addr);
1132 
1133     vq->vra = *vra;
1134     vq->vring.flags = vra->flags;
1135     vq->vring.log_guest_addr = vra->log_guest_addr;
1136 
1137 
1138     if (map_ring(dev, vq)) {
1139         vu_panic(dev, "Invalid vring_addr message");
1140         return false;
1141     }
1142 
1143     vq->used_idx = le16toh(vq->vring.used->idx);
1144 
1145     if (vq->last_avail_idx != vq->used_idx) {
1146         bool resume = dev->iface->queue_is_processed_in_order &&
1147             dev->iface->queue_is_processed_in_order(dev, index);
1148 
1149         DPRINT("Last avail index != used index: %u != %u%s\n",
1150                vq->last_avail_idx, vq->used_idx,
1151                resume ? ", resuming" : "");
1152 
1153         if (resume) {
1154             vq->shadow_avail_idx = vq->last_avail_idx = vq->used_idx;
1155         }
1156     }
1157 
1158     return false;
1159 }
1160 
1161 static bool
1162 vu_set_vring_base_exec(VuDev *dev, VhostUserMsg *vmsg)
1163 {
1164     unsigned int index = vmsg->payload.state.index;
1165     unsigned int num = vmsg->payload.state.num;
1166 
1167     DPRINT("State.index: %u\n", index);
1168     DPRINT("State.num:   %u\n", num);
1169     dev->vq[index].shadow_avail_idx = dev->vq[index].last_avail_idx = num;
1170 
1171     return false;
1172 }
1173 
1174 static bool
1175 vu_get_vring_base_exec(VuDev *dev, VhostUserMsg *vmsg)
1176 {
1177     unsigned int index = vmsg->payload.state.index;
1178 
1179     DPRINT("State.index: %u\n", index);
1180     vmsg->payload.state.num = dev->vq[index].last_avail_idx;
1181     vmsg->size = sizeof(vmsg->payload.state);
1182 
1183     dev->vq[index].started = false;
1184     if (dev->iface->queue_set_started) {
1185         dev->iface->queue_set_started(dev, index, false);
1186     }
1187 
1188     if (dev->vq[index].call_fd != -1) {
1189         close(dev->vq[index].call_fd);
1190         dev->vq[index].call_fd = -1;
1191     }
1192     if (dev->vq[index].kick_fd != -1) {
1193         dev->remove_watch(dev, dev->vq[index].kick_fd);
1194         close(dev->vq[index].kick_fd);
1195         dev->vq[index].kick_fd = -1;
1196     }
1197 
1198     return true;
1199 }
1200 
1201 static bool
1202 vu_check_queue_msg_file(VuDev *dev, VhostUserMsg *vmsg)
1203 {
1204     int index = vmsg->payload.u64 & VHOST_USER_VRING_IDX_MASK;
1205     bool nofd = vmsg->payload.u64 & VHOST_USER_VRING_NOFD_MASK;
1206 
1207     if (index >= dev->max_queues) {
1208         vmsg_close_fds(vmsg);
1209         vu_panic(dev, "Invalid queue index: %u", index);
1210         return false;
1211     }
1212 
1213     if (nofd) {
1214         vmsg_close_fds(vmsg);
1215         return true;
1216     }
1217 
1218     if (vmsg->fd_num != 1) {
1219         vmsg_close_fds(vmsg);
1220         vu_panic(dev, "Invalid fds in request: %d", vmsg->request);
1221         return false;
1222     }
1223 
1224     return true;
1225 }
1226 
1227 static int
1228 inflight_desc_compare(const void *a, const void *b)
1229 {
1230     VuVirtqInflightDesc *desc0 = (VuVirtqInflightDesc *)a,
1231                         *desc1 = (VuVirtqInflightDesc *)b;
1232 
1233     if (desc1->counter > desc0->counter &&
1234         (desc1->counter - desc0->counter) < VIRTQUEUE_MAX_SIZE * 2) {
1235         return 1;
1236     }
1237 
1238     return -1;
1239 }
1240 
1241 static int
1242 vu_check_queue_inflights(VuDev *dev, VuVirtq *vq)
1243 {
1244     int i = 0;
1245 
1246     if (!vu_has_protocol_feature(dev, VHOST_USER_PROTOCOL_F_INFLIGHT_SHMFD)) {
1247         return 0;
1248     }
1249 
1250     if (unlikely(!vq->inflight)) {
1251         return -1;
1252     }
1253 
1254     if (unlikely(!vq->inflight->version)) {
1255         /* initialize the buffer */
1256         vq->inflight->version = INFLIGHT_VERSION;
1257         return 0;
1258     }
1259 
1260     vq->used_idx = le16toh(vq->vring.used->idx);
1261     vq->resubmit_num = 0;
1262     vq->resubmit_list = NULL;
1263     vq->counter = 0;
1264 
1265     if (unlikely(vq->inflight->used_idx != vq->used_idx)) {
1266         vq->inflight->desc[vq->inflight->last_batch_head].inflight = 0;
1267 
1268         barrier();
1269 
1270         vq->inflight->used_idx = vq->used_idx;
1271     }
1272 
1273     for (i = 0; i < vq->inflight->desc_num; i++) {
1274         if (vq->inflight->desc[i].inflight == 1) {
1275             vq->inuse++;
1276         }
1277     }
1278 
1279     vq->shadow_avail_idx = vq->last_avail_idx = vq->inuse + vq->used_idx;
1280 
1281     if (vq->inuse) {
1282         vq->resubmit_list = calloc(vq->inuse, sizeof(VuVirtqInflightDesc));
1283         if (!vq->resubmit_list) {
1284             return -1;
1285         }
1286 
1287         for (i = 0; i < vq->inflight->desc_num; i++) {
1288             if (vq->inflight->desc[i].inflight) {
1289                 vq->resubmit_list[vq->resubmit_num].index = i;
1290                 vq->resubmit_list[vq->resubmit_num].counter =
1291                                         vq->inflight->desc[i].counter;
1292                 vq->resubmit_num++;
1293             }
1294         }
1295 
1296         if (vq->resubmit_num > 1) {
1297             qsort(vq->resubmit_list, vq->resubmit_num,
1298                   sizeof(VuVirtqInflightDesc), inflight_desc_compare);
1299         }
1300         vq->counter = vq->resubmit_list[0].counter + 1;
1301     }
1302 
1303     /* in case of I/O hang after reconnecting */
1304     if (eventfd_write(vq->kick_fd, 1)) {
1305         return -1;
1306     }
1307 
1308     return 0;
1309 }
1310 
1311 static bool
1312 vu_set_vring_kick_exec(VuDev *dev, VhostUserMsg *vmsg)
1313 {
1314     int index = vmsg->payload.u64 & VHOST_USER_VRING_IDX_MASK;
1315     bool nofd = vmsg->payload.u64 & VHOST_USER_VRING_NOFD_MASK;
1316 
1317     DPRINT("u64: 0x%016"PRIx64"\n", vmsg->payload.u64);
1318 
1319     if (!vu_check_queue_msg_file(dev, vmsg)) {
1320         return false;
1321     }
1322 
1323     if (dev->vq[index].kick_fd != -1) {
1324         dev->remove_watch(dev, dev->vq[index].kick_fd);
1325         close(dev->vq[index].kick_fd);
1326         dev->vq[index].kick_fd = -1;
1327     }
1328 
1329     dev->vq[index].kick_fd = nofd ? -1 : vmsg->fds[0];
1330     DPRINT("Got kick_fd: %d for vq: %d\n", dev->vq[index].kick_fd, index);
1331 
1332     dev->vq[index].started = true;
1333     if (dev->iface->queue_set_started) {
1334         dev->iface->queue_set_started(dev, index, true);
1335     }
1336 
1337     if (dev->vq[index].kick_fd != -1 && dev->vq[index].handler) {
1338         dev->set_watch(dev, dev->vq[index].kick_fd, VU_WATCH_IN,
1339                        vu_kick_cb, (void *)(long)index);
1340 
1341         DPRINT("Waiting for kicks on fd: %d for vq: %d\n",
1342                dev->vq[index].kick_fd, index);
1343     }
1344 
1345     if (vu_check_queue_inflights(dev, &dev->vq[index])) {
1346         vu_panic(dev, "Failed to check inflights for vq: %d\n", index);
1347     }
1348 
1349     return false;
1350 }
1351 
1352 void vu_set_queue_handler(VuDev *dev, VuVirtq *vq,
1353                           vu_queue_handler_cb handler)
1354 {
1355     int qidx = vq - dev->vq;
1356 
1357     vq->handler = handler;
1358     if (vq->kick_fd >= 0) {
1359         if (handler) {
1360             dev->set_watch(dev, vq->kick_fd, VU_WATCH_IN,
1361                            vu_kick_cb, (void *)(long)qidx);
1362         } else {
1363             dev->remove_watch(dev, vq->kick_fd);
1364         }
1365     }
1366 }
1367 
1368 bool vu_set_queue_host_notifier(VuDev *dev, VuVirtq *vq, int fd,
1369                                 int size, int offset)
1370 {
1371     int qidx = vq - dev->vq;
1372     int fd_num = 0;
1373     VhostUserMsg vmsg = {
1374         .request = VHOST_USER_BACKEND_VRING_HOST_NOTIFIER_MSG,
1375         .flags = VHOST_USER_VERSION | VHOST_USER_NEED_REPLY_MASK,
1376         .size = sizeof(vmsg.payload.area),
1377         .payload.area = {
1378             .u64 = qidx & VHOST_USER_VRING_IDX_MASK,
1379             .size = size,
1380             .offset = offset,
1381         },
1382     };
1383 
1384     if (fd == -1) {
1385         vmsg.payload.area.u64 |= VHOST_USER_VRING_NOFD_MASK;
1386     } else {
1387         vmsg.fds[fd_num++] = fd;
1388     }
1389 
1390     vmsg.fd_num = fd_num;
1391 
1392     if (!vu_has_protocol_feature(dev, VHOST_USER_PROTOCOL_F_BACKEND_SEND_FD)) {
1393         return false;
1394     }
1395 
1396     pthread_mutex_lock(&dev->backend_mutex);
1397     if (!vu_message_write(dev, dev->backend_fd, &vmsg)) {
1398         pthread_mutex_unlock(&dev->backend_mutex);
1399         return false;
1400     }
1401 
1402     /* Also unlocks the backend_mutex */
1403     return vu_process_message_reply(dev, &vmsg);
1404 }
1405 
1406 static bool
1407 vu_set_vring_call_exec(VuDev *dev, VhostUserMsg *vmsg)
1408 {
1409     int index = vmsg->payload.u64 & VHOST_USER_VRING_IDX_MASK;
1410     bool nofd = vmsg->payload.u64 & VHOST_USER_VRING_NOFD_MASK;
1411 
1412     DPRINT("u64: 0x%016"PRIx64"\n", vmsg->payload.u64);
1413 
1414     if (!vu_check_queue_msg_file(dev, vmsg)) {
1415         return false;
1416     }
1417 
1418     if (dev->vq[index].call_fd != -1) {
1419         close(dev->vq[index].call_fd);
1420         dev->vq[index].call_fd = -1;
1421     }
1422 
1423     dev->vq[index].call_fd = nofd ? -1 : vmsg->fds[0];
1424 
1425     /* in case of I/O hang after reconnecting */
1426     if (dev->vq[index].call_fd != -1 && eventfd_write(vmsg->fds[0], 1)) {
1427         return -1;
1428     }
1429 
1430     DPRINT("Got call_fd: %d for vq: %d\n", dev->vq[index].call_fd, index);
1431 
1432     return false;
1433 }
1434 
1435 static bool
1436 vu_set_vring_err_exec(VuDev *dev, VhostUserMsg *vmsg)
1437 {
1438     int index = vmsg->payload.u64 & VHOST_USER_VRING_IDX_MASK;
1439     bool nofd = vmsg->payload.u64 & VHOST_USER_VRING_NOFD_MASK;
1440 
1441     DPRINT("u64: 0x%016"PRIx64"\n", vmsg->payload.u64);
1442 
1443     if (!vu_check_queue_msg_file(dev, vmsg)) {
1444         return false;
1445     }
1446 
1447     if (dev->vq[index].err_fd != -1) {
1448         close(dev->vq[index].err_fd);
1449         dev->vq[index].err_fd = -1;
1450     }
1451 
1452     dev->vq[index].err_fd = nofd ? -1 : vmsg->fds[0];
1453 
1454     return false;
1455 }
1456 
1457 static bool
1458 vu_get_protocol_features_exec(VuDev *dev, VhostUserMsg *vmsg)
1459 {
1460     /*
1461      * Note that we support, but intentionally do not set,
1462      * VHOST_USER_PROTOCOL_F_INBAND_NOTIFICATIONS. This means that
1463      * a device implementation can return it in its callback
1464      * (get_protocol_features) if it wants to use this for
1465      * simulation, but it is otherwise not desirable (if even
1466      * implemented by the frontend.)
1467      */
1468     uint64_t features = 1ULL << VHOST_USER_PROTOCOL_F_MQ |
1469                         1ULL << VHOST_USER_PROTOCOL_F_LOG_SHMFD |
1470                         1ULL << VHOST_USER_PROTOCOL_F_BACKEND_REQ |
1471                         1ULL << VHOST_USER_PROTOCOL_F_HOST_NOTIFIER |
1472                         1ULL << VHOST_USER_PROTOCOL_F_BACKEND_SEND_FD |
1473                         1ULL << VHOST_USER_PROTOCOL_F_REPLY_ACK |
1474                         1ULL << VHOST_USER_PROTOCOL_F_CONFIGURE_MEM_SLOTS;
1475 
1476     if (have_userfault()) {
1477         features |= 1ULL << VHOST_USER_PROTOCOL_F_PAGEFAULT;
1478     }
1479 
1480     if (dev->iface->get_config && dev->iface->set_config) {
1481         features |= 1ULL << VHOST_USER_PROTOCOL_F_CONFIG;
1482     }
1483 
1484     if (dev->iface->get_protocol_features) {
1485         features |= dev->iface->get_protocol_features(dev);
1486     }
1487 
1488     vmsg_set_reply_u64(vmsg, features);
1489     return true;
1490 }
1491 
1492 static bool
1493 vu_set_protocol_features_exec(VuDev *dev, VhostUserMsg *vmsg)
1494 {
1495     uint64_t features = vmsg->payload.u64;
1496 
1497     DPRINT("u64: 0x%016"PRIx64"\n", features);
1498 
1499     dev->protocol_features = vmsg->payload.u64;
1500 
1501     if (vu_has_protocol_feature(dev,
1502                                 VHOST_USER_PROTOCOL_F_INBAND_NOTIFICATIONS) &&
1503         (!vu_has_protocol_feature(dev, VHOST_USER_PROTOCOL_F_BACKEND_REQ) ||
1504          !vu_has_protocol_feature(dev, VHOST_USER_PROTOCOL_F_REPLY_ACK))) {
1505         /*
1506          * The use case for using messages for kick/call is simulation, to make
1507          * the kick and call synchronous. To actually get that behaviour, both
1508          * of the other features are required.
1509          * Theoretically, one could use only kick messages, or do them without
1510          * having F_REPLY_ACK, but too many (possibly pending) messages on the
1511          * socket will eventually cause the frontend to hang, to avoid this in
1512          * scenarios where not desired enforce that the settings are in a way
1513          * that actually enables the simulation case.
1514          */
1515         vu_panic(dev,
1516                  "F_IN_BAND_NOTIFICATIONS requires F_BACKEND_REQ && F_REPLY_ACK");
1517         return false;
1518     }
1519 
1520     if (dev->iface->set_protocol_features) {
1521         dev->iface->set_protocol_features(dev, features);
1522     }
1523 
1524     return false;
1525 }
1526 
1527 static bool
1528 vu_get_queue_num_exec(VuDev *dev, VhostUserMsg *vmsg)
1529 {
1530     vmsg_set_reply_u64(vmsg, dev->max_queues);
1531     return true;
1532 }
1533 
1534 static bool
1535 vu_set_vring_enable_exec(VuDev *dev, VhostUserMsg *vmsg)
1536 {
1537     unsigned int index = vmsg->payload.state.index;
1538     unsigned int enable = vmsg->payload.state.num;
1539 
1540     DPRINT("State.index: %u\n", index);
1541     DPRINT("State.enable:   %u\n", enable);
1542 
1543     if (index >= dev->max_queues) {
1544         vu_panic(dev, "Invalid vring_enable index: %u", index);
1545         return false;
1546     }
1547 
1548     dev->vq[index].enable = enable;
1549     return false;
1550 }
1551 
1552 static bool
1553 vu_set_backend_req_fd(VuDev *dev, VhostUserMsg *vmsg)
1554 {
1555     if (vmsg->fd_num != 1) {
1556         vu_panic(dev, "Invalid backend_req_fd message (%d fd's)", vmsg->fd_num);
1557         return false;
1558     }
1559 
1560     if (dev->backend_fd != -1) {
1561         close(dev->backend_fd);
1562     }
1563     dev->backend_fd = vmsg->fds[0];
1564     DPRINT("Got backend_fd: %d\n", vmsg->fds[0]);
1565 
1566     return false;
1567 }
1568 
1569 static bool
1570 vu_get_config(VuDev *dev, VhostUserMsg *vmsg)
1571 {
1572     int ret = -1;
1573 
1574     if (dev->iface->get_config) {
1575         ret = dev->iface->get_config(dev, vmsg->payload.config.region,
1576                                      vmsg->payload.config.size);
1577     }
1578 
1579     if (ret) {
1580         /* resize to zero to indicate an error to frontend */
1581         vmsg->size = 0;
1582     }
1583 
1584     return true;
1585 }
1586 
1587 static bool
1588 vu_set_config(VuDev *dev, VhostUserMsg *vmsg)
1589 {
1590     int ret = -1;
1591 
1592     if (dev->iface->set_config) {
1593         ret = dev->iface->set_config(dev, vmsg->payload.config.region,
1594                                      vmsg->payload.config.offset,
1595                                      vmsg->payload.config.size,
1596                                      vmsg->payload.config.flags);
1597         if (ret) {
1598             vu_panic(dev, "Set virtio configuration space failed");
1599         }
1600     }
1601 
1602     return false;
1603 }
1604 
1605 static bool
1606 vu_set_postcopy_advise(VuDev *dev, VhostUserMsg *vmsg)
1607 {
1608 #ifdef UFFDIO_API
1609     struct uffdio_api api_struct;
1610 
1611     dev->postcopy_ufd = syscall(__NR_userfaultfd, O_CLOEXEC | O_NONBLOCK);
1612     vmsg->size = 0;
1613 #else
1614     dev->postcopy_ufd = -1;
1615 #endif
1616 
1617     if (dev->postcopy_ufd == -1) {
1618         vu_panic(dev, "Userfaultfd not available: %s", strerror(errno));
1619         goto out;
1620     }
1621 
1622 #ifdef UFFDIO_API
1623     api_struct.api = UFFD_API;
1624     api_struct.features = 0;
1625     if (ioctl(dev->postcopy_ufd, UFFDIO_API, &api_struct)) {
1626         vu_panic(dev, "Failed UFFDIO_API: %s", strerror(errno));
1627         close(dev->postcopy_ufd);
1628         dev->postcopy_ufd = -1;
1629         goto out;
1630     }
1631     /* TODO: Stash feature flags somewhere */
1632 #endif
1633 
1634 out:
1635     /* Return a ufd to the QEMU */
1636     vmsg->fd_num = 1;
1637     vmsg->fds[0] = dev->postcopy_ufd;
1638     return true; /* = send a reply */
1639 }
1640 
1641 static bool
1642 vu_set_postcopy_listen(VuDev *dev, VhostUserMsg *vmsg)
1643 {
1644     if (dev->nregions) {
1645         vu_panic(dev, "Regions already registered at postcopy-listen");
1646         vmsg_set_reply_u64(vmsg, -1);
1647         return true;
1648     }
1649     dev->postcopy_listening = true;
1650 
1651     vmsg_set_reply_u64(vmsg, 0);
1652     return true;
1653 }
1654 
1655 static bool
1656 vu_set_postcopy_end(VuDev *dev, VhostUserMsg *vmsg)
1657 {
1658     DPRINT("%s: Entry\n", __func__);
1659     dev->postcopy_listening = false;
1660     if (dev->postcopy_ufd > 0) {
1661         close(dev->postcopy_ufd);
1662         dev->postcopy_ufd = -1;
1663         DPRINT("%s: Done close\n", __func__);
1664     }
1665 
1666     vmsg_set_reply_u64(vmsg, 0);
1667     DPRINT("%s: exit\n", __func__);
1668     return true;
1669 }
1670 
1671 static inline uint64_t
1672 vu_inflight_queue_size(uint16_t queue_size)
1673 {
1674     return ALIGN_UP(sizeof(VuDescStateSplit) * queue_size +
1675            sizeof(uint16_t), INFLIGHT_ALIGNMENT);
1676 }
1677 
1678 #ifdef MFD_ALLOW_SEALING
1679 static void *
1680 memfd_alloc(const char *name, size_t size, unsigned int flags, int *fd)
1681 {
1682     void *ptr;
1683     int ret;
1684 
1685     *fd = memfd_create(name, MFD_ALLOW_SEALING);
1686     if (*fd < 0) {
1687         return NULL;
1688     }
1689 
1690     ret = ftruncate(*fd, size);
1691     if (ret < 0) {
1692         close(*fd);
1693         return NULL;
1694     }
1695 
1696     ret = fcntl(*fd, F_ADD_SEALS, flags);
1697     if (ret < 0) {
1698         close(*fd);
1699         return NULL;
1700     }
1701 
1702     ptr = mmap(0, size, PROT_READ | PROT_WRITE, MAP_SHARED, *fd, 0);
1703     if (ptr == MAP_FAILED) {
1704         close(*fd);
1705         return NULL;
1706     }
1707 
1708     return ptr;
1709 }
1710 #endif
1711 
1712 static bool
1713 vu_get_inflight_fd(VuDev *dev, VhostUserMsg *vmsg)
1714 {
1715     int fd = -1;
1716     void *addr = NULL;
1717     uint64_t mmap_size;
1718     uint16_t num_queues, queue_size;
1719 
1720     if (vmsg->size != sizeof(vmsg->payload.inflight)) {
1721         vu_panic(dev, "Invalid get_inflight_fd message:%d", vmsg->size);
1722         vmsg->payload.inflight.mmap_size = 0;
1723         return true;
1724     }
1725 
1726     num_queues = vmsg->payload.inflight.num_queues;
1727     queue_size = vmsg->payload.inflight.queue_size;
1728 
1729     DPRINT("set_inflight_fd num_queues: %"PRId16"\n", num_queues);
1730     DPRINT("set_inflight_fd queue_size: %"PRId16"\n", queue_size);
1731 
1732     mmap_size = vu_inflight_queue_size(queue_size) * num_queues;
1733 
1734 #ifdef MFD_ALLOW_SEALING
1735     addr = memfd_alloc("vhost-inflight", mmap_size,
1736                        F_SEAL_GROW | F_SEAL_SHRINK | F_SEAL_SEAL,
1737                        &fd);
1738 #else
1739     vu_panic(dev, "Not implemented: memfd support is missing");
1740 #endif
1741 
1742     if (!addr) {
1743         vu_panic(dev, "Failed to alloc vhost inflight area");
1744         vmsg->payload.inflight.mmap_size = 0;
1745         return true;
1746     }
1747 
1748     memset(addr, 0, mmap_size);
1749 
1750     dev->inflight_info.addr = addr;
1751     dev->inflight_info.size = vmsg->payload.inflight.mmap_size = mmap_size;
1752     dev->inflight_info.fd = vmsg->fds[0] = fd;
1753     vmsg->fd_num = 1;
1754     vmsg->payload.inflight.mmap_offset = 0;
1755 
1756     DPRINT("send inflight mmap_size: %"PRId64"\n",
1757            vmsg->payload.inflight.mmap_size);
1758     DPRINT("send inflight mmap offset: %"PRId64"\n",
1759            vmsg->payload.inflight.mmap_offset);
1760 
1761     return true;
1762 }
1763 
1764 static bool
1765 vu_set_inflight_fd(VuDev *dev, VhostUserMsg *vmsg)
1766 {
1767     int fd, i;
1768     uint64_t mmap_size, mmap_offset;
1769     uint16_t num_queues, queue_size;
1770     void *rc;
1771 
1772     if (vmsg->fd_num != 1 ||
1773         vmsg->size != sizeof(vmsg->payload.inflight)) {
1774         vu_panic(dev, "Invalid set_inflight_fd message size:%d fds:%d",
1775                  vmsg->size, vmsg->fd_num);
1776         return false;
1777     }
1778 
1779     fd = vmsg->fds[0];
1780     mmap_size = vmsg->payload.inflight.mmap_size;
1781     mmap_offset = vmsg->payload.inflight.mmap_offset;
1782     num_queues = vmsg->payload.inflight.num_queues;
1783     queue_size = vmsg->payload.inflight.queue_size;
1784 
1785     DPRINT("set_inflight_fd mmap_size: %"PRId64"\n", mmap_size);
1786     DPRINT("set_inflight_fd mmap_offset: %"PRId64"\n", mmap_offset);
1787     DPRINT("set_inflight_fd num_queues: %"PRId16"\n", num_queues);
1788     DPRINT("set_inflight_fd queue_size: %"PRId16"\n", queue_size);
1789 
1790     rc = mmap(0, mmap_size, PROT_READ | PROT_WRITE, MAP_SHARED,
1791               fd, mmap_offset);
1792 
1793     if (rc == MAP_FAILED) {
1794         vu_panic(dev, "set_inflight_fd mmap error: %s", strerror(errno));
1795         return false;
1796     }
1797 
1798     if (dev->inflight_info.fd) {
1799         close(dev->inflight_info.fd);
1800     }
1801 
1802     if (dev->inflight_info.addr) {
1803         munmap(dev->inflight_info.addr, dev->inflight_info.size);
1804     }
1805 
1806     dev->inflight_info.fd = fd;
1807     dev->inflight_info.addr = rc;
1808     dev->inflight_info.size = mmap_size;
1809 
1810     for (i = 0; i < num_queues; i++) {
1811         dev->vq[i].inflight = (VuVirtqInflight *)rc;
1812         dev->vq[i].inflight->desc_num = queue_size;
1813         rc = (void *)((char *)rc + vu_inflight_queue_size(queue_size));
1814     }
1815 
1816     return false;
1817 }
1818 
1819 static bool
1820 vu_handle_vring_kick(VuDev *dev, VhostUserMsg *vmsg)
1821 {
1822     unsigned int index = vmsg->payload.state.index;
1823 
1824     if (index >= dev->max_queues) {
1825         vu_panic(dev, "Invalid queue index: %u", index);
1826         return false;
1827     }
1828 
1829     DPRINT("Got kick message: handler:%p idx:%u\n",
1830            dev->vq[index].handler, index);
1831 
1832     if (!dev->vq[index].started) {
1833         dev->vq[index].started = true;
1834 
1835         if (dev->iface->queue_set_started) {
1836             dev->iface->queue_set_started(dev, index, true);
1837         }
1838     }
1839 
1840     if (dev->vq[index].handler) {
1841         dev->vq[index].handler(dev, index);
1842     }
1843 
1844     return false;
1845 }
1846 
1847 static bool vu_handle_get_max_memslots(VuDev *dev, VhostUserMsg *vmsg)
1848 {
1849     vmsg_set_reply_u64(vmsg, VHOST_USER_MAX_RAM_SLOTS);
1850 
1851     DPRINT("u64: 0x%016"PRIx64"\n", (uint64_t) VHOST_USER_MAX_RAM_SLOTS);
1852 
1853     return true;
1854 }
1855 
1856 static bool
1857 vu_process_message(VuDev *dev, VhostUserMsg *vmsg)
1858 {
1859     int do_reply = 0;
1860 
1861     /* Print out generic part of the request. */
1862     DPRINT("================ Vhost user message ================\n");
1863     DPRINT("Request: %s (%d)\n", vu_request_to_string(vmsg->request),
1864            vmsg->request);
1865     DPRINT("Flags:   0x%x\n", vmsg->flags);
1866     DPRINT("Size:    %u\n", vmsg->size);
1867 
1868     if (vmsg->fd_num) {
1869         int i;
1870         DPRINT("Fds:");
1871         for (i = 0; i < vmsg->fd_num; i++) {
1872             DPRINT(" %d", vmsg->fds[i]);
1873         }
1874         DPRINT("\n");
1875     }
1876 
1877     if (dev->iface->process_msg &&
1878         dev->iface->process_msg(dev, vmsg, &do_reply)) {
1879         return do_reply;
1880     }
1881 
1882     switch (vmsg->request) {
1883     case VHOST_USER_GET_FEATURES:
1884         return vu_get_features_exec(dev, vmsg);
1885     case VHOST_USER_SET_FEATURES:
1886         return vu_set_features_exec(dev, vmsg);
1887     case VHOST_USER_GET_PROTOCOL_FEATURES:
1888         return vu_get_protocol_features_exec(dev, vmsg);
1889     case VHOST_USER_SET_PROTOCOL_FEATURES:
1890         return vu_set_protocol_features_exec(dev, vmsg);
1891     case VHOST_USER_SET_OWNER:
1892         return vu_set_owner_exec(dev, vmsg);
1893     case VHOST_USER_RESET_OWNER:
1894         return vu_reset_device_exec(dev, vmsg);
1895     case VHOST_USER_SET_MEM_TABLE:
1896         return vu_set_mem_table_exec(dev, vmsg);
1897     case VHOST_USER_SET_LOG_BASE:
1898         return vu_set_log_base_exec(dev, vmsg);
1899     case VHOST_USER_SET_LOG_FD:
1900         return vu_set_log_fd_exec(dev, vmsg);
1901     case VHOST_USER_SET_VRING_NUM:
1902         return vu_set_vring_num_exec(dev, vmsg);
1903     case VHOST_USER_SET_VRING_ADDR:
1904         return vu_set_vring_addr_exec(dev, vmsg);
1905     case VHOST_USER_SET_VRING_BASE:
1906         return vu_set_vring_base_exec(dev, vmsg);
1907     case VHOST_USER_GET_VRING_BASE:
1908         return vu_get_vring_base_exec(dev, vmsg);
1909     case VHOST_USER_SET_VRING_KICK:
1910         return vu_set_vring_kick_exec(dev, vmsg);
1911     case VHOST_USER_SET_VRING_CALL:
1912         return vu_set_vring_call_exec(dev, vmsg);
1913     case VHOST_USER_SET_VRING_ERR:
1914         return vu_set_vring_err_exec(dev, vmsg);
1915     case VHOST_USER_GET_QUEUE_NUM:
1916         return vu_get_queue_num_exec(dev, vmsg);
1917     case VHOST_USER_SET_VRING_ENABLE:
1918         return vu_set_vring_enable_exec(dev, vmsg);
1919     case VHOST_USER_SET_BACKEND_REQ_FD:
1920         return vu_set_backend_req_fd(dev, vmsg);
1921     case VHOST_USER_GET_CONFIG:
1922         return vu_get_config(dev, vmsg);
1923     case VHOST_USER_SET_CONFIG:
1924         return vu_set_config(dev, vmsg);
1925     case VHOST_USER_NONE:
1926         /* if you need processing before exit, override iface->process_msg */
1927         exit(0);
1928     case VHOST_USER_POSTCOPY_ADVISE:
1929         return vu_set_postcopy_advise(dev, vmsg);
1930     case VHOST_USER_POSTCOPY_LISTEN:
1931         return vu_set_postcopy_listen(dev, vmsg);
1932     case VHOST_USER_POSTCOPY_END:
1933         return vu_set_postcopy_end(dev, vmsg);
1934     case VHOST_USER_GET_INFLIGHT_FD:
1935         return vu_get_inflight_fd(dev, vmsg);
1936     case VHOST_USER_SET_INFLIGHT_FD:
1937         return vu_set_inflight_fd(dev, vmsg);
1938     case VHOST_USER_VRING_KICK:
1939         return vu_handle_vring_kick(dev, vmsg);
1940     case VHOST_USER_GET_MAX_MEM_SLOTS:
1941         return vu_handle_get_max_memslots(dev, vmsg);
1942     case VHOST_USER_ADD_MEM_REG:
1943         return vu_add_mem_reg(dev, vmsg);
1944     case VHOST_USER_REM_MEM_REG:
1945         return vu_rem_mem_reg(dev, vmsg);
1946     default:
1947         vmsg_close_fds(vmsg);
1948         vu_panic(dev, "Unhandled request: %d", vmsg->request);
1949     }
1950 
1951     return false;
1952 }
1953 
1954 bool
1955 vu_dispatch(VuDev *dev)
1956 {
1957     VhostUserMsg vmsg = { 0, };
1958     int reply_requested;
1959     bool need_reply, success = false;
1960 
1961     if (!dev->read_msg(dev, dev->sock, &vmsg)) {
1962         goto end;
1963     }
1964 
1965     need_reply = vmsg.flags & VHOST_USER_NEED_REPLY_MASK;
1966 
1967     reply_requested = vu_process_message(dev, &vmsg);
1968     if (!reply_requested && need_reply) {
1969         vmsg_set_reply_u64(&vmsg, 0);
1970         reply_requested = 1;
1971     }
1972 
1973     if (!reply_requested) {
1974         success = true;
1975         goto end;
1976     }
1977 
1978     if (!vu_send_reply(dev, dev->sock, &vmsg)) {
1979         goto end;
1980     }
1981 
1982     success = true;
1983 
1984 end:
1985     free(vmsg.data);
1986     return success;
1987 }
1988 
1989 void
1990 vu_deinit(VuDev *dev)
1991 {
1992     unsigned int i;
1993 
1994     for (i = 0; i < dev->nregions; i++) {
1995         VuDevRegion *r = &dev->regions[i];
1996         void *m = (void *) (uintptr_t) r->mmap_addr;
1997         if (m != MAP_FAILED) {
1998             munmap(m, r->size + r->mmap_offset);
1999         }
2000     }
2001     dev->nregions = 0;
2002 
2003     for (i = 0; i < dev->max_queues; i++) {
2004         VuVirtq *vq = &dev->vq[i];
2005 
2006         if (vq->call_fd != -1) {
2007             close(vq->call_fd);
2008             vq->call_fd = -1;
2009         }
2010 
2011         if (vq->kick_fd != -1) {
2012             dev->remove_watch(dev, vq->kick_fd);
2013             close(vq->kick_fd);
2014             vq->kick_fd = -1;
2015         }
2016 
2017         if (vq->err_fd != -1) {
2018             close(vq->err_fd);
2019             vq->err_fd = -1;
2020         }
2021 
2022         if (vq->resubmit_list) {
2023             free(vq->resubmit_list);
2024             vq->resubmit_list = NULL;
2025         }
2026 
2027         vq->inflight = NULL;
2028     }
2029 
2030     if (dev->inflight_info.addr) {
2031         munmap(dev->inflight_info.addr, dev->inflight_info.size);
2032         dev->inflight_info.addr = NULL;
2033     }
2034 
2035     if (dev->inflight_info.fd > 0) {
2036         close(dev->inflight_info.fd);
2037         dev->inflight_info.fd = -1;
2038     }
2039 
2040     vu_close_log(dev);
2041     if (dev->backend_fd != -1) {
2042         close(dev->backend_fd);
2043         dev->backend_fd = -1;
2044     }
2045     pthread_mutex_destroy(&dev->backend_mutex);
2046 
2047     if (dev->sock != -1) {
2048         close(dev->sock);
2049     }
2050 
2051     free(dev->vq);
2052     dev->vq = NULL;
2053 }
2054 
2055 bool
2056 vu_init(VuDev *dev,
2057         uint16_t max_queues,
2058         int socket,
2059         vu_panic_cb panic,
2060         vu_read_msg_cb read_msg,
2061         vu_set_watch_cb set_watch,
2062         vu_remove_watch_cb remove_watch,
2063         const VuDevIface *iface)
2064 {
2065     uint16_t i;
2066 
2067     assert(max_queues > 0);
2068     assert(socket >= 0);
2069     assert(set_watch);
2070     assert(remove_watch);
2071     assert(iface);
2072     assert(panic);
2073 
2074     memset(dev, 0, sizeof(*dev));
2075 
2076     dev->sock = socket;
2077     dev->panic = panic;
2078     dev->read_msg = read_msg ? read_msg : vu_message_read_default;
2079     dev->set_watch = set_watch;
2080     dev->remove_watch = remove_watch;
2081     dev->iface = iface;
2082     dev->log_call_fd = -1;
2083     pthread_mutex_init(&dev->backend_mutex, NULL);
2084     dev->backend_fd = -1;
2085     dev->max_queues = max_queues;
2086 
2087     dev->vq = malloc(max_queues * sizeof(dev->vq[0]));
2088     if (!dev->vq) {
2089         DPRINT("%s: failed to malloc virtqueues\n", __func__);
2090         return false;
2091     }
2092 
2093     for (i = 0; i < max_queues; i++) {
2094         dev->vq[i] = (VuVirtq) {
2095             .call_fd = -1, .kick_fd = -1, .err_fd = -1,
2096             .notification = true,
2097         };
2098     }
2099 
2100     return true;
2101 }
2102 
2103 VuVirtq *
2104 vu_get_queue(VuDev *dev, int qidx)
2105 {
2106     assert(qidx < dev->max_queues);
2107     return &dev->vq[qidx];
2108 }
2109 
2110 bool
2111 vu_queue_enabled(VuDev *dev, VuVirtq *vq)
2112 {
2113     return vq->enable;
2114 }
2115 
2116 bool
2117 vu_queue_started(const VuDev *dev, const VuVirtq *vq)
2118 {
2119     return vq->started;
2120 }
2121 
2122 static inline uint16_t
2123 vring_avail_flags(VuVirtq *vq)
2124 {
2125     return le16toh(vq->vring.avail->flags);
2126 }
2127 
2128 static inline uint16_t
2129 vring_avail_idx(VuVirtq *vq)
2130 {
2131     vq->shadow_avail_idx = le16toh(vq->vring.avail->idx);
2132 
2133     return vq->shadow_avail_idx;
2134 }
2135 
2136 static inline uint16_t
2137 vring_avail_ring(VuVirtq *vq, int i)
2138 {
2139     return le16toh(vq->vring.avail->ring[i]);
2140 }
2141 
2142 static inline uint16_t
2143 vring_get_used_event(VuVirtq *vq)
2144 {
2145     return vring_avail_ring(vq, vq->vring.num);
2146 }
2147 
2148 static int
2149 virtqueue_num_heads(VuDev *dev, VuVirtq *vq, unsigned int idx)
2150 {
2151     uint16_t num_heads = vring_avail_idx(vq) - idx;
2152 
2153     /* Check it isn't doing very strange things with descriptor numbers. */
2154     if (num_heads > vq->vring.num) {
2155         vu_panic(dev, "Guest moved used index from %u to %u",
2156                  idx, vq->shadow_avail_idx);
2157         return -1;
2158     }
2159     if (num_heads) {
2160         /* On success, callers read a descriptor at vq->last_avail_idx.
2161          * Make sure descriptor read does not bypass avail index read. */
2162         smp_rmb();
2163     }
2164 
2165     return num_heads;
2166 }
2167 
2168 static bool
2169 virtqueue_get_head(VuDev *dev, VuVirtq *vq,
2170                    unsigned int idx, unsigned int *head)
2171 {
2172     /* Grab the next descriptor number they're advertising, and increment
2173      * the index we've seen. */
2174     *head = vring_avail_ring(vq, idx % vq->vring.num);
2175 
2176     /* If their number is silly, that's a fatal mistake. */
2177     if (*head >= vq->vring.num) {
2178         vu_panic(dev, "Guest says index %u is available", *head);
2179         return false;
2180     }
2181 
2182     return true;
2183 }
2184 
2185 static int
2186 virtqueue_read_indirect_desc(VuDev *dev, struct vring_desc *desc,
2187                              uint64_t addr, size_t len)
2188 {
2189     struct vring_desc *ori_desc;
2190     uint64_t read_len;
2191 
2192     if (len > (VIRTQUEUE_MAX_SIZE * sizeof(struct vring_desc))) {
2193         return -1;
2194     }
2195 
2196     if (len == 0) {
2197         return -1;
2198     }
2199 
2200     while (len) {
2201         read_len = len;
2202         ori_desc = vu_gpa_to_va(dev, &read_len, addr);
2203         if (!ori_desc) {
2204             return -1;
2205         }
2206 
2207         memcpy(desc, ori_desc, read_len);
2208         len -= read_len;
2209         addr += read_len;
2210         desc += read_len;
2211     }
2212 
2213     return 0;
2214 }
2215 
2216 enum {
2217     VIRTQUEUE_READ_DESC_ERROR = -1,
2218     VIRTQUEUE_READ_DESC_DONE = 0,   /* end of chain */
2219     VIRTQUEUE_READ_DESC_MORE = 1,   /* more buffers in chain */
2220 };
2221 
2222 static int
2223 virtqueue_read_next_desc(VuDev *dev, struct vring_desc *desc,
2224                          int i, unsigned int max, unsigned int *next)
2225 {
2226     /* If this descriptor says it doesn't chain, we're done. */
2227     if (!(le16toh(desc[i].flags) & VRING_DESC_F_NEXT)) {
2228         return VIRTQUEUE_READ_DESC_DONE;
2229     }
2230 
2231     /* Check they're not leading us off end of descriptors. */
2232     *next = le16toh(desc[i].next);
2233     /* Make sure compiler knows to grab that: we don't want it changing! */
2234     smp_wmb();
2235 
2236     if (*next >= max) {
2237         vu_panic(dev, "Desc next is %u", *next);
2238         return VIRTQUEUE_READ_DESC_ERROR;
2239     }
2240 
2241     return VIRTQUEUE_READ_DESC_MORE;
2242 }
2243 
2244 void
2245 vu_queue_get_avail_bytes(VuDev *dev, VuVirtq *vq, unsigned int *in_bytes,
2246                          unsigned int *out_bytes,
2247                          unsigned max_in_bytes, unsigned max_out_bytes)
2248 {
2249     unsigned int idx;
2250     unsigned int total_bufs, in_total, out_total;
2251     int rc;
2252 
2253     idx = vq->last_avail_idx;
2254 
2255     total_bufs = in_total = out_total = 0;
2256     if (unlikely(dev->broken) ||
2257         unlikely(!vq->vring.avail)) {
2258         goto done;
2259     }
2260 
2261     while ((rc = virtqueue_num_heads(dev, vq, idx)) > 0) {
2262         unsigned int max, desc_len, num_bufs, indirect = 0;
2263         uint64_t desc_addr, read_len;
2264         struct vring_desc *desc;
2265         struct vring_desc desc_buf[VIRTQUEUE_MAX_SIZE];
2266         unsigned int i;
2267 
2268         max = vq->vring.num;
2269         num_bufs = total_bufs;
2270         if (!virtqueue_get_head(dev, vq, idx++, &i)) {
2271             goto err;
2272         }
2273         desc = vq->vring.desc;
2274 
2275         if (le16toh(desc[i].flags) & VRING_DESC_F_INDIRECT) {
2276             if (le32toh(desc[i].len) % sizeof(struct vring_desc)) {
2277                 vu_panic(dev, "Invalid size for indirect buffer table");
2278                 goto err;
2279             }
2280 
2281             /* If we've got too many, that implies a descriptor loop. */
2282             if (num_bufs >= max) {
2283                 vu_panic(dev, "Looped descriptor");
2284                 goto err;
2285             }
2286 
2287             /* loop over the indirect descriptor table */
2288             indirect = 1;
2289             desc_addr = le64toh(desc[i].addr);
2290             desc_len = le32toh(desc[i].len);
2291             max = desc_len / sizeof(struct vring_desc);
2292             read_len = desc_len;
2293             desc = vu_gpa_to_va(dev, &read_len, desc_addr);
2294             if (unlikely(desc && read_len != desc_len)) {
2295                 /* Failed to use zero copy */
2296                 desc = NULL;
2297                 if (!virtqueue_read_indirect_desc(dev, desc_buf,
2298                                                   desc_addr,
2299                                                   desc_len)) {
2300                     desc = desc_buf;
2301                 }
2302             }
2303             if (!desc) {
2304                 vu_panic(dev, "Invalid indirect buffer table");
2305                 goto err;
2306             }
2307             num_bufs = i = 0;
2308         }
2309 
2310         do {
2311             /* If we've got too many, that implies a descriptor loop. */
2312             if (++num_bufs > max) {
2313                 vu_panic(dev, "Looped descriptor");
2314                 goto err;
2315             }
2316 
2317             if (le16toh(desc[i].flags) & VRING_DESC_F_WRITE) {
2318                 in_total += le32toh(desc[i].len);
2319             } else {
2320                 out_total += le32toh(desc[i].len);
2321             }
2322             if (in_total >= max_in_bytes && out_total >= max_out_bytes) {
2323                 goto done;
2324             }
2325             rc = virtqueue_read_next_desc(dev, desc, i, max, &i);
2326         } while (rc == VIRTQUEUE_READ_DESC_MORE);
2327 
2328         if (rc == VIRTQUEUE_READ_DESC_ERROR) {
2329             goto err;
2330         }
2331 
2332         if (!indirect) {
2333             total_bufs = num_bufs;
2334         } else {
2335             total_bufs++;
2336         }
2337     }
2338     if (rc < 0) {
2339         goto err;
2340     }
2341 done:
2342     if (in_bytes) {
2343         *in_bytes = in_total;
2344     }
2345     if (out_bytes) {
2346         *out_bytes = out_total;
2347     }
2348     return;
2349 
2350 err:
2351     in_total = out_total = 0;
2352     goto done;
2353 }
2354 
2355 bool
2356 vu_queue_avail_bytes(VuDev *dev, VuVirtq *vq, unsigned int in_bytes,
2357                      unsigned int out_bytes)
2358 {
2359     unsigned int in_total, out_total;
2360 
2361     vu_queue_get_avail_bytes(dev, vq, &in_total, &out_total,
2362                              in_bytes, out_bytes);
2363 
2364     return in_bytes <= in_total && out_bytes <= out_total;
2365 }
2366 
2367 /* Fetch avail_idx from VQ memory only when we really need to know if
2368  * guest has added some buffers. */
2369 bool
2370 vu_queue_empty(VuDev *dev, VuVirtq *vq)
2371 {
2372     if (unlikely(dev->broken) ||
2373         unlikely(!vq->vring.avail)) {
2374         return true;
2375     }
2376 
2377     if (vq->shadow_avail_idx != vq->last_avail_idx) {
2378         return false;
2379     }
2380 
2381     return vring_avail_idx(vq) == vq->last_avail_idx;
2382 }
2383 
2384 static bool
2385 vring_notify(VuDev *dev, VuVirtq *vq)
2386 {
2387     uint16_t old, new;
2388     bool v;
2389 
2390     /* We need to expose used array entries before checking used event. */
2391     smp_mb();
2392 
2393     /* Always notify when queue is empty (when feature acknowledge) */
2394     if (vu_has_feature(dev, VIRTIO_F_NOTIFY_ON_EMPTY) &&
2395         !vq->inuse && vu_queue_empty(dev, vq)) {
2396         return true;
2397     }
2398 
2399     if (!vu_has_feature(dev, VIRTIO_RING_F_EVENT_IDX)) {
2400         return !(vring_avail_flags(vq) & VRING_AVAIL_F_NO_INTERRUPT);
2401     }
2402 
2403     v = vq->signalled_used_valid;
2404     vq->signalled_used_valid = true;
2405     old = vq->signalled_used;
2406     new = vq->signalled_used = vq->used_idx;
2407     return !v || vring_need_event(vring_get_used_event(vq), new, old);
2408 }
2409 
2410 static void _vu_queue_notify(VuDev *dev, VuVirtq *vq, bool sync)
2411 {
2412     if (unlikely(dev->broken) ||
2413         unlikely(!vq->vring.avail)) {
2414         return;
2415     }
2416 
2417     if (!vring_notify(dev, vq)) {
2418         DPRINT("skipped notify...\n");
2419         return;
2420     }
2421 
2422     if (vq->call_fd < 0 &&
2423         vu_has_protocol_feature(dev,
2424                                 VHOST_USER_PROTOCOL_F_INBAND_NOTIFICATIONS) &&
2425         vu_has_protocol_feature(dev, VHOST_USER_PROTOCOL_F_BACKEND_REQ)) {
2426         VhostUserMsg vmsg = {
2427             .request = VHOST_USER_BACKEND_VRING_CALL,
2428             .flags = VHOST_USER_VERSION,
2429             .size = sizeof(vmsg.payload.state),
2430             .payload.state = {
2431                 .index = vq - dev->vq,
2432             },
2433         };
2434         bool ack = sync &&
2435                    vu_has_protocol_feature(dev,
2436                                            VHOST_USER_PROTOCOL_F_REPLY_ACK);
2437 
2438         if (ack) {
2439             vmsg.flags |= VHOST_USER_NEED_REPLY_MASK;
2440         }
2441 
2442         vu_message_write(dev, dev->backend_fd, &vmsg);
2443         if (ack) {
2444             vu_message_read_default(dev, dev->backend_fd, &vmsg);
2445         }
2446         return;
2447     }
2448 
2449     if (eventfd_write(vq->call_fd, 1) < 0) {
2450         vu_panic(dev, "Error writing eventfd: %s", strerror(errno));
2451     }
2452 }
2453 
2454 void vu_queue_notify(VuDev *dev, VuVirtq *vq)
2455 {
2456     _vu_queue_notify(dev, vq, false);
2457 }
2458 
2459 void vu_queue_notify_sync(VuDev *dev, VuVirtq *vq)
2460 {
2461     _vu_queue_notify(dev, vq, true);
2462 }
2463 
2464 void vu_config_change_msg(VuDev *dev)
2465 {
2466     VhostUserMsg vmsg = {
2467         .request = VHOST_USER_BACKEND_CONFIG_CHANGE_MSG,
2468         .flags = VHOST_USER_VERSION,
2469     };
2470 
2471     vu_message_write(dev, dev->backend_fd, &vmsg);
2472 }
2473 
2474 static inline void
2475 vring_used_flags_set_bit(VuVirtq *vq, int mask)
2476 {
2477     uint16_t *flags;
2478 
2479     flags = (uint16_t *)((char*)vq->vring.used +
2480                          offsetof(struct vring_used, flags));
2481     *flags = htole16(le16toh(*flags) | mask);
2482 }
2483 
2484 static inline void
2485 vring_used_flags_unset_bit(VuVirtq *vq, int mask)
2486 {
2487     uint16_t *flags;
2488 
2489     flags = (uint16_t *)((char*)vq->vring.used +
2490                          offsetof(struct vring_used, flags));
2491     *flags = htole16(le16toh(*flags) & ~mask);
2492 }
2493 
2494 static inline void
2495 vring_set_avail_event(VuVirtq *vq, uint16_t val)
2496 {
2497     uint16_t val_le = htole16(val);
2498 
2499     if (!vq->notification) {
2500         return;
2501     }
2502 
2503     memcpy(&vq->vring.used->ring[vq->vring.num], &val_le, sizeof(uint16_t));
2504 }
2505 
2506 void
2507 vu_queue_set_notification(VuDev *dev, VuVirtq *vq, int enable)
2508 {
2509     vq->notification = enable;
2510     if (vu_has_feature(dev, VIRTIO_RING_F_EVENT_IDX)) {
2511         vring_set_avail_event(vq, vring_avail_idx(vq));
2512     } else if (enable) {
2513         vring_used_flags_unset_bit(vq, VRING_USED_F_NO_NOTIFY);
2514     } else {
2515         vring_used_flags_set_bit(vq, VRING_USED_F_NO_NOTIFY);
2516     }
2517     if (enable) {
2518         /* Expose avail event/used flags before caller checks the avail idx. */
2519         smp_mb();
2520     }
2521 }
2522 
2523 static bool
2524 virtqueue_map_desc(VuDev *dev,
2525                    unsigned int *p_num_sg, struct iovec *iov,
2526                    unsigned int max_num_sg, bool is_write,
2527                    uint64_t pa, size_t sz)
2528 {
2529     unsigned num_sg = *p_num_sg;
2530 
2531     assert(num_sg <= max_num_sg);
2532 
2533     if (!sz) {
2534         vu_panic(dev, "virtio: zero sized buffers are not allowed");
2535         return false;
2536     }
2537 
2538     while (sz) {
2539         uint64_t len = sz;
2540 
2541         if (num_sg == max_num_sg) {
2542             vu_panic(dev, "virtio: too many descriptors in indirect table");
2543             return false;
2544         }
2545 
2546         iov[num_sg].iov_base = vu_gpa_to_va(dev, &len, pa);
2547         if (iov[num_sg].iov_base == NULL) {
2548             vu_panic(dev, "virtio: invalid address for buffers");
2549             return false;
2550         }
2551         iov[num_sg].iov_len = len;
2552         num_sg++;
2553         sz -= len;
2554         pa += len;
2555     }
2556 
2557     *p_num_sg = num_sg;
2558     return true;
2559 }
2560 
2561 static void *
2562 virtqueue_alloc_element(size_t sz,
2563                                      unsigned out_num, unsigned in_num)
2564 {
2565     VuVirtqElement *elem;
2566     size_t in_sg_ofs = ALIGN_UP(sz, __alignof__(elem->in_sg[0]));
2567     size_t out_sg_ofs = in_sg_ofs + in_num * sizeof(elem->in_sg[0]);
2568     size_t out_sg_end = out_sg_ofs + out_num * sizeof(elem->out_sg[0]);
2569 
2570     assert(sz >= sizeof(VuVirtqElement));
2571     elem = malloc(out_sg_end);
2572     if (!elem) {
2573         DPRINT("%s: failed to malloc virtqueue element\n", __func__);
2574         return NULL;
2575     }
2576     elem->out_num = out_num;
2577     elem->in_num = in_num;
2578     elem->in_sg = (void *)elem + in_sg_ofs;
2579     elem->out_sg = (void *)elem + out_sg_ofs;
2580     return elem;
2581 }
2582 
2583 static void *
2584 vu_queue_map_desc(VuDev *dev, VuVirtq *vq, unsigned int idx, size_t sz)
2585 {
2586     struct vring_desc *desc = vq->vring.desc;
2587     uint64_t desc_addr, read_len;
2588     unsigned int desc_len;
2589     unsigned int max = vq->vring.num;
2590     unsigned int i = idx;
2591     VuVirtqElement *elem;
2592     unsigned int out_num = 0, in_num = 0;
2593     struct iovec iov[VIRTQUEUE_MAX_SIZE];
2594     struct vring_desc desc_buf[VIRTQUEUE_MAX_SIZE];
2595     int rc;
2596 
2597     if (le16toh(desc[i].flags) & VRING_DESC_F_INDIRECT) {
2598         if (le32toh(desc[i].len) % sizeof(struct vring_desc)) {
2599             vu_panic(dev, "Invalid size for indirect buffer table");
2600             return NULL;
2601         }
2602 
2603         /* loop over the indirect descriptor table */
2604         desc_addr = le64toh(desc[i].addr);
2605         desc_len = le32toh(desc[i].len);
2606         max = desc_len / sizeof(struct vring_desc);
2607         read_len = desc_len;
2608         desc = vu_gpa_to_va(dev, &read_len, desc_addr);
2609         if (unlikely(desc && read_len != desc_len)) {
2610             /* Failed to use zero copy */
2611             desc = NULL;
2612             if (!virtqueue_read_indirect_desc(dev, desc_buf,
2613                                               desc_addr,
2614                                               desc_len)) {
2615                 desc = desc_buf;
2616             }
2617         }
2618         if (!desc) {
2619             vu_panic(dev, "Invalid indirect buffer table");
2620             return NULL;
2621         }
2622         i = 0;
2623     }
2624 
2625     /* Collect all the descriptors */
2626     do {
2627         if (le16toh(desc[i].flags) & VRING_DESC_F_WRITE) {
2628             if (!virtqueue_map_desc(dev, &in_num, iov + out_num,
2629                                VIRTQUEUE_MAX_SIZE - out_num, true,
2630                                le64toh(desc[i].addr),
2631                                le32toh(desc[i].len))) {
2632                 return NULL;
2633             }
2634         } else {
2635             if (in_num) {
2636                 vu_panic(dev, "Incorrect order for descriptors");
2637                 return NULL;
2638             }
2639             if (!virtqueue_map_desc(dev, &out_num, iov,
2640                                VIRTQUEUE_MAX_SIZE, false,
2641                                le64toh(desc[i].addr),
2642                                le32toh(desc[i].len))) {
2643                 return NULL;
2644             }
2645         }
2646 
2647         /* If we've got too many, that implies a descriptor loop. */
2648         if ((in_num + out_num) > max) {
2649             vu_panic(dev, "Looped descriptor");
2650             return NULL;
2651         }
2652         rc = virtqueue_read_next_desc(dev, desc, i, max, &i);
2653     } while (rc == VIRTQUEUE_READ_DESC_MORE);
2654 
2655     if (rc == VIRTQUEUE_READ_DESC_ERROR) {
2656         vu_panic(dev, "read descriptor error");
2657         return NULL;
2658     }
2659 
2660     /* Now copy what we have collected and mapped */
2661     elem = virtqueue_alloc_element(sz, out_num, in_num);
2662     if (!elem) {
2663         return NULL;
2664     }
2665     elem->index = idx;
2666     for (i = 0; i < out_num; i++) {
2667         elem->out_sg[i] = iov[i];
2668     }
2669     for (i = 0; i < in_num; i++) {
2670         elem->in_sg[i] = iov[out_num + i];
2671     }
2672 
2673     return elem;
2674 }
2675 
2676 static int
2677 vu_queue_inflight_get(VuDev *dev, VuVirtq *vq, int desc_idx)
2678 {
2679     if (!vu_has_protocol_feature(dev, VHOST_USER_PROTOCOL_F_INFLIGHT_SHMFD)) {
2680         return 0;
2681     }
2682 
2683     if (unlikely(!vq->inflight)) {
2684         return -1;
2685     }
2686 
2687     vq->inflight->desc[desc_idx].counter = vq->counter++;
2688     vq->inflight->desc[desc_idx].inflight = 1;
2689 
2690     return 0;
2691 }
2692 
2693 static int
2694 vu_queue_inflight_pre_put(VuDev *dev, VuVirtq *vq, int desc_idx)
2695 {
2696     if (!vu_has_protocol_feature(dev, VHOST_USER_PROTOCOL_F_INFLIGHT_SHMFD)) {
2697         return 0;
2698     }
2699 
2700     if (unlikely(!vq->inflight)) {
2701         return -1;
2702     }
2703 
2704     vq->inflight->last_batch_head = desc_idx;
2705 
2706     return 0;
2707 }
2708 
2709 static int
2710 vu_queue_inflight_post_put(VuDev *dev, VuVirtq *vq, int desc_idx)
2711 {
2712     if (!vu_has_protocol_feature(dev, VHOST_USER_PROTOCOL_F_INFLIGHT_SHMFD)) {
2713         return 0;
2714     }
2715 
2716     if (unlikely(!vq->inflight)) {
2717         return -1;
2718     }
2719 
2720     barrier();
2721 
2722     vq->inflight->desc[desc_idx].inflight = 0;
2723 
2724     barrier();
2725 
2726     vq->inflight->used_idx = vq->used_idx;
2727 
2728     return 0;
2729 }
2730 
2731 void *
2732 vu_queue_pop(VuDev *dev, VuVirtq *vq, size_t sz)
2733 {
2734     int i;
2735     unsigned int head;
2736     VuVirtqElement *elem;
2737 
2738     if (unlikely(dev->broken) ||
2739         unlikely(!vq->vring.avail)) {
2740         return NULL;
2741     }
2742 
2743     if (unlikely(vq->resubmit_list && vq->resubmit_num > 0)) {
2744         i = (--vq->resubmit_num);
2745         elem = vu_queue_map_desc(dev, vq, vq->resubmit_list[i].index, sz);
2746 
2747         if (!vq->resubmit_num) {
2748             free(vq->resubmit_list);
2749             vq->resubmit_list = NULL;
2750         }
2751 
2752         return elem;
2753     }
2754 
2755     if (vu_queue_empty(dev, vq)) {
2756         return NULL;
2757     }
2758     /*
2759      * Needed after virtio_queue_empty(), see comment in
2760      * virtqueue_num_heads().
2761      */
2762     smp_rmb();
2763 
2764     if (vq->inuse >= vq->vring.num) {
2765         vu_panic(dev, "Virtqueue size exceeded");
2766         return NULL;
2767     }
2768 
2769     if (!virtqueue_get_head(dev, vq, vq->last_avail_idx++, &head)) {
2770         return NULL;
2771     }
2772 
2773     if (vu_has_feature(dev, VIRTIO_RING_F_EVENT_IDX)) {
2774         vring_set_avail_event(vq, vq->last_avail_idx);
2775     }
2776 
2777     elem = vu_queue_map_desc(dev, vq, head, sz);
2778 
2779     if (!elem) {
2780         return NULL;
2781     }
2782 
2783     vq->inuse++;
2784 
2785     vu_queue_inflight_get(dev, vq, head);
2786 
2787     return elem;
2788 }
2789 
2790 static void
2791 vu_queue_detach_element(VuDev *dev, VuVirtq *vq, VuVirtqElement *elem,
2792                         size_t len)
2793 {
2794     vq->inuse--;
2795     /* unmap, when DMA support is added */
2796 }
2797 
2798 void
2799 vu_queue_unpop(VuDev *dev, VuVirtq *vq, VuVirtqElement *elem,
2800                size_t len)
2801 {
2802     vq->last_avail_idx--;
2803     vu_queue_detach_element(dev, vq, elem, len);
2804 }
2805 
2806 bool
2807 vu_queue_rewind(VuDev *dev, VuVirtq *vq, unsigned int num)
2808 {
2809     if (num > vq->inuse) {
2810         return false;
2811     }
2812     vq->last_avail_idx -= num;
2813     vq->inuse -= num;
2814     return true;
2815 }
2816 
2817 static inline
2818 void vring_used_write(VuDev *dev, VuVirtq *vq,
2819                       struct vring_used_elem *uelem, int i)
2820 {
2821     struct vring_used *used = vq->vring.used;
2822 
2823     used->ring[i] = *uelem;
2824     vu_log_write(dev, vq->vring.log_guest_addr +
2825                  offsetof(struct vring_used, ring[i]),
2826                  sizeof(used->ring[i]));
2827 }
2828 
2829 
2830 static void
2831 vu_log_queue_fill(VuDev *dev, VuVirtq *vq,
2832                   const VuVirtqElement *elem,
2833                   unsigned int len)
2834 {
2835     struct vring_desc *desc = vq->vring.desc;
2836     unsigned int i, max, min, desc_len;
2837     uint64_t desc_addr, read_len;
2838     struct vring_desc desc_buf[VIRTQUEUE_MAX_SIZE];
2839     unsigned num_bufs = 0;
2840 
2841     max = vq->vring.num;
2842     i = elem->index;
2843 
2844     if (le16toh(desc[i].flags) & VRING_DESC_F_INDIRECT) {
2845         if (le32toh(desc[i].len) % sizeof(struct vring_desc)) {
2846             vu_panic(dev, "Invalid size for indirect buffer table");
2847             return;
2848         }
2849 
2850         /* loop over the indirect descriptor table */
2851         desc_addr = le64toh(desc[i].addr);
2852         desc_len = le32toh(desc[i].len);
2853         max = desc_len / sizeof(struct vring_desc);
2854         read_len = desc_len;
2855         desc = vu_gpa_to_va(dev, &read_len, desc_addr);
2856         if (unlikely(desc && read_len != desc_len)) {
2857             /* Failed to use zero copy */
2858             desc = NULL;
2859             if (!virtqueue_read_indirect_desc(dev, desc_buf,
2860                                               desc_addr,
2861                                               desc_len)) {
2862                 desc = desc_buf;
2863             }
2864         }
2865         if (!desc) {
2866             vu_panic(dev, "Invalid indirect buffer table");
2867             return;
2868         }
2869         i = 0;
2870     }
2871 
2872     do {
2873         if (++num_bufs > max) {
2874             vu_panic(dev, "Looped descriptor");
2875             return;
2876         }
2877 
2878         if (le16toh(desc[i].flags) & VRING_DESC_F_WRITE) {
2879             min = MIN(le32toh(desc[i].len), len);
2880             vu_log_write(dev, le64toh(desc[i].addr), min);
2881             len -= min;
2882         }
2883 
2884     } while (len > 0 &&
2885              (virtqueue_read_next_desc(dev, desc, i, max, &i)
2886               == VIRTQUEUE_READ_DESC_MORE));
2887 }
2888 
2889 void
2890 vu_queue_fill(VuDev *dev, VuVirtq *vq,
2891               const VuVirtqElement *elem,
2892               unsigned int len, unsigned int idx)
2893 {
2894     struct vring_used_elem uelem;
2895 
2896     if (unlikely(dev->broken) ||
2897         unlikely(!vq->vring.avail)) {
2898         return;
2899     }
2900 
2901     vu_log_queue_fill(dev, vq, elem, len);
2902 
2903     idx = (idx + vq->used_idx) % vq->vring.num;
2904 
2905     uelem.id = htole32(elem->index);
2906     uelem.len = htole32(len);
2907     vring_used_write(dev, vq, &uelem, idx);
2908 }
2909 
2910 static inline
2911 void vring_used_idx_set(VuDev *dev, VuVirtq *vq, uint16_t val)
2912 {
2913     vq->vring.used->idx = htole16(val);
2914     vu_log_write(dev,
2915                  vq->vring.log_guest_addr + offsetof(struct vring_used, idx),
2916                  sizeof(vq->vring.used->idx));
2917 
2918     vq->used_idx = val;
2919 }
2920 
2921 void
2922 vu_queue_flush(VuDev *dev, VuVirtq *vq, unsigned int count)
2923 {
2924     uint16_t old, new;
2925 
2926     if (unlikely(dev->broken) ||
2927         unlikely(!vq->vring.avail)) {
2928         return;
2929     }
2930 
2931     /* Make sure buffer is written before we update index. */
2932     smp_wmb();
2933 
2934     old = vq->used_idx;
2935     new = old + count;
2936     vring_used_idx_set(dev, vq, new);
2937     vq->inuse -= count;
2938     if (unlikely((int16_t)(new - vq->signalled_used) < (uint16_t)(new - old))) {
2939         vq->signalled_used_valid = false;
2940     }
2941 }
2942 
2943 void
2944 vu_queue_push(VuDev *dev, VuVirtq *vq,
2945               const VuVirtqElement *elem, unsigned int len)
2946 {
2947     vu_queue_fill(dev, vq, elem, len, 0);
2948     vu_queue_inflight_pre_put(dev, vq, elem->index);
2949     vu_queue_flush(dev, vq, 1);
2950     vu_queue_inflight_post_put(dev, vq, elem->index);
2951 }
2952