1 /*
2  * Vhost User library
3  *
4  * Copyright IBM, Corp. 2007
5  * Copyright (c) 2016 Red Hat, Inc.
6  *
7  * Authors:
8  *  Anthony Liguori <aliguori@us.ibm.com>
9  *  Marc-André Lureau <mlureau@redhat.com>
10  *  Victor Kaplansky <victork@redhat.com>
11  *
12  * This work is licensed under the terms of the GNU GPL, version 2 or
13  * later.  See the COPYING file in the top-level directory.
14  */
15 
16 #ifndef _GNU_SOURCE
17 #define _GNU_SOURCE
18 #endif
19 
20 /* this code avoids GLib dependency */
21 #include <stdlib.h>
22 #include <stdio.h>
23 #include <unistd.h>
24 #include <stdarg.h>
25 #include <errno.h>
26 #include <string.h>
27 #include <assert.h>
28 #include <inttypes.h>
29 #include <sys/types.h>
30 #include <sys/socket.h>
31 #include <sys/eventfd.h>
32 #include <sys/mman.h>
33 #include <endian.h>
34 
35 #if defined(__linux__)
36 #include <sys/syscall.h>
37 #include <fcntl.h>
38 #include <sys/ioctl.h>
39 #include <linux/vhost.h>
40 
41 #ifdef __NR_userfaultfd
42 #include <linux/userfaultfd.h>
43 #endif
44 
45 #endif
46 
47 #include "include/atomic.h"
48 
49 #include "libvhost-user.h"
50 
51 /* usually provided by GLib */
52 #if     __GNUC__ > 2 || (__GNUC__ == 2 && __GNUC_MINOR__ > 4)
53 #if !defined(__clang__) && (__GNUC__ == 4 && __GNUC_MINOR__ == 4)
54 #define G_GNUC_PRINTF(format_idx, arg_idx) \
55   __attribute__((__format__(gnu_printf, format_idx, arg_idx)))
56 #else
57 #define G_GNUC_PRINTF(format_idx, arg_idx) \
58   __attribute__((__format__(__printf__, format_idx, arg_idx)))
59 #endif
60 #else   /* !__GNUC__ */
61 #define G_GNUC_PRINTF(format_idx, arg_idx)
62 #endif  /* !__GNUC__ */
63 #ifndef MIN
64 #define MIN(x, y) ({                            \
65             __typeof__(x) _min1 = (x);          \
66             __typeof__(y) _min2 = (y);          \
67             (void) (&_min1 == &_min2);          \
68             _min1 < _min2 ? _min1 : _min2; })
69 #endif
70 
71 /* Round number down to multiple */
72 #define ALIGN_DOWN(n, m) ((n) / (m) * (m))
73 
74 /* Round number up to multiple */
75 #define ALIGN_UP(n, m) ALIGN_DOWN((n) + (m) - 1, (m))
76 
77 #ifndef unlikely
78 #define unlikely(x)   __builtin_expect(!!(x), 0)
79 #endif
80 
81 /* Align each region to cache line size in inflight buffer */
82 #define INFLIGHT_ALIGNMENT 64
83 
84 /* The version of inflight buffer */
85 #define INFLIGHT_VERSION 1
86 
87 /* The version of the protocol we support */
88 #define VHOST_USER_VERSION 1
89 #define LIBVHOST_USER_DEBUG 0
90 
91 #define DPRINT(...)                             \
92     do {                                        \
93         if (LIBVHOST_USER_DEBUG) {              \
94             fprintf(stderr, __VA_ARGS__);        \
95         }                                       \
96     } while (0)
97 
98 static inline
99 bool has_feature(uint64_t features, unsigned int fbit)
100 {
101     assert(fbit < 64);
102     return !!(features & (1ULL << fbit));
103 }
104 
105 static inline
106 bool vu_has_feature(VuDev *dev,
107                     unsigned int fbit)
108 {
109     return has_feature(dev->features, fbit);
110 }
111 
112 static inline bool vu_has_protocol_feature(VuDev *dev, unsigned int fbit)
113 {
114     return has_feature(dev->protocol_features, fbit);
115 }
116 
117 const char *
118 vu_request_to_string(unsigned int req)
119 {
120 #define REQ(req) [req] = #req
121     static const char *vu_request_str[] = {
122         REQ(VHOST_USER_NONE),
123         REQ(VHOST_USER_GET_FEATURES),
124         REQ(VHOST_USER_SET_FEATURES),
125         REQ(VHOST_USER_SET_OWNER),
126         REQ(VHOST_USER_RESET_OWNER),
127         REQ(VHOST_USER_SET_MEM_TABLE),
128         REQ(VHOST_USER_SET_LOG_BASE),
129         REQ(VHOST_USER_SET_LOG_FD),
130         REQ(VHOST_USER_SET_VRING_NUM),
131         REQ(VHOST_USER_SET_VRING_ADDR),
132         REQ(VHOST_USER_SET_VRING_BASE),
133         REQ(VHOST_USER_GET_VRING_BASE),
134         REQ(VHOST_USER_SET_VRING_KICK),
135         REQ(VHOST_USER_SET_VRING_CALL),
136         REQ(VHOST_USER_SET_VRING_ERR),
137         REQ(VHOST_USER_GET_PROTOCOL_FEATURES),
138         REQ(VHOST_USER_SET_PROTOCOL_FEATURES),
139         REQ(VHOST_USER_GET_QUEUE_NUM),
140         REQ(VHOST_USER_SET_VRING_ENABLE),
141         REQ(VHOST_USER_SEND_RARP),
142         REQ(VHOST_USER_NET_SET_MTU),
143         REQ(VHOST_USER_SET_SLAVE_REQ_FD),
144         REQ(VHOST_USER_IOTLB_MSG),
145         REQ(VHOST_USER_SET_VRING_ENDIAN),
146         REQ(VHOST_USER_GET_CONFIG),
147         REQ(VHOST_USER_SET_CONFIG),
148         REQ(VHOST_USER_POSTCOPY_ADVISE),
149         REQ(VHOST_USER_POSTCOPY_LISTEN),
150         REQ(VHOST_USER_POSTCOPY_END),
151         REQ(VHOST_USER_GET_INFLIGHT_FD),
152         REQ(VHOST_USER_SET_INFLIGHT_FD),
153         REQ(VHOST_USER_GPU_SET_SOCKET),
154         REQ(VHOST_USER_VRING_KICK),
155         REQ(VHOST_USER_GET_MAX_MEM_SLOTS),
156         REQ(VHOST_USER_ADD_MEM_REG),
157         REQ(VHOST_USER_REM_MEM_REG),
158         REQ(VHOST_USER_MAX),
159     };
160 #undef REQ
161 
162     if (req < VHOST_USER_MAX) {
163         return vu_request_str[req];
164     } else {
165         return "unknown";
166     }
167 }
168 
169 static void G_GNUC_PRINTF(2, 3)
170 vu_panic(VuDev *dev, const char *msg, ...)
171 {
172     char *buf = NULL;
173     va_list ap;
174 
175     va_start(ap, msg);
176     if (vasprintf(&buf, msg, ap) < 0) {
177         buf = NULL;
178     }
179     va_end(ap);
180 
181     dev->broken = true;
182     dev->panic(dev, buf);
183     free(buf);
184 
185     /*
186      * FIXME:
187      * find a way to call virtio_error, or perhaps close the connection?
188      */
189 }
190 
191 /* Translate guest physical address to our virtual address.  */
192 void *
193 vu_gpa_to_va(VuDev *dev, uint64_t *plen, uint64_t guest_addr)
194 {
195     unsigned int i;
196 
197     if (*plen == 0) {
198         return NULL;
199     }
200 
201     /* Find matching memory region.  */
202     for (i = 0; i < dev->nregions; i++) {
203         VuDevRegion *r = &dev->regions[i];
204 
205         if ((guest_addr >= r->gpa) && (guest_addr < (r->gpa + r->size))) {
206             if ((guest_addr + *plen) > (r->gpa + r->size)) {
207                 *plen = r->gpa + r->size - guest_addr;
208             }
209             return (void *)(uintptr_t)
210                 guest_addr - r->gpa + r->mmap_addr + r->mmap_offset;
211         }
212     }
213 
214     return NULL;
215 }
216 
217 /* Translate qemu virtual address to our virtual address.  */
218 static void *
219 qva_to_va(VuDev *dev, uint64_t qemu_addr)
220 {
221     unsigned int i;
222 
223     /* Find matching memory region.  */
224     for (i = 0; i < dev->nregions; i++) {
225         VuDevRegion *r = &dev->regions[i];
226 
227         if ((qemu_addr >= r->qva) && (qemu_addr < (r->qva + r->size))) {
228             return (void *)(uintptr_t)
229                 qemu_addr - r->qva + r->mmap_addr + r->mmap_offset;
230         }
231     }
232 
233     return NULL;
234 }
235 
236 static void
237 vmsg_close_fds(VhostUserMsg *vmsg)
238 {
239     int i;
240 
241     for (i = 0; i < vmsg->fd_num; i++) {
242         close(vmsg->fds[i]);
243     }
244 }
245 
246 /* Set reply payload.u64 and clear request flags and fd_num */
247 static void vmsg_set_reply_u64(VhostUserMsg *vmsg, uint64_t val)
248 {
249     vmsg->flags = 0; /* defaults will be set by vu_send_reply() */
250     vmsg->size = sizeof(vmsg->payload.u64);
251     vmsg->payload.u64 = val;
252     vmsg->fd_num = 0;
253 }
254 
255 /* A test to see if we have userfault available */
256 static bool
257 have_userfault(void)
258 {
259 #if defined(__linux__) && defined(__NR_userfaultfd) &&\
260         defined(UFFD_FEATURE_MISSING_SHMEM) &&\
261         defined(UFFD_FEATURE_MISSING_HUGETLBFS)
262     /* Now test the kernel we're running on really has the features */
263     int ufd = syscall(__NR_userfaultfd, O_CLOEXEC | O_NONBLOCK);
264     struct uffdio_api api_struct;
265     if (ufd < 0) {
266         return false;
267     }
268 
269     api_struct.api = UFFD_API;
270     api_struct.features = UFFD_FEATURE_MISSING_SHMEM |
271                           UFFD_FEATURE_MISSING_HUGETLBFS;
272     if (ioctl(ufd, UFFDIO_API, &api_struct)) {
273         close(ufd);
274         return false;
275     }
276     close(ufd);
277     return true;
278 
279 #else
280     return false;
281 #endif
282 }
283 
284 static bool
285 vu_message_read_default(VuDev *dev, int conn_fd, VhostUserMsg *vmsg)
286 {
287     char control[CMSG_SPACE(VHOST_MEMORY_BASELINE_NREGIONS * sizeof(int))] = {};
288     struct iovec iov = {
289         .iov_base = (char *)vmsg,
290         .iov_len = VHOST_USER_HDR_SIZE,
291     };
292     struct msghdr msg = {
293         .msg_iov = &iov,
294         .msg_iovlen = 1,
295         .msg_control = control,
296         .msg_controllen = sizeof(control),
297     };
298     size_t fd_size;
299     struct cmsghdr *cmsg;
300     int rc;
301 
302     do {
303         rc = recvmsg(conn_fd, &msg, 0);
304     } while (rc < 0 && (errno == EINTR || errno == EAGAIN));
305 
306     if (rc < 0) {
307         vu_panic(dev, "Error while recvmsg: %s", strerror(errno));
308         return false;
309     }
310 
311     vmsg->fd_num = 0;
312     for (cmsg = CMSG_FIRSTHDR(&msg);
313          cmsg != NULL;
314          cmsg = CMSG_NXTHDR(&msg, cmsg))
315     {
316         if (cmsg->cmsg_level == SOL_SOCKET && cmsg->cmsg_type == SCM_RIGHTS) {
317             fd_size = cmsg->cmsg_len - CMSG_LEN(0);
318             vmsg->fd_num = fd_size / sizeof(int);
319             memcpy(vmsg->fds, CMSG_DATA(cmsg), fd_size);
320             break;
321         }
322     }
323 
324     if (vmsg->size > sizeof(vmsg->payload)) {
325         vu_panic(dev,
326                  "Error: too big message request: %d, size: vmsg->size: %u, "
327                  "while sizeof(vmsg->payload) = %zu\n",
328                  vmsg->request, vmsg->size, sizeof(vmsg->payload));
329         goto fail;
330     }
331 
332     if (vmsg->size) {
333         do {
334             rc = read(conn_fd, &vmsg->payload, vmsg->size);
335         } while (rc < 0 && (errno == EINTR || errno == EAGAIN));
336 
337         if (rc <= 0) {
338             vu_panic(dev, "Error while reading: %s", strerror(errno));
339             goto fail;
340         }
341 
342         assert((uint32_t)rc == vmsg->size);
343     }
344 
345     return true;
346 
347 fail:
348     vmsg_close_fds(vmsg);
349 
350     return false;
351 }
352 
353 static bool
354 vu_message_write(VuDev *dev, int conn_fd, VhostUserMsg *vmsg)
355 {
356     int rc;
357     uint8_t *p = (uint8_t *)vmsg;
358     char control[CMSG_SPACE(VHOST_MEMORY_BASELINE_NREGIONS * sizeof(int))] = {};
359     struct iovec iov = {
360         .iov_base = (char *)vmsg,
361         .iov_len = VHOST_USER_HDR_SIZE,
362     };
363     struct msghdr msg = {
364         .msg_iov = &iov,
365         .msg_iovlen = 1,
366         .msg_control = control,
367     };
368     struct cmsghdr *cmsg;
369 
370     memset(control, 0, sizeof(control));
371     assert(vmsg->fd_num <= VHOST_MEMORY_BASELINE_NREGIONS);
372     if (vmsg->fd_num > 0) {
373         size_t fdsize = vmsg->fd_num * sizeof(int);
374         msg.msg_controllen = CMSG_SPACE(fdsize);
375         cmsg = CMSG_FIRSTHDR(&msg);
376         cmsg->cmsg_len = CMSG_LEN(fdsize);
377         cmsg->cmsg_level = SOL_SOCKET;
378         cmsg->cmsg_type = SCM_RIGHTS;
379         memcpy(CMSG_DATA(cmsg), vmsg->fds, fdsize);
380     } else {
381         msg.msg_controllen = 0;
382     }
383 
384     do {
385         rc = sendmsg(conn_fd, &msg, 0);
386     } while (rc < 0 && (errno == EINTR || errno == EAGAIN));
387 
388     if (vmsg->size) {
389         do {
390             if (vmsg->data) {
391                 rc = write(conn_fd, vmsg->data, vmsg->size);
392             } else {
393                 rc = write(conn_fd, p + VHOST_USER_HDR_SIZE, vmsg->size);
394             }
395         } while (rc < 0 && (errno == EINTR || errno == EAGAIN));
396     }
397 
398     if (rc <= 0) {
399         vu_panic(dev, "Error while writing: %s", strerror(errno));
400         return false;
401     }
402 
403     return true;
404 }
405 
406 static bool
407 vu_send_reply(VuDev *dev, int conn_fd, VhostUserMsg *vmsg)
408 {
409     /* Set the version in the flags when sending the reply */
410     vmsg->flags &= ~VHOST_USER_VERSION_MASK;
411     vmsg->flags |= VHOST_USER_VERSION;
412     vmsg->flags |= VHOST_USER_REPLY_MASK;
413 
414     return vu_message_write(dev, conn_fd, vmsg);
415 }
416 
417 /*
418  * Processes a reply on the slave channel.
419  * Entered with slave_mutex held and releases it before exit.
420  * Returns true on success.
421  */
422 static bool
423 vu_process_message_reply(VuDev *dev, const VhostUserMsg *vmsg)
424 {
425     VhostUserMsg msg_reply;
426     bool result = false;
427 
428     if ((vmsg->flags & VHOST_USER_NEED_REPLY_MASK) == 0) {
429         result = true;
430         goto out;
431     }
432 
433     if (!vu_message_read_default(dev, dev->slave_fd, &msg_reply)) {
434         goto out;
435     }
436 
437     if (msg_reply.request != vmsg->request) {
438         DPRINT("Received unexpected msg type. Expected %d received %d",
439                vmsg->request, msg_reply.request);
440         goto out;
441     }
442 
443     result = msg_reply.payload.u64 == 0;
444 
445 out:
446     pthread_mutex_unlock(&dev->slave_mutex);
447     return result;
448 }
449 
450 /* Kick the log_call_fd if required. */
451 static void
452 vu_log_kick(VuDev *dev)
453 {
454     if (dev->log_call_fd != -1) {
455         DPRINT("Kicking the QEMU's log...\n");
456         if (eventfd_write(dev->log_call_fd, 1) < 0) {
457             vu_panic(dev, "Error writing eventfd: %s", strerror(errno));
458         }
459     }
460 }
461 
462 static void
463 vu_log_page(uint8_t *log_table, uint64_t page)
464 {
465     DPRINT("Logged dirty guest page: %"PRId64"\n", page);
466     qatomic_or(&log_table[page / 8], 1 << (page % 8));
467 }
468 
469 static void
470 vu_log_write(VuDev *dev, uint64_t address, uint64_t length)
471 {
472     uint64_t page;
473 
474     if (!(dev->features & (1ULL << VHOST_F_LOG_ALL)) ||
475         !dev->log_table || !length) {
476         return;
477     }
478 
479     assert(dev->log_size > ((address + length - 1) / VHOST_LOG_PAGE / 8));
480 
481     page = address / VHOST_LOG_PAGE;
482     while (page * VHOST_LOG_PAGE < address + length) {
483         vu_log_page(dev->log_table, page);
484         page += 1;
485     }
486 
487     vu_log_kick(dev);
488 }
489 
490 static void
491 vu_kick_cb(VuDev *dev, int condition, void *data)
492 {
493     int index = (intptr_t)data;
494     VuVirtq *vq = &dev->vq[index];
495     int sock = vq->kick_fd;
496     eventfd_t kick_data;
497     ssize_t rc;
498 
499     rc = eventfd_read(sock, &kick_data);
500     if (rc == -1) {
501         vu_panic(dev, "kick eventfd_read(): %s", strerror(errno));
502         dev->remove_watch(dev, dev->vq[index].kick_fd);
503     } else {
504         DPRINT("Got kick_data: %016"PRIx64" handler:%p idx:%d\n",
505                kick_data, vq->handler, index);
506         if (vq->handler) {
507             vq->handler(dev, index);
508         }
509     }
510 }
511 
512 static bool
513 vu_get_features_exec(VuDev *dev, VhostUserMsg *vmsg)
514 {
515     vmsg->payload.u64 =
516         /*
517          * The following VIRTIO feature bits are supported by our virtqueue
518          * implementation:
519          */
520         1ULL << VIRTIO_F_NOTIFY_ON_EMPTY |
521         1ULL << VIRTIO_RING_F_INDIRECT_DESC |
522         1ULL << VIRTIO_RING_F_EVENT_IDX |
523         1ULL << VIRTIO_F_VERSION_1 |
524 
525         /* vhost-user feature bits */
526         1ULL << VHOST_F_LOG_ALL |
527         1ULL << VHOST_USER_F_PROTOCOL_FEATURES;
528 
529     if (dev->iface->get_features) {
530         vmsg->payload.u64 |= dev->iface->get_features(dev);
531     }
532 
533     vmsg->size = sizeof(vmsg->payload.u64);
534     vmsg->fd_num = 0;
535 
536     DPRINT("Sending back to guest u64: 0x%016"PRIx64"\n", vmsg->payload.u64);
537 
538     return true;
539 }
540 
541 static void
542 vu_set_enable_all_rings(VuDev *dev, bool enabled)
543 {
544     uint16_t i;
545 
546     for (i = 0; i < dev->max_queues; i++) {
547         dev->vq[i].enable = enabled;
548     }
549 }
550 
551 static bool
552 vu_set_features_exec(VuDev *dev, VhostUserMsg *vmsg)
553 {
554     DPRINT("u64: 0x%016"PRIx64"\n", vmsg->payload.u64);
555 
556     dev->features = vmsg->payload.u64;
557     if (!vu_has_feature(dev, VIRTIO_F_VERSION_1)) {
558         /*
559          * We only support devices conforming to VIRTIO 1.0 or
560          * later
561          */
562         vu_panic(dev, "virtio legacy devices aren't supported by libvhost-user");
563         return false;
564     }
565 
566     if (!(dev->features & VHOST_USER_F_PROTOCOL_FEATURES)) {
567         vu_set_enable_all_rings(dev, true);
568     }
569 
570     if (dev->iface->set_features) {
571         dev->iface->set_features(dev, dev->features);
572     }
573 
574     return false;
575 }
576 
577 static bool
578 vu_set_owner_exec(VuDev *dev, VhostUserMsg *vmsg)
579 {
580     return false;
581 }
582 
583 static void
584 vu_close_log(VuDev *dev)
585 {
586     if (dev->log_table) {
587         if (munmap(dev->log_table, dev->log_size) != 0) {
588             perror("close log munmap() error");
589         }
590 
591         dev->log_table = NULL;
592     }
593     if (dev->log_call_fd != -1) {
594         close(dev->log_call_fd);
595         dev->log_call_fd = -1;
596     }
597 }
598 
599 static bool
600 vu_reset_device_exec(VuDev *dev, VhostUserMsg *vmsg)
601 {
602     vu_set_enable_all_rings(dev, false);
603 
604     return false;
605 }
606 
607 static bool
608 map_ring(VuDev *dev, VuVirtq *vq)
609 {
610     vq->vring.desc = qva_to_va(dev, vq->vra.desc_user_addr);
611     vq->vring.used = qva_to_va(dev, vq->vra.used_user_addr);
612     vq->vring.avail = qva_to_va(dev, vq->vra.avail_user_addr);
613 
614     DPRINT("Setting virtq addresses:\n");
615     DPRINT("    vring_desc  at %p\n", vq->vring.desc);
616     DPRINT("    vring_used  at %p\n", vq->vring.used);
617     DPRINT("    vring_avail at %p\n", vq->vring.avail);
618 
619     return !(vq->vring.desc && vq->vring.used && vq->vring.avail);
620 }
621 
622 static bool
623 generate_faults(VuDev *dev) {
624     unsigned int i;
625     for (i = 0; i < dev->nregions; i++) {
626         VuDevRegion *dev_region = &dev->regions[i];
627         int ret;
628 #ifdef UFFDIO_REGISTER
629         struct uffdio_register reg_struct;
630 
631         /*
632          * We should already have an open ufd. Mark each memory
633          * range as ufd.
634          * Discard any mapping we have here; note I can't use MADV_REMOVE
635          * or fallocate to make the hole since I don't want to lose
636          * data that's already arrived in the shared process.
637          * TODO: How to do hugepage
638          */
639         ret = madvise((void *)(uintptr_t)dev_region->mmap_addr,
640                       dev_region->size + dev_region->mmap_offset,
641                       MADV_DONTNEED);
642         if (ret) {
643             fprintf(stderr,
644                     "%s: Failed to madvise(DONTNEED) region %d: %s\n",
645                     __func__, i, strerror(errno));
646         }
647         /*
648          * Turn off transparent hugepages so we dont get lose wakeups
649          * in neighbouring pages.
650          * TODO: Turn this backon later.
651          */
652         ret = madvise((void *)(uintptr_t)dev_region->mmap_addr,
653                       dev_region->size + dev_region->mmap_offset,
654                       MADV_NOHUGEPAGE);
655         if (ret) {
656             /*
657              * Note: This can happen legally on kernels that are configured
658              * without madvise'able hugepages
659              */
660             fprintf(stderr,
661                     "%s: Failed to madvise(NOHUGEPAGE) region %d: %s\n",
662                     __func__, i, strerror(errno));
663         }
664 
665         reg_struct.range.start = (uintptr_t)dev_region->mmap_addr;
666         reg_struct.range.len = dev_region->size + dev_region->mmap_offset;
667         reg_struct.mode = UFFDIO_REGISTER_MODE_MISSING;
668 
669         if (ioctl(dev->postcopy_ufd, UFFDIO_REGISTER, &reg_struct)) {
670             vu_panic(dev, "%s: Failed to userfault region %d "
671                           "@%" PRIx64 " + size:%" PRIx64 " offset: %" PRIx64
672                           ": (ufd=%d)%s\n",
673                      __func__, i,
674                      dev_region->mmap_addr,
675                      dev_region->size, dev_region->mmap_offset,
676                      dev->postcopy_ufd, strerror(errno));
677             return false;
678         }
679         if (!(reg_struct.ioctls & ((__u64)1 << _UFFDIO_COPY))) {
680             vu_panic(dev, "%s Region (%d) doesn't support COPY",
681                      __func__, i);
682             return false;
683         }
684         DPRINT("%s: region %d: Registered userfault for %"
685                PRIx64 " + %" PRIx64 "\n", __func__, i,
686                (uint64_t)reg_struct.range.start,
687                (uint64_t)reg_struct.range.len);
688         /* Now it's registered we can let the client at it */
689         if (mprotect((void *)(uintptr_t)dev_region->mmap_addr,
690                      dev_region->size + dev_region->mmap_offset,
691                      PROT_READ | PROT_WRITE)) {
692             vu_panic(dev, "failed to mprotect region %d for postcopy (%s)",
693                      i, strerror(errno));
694             return false;
695         }
696         /* TODO: Stash 'zero' support flags somewhere */
697 #endif
698     }
699 
700     return true;
701 }
702 
703 static bool
704 vu_add_mem_reg(VuDev *dev, VhostUserMsg *vmsg) {
705     int i;
706     bool track_ramblocks = dev->postcopy_listening;
707     VhostUserMemoryRegion m = vmsg->payload.memreg.region, *msg_region = &m;
708     VuDevRegion *dev_region = &dev->regions[dev->nregions];
709     void *mmap_addr;
710 
711     if (vmsg->fd_num != 1) {
712         vmsg_close_fds(vmsg);
713         vu_panic(dev, "VHOST_USER_ADD_MEM_REG received %d fds - only 1 fd "
714                       "should be sent for this message type", vmsg->fd_num);
715         return false;
716     }
717 
718     if (vmsg->size < VHOST_USER_MEM_REG_SIZE) {
719         close(vmsg->fds[0]);
720         vu_panic(dev, "VHOST_USER_ADD_MEM_REG requires a message size of at "
721                       "least %zu bytes and only %d bytes were received",
722                       VHOST_USER_MEM_REG_SIZE, vmsg->size);
723         return false;
724     }
725 
726     if (dev->nregions == VHOST_USER_MAX_RAM_SLOTS) {
727         close(vmsg->fds[0]);
728         vu_panic(dev, "failing attempt to hot add memory via "
729                       "VHOST_USER_ADD_MEM_REG message because the backend has "
730                       "no free ram slots available");
731         return false;
732     }
733 
734     /*
735      * If we are in postcopy mode and we receive a u64 payload with a 0 value
736      * we know all the postcopy client bases have been received, and we
737      * should start generating faults.
738      */
739     if (track_ramblocks &&
740         vmsg->size == sizeof(vmsg->payload.u64) &&
741         vmsg->payload.u64 == 0) {
742         (void)generate_faults(dev);
743         return false;
744     }
745 
746     DPRINT("Adding region: %u\n", dev->nregions);
747     DPRINT("    guest_phys_addr: 0x%016"PRIx64"\n",
748            msg_region->guest_phys_addr);
749     DPRINT("    memory_size:     0x%016"PRIx64"\n",
750            msg_region->memory_size);
751     DPRINT("    userspace_addr   0x%016"PRIx64"\n",
752            msg_region->userspace_addr);
753     DPRINT("    mmap_offset      0x%016"PRIx64"\n",
754            msg_region->mmap_offset);
755 
756     dev_region->gpa = msg_region->guest_phys_addr;
757     dev_region->size = msg_region->memory_size;
758     dev_region->qva = msg_region->userspace_addr;
759     dev_region->mmap_offset = msg_region->mmap_offset;
760 
761     /*
762      * We don't use offset argument of mmap() since the
763      * mapped address has to be page aligned, and we use huge
764      * pages.
765      */
766     if (track_ramblocks) {
767         /*
768          * In postcopy we're using PROT_NONE here to catch anyone
769          * accessing it before we userfault.
770          */
771         mmap_addr = mmap(0, dev_region->size + dev_region->mmap_offset,
772                          PROT_NONE, MAP_SHARED | MAP_NORESERVE,
773                          vmsg->fds[0], 0);
774     } else {
775         mmap_addr = mmap(0, dev_region->size + dev_region->mmap_offset,
776                          PROT_READ | PROT_WRITE, MAP_SHARED | MAP_NORESERVE,
777                          vmsg->fds[0], 0);
778     }
779 
780     if (mmap_addr == MAP_FAILED) {
781         vu_panic(dev, "region mmap error: %s", strerror(errno));
782     } else {
783         dev_region->mmap_addr = (uint64_t)(uintptr_t)mmap_addr;
784         DPRINT("    mmap_addr:       0x%016"PRIx64"\n",
785                dev_region->mmap_addr);
786     }
787 
788     close(vmsg->fds[0]);
789 
790     if (track_ramblocks) {
791         /*
792          * Return the address to QEMU so that it can translate the ufd
793          * fault addresses back.
794          */
795         msg_region->userspace_addr = (uintptr_t)(mmap_addr +
796                                                  dev_region->mmap_offset);
797 
798         /* Send the message back to qemu with the addresses filled in. */
799         vmsg->fd_num = 0;
800         DPRINT("Successfully added new region in postcopy\n");
801         dev->nregions++;
802         return true;
803     } else {
804         for (i = 0; i < dev->max_queues; i++) {
805             if (dev->vq[i].vring.desc) {
806                 if (map_ring(dev, &dev->vq[i])) {
807                     vu_panic(dev, "remapping queue %d for new memory region",
808                              i);
809                 }
810             }
811         }
812 
813         DPRINT("Successfully added new region\n");
814         dev->nregions++;
815         return false;
816     }
817 }
818 
819 static inline bool reg_equal(VuDevRegion *vudev_reg,
820                              VhostUserMemoryRegion *msg_reg)
821 {
822     if (vudev_reg->gpa == msg_reg->guest_phys_addr &&
823         vudev_reg->qva == msg_reg->userspace_addr &&
824         vudev_reg->size == msg_reg->memory_size) {
825         return true;
826     }
827 
828     return false;
829 }
830 
831 static bool
832 vu_rem_mem_reg(VuDev *dev, VhostUserMsg *vmsg) {
833     VhostUserMemoryRegion m = vmsg->payload.memreg.region, *msg_region = &m;
834     unsigned int i;
835     bool found = false;
836 
837     if (vmsg->fd_num > 1) {
838         vmsg_close_fds(vmsg);
839         vu_panic(dev, "VHOST_USER_REM_MEM_REG received %d fds - at most 1 fd "
840                       "should be sent for this message type", vmsg->fd_num);
841         return false;
842     }
843 
844     if (vmsg->size < VHOST_USER_MEM_REG_SIZE) {
845         vmsg_close_fds(vmsg);
846         vu_panic(dev, "VHOST_USER_REM_MEM_REG requires a message size of at "
847                       "least %zu bytes and only %d bytes were received",
848                       VHOST_USER_MEM_REG_SIZE, vmsg->size);
849         return false;
850     }
851 
852     DPRINT("Removing region:\n");
853     DPRINT("    guest_phys_addr: 0x%016"PRIx64"\n",
854            msg_region->guest_phys_addr);
855     DPRINT("    memory_size:     0x%016"PRIx64"\n",
856            msg_region->memory_size);
857     DPRINT("    userspace_addr   0x%016"PRIx64"\n",
858            msg_region->userspace_addr);
859     DPRINT("    mmap_offset      0x%016"PRIx64"\n",
860            msg_region->mmap_offset);
861 
862     for (i = 0; i < dev->nregions; i++) {
863         if (reg_equal(&dev->regions[i], msg_region)) {
864             VuDevRegion *r = &dev->regions[i];
865             void *m = (void *) (uintptr_t) r->mmap_addr;
866 
867             if (m) {
868                 munmap(m, r->size + r->mmap_offset);
869             }
870 
871             /*
872              * Shift all affected entries by 1 to close the hole at index i and
873              * zero out the last entry.
874              */
875             memmove(dev->regions + i, dev->regions + i + 1,
876                     sizeof(VuDevRegion) * (dev->nregions - i - 1));
877             memset(dev->regions + dev->nregions - 1, 0, sizeof(VuDevRegion));
878             DPRINT("Successfully removed a region\n");
879             dev->nregions--;
880             i--;
881 
882             found = true;
883 
884             /* Continue the search for eventual duplicates. */
885         }
886     }
887 
888     if (!found) {
889         vu_panic(dev, "Specified region not found\n");
890     }
891 
892     vmsg_close_fds(vmsg);
893 
894     return false;
895 }
896 
897 static bool
898 vu_set_mem_table_exec_postcopy(VuDev *dev, VhostUserMsg *vmsg)
899 {
900     unsigned int i;
901     VhostUserMemory m = vmsg->payload.memory, *memory = &m;
902     dev->nregions = memory->nregions;
903 
904     DPRINT("Nregions: %u\n", memory->nregions);
905     for (i = 0; i < dev->nregions; i++) {
906         void *mmap_addr;
907         VhostUserMemoryRegion *msg_region = &memory->regions[i];
908         VuDevRegion *dev_region = &dev->regions[i];
909 
910         DPRINT("Region %d\n", i);
911         DPRINT("    guest_phys_addr: 0x%016"PRIx64"\n",
912                msg_region->guest_phys_addr);
913         DPRINT("    memory_size:     0x%016"PRIx64"\n",
914                msg_region->memory_size);
915         DPRINT("    userspace_addr   0x%016"PRIx64"\n",
916                msg_region->userspace_addr);
917         DPRINT("    mmap_offset      0x%016"PRIx64"\n",
918                msg_region->mmap_offset);
919 
920         dev_region->gpa = msg_region->guest_phys_addr;
921         dev_region->size = msg_region->memory_size;
922         dev_region->qva = msg_region->userspace_addr;
923         dev_region->mmap_offset = msg_region->mmap_offset;
924 
925         /* We don't use offset argument of mmap() since the
926          * mapped address has to be page aligned, and we use huge
927          * pages.
928          * In postcopy we're using PROT_NONE here to catch anyone
929          * accessing it before we userfault
930          */
931         mmap_addr = mmap(0, dev_region->size + dev_region->mmap_offset,
932                          PROT_NONE, MAP_SHARED | MAP_NORESERVE,
933                          vmsg->fds[i], 0);
934 
935         if (mmap_addr == MAP_FAILED) {
936             vu_panic(dev, "region mmap error: %s", strerror(errno));
937         } else {
938             dev_region->mmap_addr = (uint64_t)(uintptr_t)mmap_addr;
939             DPRINT("    mmap_addr:       0x%016"PRIx64"\n",
940                    dev_region->mmap_addr);
941         }
942 
943         /* Return the address to QEMU so that it can translate the ufd
944          * fault addresses back.
945          */
946         msg_region->userspace_addr = (uintptr_t)(mmap_addr +
947                                                  dev_region->mmap_offset);
948         close(vmsg->fds[i]);
949     }
950 
951     /* Send the message back to qemu with the addresses filled in */
952     vmsg->fd_num = 0;
953     if (!vu_send_reply(dev, dev->sock, vmsg)) {
954         vu_panic(dev, "failed to respond to set-mem-table for postcopy");
955         return false;
956     }
957 
958     /* Wait for QEMU to confirm that it's registered the handler for the
959      * faults.
960      */
961     if (!dev->read_msg(dev, dev->sock, vmsg) ||
962         vmsg->size != sizeof(vmsg->payload.u64) ||
963         vmsg->payload.u64 != 0) {
964         vu_panic(dev, "failed to receive valid ack for postcopy set-mem-table");
965         return false;
966     }
967 
968     /* OK, now we can go and register the memory and generate faults */
969     (void)generate_faults(dev);
970 
971     return false;
972 }
973 
974 static bool
975 vu_set_mem_table_exec(VuDev *dev, VhostUserMsg *vmsg)
976 {
977     unsigned int i;
978     VhostUserMemory m = vmsg->payload.memory, *memory = &m;
979 
980     for (i = 0; i < dev->nregions; i++) {
981         VuDevRegion *r = &dev->regions[i];
982         void *m = (void *) (uintptr_t) r->mmap_addr;
983 
984         if (m) {
985             munmap(m, r->size + r->mmap_offset);
986         }
987     }
988     dev->nregions = memory->nregions;
989 
990     if (dev->postcopy_listening) {
991         return vu_set_mem_table_exec_postcopy(dev, vmsg);
992     }
993 
994     DPRINT("Nregions: %u\n", memory->nregions);
995     for (i = 0; i < dev->nregions; i++) {
996         void *mmap_addr;
997         VhostUserMemoryRegion *msg_region = &memory->regions[i];
998         VuDevRegion *dev_region = &dev->regions[i];
999 
1000         DPRINT("Region %d\n", i);
1001         DPRINT("    guest_phys_addr: 0x%016"PRIx64"\n",
1002                msg_region->guest_phys_addr);
1003         DPRINT("    memory_size:     0x%016"PRIx64"\n",
1004                msg_region->memory_size);
1005         DPRINT("    userspace_addr   0x%016"PRIx64"\n",
1006                msg_region->userspace_addr);
1007         DPRINT("    mmap_offset      0x%016"PRIx64"\n",
1008                msg_region->mmap_offset);
1009 
1010         dev_region->gpa = msg_region->guest_phys_addr;
1011         dev_region->size = msg_region->memory_size;
1012         dev_region->qva = msg_region->userspace_addr;
1013         dev_region->mmap_offset = msg_region->mmap_offset;
1014 
1015         /* We don't use offset argument of mmap() since the
1016          * mapped address has to be page aligned, and we use huge
1017          * pages.  */
1018         mmap_addr = mmap(0, dev_region->size + dev_region->mmap_offset,
1019                          PROT_READ | PROT_WRITE, MAP_SHARED | MAP_NORESERVE,
1020                          vmsg->fds[i], 0);
1021 
1022         if (mmap_addr == MAP_FAILED) {
1023             vu_panic(dev, "region mmap error: %s", strerror(errno));
1024         } else {
1025             dev_region->mmap_addr = (uint64_t)(uintptr_t)mmap_addr;
1026             DPRINT("    mmap_addr:       0x%016"PRIx64"\n",
1027                    dev_region->mmap_addr);
1028         }
1029 
1030         close(vmsg->fds[i]);
1031     }
1032 
1033     for (i = 0; i < dev->max_queues; i++) {
1034         if (dev->vq[i].vring.desc) {
1035             if (map_ring(dev, &dev->vq[i])) {
1036                 vu_panic(dev, "remapping queue %d during setmemtable", i);
1037             }
1038         }
1039     }
1040 
1041     return false;
1042 }
1043 
1044 static bool
1045 vu_set_log_base_exec(VuDev *dev, VhostUserMsg *vmsg)
1046 {
1047     int fd;
1048     uint64_t log_mmap_size, log_mmap_offset;
1049     void *rc;
1050 
1051     if (vmsg->fd_num != 1 ||
1052         vmsg->size != sizeof(vmsg->payload.log)) {
1053         vu_panic(dev, "Invalid log_base message");
1054         return true;
1055     }
1056 
1057     fd = vmsg->fds[0];
1058     log_mmap_offset = vmsg->payload.log.mmap_offset;
1059     log_mmap_size = vmsg->payload.log.mmap_size;
1060     DPRINT("Log mmap_offset: %"PRId64"\n", log_mmap_offset);
1061     DPRINT("Log mmap_size:   %"PRId64"\n", log_mmap_size);
1062 
1063     rc = mmap(0, log_mmap_size, PROT_READ | PROT_WRITE, MAP_SHARED, fd,
1064               log_mmap_offset);
1065     close(fd);
1066     if (rc == MAP_FAILED) {
1067         perror("log mmap error");
1068     }
1069 
1070     if (dev->log_table) {
1071         munmap(dev->log_table, dev->log_size);
1072     }
1073     dev->log_table = rc;
1074     dev->log_size = log_mmap_size;
1075 
1076     vmsg->size = sizeof(vmsg->payload.u64);
1077     vmsg->fd_num = 0;
1078 
1079     return true;
1080 }
1081 
1082 static bool
1083 vu_set_log_fd_exec(VuDev *dev, VhostUserMsg *vmsg)
1084 {
1085     if (vmsg->fd_num != 1) {
1086         vu_panic(dev, "Invalid log_fd message");
1087         return false;
1088     }
1089 
1090     if (dev->log_call_fd != -1) {
1091         close(dev->log_call_fd);
1092     }
1093     dev->log_call_fd = vmsg->fds[0];
1094     DPRINT("Got log_call_fd: %d\n", vmsg->fds[0]);
1095 
1096     return false;
1097 }
1098 
1099 static bool
1100 vu_set_vring_num_exec(VuDev *dev, VhostUserMsg *vmsg)
1101 {
1102     unsigned int index = vmsg->payload.state.index;
1103     unsigned int num = vmsg->payload.state.num;
1104 
1105     DPRINT("State.index: %u\n", index);
1106     DPRINT("State.num:   %u\n", num);
1107     dev->vq[index].vring.num = num;
1108 
1109     return false;
1110 }
1111 
1112 static bool
1113 vu_set_vring_addr_exec(VuDev *dev, VhostUserMsg *vmsg)
1114 {
1115     struct vhost_vring_addr addr = vmsg->payload.addr, *vra = &addr;
1116     unsigned int index = vra->index;
1117     VuVirtq *vq = &dev->vq[index];
1118 
1119     DPRINT("vhost_vring_addr:\n");
1120     DPRINT("    index:  %d\n", vra->index);
1121     DPRINT("    flags:  %d\n", vra->flags);
1122     DPRINT("    desc_user_addr:   0x%016" PRIx64 "\n", (uint64_t)vra->desc_user_addr);
1123     DPRINT("    used_user_addr:   0x%016" PRIx64 "\n", (uint64_t)vra->used_user_addr);
1124     DPRINT("    avail_user_addr:  0x%016" PRIx64 "\n", (uint64_t)vra->avail_user_addr);
1125     DPRINT("    log_guest_addr:   0x%016" PRIx64 "\n", (uint64_t)vra->log_guest_addr);
1126 
1127     vq->vra = *vra;
1128     vq->vring.flags = vra->flags;
1129     vq->vring.log_guest_addr = vra->log_guest_addr;
1130 
1131 
1132     if (map_ring(dev, vq)) {
1133         vu_panic(dev, "Invalid vring_addr message");
1134         return false;
1135     }
1136 
1137     vq->used_idx = le16toh(vq->vring.used->idx);
1138 
1139     if (vq->last_avail_idx != vq->used_idx) {
1140         bool resume = dev->iface->queue_is_processed_in_order &&
1141             dev->iface->queue_is_processed_in_order(dev, index);
1142 
1143         DPRINT("Last avail index != used index: %u != %u%s\n",
1144                vq->last_avail_idx, vq->used_idx,
1145                resume ? ", resuming" : "");
1146 
1147         if (resume) {
1148             vq->shadow_avail_idx = vq->last_avail_idx = vq->used_idx;
1149         }
1150     }
1151 
1152     return false;
1153 }
1154 
1155 static bool
1156 vu_set_vring_base_exec(VuDev *dev, VhostUserMsg *vmsg)
1157 {
1158     unsigned int index = vmsg->payload.state.index;
1159     unsigned int num = vmsg->payload.state.num;
1160 
1161     DPRINT("State.index: %u\n", index);
1162     DPRINT("State.num:   %u\n", num);
1163     dev->vq[index].shadow_avail_idx = dev->vq[index].last_avail_idx = num;
1164 
1165     return false;
1166 }
1167 
1168 static bool
1169 vu_get_vring_base_exec(VuDev *dev, VhostUserMsg *vmsg)
1170 {
1171     unsigned int index = vmsg->payload.state.index;
1172 
1173     DPRINT("State.index: %u\n", index);
1174     vmsg->payload.state.num = dev->vq[index].last_avail_idx;
1175     vmsg->size = sizeof(vmsg->payload.state);
1176 
1177     dev->vq[index].started = false;
1178     if (dev->iface->queue_set_started) {
1179         dev->iface->queue_set_started(dev, index, false);
1180     }
1181 
1182     if (dev->vq[index].call_fd != -1) {
1183         close(dev->vq[index].call_fd);
1184         dev->vq[index].call_fd = -1;
1185     }
1186     if (dev->vq[index].kick_fd != -1) {
1187         dev->remove_watch(dev, dev->vq[index].kick_fd);
1188         close(dev->vq[index].kick_fd);
1189         dev->vq[index].kick_fd = -1;
1190     }
1191 
1192     return true;
1193 }
1194 
1195 static bool
1196 vu_check_queue_msg_file(VuDev *dev, VhostUserMsg *vmsg)
1197 {
1198     int index = vmsg->payload.u64 & VHOST_USER_VRING_IDX_MASK;
1199     bool nofd = vmsg->payload.u64 & VHOST_USER_VRING_NOFD_MASK;
1200 
1201     if (index >= dev->max_queues) {
1202         vmsg_close_fds(vmsg);
1203         vu_panic(dev, "Invalid queue index: %u", index);
1204         return false;
1205     }
1206 
1207     if (nofd) {
1208         vmsg_close_fds(vmsg);
1209         return true;
1210     }
1211 
1212     if (vmsg->fd_num != 1) {
1213         vmsg_close_fds(vmsg);
1214         vu_panic(dev, "Invalid fds in request: %d", vmsg->request);
1215         return false;
1216     }
1217 
1218     return true;
1219 }
1220 
1221 static int
1222 inflight_desc_compare(const void *a, const void *b)
1223 {
1224     VuVirtqInflightDesc *desc0 = (VuVirtqInflightDesc *)a,
1225                         *desc1 = (VuVirtqInflightDesc *)b;
1226 
1227     if (desc1->counter > desc0->counter &&
1228         (desc1->counter - desc0->counter) < VIRTQUEUE_MAX_SIZE * 2) {
1229         return 1;
1230     }
1231 
1232     return -1;
1233 }
1234 
1235 static int
1236 vu_check_queue_inflights(VuDev *dev, VuVirtq *vq)
1237 {
1238     int i = 0;
1239 
1240     if (!vu_has_protocol_feature(dev, VHOST_USER_PROTOCOL_F_INFLIGHT_SHMFD)) {
1241         return 0;
1242     }
1243 
1244     if (unlikely(!vq->inflight)) {
1245         return -1;
1246     }
1247 
1248     if (unlikely(!vq->inflight->version)) {
1249         /* initialize the buffer */
1250         vq->inflight->version = INFLIGHT_VERSION;
1251         return 0;
1252     }
1253 
1254     vq->used_idx = le16toh(vq->vring.used->idx);
1255     vq->resubmit_num = 0;
1256     vq->resubmit_list = NULL;
1257     vq->counter = 0;
1258 
1259     if (unlikely(vq->inflight->used_idx != vq->used_idx)) {
1260         vq->inflight->desc[vq->inflight->last_batch_head].inflight = 0;
1261 
1262         barrier();
1263 
1264         vq->inflight->used_idx = vq->used_idx;
1265     }
1266 
1267     for (i = 0; i < vq->inflight->desc_num; i++) {
1268         if (vq->inflight->desc[i].inflight == 1) {
1269             vq->inuse++;
1270         }
1271     }
1272 
1273     vq->shadow_avail_idx = vq->last_avail_idx = vq->inuse + vq->used_idx;
1274 
1275     if (vq->inuse) {
1276         vq->resubmit_list = calloc(vq->inuse, sizeof(VuVirtqInflightDesc));
1277         if (!vq->resubmit_list) {
1278             return -1;
1279         }
1280 
1281         for (i = 0; i < vq->inflight->desc_num; i++) {
1282             if (vq->inflight->desc[i].inflight) {
1283                 vq->resubmit_list[vq->resubmit_num].index = i;
1284                 vq->resubmit_list[vq->resubmit_num].counter =
1285                                         vq->inflight->desc[i].counter;
1286                 vq->resubmit_num++;
1287             }
1288         }
1289 
1290         if (vq->resubmit_num > 1) {
1291             qsort(vq->resubmit_list, vq->resubmit_num,
1292                   sizeof(VuVirtqInflightDesc), inflight_desc_compare);
1293         }
1294         vq->counter = vq->resubmit_list[0].counter + 1;
1295     }
1296 
1297     /* in case of I/O hang after reconnecting */
1298     if (eventfd_write(vq->kick_fd, 1)) {
1299         return -1;
1300     }
1301 
1302     return 0;
1303 }
1304 
1305 static bool
1306 vu_set_vring_kick_exec(VuDev *dev, VhostUserMsg *vmsg)
1307 {
1308     int index = vmsg->payload.u64 & VHOST_USER_VRING_IDX_MASK;
1309     bool nofd = vmsg->payload.u64 & VHOST_USER_VRING_NOFD_MASK;
1310 
1311     DPRINT("u64: 0x%016"PRIx64"\n", vmsg->payload.u64);
1312 
1313     if (!vu_check_queue_msg_file(dev, vmsg)) {
1314         return false;
1315     }
1316 
1317     if (dev->vq[index].kick_fd != -1) {
1318         dev->remove_watch(dev, dev->vq[index].kick_fd);
1319         close(dev->vq[index].kick_fd);
1320         dev->vq[index].kick_fd = -1;
1321     }
1322 
1323     dev->vq[index].kick_fd = nofd ? -1 : vmsg->fds[0];
1324     DPRINT("Got kick_fd: %d for vq: %d\n", dev->vq[index].kick_fd, index);
1325 
1326     dev->vq[index].started = true;
1327     if (dev->iface->queue_set_started) {
1328         dev->iface->queue_set_started(dev, index, true);
1329     }
1330 
1331     if (dev->vq[index].kick_fd != -1 && dev->vq[index].handler) {
1332         dev->set_watch(dev, dev->vq[index].kick_fd, VU_WATCH_IN,
1333                        vu_kick_cb, (void *)(long)index);
1334 
1335         DPRINT("Waiting for kicks on fd: %d for vq: %d\n",
1336                dev->vq[index].kick_fd, index);
1337     }
1338 
1339     if (vu_check_queue_inflights(dev, &dev->vq[index])) {
1340         vu_panic(dev, "Failed to check inflights for vq: %d\n", index);
1341     }
1342 
1343     return false;
1344 }
1345 
1346 void vu_set_queue_handler(VuDev *dev, VuVirtq *vq,
1347                           vu_queue_handler_cb handler)
1348 {
1349     int qidx = vq - dev->vq;
1350 
1351     vq->handler = handler;
1352     if (vq->kick_fd >= 0) {
1353         if (handler) {
1354             dev->set_watch(dev, vq->kick_fd, VU_WATCH_IN,
1355                            vu_kick_cb, (void *)(long)qidx);
1356         } else {
1357             dev->remove_watch(dev, vq->kick_fd);
1358         }
1359     }
1360 }
1361 
1362 bool vu_set_queue_host_notifier(VuDev *dev, VuVirtq *vq, int fd,
1363                                 int size, int offset)
1364 {
1365     int qidx = vq - dev->vq;
1366     int fd_num = 0;
1367     VhostUserMsg vmsg = {
1368         .request = VHOST_USER_SLAVE_VRING_HOST_NOTIFIER_MSG,
1369         .flags = VHOST_USER_VERSION | VHOST_USER_NEED_REPLY_MASK,
1370         .size = sizeof(vmsg.payload.area),
1371         .payload.area = {
1372             .u64 = qidx & VHOST_USER_VRING_IDX_MASK,
1373             .size = size,
1374             .offset = offset,
1375         },
1376     };
1377 
1378     if (fd == -1) {
1379         vmsg.payload.area.u64 |= VHOST_USER_VRING_NOFD_MASK;
1380     } else {
1381         vmsg.fds[fd_num++] = fd;
1382     }
1383 
1384     vmsg.fd_num = fd_num;
1385 
1386     if (!vu_has_protocol_feature(dev, VHOST_USER_PROTOCOL_F_SLAVE_SEND_FD)) {
1387         return false;
1388     }
1389 
1390     pthread_mutex_lock(&dev->slave_mutex);
1391     if (!vu_message_write(dev, dev->slave_fd, &vmsg)) {
1392         pthread_mutex_unlock(&dev->slave_mutex);
1393         return false;
1394     }
1395 
1396     /* Also unlocks the slave_mutex */
1397     return vu_process_message_reply(dev, &vmsg);
1398 }
1399 
1400 static bool
1401 vu_set_vring_call_exec(VuDev *dev, VhostUserMsg *vmsg)
1402 {
1403     int index = vmsg->payload.u64 & VHOST_USER_VRING_IDX_MASK;
1404     bool nofd = vmsg->payload.u64 & VHOST_USER_VRING_NOFD_MASK;
1405 
1406     DPRINT("u64: 0x%016"PRIx64"\n", vmsg->payload.u64);
1407 
1408     if (!vu_check_queue_msg_file(dev, vmsg)) {
1409         return false;
1410     }
1411 
1412     if (dev->vq[index].call_fd != -1) {
1413         close(dev->vq[index].call_fd);
1414         dev->vq[index].call_fd = -1;
1415     }
1416 
1417     dev->vq[index].call_fd = nofd ? -1 : vmsg->fds[0];
1418 
1419     /* in case of I/O hang after reconnecting */
1420     if (dev->vq[index].call_fd != -1 && eventfd_write(vmsg->fds[0], 1)) {
1421         return -1;
1422     }
1423 
1424     DPRINT("Got call_fd: %d for vq: %d\n", dev->vq[index].call_fd, index);
1425 
1426     return false;
1427 }
1428 
1429 static bool
1430 vu_set_vring_err_exec(VuDev *dev, VhostUserMsg *vmsg)
1431 {
1432     int index = vmsg->payload.u64 & VHOST_USER_VRING_IDX_MASK;
1433     bool nofd = vmsg->payload.u64 & VHOST_USER_VRING_NOFD_MASK;
1434 
1435     DPRINT("u64: 0x%016"PRIx64"\n", vmsg->payload.u64);
1436 
1437     if (!vu_check_queue_msg_file(dev, vmsg)) {
1438         return false;
1439     }
1440 
1441     if (dev->vq[index].err_fd != -1) {
1442         close(dev->vq[index].err_fd);
1443         dev->vq[index].err_fd = -1;
1444     }
1445 
1446     dev->vq[index].err_fd = nofd ? -1 : vmsg->fds[0];
1447 
1448     return false;
1449 }
1450 
1451 static bool
1452 vu_get_protocol_features_exec(VuDev *dev, VhostUserMsg *vmsg)
1453 {
1454     /*
1455      * Note that we support, but intentionally do not set,
1456      * VHOST_USER_PROTOCOL_F_INBAND_NOTIFICATIONS. This means that
1457      * a device implementation can return it in its callback
1458      * (get_protocol_features) if it wants to use this for
1459      * simulation, but it is otherwise not desirable (if even
1460      * implemented by the master.)
1461      */
1462     uint64_t features = 1ULL << VHOST_USER_PROTOCOL_F_MQ |
1463                         1ULL << VHOST_USER_PROTOCOL_F_LOG_SHMFD |
1464                         1ULL << VHOST_USER_PROTOCOL_F_SLAVE_REQ |
1465                         1ULL << VHOST_USER_PROTOCOL_F_HOST_NOTIFIER |
1466                         1ULL << VHOST_USER_PROTOCOL_F_SLAVE_SEND_FD |
1467                         1ULL << VHOST_USER_PROTOCOL_F_REPLY_ACK |
1468                         1ULL << VHOST_USER_PROTOCOL_F_CONFIGURE_MEM_SLOTS;
1469 
1470     if (have_userfault()) {
1471         features |= 1ULL << VHOST_USER_PROTOCOL_F_PAGEFAULT;
1472     }
1473 
1474     if (dev->iface->get_config && dev->iface->set_config) {
1475         features |= 1ULL << VHOST_USER_PROTOCOL_F_CONFIG;
1476     }
1477 
1478     if (dev->iface->get_protocol_features) {
1479         features |= dev->iface->get_protocol_features(dev);
1480     }
1481 
1482     vmsg_set_reply_u64(vmsg, features);
1483     return true;
1484 }
1485 
1486 static bool
1487 vu_set_protocol_features_exec(VuDev *dev, VhostUserMsg *vmsg)
1488 {
1489     uint64_t features = vmsg->payload.u64;
1490 
1491     DPRINT("u64: 0x%016"PRIx64"\n", features);
1492 
1493     dev->protocol_features = vmsg->payload.u64;
1494 
1495     if (vu_has_protocol_feature(dev,
1496                                 VHOST_USER_PROTOCOL_F_INBAND_NOTIFICATIONS) &&
1497         (!vu_has_protocol_feature(dev, VHOST_USER_PROTOCOL_F_SLAVE_REQ) ||
1498          !vu_has_protocol_feature(dev, VHOST_USER_PROTOCOL_F_REPLY_ACK))) {
1499         /*
1500          * The use case for using messages for kick/call is simulation, to make
1501          * the kick and call synchronous. To actually get that behaviour, both
1502          * of the other features are required.
1503          * Theoretically, one could use only kick messages, or do them without
1504          * having F_REPLY_ACK, but too many (possibly pending) messages on the
1505          * socket will eventually cause the master to hang, to avoid this in
1506          * scenarios where not desired enforce that the settings are in a way
1507          * that actually enables the simulation case.
1508          */
1509         vu_panic(dev,
1510                  "F_IN_BAND_NOTIFICATIONS requires F_SLAVE_REQ && F_REPLY_ACK");
1511         return false;
1512     }
1513 
1514     if (dev->iface->set_protocol_features) {
1515         dev->iface->set_protocol_features(dev, features);
1516     }
1517 
1518     return false;
1519 }
1520 
1521 static bool
1522 vu_get_queue_num_exec(VuDev *dev, VhostUserMsg *vmsg)
1523 {
1524     vmsg_set_reply_u64(vmsg, dev->max_queues);
1525     return true;
1526 }
1527 
1528 static bool
1529 vu_set_vring_enable_exec(VuDev *dev, VhostUserMsg *vmsg)
1530 {
1531     unsigned int index = vmsg->payload.state.index;
1532     unsigned int enable = vmsg->payload.state.num;
1533 
1534     DPRINT("State.index: %u\n", index);
1535     DPRINT("State.enable:   %u\n", enable);
1536 
1537     if (index >= dev->max_queues) {
1538         vu_panic(dev, "Invalid vring_enable index: %u", index);
1539         return false;
1540     }
1541 
1542     dev->vq[index].enable = enable;
1543     return false;
1544 }
1545 
1546 static bool
1547 vu_set_slave_req_fd(VuDev *dev, VhostUserMsg *vmsg)
1548 {
1549     if (vmsg->fd_num != 1) {
1550         vu_panic(dev, "Invalid slave_req_fd message (%d fd's)", vmsg->fd_num);
1551         return false;
1552     }
1553 
1554     if (dev->slave_fd != -1) {
1555         close(dev->slave_fd);
1556     }
1557     dev->slave_fd = vmsg->fds[0];
1558     DPRINT("Got slave_fd: %d\n", vmsg->fds[0]);
1559 
1560     return false;
1561 }
1562 
1563 static bool
1564 vu_get_config(VuDev *dev, VhostUserMsg *vmsg)
1565 {
1566     int ret = -1;
1567 
1568     if (dev->iface->get_config) {
1569         ret = dev->iface->get_config(dev, vmsg->payload.config.region,
1570                                      vmsg->payload.config.size);
1571     }
1572 
1573     if (ret) {
1574         /* resize to zero to indicate an error to master */
1575         vmsg->size = 0;
1576     }
1577 
1578     return true;
1579 }
1580 
1581 static bool
1582 vu_set_config(VuDev *dev, VhostUserMsg *vmsg)
1583 {
1584     int ret = -1;
1585 
1586     if (dev->iface->set_config) {
1587         ret = dev->iface->set_config(dev, vmsg->payload.config.region,
1588                                      vmsg->payload.config.offset,
1589                                      vmsg->payload.config.size,
1590                                      vmsg->payload.config.flags);
1591         if (ret) {
1592             vu_panic(dev, "Set virtio configuration space failed");
1593         }
1594     }
1595 
1596     return false;
1597 }
1598 
1599 static bool
1600 vu_set_postcopy_advise(VuDev *dev, VhostUserMsg *vmsg)
1601 {
1602 #ifdef UFFDIO_API
1603     struct uffdio_api api_struct;
1604 
1605     dev->postcopy_ufd = syscall(__NR_userfaultfd, O_CLOEXEC | O_NONBLOCK);
1606     vmsg->size = 0;
1607 #else
1608     dev->postcopy_ufd = -1;
1609 #endif
1610 
1611     if (dev->postcopy_ufd == -1) {
1612         vu_panic(dev, "Userfaultfd not available: %s", strerror(errno));
1613         goto out;
1614     }
1615 
1616 #ifdef UFFDIO_API
1617     api_struct.api = UFFD_API;
1618     api_struct.features = 0;
1619     if (ioctl(dev->postcopy_ufd, UFFDIO_API, &api_struct)) {
1620         vu_panic(dev, "Failed UFFDIO_API: %s", strerror(errno));
1621         close(dev->postcopy_ufd);
1622         dev->postcopy_ufd = -1;
1623         goto out;
1624     }
1625     /* TODO: Stash feature flags somewhere */
1626 #endif
1627 
1628 out:
1629     /* Return a ufd to the QEMU */
1630     vmsg->fd_num = 1;
1631     vmsg->fds[0] = dev->postcopy_ufd;
1632     return true; /* = send a reply */
1633 }
1634 
1635 static bool
1636 vu_set_postcopy_listen(VuDev *dev, VhostUserMsg *vmsg)
1637 {
1638     if (dev->nregions) {
1639         vu_panic(dev, "Regions already registered at postcopy-listen");
1640         vmsg_set_reply_u64(vmsg, -1);
1641         return true;
1642     }
1643     dev->postcopy_listening = true;
1644 
1645     vmsg_set_reply_u64(vmsg, 0);
1646     return true;
1647 }
1648 
1649 static bool
1650 vu_set_postcopy_end(VuDev *dev, VhostUserMsg *vmsg)
1651 {
1652     DPRINT("%s: Entry\n", __func__);
1653     dev->postcopy_listening = false;
1654     if (dev->postcopy_ufd > 0) {
1655         close(dev->postcopy_ufd);
1656         dev->postcopy_ufd = -1;
1657         DPRINT("%s: Done close\n", __func__);
1658     }
1659 
1660     vmsg_set_reply_u64(vmsg, 0);
1661     DPRINT("%s: exit\n", __func__);
1662     return true;
1663 }
1664 
1665 static inline uint64_t
1666 vu_inflight_queue_size(uint16_t queue_size)
1667 {
1668     return ALIGN_UP(sizeof(VuDescStateSplit) * queue_size +
1669            sizeof(uint16_t), INFLIGHT_ALIGNMENT);
1670 }
1671 
1672 #ifdef MFD_ALLOW_SEALING
1673 static void *
1674 memfd_alloc(const char *name, size_t size, unsigned int flags, int *fd)
1675 {
1676     void *ptr;
1677     int ret;
1678 
1679     *fd = memfd_create(name, MFD_ALLOW_SEALING);
1680     if (*fd < 0) {
1681         return NULL;
1682     }
1683 
1684     ret = ftruncate(*fd, size);
1685     if (ret < 0) {
1686         close(*fd);
1687         return NULL;
1688     }
1689 
1690     ret = fcntl(*fd, F_ADD_SEALS, flags);
1691     if (ret < 0) {
1692         close(*fd);
1693         return NULL;
1694     }
1695 
1696     ptr = mmap(0, size, PROT_READ | PROT_WRITE, MAP_SHARED, *fd, 0);
1697     if (ptr == MAP_FAILED) {
1698         close(*fd);
1699         return NULL;
1700     }
1701 
1702     return ptr;
1703 }
1704 #endif
1705 
1706 static bool
1707 vu_get_inflight_fd(VuDev *dev, VhostUserMsg *vmsg)
1708 {
1709     int fd = -1;
1710     void *addr = NULL;
1711     uint64_t mmap_size;
1712     uint16_t num_queues, queue_size;
1713 
1714     if (vmsg->size != sizeof(vmsg->payload.inflight)) {
1715         vu_panic(dev, "Invalid get_inflight_fd message:%d", vmsg->size);
1716         vmsg->payload.inflight.mmap_size = 0;
1717         return true;
1718     }
1719 
1720     num_queues = vmsg->payload.inflight.num_queues;
1721     queue_size = vmsg->payload.inflight.queue_size;
1722 
1723     DPRINT("set_inflight_fd num_queues: %"PRId16"\n", num_queues);
1724     DPRINT("set_inflight_fd queue_size: %"PRId16"\n", queue_size);
1725 
1726     mmap_size = vu_inflight_queue_size(queue_size) * num_queues;
1727 
1728 #ifdef MFD_ALLOW_SEALING
1729     addr = memfd_alloc("vhost-inflight", mmap_size,
1730                        F_SEAL_GROW | F_SEAL_SHRINK | F_SEAL_SEAL,
1731                        &fd);
1732 #else
1733     vu_panic(dev, "Not implemented: memfd support is missing");
1734 #endif
1735 
1736     if (!addr) {
1737         vu_panic(dev, "Failed to alloc vhost inflight area");
1738         vmsg->payload.inflight.mmap_size = 0;
1739         return true;
1740     }
1741 
1742     memset(addr, 0, mmap_size);
1743 
1744     dev->inflight_info.addr = addr;
1745     dev->inflight_info.size = vmsg->payload.inflight.mmap_size = mmap_size;
1746     dev->inflight_info.fd = vmsg->fds[0] = fd;
1747     vmsg->fd_num = 1;
1748     vmsg->payload.inflight.mmap_offset = 0;
1749 
1750     DPRINT("send inflight mmap_size: %"PRId64"\n",
1751            vmsg->payload.inflight.mmap_size);
1752     DPRINT("send inflight mmap offset: %"PRId64"\n",
1753            vmsg->payload.inflight.mmap_offset);
1754 
1755     return true;
1756 }
1757 
1758 static bool
1759 vu_set_inflight_fd(VuDev *dev, VhostUserMsg *vmsg)
1760 {
1761     int fd, i;
1762     uint64_t mmap_size, mmap_offset;
1763     uint16_t num_queues, queue_size;
1764     void *rc;
1765 
1766     if (vmsg->fd_num != 1 ||
1767         vmsg->size != sizeof(vmsg->payload.inflight)) {
1768         vu_panic(dev, "Invalid set_inflight_fd message size:%d fds:%d",
1769                  vmsg->size, vmsg->fd_num);
1770         return false;
1771     }
1772 
1773     fd = vmsg->fds[0];
1774     mmap_size = vmsg->payload.inflight.mmap_size;
1775     mmap_offset = vmsg->payload.inflight.mmap_offset;
1776     num_queues = vmsg->payload.inflight.num_queues;
1777     queue_size = vmsg->payload.inflight.queue_size;
1778 
1779     DPRINT("set_inflight_fd mmap_size: %"PRId64"\n", mmap_size);
1780     DPRINT("set_inflight_fd mmap_offset: %"PRId64"\n", mmap_offset);
1781     DPRINT("set_inflight_fd num_queues: %"PRId16"\n", num_queues);
1782     DPRINT("set_inflight_fd queue_size: %"PRId16"\n", queue_size);
1783 
1784     rc = mmap(0, mmap_size, PROT_READ | PROT_WRITE, MAP_SHARED,
1785               fd, mmap_offset);
1786 
1787     if (rc == MAP_FAILED) {
1788         vu_panic(dev, "set_inflight_fd mmap error: %s", strerror(errno));
1789         return false;
1790     }
1791 
1792     if (dev->inflight_info.fd) {
1793         close(dev->inflight_info.fd);
1794     }
1795 
1796     if (dev->inflight_info.addr) {
1797         munmap(dev->inflight_info.addr, dev->inflight_info.size);
1798     }
1799 
1800     dev->inflight_info.fd = fd;
1801     dev->inflight_info.addr = rc;
1802     dev->inflight_info.size = mmap_size;
1803 
1804     for (i = 0; i < num_queues; i++) {
1805         dev->vq[i].inflight = (VuVirtqInflight *)rc;
1806         dev->vq[i].inflight->desc_num = queue_size;
1807         rc = (void *)((char *)rc + vu_inflight_queue_size(queue_size));
1808     }
1809 
1810     return false;
1811 }
1812 
1813 static bool
1814 vu_handle_vring_kick(VuDev *dev, VhostUserMsg *vmsg)
1815 {
1816     unsigned int index = vmsg->payload.state.index;
1817 
1818     if (index >= dev->max_queues) {
1819         vu_panic(dev, "Invalid queue index: %u", index);
1820         return false;
1821     }
1822 
1823     DPRINT("Got kick message: handler:%p idx:%u\n",
1824            dev->vq[index].handler, index);
1825 
1826     if (!dev->vq[index].started) {
1827         dev->vq[index].started = true;
1828 
1829         if (dev->iface->queue_set_started) {
1830             dev->iface->queue_set_started(dev, index, true);
1831         }
1832     }
1833 
1834     if (dev->vq[index].handler) {
1835         dev->vq[index].handler(dev, index);
1836     }
1837 
1838     return false;
1839 }
1840 
1841 static bool vu_handle_get_max_memslots(VuDev *dev, VhostUserMsg *vmsg)
1842 {
1843     vmsg_set_reply_u64(vmsg, VHOST_USER_MAX_RAM_SLOTS);
1844 
1845     DPRINT("u64: 0x%016"PRIx64"\n", (uint64_t) VHOST_USER_MAX_RAM_SLOTS);
1846 
1847     return true;
1848 }
1849 
1850 static bool
1851 vu_process_message(VuDev *dev, VhostUserMsg *vmsg)
1852 {
1853     int do_reply = 0;
1854 
1855     /* Print out generic part of the request. */
1856     DPRINT("================ Vhost user message ================\n");
1857     DPRINT("Request: %s (%d)\n", vu_request_to_string(vmsg->request),
1858            vmsg->request);
1859     DPRINT("Flags:   0x%x\n", vmsg->flags);
1860     DPRINT("Size:    %u\n", vmsg->size);
1861 
1862     if (vmsg->fd_num) {
1863         int i;
1864         DPRINT("Fds:");
1865         for (i = 0; i < vmsg->fd_num; i++) {
1866             DPRINT(" %d", vmsg->fds[i]);
1867         }
1868         DPRINT("\n");
1869     }
1870 
1871     if (dev->iface->process_msg &&
1872         dev->iface->process_msg(dev, vmsg, &do_reply)) {
1873         return do_reply;
1874     }
1875 
1876     switch (vmsg->request) {
1877     case VHOST_USER_GET_FEATURES:
1878         return vu_get_features_exec(dev, vmsg);
1879     case VHOST_USER_SET_FEATURES:
1880         return vu_set_features_exec(dev, vmsg);
1881     case VHOST_USER_GET_PROTOCOL_FEATURES:
1882         return vu_get_protocol_features_exec(dev, vmsg);
1883     case VHOST_USER_SET_PROTOCOL_FEATURES:
1884         return vu_set_protocol_features_exec(dev, vmsg);
1885     case VHOST_USER_SET_OWNER:
1886         return vu_set_owner_exec(dev, vmsg);
1887     case VHOST_USER_RESET_OWNER:
1888         return vu_reset_device_exec(dev, vmsg);
1889     case VHOST_USER_SET_MEM_TABLE:
1890         return vu_set_mem_table_exec(dev, vmsg);
1891     case VHOST_USER_SET_LOG_BASE:
1892         return vu_set_log_base_exec(dev, vmsg);
1893     case VHOST_USER_SET_LOG_FD:
1894         return vu_set_log_fd_exec(dev, vmsg);
1895     case VHOST_USER_SET_VRING_NUM:
1896         return vu_set_vring_num_exec(dev, vmsg);
1897     case VHOST_USER_SET_VRING_ADDR:
1898         return vu_set_vring_addr_exec(dev, vmsg);
1899     case VHOST_USER_SET_VRING_BASE:
1900         return vu_set_vring_base_exec(dev, vmsg);
1901     case VHOST_USER_GET_VRING_BASE:
1902         return vu_get_vring_base_exec(dev, vmsg);
1903     case VHOST_USER_SET_VRING_KICK:
1904         return vu_set_vring_kick_exec(dev, vmsg);
1905     case VHOST_USER_SET_VRING_CALL:
1906         return vu_set_vring_call_exec(dev, vmsg);
1907     case VHOST_USER_SET_VRING_ERR:
1908         return vu_set_vring_err_exec(dev, vmsg);
1909     case VHOST_USER_GET_QUEUE_NUM:
1910         return vu_get_queue_num_exec(dev, vmsg);
1911     case VHOST_USER_SET_VRING_ENABLE:
1912         return vu_set_vring_enable_exec(dev, vmsg);
1913     case VHOST_USER_SET_SLAVE_REQ_FD:
1914         return vu_set_slave_req_fd(dev, vmsg);
1915     case VHOST_USER_GET_CONFIG:
1916         return vu_get_config(dev, vmsg);
1917     case VHOST_USER_SET_CONFIG:
1918         return vu_set_config(dev, vmsg);
1919     case VHOST_USER_NONE:
1920         /* if you need processing before exit, override iface->process_msg */
1921         exit(0);
1922     case VHOST_USER_POSTCOPY_ADVISE:
1923         return vu_set_postcopy_advise(dev, vmsg);
1924     case VHOST_USER_POSTCOPY_LISTEN:
1925         return vu_set_postcopy_listen(dev, vmsg);
1926     case VHOST_USER_POSTCOPY_END:
1927         return vu_set_postcopy_end(dev, vmsg);
1928     case VHOST_USER_GET_INFLIGHT_FD:
1929         return vu_get_inflight_fd(dev, vmsg);
1930     case VHOST_USER_SET_INFLIGHT_FD:
1931         return vu_set_inflight_fd(dev, vmsg);
1932     case VHOST_USER_VRING_KICK:
1933         return vu_handle_vring_kick(dev, vmsg);
1934     case VHOST_USER_GET_MAX_MEM_SLOTS:
1935         return vu_handle_get_max_memslots(dev, vmsg);
1936     case VHOST_USER_ADD_MEM_REG:
1937         return vu_add_mem_reg(dev, vmsg);
1938     case VHOST_USER_REM_MEM_REG:
1939         return vu_rem_mem_reg(dev, vmsg);
1940     default:
1941         vmsg_close_fds(vmsg);
1942         vu_panic(dev, "Unhandled request: %d", vmsg->request);
1943     }
1944 
1945     return false;
1946 }
1947 
1948 bool
1949 vu_dispatch(VuDev *dev)
1950 {
1951     VhostUserMsg vmsg = { 0, };
1952     int reply_requested;
1953     bool need_reply, success = false;
1954 
1955     if (!dev->read_msg(dev, dev->sock, &vmsg)) {
1956         goto end;
1957     }
1958 
1959     need_reply = vmsg.flags & VHOST_USER_NEED_REPLY_MASK;
1960 
1961     reply_requested = vu_process_message(dev, &vmsg);
1962     if (!reply_requested && need_reply) {
1963         vmsg_set_reply_u64(&vmsg, 0);
1964         reply_requested = 1;
1965     }
1966 
1967     if (!reply_requested) {
1968         success = true;
1969         goto end;
1970     }
1971 
1972     if (!vu_send_reply(dev, dev->sock, &vmsg)) {
1973         goto end;
1974     }
1975 
1976     success = true;
1977 
1978 end:
1979     free(vmsg.data);
1980     return success;
1981 }
1982 
1983 void
1984 vu_deinit(VuDev *dev)
1985 {
1986     unsigned int i;
1987 
1988     for (i = 0; i < dev->nregions; i++) {
1989         VuDevRegion *r = &dev->regions[i];
1990         void *m = (void *) (uintptr_t) r->mmap_addr;
1991         if (m != MAP_FAILED) {
1992             munmap(m, r->size + r->mmap_offset);
1993         }
1994     }
1995     dev->nregions = 0;
1996 
1997     for (i = 0; i < dev->max_queues; i++) {
1998         VuVirtq *vq = &dev->vq[i];
1999 
2000         if (vq->call_fd != -1) {
2001             close(vq->call_fd);
2002             vq->call_fd = -1;
2003         }
2004 
2005         if (vq->kick_fd != -1) {
2006             dev->remove_watch(dev, vq->kick_fd);
2007             close(vq->kick_fd);
2008             vq->kick_fd = -1;
2009         }
2010 
2011         if (vq->err_fd != -1) {
2012             close(vq->err_fd);
2013             vq->err_fd = -1;
2014         }
2015 
2016         if (vq->resubmit_list) {
2017             free(vq->resubmit_list);
2018             vq->resubmit_list = NULL;
2019         }
2020 
2021         vq->inflight = NULL;
2022     }
2023 
2024     if (dev->inflight_info.addr) {
2025         munmap(dev->inflight_info.addr, dev->inflight_info.size);
2026         dev->inflight_info.addr = NULL;
2027     }
2028 
2029     if (dev->inflight_info.fd > 0) {
2030         close(dev->inflight_info.fd);
2031         dev->inflight_info.fd = -1;
2032     }
2033 
2034     vu_close_log(dev);
2035     if (dev->slave_fd != -1) {
2036         close(dev->slave_fd);
2037         dev->slave_fd = -1;
2038     }
2039     pthread_mutex_destroy(&dev->slave_mutex);
2040 
2041     if (dev->sock != -1) {
2042         close(dev->sock);
2043     }
2044 
2045     free(dev->vq);
2046     dev->vq = NULL;
2047 }
2048 
2049 bool
2050 vu_init(VuDev *dev,
2051         uint16_t max_queues,
2052         int socket,
2053         vu_panic_cb panic,
2054         vu_read_msg_cb read_msg,
2055         vu_set_watch_cb set_watch,
2056         vu_remove_watch_cb remove_watch,
2057         const VuDevIface *iface)
2058 {
2059     uint16_t i;
2060 
2061     assert(max_queues > 0);
2062     assert(socket >= 0);
2063     assert(set_watch);
2064     assert(remove_watch);
2065     assert(iface);
2066     assert(panic);
2067 
2068     memset(dev, 0, sizeof(*dev));
2069 
2070     dev->sock = socket;
2071     dev->panic = panic;
2072     dev->read_msg = read_msg ? read_msg : vu_message_read_default;
2073     dev->set_watch = set_watch;
2074     dev->remove_watch = remove_watch;
2075     dev->iface = iface;
2076     dev->log_call_fd = -1;
2077     pthread_mutex_init(&dev->slave_mutex, NULL);
2078     dev->slave_fd = -1;
2079     dev->max_queues = max_queues;
2080 
2081     dev->vq = malloc(max_queues * sizeof(dev->vq[0]));
2082     if (!dev->vq) {
2083         DPRINT("%s: failed to malloc virtqueues\n", __func__);
2084         return false;
2085     }
2086 
2087     for (i = 0; i < max_queues; i++) {
2088         dev->vq[i] = (VuVirtq) {
2089             .call_fd = -1, .kick_fd = -1, .err_fd = -1,
2090             .notification = true,
2091         };
2092     }
2093 
2094     return true;
2095 }
2096 
2097 VuVirtq *
2098 vu_get_queue(VuDev *dev, int qidx)
2099 {
2100     assert(qidx < dev->max_queues);
2101     return &dev->vq[qidx];
2102 }
2103 
2104 bool
2105 vu_queue_enabled(VuDev *dev, VuVirtq *vq)
2106 {
2107     return vq->enable;
2108 }
2109 
2110 bool
2111 vu_queue_started(const VuDev *dev, const VuVirtq *vq)
2112 {
2113     return vq->started;
2114 }
2115 
2116 static inline uint16_t
2117 vring_avail_flags(VuVirtq *vq)
2118 {
2119     return le16toh(vq->vring.avail->flags);
2120 }
2121 
2122 static inline uint16_t
2123 vring_avail_idx(VuVirtq *vq)
2124 {
2125     vq->shadow_avail_idx = le16toh(vq->vring.avail->idx);
2126 
2127     return vq->shadow_avail_idx;
2128 }
2129 
2130 static inline uint16_t
2131 vring_avail_ring(VuVirtq *vq, int i)
2132 {
2133     return le16toh(vq->vring.avail->ring[i]);
2134 }
2135 
2136 static inline uint16_t
2137 vring_get_used_event(VuVirtq *vq)
2138 {
2139     return vring_avail_ring(vq, vq->vring.num);
2140 }
2141 
2142 static int
2143 virtqueue_num_heads(VuDev *dev, VuVirtq *vq, unsigned int idx)
2144 {
2145     uint16_t num_heads = vring_avail_idx(vq) - idx;
2146 
2147     /* Check it isn't doing very strange things with descriptor numbers. */
2148     if (num_heads > vq->vring.num) {
2149         vu_panic(dev, "Guest moved used index from %u to %u",
2150                  idx, vq->shadow_avail_idx);
2151         return -1;
2152     }
2153     if (num_heads) {
2154         /* On success, callers read a descriptor at vq->last_avail_idx.
2155          * Make sure descriptor read does not bypass avail index read. */
2156         smp_rmb();
2157     }
2158 
2159     return num_heads;
2160 }
2161 
2162 static bool
2163 virtqueue_get_head(VuDev *dev, VuVirtq *vq,
2164                    unsigned int idx, unsigned int *head)
2165 {
2166     /* Grab the next descriptor number they're advertising, and increment
2167      * the index we've seen. */
2168     *head = vring_avail_ring(vq, idx % vq->vring.num);
2169 
2170     /* If their number is silly, that's a fatal mistake. */
2171     if (*head >= vq->vring.num) {
2172         vu_panic(dev, "Guest says index %u is available", *head);
2173         return false;
2174     }
2175 
2176     return true;
2177 }
2178 
2179 static int
2180 virtqueue_read_indirect_desc(VuDev *dev, struct vring_desc *desc,
2181                              uint64_t addr, size_t len)
2182 {
2183     struct vring_desc *ori_desc;
2184     uint64_t read_len;
2185 
2186     if (len > (VIRTQUEUE_MAX_SIZE * sizeof(struct vring_desc))) {
2187         return -1;
2188     }
2189 
2190     if (len == 0) {
2191         return -1;
2192     }
2193 
2194     while (len) {
2195         read_len = len;
2196         ori_desc = vu_gpa_to_va(dev, &read_len, addr);
2197         if (!ori_desc) {
2198             return -1;
2199         }
2200 
2201         memcpy(desc, ori_desc, read_len);
2202         len -= read_len;
2203         addr += read_len;
2204         desc += read_len;
2205     }
2206 
2207     return 0;
2208 }
2209 
2210 enum {
2211     VIRTQUEUE_READ_DESC_ERROR = -1,
2212     VIRTQUEUE_READ_DESC_DONE = 0,   /* end of chain */
2213     VIRTQUEUE_READ_DESC_MORE = 1,   /* more buffers in chain */
2214 };
2215 
2216 static int
2217 virtqueue_read_next_desc(VuDev *dev, struct vring_desc *desc,
2218                          int i, unsigned int max, unsigned int *next)
2219 {
2220     /* If this descriptor says it doesn't chain, we're done. */
2221     if (!(le16toh(desc[i].flags) & VRING_DESC_F_NEXT)) {
2222         return VIRTQUEUE_READ_DESC_DONE;
2223     }
2224 
2225     /* Check they're not leading us off end of descriptors. */
2226     *next = le16toh(desc[i].next);
2227     /* Make sure compiler knows to grab that: we don't want it changing! */
2228     smp_wmb();
2229 
2230     if (*next >= max) {
2231         vu_panic(dev, "Desc next is %u", *next);
2232         return VIRTQUEUE_READ_DESC_ERROR;
2233     }
2234 
2235     return VIRTQUEUE_READ_DESC_MORE;
2236 }
2237 
2238 void
2239 vu_queue_get_avail_bytes(VuDev *dev, VuVirtq *vq, unsigned int *in_bytes,
2240                          unsigned int *out_bytes,
2241                          unsigned max_in_bytes, unsigned max_out_bytes)
2242 {
2243     unsigned int idx;
2244     unsigned int total_bufs, in_total, out_total;
2245     int rc;
2246 
2247     idx = vq->last_avail_idx;
2248 
2249     total_bufs = in_total = out_total = 0;
2250     if (unlikely(dev->broken) ||
2251         unlikely(!vq->vring.avail)) {
2252         goto done;
2253     }
2254 
2255     while ((rc = virtqueue_num_heads(dev, vq, idx)) > 0) {
2256         unsigned int max, desc_len, num_bufs, indirect = 0;
2257         uint64_t desc_addr, read_len;
2258         struct vring_desc *desc;
2259         struct vring_desc desc_buf[VIRTQUEUE_MAX_SIZE];
2260         unsigned int i;
2261 
2262         max = vq->vring.num;
2263         num_bufs = total_bufs;
2264         if (!virtqueue_get_head(dev, vq, idx++, &i)) {
2265             goto err;
2266         }
2267         desc = vq->vring.desc;
2268 
2269         if (le16toh(desc[i].flags) & VRING_DESC_F_INDIRECT) {
2270             if (le32toh(desc[i].len) % sizeof(struct vring_desc)) {
2271                 vu_panic(dev, "Invalid size for indirect buffer table");
2272                 goto err;
2273             }
2274 
2275             /* If we've got too many, that implies a descriptor loop. */
2276             if (num_bufs >= max) {
2277                 vu_panic(dev, "Looped descriptor");
2278                 goto err;
2279             }
2280 
2281             /* loop over the indirect descriptor table */
2282             indirect = 1;
2283             desc_addr = le64toh(desc[i].addr);
2284             desc_len = le32toh(desc[i].len);
2285             max = desc_len / sizeof(struct vring_desc);
2286             read_len = desc_len;
2287             desc = vu_gpa_to_va(dev, &read_len, desc_addr);
2288             if (unlikely(desc && read_len != desc_len)) {
2289                 /* Failed to use zero copy */
2290                 desc = NULL;
2291                 if (!virtqueue_read_indirect_desc(dev, desc_buf,
2292                                                   desc_addr,
2293                                                   desc_len)) {
2294                     desc = desc_buf;
2295                 }
2296             }
2297             if (!desc) {
2298                 vu_panic(dev, "Invalid indirect buffer table");
2299                 goto err;
2300             }
2301             num_bufs = i = 0;
2302         }
2303 
2304         do {
2305             /* If we've got too many, that implies a descriptor loop. */
2306             if (++num_bufs > max) {
2307                 vu_panic(dev, "Looped descriptor");
2308                 goto err;
2309             }
2310 
2311             if (le16toh(desc[i].flags) & VRING_DESC_F_WRITE) {
2312                 in_total += le32toh(desc[i].len);
2313             } else {
2314                 out_total += le32toh(desc[i].len);
2315             }
2316             if (in_total >= max_in_bytes && out_total >= max_out_bytes) {
2317                 goto done;
2318             }
2319             rc = virtqueue_read_next_desc(dev, desc, i, max, &i);
2320         } while (rc == VIRTQUEUE_READ_DESC_MORE);
2321 
2322         if (rc == VIRTQUEUE_READ_DESC_ERROR) {
2323             goto err;
2324         }
2325 
2326         if (!indirect) {
2327             total_bufs = num_bufs;
2328         } else {
2329             total_bufs++;
2330         }
2331     }
2332     if (rc < 0) {
2333         goto err;
2334     }
2335 done:
2336     if (in_bytes) {
2337         *in_bytes = in_total;
2338     }
2339     if (out_bytes) {
2340         *out_bytes = out_total;
2341     }
2342     return;
2343 
2344 err:
2345     in_total = out_total = 0;
2346     goto done;
2347 }
2348 
2349 bool
2350 vu_queue_avail_bytes(VuDev *dev, VuVirtq *vq, unsigned int in_bytes,
2351                      unsigned int out_bytes)
2352 {
2353     unsigned int in_total, out_total;
2354 
2355     vu_queue_get_avail_bytes(dev, vq, &in_total, &out_total,
2356                              in_bytes, out_bytes);
2357 
2358     return in_bytes <= in_total && out_bytes <= out_total;
2359 }
2360 
2361 /* Fetch avail_idx from VQ memory only when we really need to know if
2362  * guest has added some buffers. */
2363 bool
2364 vu_queue_empty(VuDev *dev, VuVirtq *vq)
2365 {
2366     if (unlikely(dev->broken) ||
2367         unlikely(!vq->vring.avail)) {
2368         return true;
2369     }
2370 
2371     if (vq->shadow_avail_idx != vq->last_avail_idx) {
2372         return false;
2373     }
2374 
2375     return vring_avail_idx(vq) == vq->last_avail_idx;
2376 }
2377 
2378 static bool
2379 vring_notify(VuDev *dev, VuVirtq *vq)
2380 {
2381     uint16_t old, new;
2382     bool v;
2383 
2384     /* We need to expose used array entries before checking used event. */
2385     smp_mb();
2386 
2387     /* Always notify when queue is empty (when feature acknowledge) */
2388     if (vu_has_feature(dev, VIRTIO_F_NOTIFY_ON_EMPTY) &&
2389         !vq->inuse && vu_queue_empty(dev, vq)) {
2390         return true;
2391     }
2392 
2393     if (!vu_has_feature(dev, VIRTIO_RING_F_EVENT_IDX)) {
2394         return !(vring_avail_flags(vq) & VRING_AVAIL_F_NO_INTERRUPT);
2395     }
2396 
2397     v = vq->signalled_used_valid;
2398     vq->signalled_used_valid = true;
2399     old = vq->signalled_used;
2400     new = vq->signalled_used = vq->used_idx;
2401     return !v || vring_need_event(vring_get_used_event(vq), new, old);
2402 }
2403 
2404 static void _vu_queue_notify(VuDev *dev, VuVirtq *vq, bool sync)
2405 {
2406     if (unlikely(dev->broken) ||
2407         unlikely(!vq->vring.avail)) {
2408         return;
2409     }
2410 
2411     if (!vring_notify(dev, vq)) {
2412         DPRINT("skipped notify...\n");
2413         return;
2414     }
2415 
2416     if (vq->call_fd < 0 &&
2417         vu_has_protocol_feature(dev,
2418                                 VHOST_USER_PROTOCOL_F_INBAND_NOTIFICATIONS) &&
2419         vu_has_protocol_feature(dev, VHOST_USER_PROTOCOL_F_SLAVE_REQ)) {
2420         VhostUserMsg vmsg = {
2421             .request = VHOST_USER_SLAVE_VRING_CALL,
2422             .flags = VHOST_USER_VERSION,
2423             .size = sizeof(vmsg.payload.state),
2424             .payload.state = {
2425                 .index = vq - dev->vq,
2426             },
2427         };
2428         bool ack = sync &&
2429                    vu_has_protocol_feature(dev,
2430                                            VHOST_USER_PROTOCOL_F_REPLY_ACK);
2431 
2432         if (ack) {
2433             vmsg.flags |= VHOST_USER_NEED_REPLY_MASK;
2434         }
2435 
2436         vu_message_write(dev, dev->slave_fd, &vmsg);
2437         if (ack) {
2438             vu_message_read_default(dev, dev->slave_fd, &vmsg);
2439         }
2440         return;
2441     }
2442 
2443     if (eventfd_write(vq->call_fd, 1) < 0) {
2444         vu_panic(dev, "Error writing eventfd: %s", strerror(errno));
2445     }
2446 }
2447 
2448 void vu_queue_notify(VuDev *dev, VuVirtq *vq)
2449 {
2450     _vu_queue_notify(dev, vq, false);
2451 }
2452 
2453 void vu_queue_notify_sync(VuDev *dev, VuVirtq *vq)
2454 {
2455     _vu_queue_notify(dev, vq, true);
2456 }
2457 
2458 static inline void
2459 vring_used_flags_set_bit(VuVirtq *vq, int mask)
2460 {
2461     uint16_t *flags;
2462 
2463     flags = (uint16_t *)((char*)vq->vring.used +
2464                          offsetof(struct vring_used, flags));
2465     *flags = htole16(le16toh(*flags) | mask);
2466 }
2467 
2468 static inline void
2469 vring_used_flags_unset_bit(VuVirtq *vq, int mask)
2470 {
2471     uint16_t *flags;
2472 
2473     flags = (uint16_t *)((char*)vq->vring.used +
2474                          offsetof(struct vring_used, flags));
2475     *flags = htole16(le16toh(*flags) & ~mask);
2476 }
2477 
2478 static inline void
2479 vring_set_avail_event(VuVirtq *vq, uint16_t val)
2480 {
2481     uint16_t val_le = htole16(val);
2482 
2483     if (!vq->notification) {
2484         return;
2485     }
2486 
2487     memcpy(&vq->vring.used->ring[vq->vring.num], &val_le, sizeof(uint16_t));
2488 }
2489 
2490 void
2491 vu_queue_set_notification(VuDev *dev, VuVirtq *vq, int enable)
2492 {
2493     vq->notification = enable;
2494     if (vu_has_feature(dev, VIRTIO_RING_F_EVENT_IDX)) {
2495         vring_set_avail_event(vq, vring_avail_idx(vq));
2496     } else if (enable) {
2497         vring_used_flags_unset_bit(vq, VRING_USED_F_NO_NOTIFY);
2498     } else {
2499         vring_used_flags_set_bit(vq, VRING_USED_F_NO_NOTIFY);
2500     }
2501     if (enable) {
2502         /* Expose avail event/used flags before caller checks the avail idx. */
2503         smp_mb();
2504     }
2505 }
2506 
2507 static bool
2508 virtqueue_map_desc(VuDev *dev,
2509                    unsigned int *p_num_sg, struct iovec *iov,
2510                    unsigned int max_num_sg, bool is_write,
2511                    uint64_t pa, size_t sz)
2512 {
2513     unsigned num_sg = *p_num_sg;
2514 
2515     assert(num_sg <= max_num_sg);
2516 
2517     if (!sz) {
2518         vu_panic(dev, "virtio: zero sized buffers are not allowed");
2519         return false;
2520     }
2521 
2522     while (sz) {
2523         uint64_t len = sz;
2524 
2525         if (num_sg == max_num_sg) {
2526             vu_panic(dev, "virtio: too many descriptors in indirect table");
2527             return false;
2528         }
2529 
2530         iov[num_sg].iov_base = vu_gpa_to_va(dev, &len, pa);
2531         if (iov[num_sg].iov_base == NULL) {
2532             vu_panic(dev, "virtio: invalid address for buffers");
2533             return false;
2534         }
2535         iov[num_sg].iov_len = len;
2536         num_sg++;
2537         sz -= len;
2538         pa += len;
2539     }
2540 
2541     *p_num_sg = num_sg;
2542     return true;
2543 }
2544 
2545 static void *
2546 virtqueue_alloc_element(size_t sz,
2547                                      unsigned out_num, unsigned in_num)
2548 {
2549     VuVirtqElement *elem;
2550     size_t in_sg_ofs = ALIGN_UP(sz, __alignof__(elem->in_sg[0]));
2551     size_t out_sg_ofs = in_sg_ofs + in_num * sizeof(elem->in_sg[0]);
2552     size_t out_sg_end = out_sg_ofs + out_num * sizeof(elem->out_sg[0]);
2553 
2554     assert(sz >= sizeof(VuVirtqElement));
2555     elem = malloc(out_sg_end);
2556     elem->out_num = out_num;
2557     elem->in_num = in_num;
2558     elem->in_sg = (void *)elem + in_sg_ofs;
2559     elem->out_sg = (void *)elem + out_sg_ofs;
2560     return elem;
2561 }
2562 
2563 static void *
2564 vu_queue_map_desc(VuDev *dev, VuVirtq *vq, unsigned int idx, size_t sz)
2565 {
2566     struct vring_desc *desc = vq->vring.desc;
2567     uint64_t desc_addr, read_len;
2568     unsigned int desc_len;
2569     unsigned int max = vq->vring.num;
2570     unsigned int i = idx;
2571     VuVirtqElement *elem;
2572     unsigned int out_num = 0, in_num = 0;
2573     struct iovec iov[VIRTQUEUE_MAX_SIZE];
2574     struct vring_desc desc_buf[VIRTQUEUE_MAX_SIZE];
2575     int rc;
2576 
2577     if (le16toh(desc[i].flags) & VRING_DESC_F_INDIRECT) {
2578         if (le32toh(desc[i].len) % sizeof(struct vring_desc)) {
2579             vu_panic(dev, "Invalid size for indirect buffer table");
2580             return NULL;
2581         }
2582 
2583         /* loop over the indirect descriptor table */
2584         desc_addr = le64toh(desc[i].addr);
2585         desc_len = le32toh(desc[i].len);
2586         max = desc_len / sizeof(struct vring_desc);
2587         read_len = desc_len;
2588         desc = vu_gpa_to_va(dev, &read_len, desc_addr);
2589         if (unlikely(desc && read_len != desc_len)) {
2590             /* Failed to use zero copy */
2591             desc = NULL;
2592             if (!virtqueue_read_indirect_desc(dev, desc_buf,
2593                                               desc_addr,
2594                                               desc_len)) {
2595                 desc = desc_buf;
2596             }
2597         }
2598         if (!desc) {
2599             vu_panic(dev, "Invalid indirect buffer table");
2600             return NULL;
2601         }
2602         i = 0;
2603     }
2604 
2605     /* Collect all the descriptors */
2606     do {
2607         if (le16toh(desc[i].flags) & VRING_DESC_F_WRITE) {
2608             if (!virtqueue_map_desc(dev, &in_num, iov + out_num,
2609                                VIRTQUEUE_MAX_SIZE - out_num, true,
2610                                le64toh(desc[i].addr),
2611                                le32toh(desc[i].len))) {
2612                 return NULL;
2613             }
2614         } else {
2615             if (in_num) {
2616                 vu_panic(dev, "Incorrect order for descriptors");
2617                 return NULL;
2618             }
2619             if (!virtqueue_map_desc(dev, &out_num, iov,
2620                                VIRTQUEUE_MAX_SIZE, false,
2621                                le64toh(desc[i].addr),
2622                                le32toh(desc[i].len))) {
2623                 return NULL;
2624             }
2625         }
2626 
2627         /* If we've got too many, that implies a descriptor loop. */
2628         if ((in_num + out_num) > max) {
2629             vu_panic(dev, "Looped descriptor");
2630             return NULL;
2631         }
2632         rc = virtqueue_read_next_desc(dev, desc, i, max, &i);
2633     } while (rc == VIRTQUEUE_READ_DESC_MORE);
2634 
2635     if (rc == VIRTQUEUE_READ_DESC_ERROR) {
2636         vu_panic(dev, "read descriptor error");
2637         return NULL;
2638     }
2639 
2640     /* Now copy what we have collected and mapped */
2641     elem = virtqueue_alloc_element(sz, out_num, in_num);
2642     elem->index = idx;
2643     for (i = 0; i < out_num; i++) {
2644         elem->out_sg[i] = iov[i];
2645     }
2646     for (i = 0; i < in_num; i++) {
2647         elem->in_sg[i] = iov[out_num + i];
2648     }
2649 
2650     return elem;
2651 }
2652 
2653 static int
2654 vu_queue_inflight_get(VuDev *dev, VuVirtq *vq, int desc_idx)
2655 {
2656     if (!vu_has_protocol_feature(dev, VHOST_USER_PROTOCOL_F_INFLIGHT_SHMFD)) {
2657         return 0;
2658     }
2659 
2660     if (unlikely(!vq->inflight)) {
2661         return -1;
2662     }
2663 
2664     vq->inflight->desc[desc_idx].counter = vq->counter++;
2665     vq->inflight->desc[desc_idx].inflight = 1;
2666 
2667     return 0;
2668 }
2669 
2670 static int
2671 vu_queue_inflight_pre_put(VuDev *dev, VuVirtq *vq, int desc_idx)
2672 {
2673     if (!vu_has_protocol_feature(dev, VHOST_USER_PROTOCOL_F_INFLIGHT_SHMFD)) {
2674         return 0;
2675     }
2676 
2677     if (unlikely(!vq->inflight)) {
2678         return -1;
2679     }
2680 
2681     vq->inflight->last_batch_head = desc_idx;
2682 
2683     return 0;
2684 }
2685 
2686 static int
2687 vu_queue_inflight_post_put(VuDev *dev, VuVirtq *vq, int desc_idx)
2688 {
2689     if (!vu_has_protocol_feature(dev, VHOST_USER_PROTOCOL_F_INFLIGHT_SHMFD)) {
2690         return 0;
2691     }
2692 
2693     if (unlikely(!vq->inflight)) {
2694         return -1;
2695     }
2696 
2697     barrier();
2698 
2699     vq->inflight->desc[desc_idx].inflight = 0;
2700 
2701     barrier();
2702 
2703     vq->inflight->used_idx = vq->used_idx;
2704 
2705     return 0;
2706 }
2707 
2708 void *
2709 vu_queue_pop(VuDev *dev, VuVirtq *vq, size_t sz)
2710 {
2711     int i;
2712     unsigned int head;
2713     VuVirtqElement *elem;
2714 
2715     if (unlikely(dev->broken) ||
2716         unlikely(!vq->vring.avail)) {
2717         return NULL;
2718     }
2719 
2720     if (unlikely(vq->resubmit_list && vq->resubmit_num > 0)) {
2721         i = (--vq->resubmit_num);
2722         elem = vu_queue_map_desc(dev, vq, vq->resubmit_list[i].index, sz);
2723 
2724         if (!vq->resubmit_num) {
2725             free(vq->resubmit_list);
2726             vq->resubmit_list = NULL;
2727         }
2728 
2729         return elem;
2730     }
2731 
2732     if (vu_queue_empty(dev, vq)) {
2733         return NULL;
2734     }
2735     /*
2736      * Needed after virtio_queue_empty(), see comment in
2737      * virtqueue_num_heads().
2738      */
2739     smp_rmb();
2740 
2741     if (vq->inuse >= vq->vring.num) {
2742         vu_panic(dev, "Virtqueue size exceeded");
2743         return NULL;
2744     }
2745 
2746     if (!virtqueue_get_head(dev, vq, vq->last_avail_idx++, &head)) {
2747         return NULL;
2748     }
2749 
2750     if (vu_has_feature(dev, VIRTIO_RING_F_EVENT_IDX)) {
2751         vring_set_avail_event(vq, vq->last_avail_idx);
2752     }
2753 
2754     elem = vu_queue_map_desc(dev, vq, head, sz);
2755 
2756     if (!elem) {
2757         return NULL;
2758     }
2759 
2760     vq->inuse++;
2761 
2762     vu_queue_inflight_get(dev, vq, head);
2763 
2764     return elem;
2765 }
2766 
2767 static void
2768 vu_queue_detach_element(VuDev *dev, VuVirtq *vq, VuVirtqElement *elem,
2769                         size_t len)
2770 {
2771     vq->inuse--;
2772     /* unmap, when DMA support is added */
2773 }
2774 
2775 void
2776 vu_queue_unpop(VuDev *dev, VuVirtq *vq, VuVirtqElement *elem,
2777                size_t len)
2778 {
2779     vq->last_avail_idx--;
2780     vu_queue_detach_element(dev, vq, elem, len);
2781 }
2782 
2783 bool
2784 vu_queue_rewind(VuDev *dev, VuVirtq *vq, unsigned int num)
2785 {
2786     if (num > vq->inuse) {
2787         return false;
2788     }
2789     vq->last_avail_idx -= num;
2790     vq->inuse -= num;
2791     return true;
2792 }
2793 
2794 static inline
2795 void vring_used_write(VuDev *dev, VuVirtq *vq,
2796                       struct vring_used_elem *uelem, int i)
2797 {
2798     struct vring_used *used = vq->vring.used;
2799 
2800     used->ring[i] = *uelem;
2801     vu_log_write(dev, vq->vring.log_guest_addr +
2802                  offsetof(struct vring_used, ring[i]),
2803                  sizeof(used->ring[i]));
2804 }
2805 
2806 
2807 static void
2808 vu_log_queue_fill(VuDev *dev, VuVirtq *vq,
2809                   const VuVirtqElement *elem,
2810                   unsigned int len)
2811 {
2812     struct vring_desc *desc = vq->vring.desc;
2813     unsigned int i, max, min, desc_len;
2814     uint64_t desc_addr, read_len;
2815     struct vring_desc desc_buf[VIRTQUEUE_MAX_SIZE];
2816     unsigned num_bufs = 0;
2817 
2818     max = vq->vring.num;
2819     i = elem->index;
2820 
2821     if (le16toh(desc[i].flags) & VRING_DESC_F_INDIRECT) {
2822         if (le32toh(desc[i].len) % sizeof(struct vring_desc)) {
2823             vu_panic(dev, "Invalid size for indirect buffer table");
2824             return;
2825         }
2826 
2827         /* loop over the indirect descriptor table */
2828         desc_addr = le64toh(desc[i].addr);
2829         desc_len = le32toh(desc[i].len);
2830         max = desc_len / sizeof(struct vring_desc);
2831         read_len = desc_len;
2832         desc = vu_gpa_to_va(dev, &read_len, desc_addr);
2833         if (unlikely(desc && read_len != desc_len)) {
2834             /* Failed to use zero copy */
2835             desc = NULL;
2836             if (!virtqueue_read_indirect_desc(dev, desc_buf,
2837                                               desc_addr,
2838                                               desc_len)) {
2839                 desc = desc_buf;
2840             }
2841         }
2842         if (!desc) {
2843             vu_panic(dev, "Invalid indirect buffer table");
2844             return;
2845         }
2846         i = 0;
2847     }
2848 
2849     do {
2850         if (++num_bufs > max) {
2851             vu_panic(dev, "Looped descriptor");
2852             return;
2853         }
2854 
2855         if (le16toh(desc[i].flags) & VRING_DESC_F_WRITE) {
2856             min = MIN(le32toh(desc[i].len), len);
2857             vu_log_write(dev, le64toh(desc[i].addr), min);
2858             len -= min;
2859         }
2860 
2861     } while (len > 0 &&
2862              (virtqueue_read_next_desc(dev, desc, i, max, &i)
2863               == VIRTQUEUE_READ_DESC_MORE));
2864 }
2865 
2866 void
2867 vu_queue_fill(VuDev *dev, VuVirtq *vq,
2868               const VuVirtqElement *elem,
2869               unsigned int len, unsigned int idx)
2870 {
2871     struct vring_used_elem uelem;
2872 
2873     if (unlikely(dev->broken) ||
2874         unlikely(!vq->vring.avail)) {
2875         return;
2876     }
2877 
2878     vu_log_queue_fill(dev, vq, elem, len);
2879 
2880     idx = (idx + vq->used_idx) % vq->vring.num;
2881 
2882     uelem.id = htole32(elem->index);
2883     uelem.len = htole32(len);
2884     vring_used_write(dev, vq, &uelem, idx);
2885 }
2886 
2887 static inline
2888 void vring_used_idx_set(VuDev *dev, VuVirtq *vq, uint16_t val)
2889 {
2890     vq->vring.used->idx = htole16(val);
2891     vu_log_write(dev,
2892                  vq->vring.log_guest_addr + offsetof(struct vring_used, idx),
2893                  sizeof(vq->vring.used->idx));
2894 
2895     vq->used_idx = val;
2896 }
2897 
2898 void
2899 vu_queue_flush(VuDev *dev, VuVirtq *vq, unsigned int count)
2900 {
2901     uint16_t old, new;
2902 
2903     if (unlikely(dev->broken) ||
2904         unlikely(!vq->vring.avail)) {
2905         return;
2906     }
2907 
2908     /* Make sure buffer is written before we update index. */
2909     smp_wmb();
2910 
2911     old = vq->used_idx;
2912     new = old + count;
2913     vring_used_idx_set(dev, vq, new);
2914     vq->inuse -= count;
2915     if (unlikely((int16_t)(new - vq->signalled_used) < (uint16_t)(new - old))) {
2916         vq->signalled_used_valid = false;
2917     }
2918 }
2919 
2920 void
2921 vu_queue_push(VuDev *dev, VuVirtq *vq,
2922               const VuVirtqElement *elem, unsigned int len)
2923 {
2924     vu_queue_fill(dev, vq, elem, len, 0);
2925     vu_queue_inflight_pre_put(dev, vq, elem->index);
2926     vu_queue_flush(dev, vq, 1);
2927     vu_queue_inflight_post_put(dev, vq, elem->index);
2928 }
2929