1 /*
2  * VDUSE (vDPA Device in Userspace) library
3  *
4  * Copyright (C) 2022 Bytedance Inc. and/or its affiliates. All rights reserved.
5  *   Portions of codes and concepts borrowed from libvhost-user.c, so:
6  *     Copyright IBM, Corp. 2007
7  *     Copyright (c) 2016 Red Hat, Inc.
8  *
9  * Author:
10  *   Xie Yongji <xieyongji@bytedance.com>
11  *   Anthony Liguori <aliguori@us.ibm.com>
12  *   Marc-André Lureau <mlureau@redhat.com>
13  *   Victor Kaplansky <victork@redhat.com>
14  *
15  * This work is licensed under the terms of the GNU GPL, version 2 or
16  * later.  See the COPYING file in the top-level directory.
17  */
18 
19 #include <stdlib.h>
20 #include <stdio.h>
21 #include <stdbool.h>
22 #include <stddef.h>
23 #include <errno.h>
24 #include <string.h>
25 #include <assert.h>
26 #include <endian.h>
27 #include <unistd.h>
28 #include <limits.h>
29 #include <fcntl.h>
30 #include <inttypes.h>
31 
32 #include <sys/ioctl.h>
33 #include <sys/eventfd.h>
34 #include <sys/mman.h>
35 
36 #include "include/atomic.h"
37 #include "linux-headers/linux/virtio_ring.h"
38 #include "linux-headers/linux/virtio_config.h"
39 #include "linux-headers/linux/vduse.h"
40 #include "libvduse.h"
41 
42 #define VDUSE_VQ_ALIGN 4096
43 #define MAX_IOVA_REGIONS 256
44 
45 #define LOG_ALIGNMENT 64
46 
47 /* Round number down to multiple */
48 #define ALIGN_DOWN(n, m) ((n) / (m) * (m))
49 
50 /* Round number up to multiple */
51 #define ALIGN_UP(n, m) ALIGN_DOWN((n) + (m) - 1, (m))
52 
53 #ifndef unlikely
54 #define unlikely(x)   __builtin_expect(!!(x), 0)
55 #endif
56 
57 typedef struct VduseDescStateSplit {
58     uint8_t inflight;
59     uint8_t padding[5];
60     uint16_t next;
61     uint64_t counter;
62 } VduseDescStateSplit;
63 
64 typedef struct VduseVirtqLogInflight {
65     uint64_t features;
66     uint16_t version;
67     uint16_t desc_num;
68     uint16_t last_batch_head;
69     uint16_t used_idx;
70     VduseDescStateSplit desc[];
71 } VduseVirtqLogInflight;
72 
73 typedef struct VduseVirtqLog {
74     VduseVirtqLogInflight inflight;
75 } VduseVirtqLog;
76 
77 typedef struct VduseVirtqInflightDesc {
78     uint16_t index;
79     uint64_t counter;
80 } VduseVirtqInflightDesc;
81 
82 typedef struct VduseRing {
83     unsigned int num;
84     uint64_t desc_addr;
85     uint64_t avail_addr;
86     uint64_t used_addr;
87     struct vring_desc *desc;
88     struct vring_avail *avail;
89     struct vring_used *used;
90 } VduseRing;
91 
92 struct VduseVirtq {
93     VduseRing vring;
94     uint16_t last_avail_idx;
95     uint16_t shadow_avail_idx;
96     uint16_t used_idx;
97     uint16_t signalled_used;
98     bool signalled_used_valid;
99     int index;
100     int inuse;
101     bool ready;
102     int fd;
103     VduseDev *dev;
104     VduseVirtqInflightDesc *resubmit_list;
105     uint16_t resubmit_num;
106     uint64_t counter;
107     VduseVirtqLog *log;
108 };
109 
110 typedef struct VduseIovaRegion {
111     uint64_t iova;
112     uint64_t size;
113     uint64_t mmap_offset;
114     uint64_t mmap_addr;
115 } VduseIovaRegion;
116 
117 struct VduseDev {
118     VduseVirtq *vqs;
119     VduseIovaRegion regions[MAX_IOVA_REGIONS];
120     int num_regions;
121     char *name;
122     uint32_t device_id;
123     uint32_t vendor_id;
124     uint16_t num_queues;
125     uint16_t queue_size;
126     uint64_t features;
127     const VduseOps *ops;
128     int fd;
129     int ctrl_fd;
130     void *priv;
131     void *log;
132 };
133 
134 static inline size_t vduse_vq_log_size(uint16_t queue_size)
135 {
136     return ALIGN_UP(sizeof(VduseDescStateSplit) * queue_size +
137                     sizeof(VduseVirtqLogInflight), LOG_ALIGNMENT);
138 }
139 
140 static void *vduse_log_get(const char *filename, size_t size)
141 {
142     void *ptr = MAP_FAILED;
143     int fd;
144 
145     fd = open(filename, O_RDWR | O_CREAT, 0600);
146     if (fd == -1) {
147         return MAP_FAILED;
148     }
149 
150     if (ftruncate(fd, size) == -1) {
151         goto out;
152     }
153 
154     ptr = mmap(0, size, PROT_READ | PROT_WRITE, MAP_SHARED, fd, 0);
155 
156 out:
157     close(fd);
158     return ptr;
159 }
160 
161 static inline bool has_feature(uint64_t features, unsigned int fbit)
162 {
163     assert(fbit < 64);
164     return !!(features & (1ULL << fbit));
165 }
166 
167 static inline bool vduse_dev_has_feature(VduseDev *dev, unsigned int fbit)
168 {
169     return has_feature(dev->features, fbit);
170 }
171 
172 uint64_t vduse_get_virtio_features(void)
173 {
174     return (1ULL << VIRTIO_F_IOMMU_PLATFORM) |
175            (1ULL << VIRTIO_F_VERSION_1) |
176            (1ULL << VIRTIO_F_NOTIFY_ON_EMPTY) |
177            (1ULL << VIRTIO_RING_F_EVENT_IDX) |
178            (1ULL << VIRTIO_RING_F_INDIRECT_DESC);
179 }
180 
181 VduseDev *vduse_queue_get_dev(VduseVirtq *vq)
182 {
183     return vq->dev;
184 }
185 
186 int vduse_queue_get_fd(VduseVirtq *vq)
187 {
188     return vq->fd;
189 }
190 
191 void *vduse_dev_get_priv(VduseDev *dev)
192 {
193     return dev->priv;
194 }
195 
196 VduseVirtq *vduse_dev_get_queue(VduseDev *dev, int index)
197 {
198     return &dev->vqs[index];
199 }
200 
201 int vduse_dev_get_fd(VduseDev *dev)
202 {
203     return dev->fd;
204 }
205 
206 static int vduse_inject_irq(VduseDev *dev, int index)
207 {
208     return ioctl(dev->fd, VDUSE_VQ_INJECT_IRQ, &index);
209 }
210 
211 static int inflight_desc_compare(const void *a, const void *b)
212 {
213     VduseVirtqInflightDesc *desc0 = (VduseVirtqInflightDesc *)a,
214                            *desc1 = (VduseVirtqInflightDesc *)b;
215 
216     if (desc1->counter > desc0->counter &&
217         (desc1->counter - desc0->counter) < VIRTQUEUE_MAX_SIZE * 2) {
218         return 1;
219     }
220 
221     return -1;
222 }
223 
224 static int vduse_queue_check_inflights(VduseVirtq *vq)
225 {
226     int i = 0;
227     VduseDev *dev = vq->dev;
228 
229     vq->used_idx = le16toh(vq->vring.used->idx);
230     vq->resubmit_num = 0;
231     vq->resubmit_list = NULL;
232     vq->counter = 0;
233 
234     if (unlikely(vq->log->inflight.used_idx != vq->used_idx)) {
235         if (vq->log->inflight.last_batch_head > VIRTQUEUE_MAX_SIZE) {
236             return -1;
237         }
238 
239         vq->log->inflight.desc[vq->log->inflight.last_batch_head].inflight = 0;
240 
241         barrier();
242 
243         vq->log->inflight.used_idx = vq->used_idx;
244     }
245 
246     for (i = 0; i < vq->log->inflight.desc_num; i++) {
247         if (vq->log->inflight.desc[i].inflight == 1) {
248             vq->inuse++;
249         }
250     }
251 
252     vq->shadow_avail_idx = vq->last_avail_idx = vq->inuse + vq->used_idx;
253 
254     if (vq->inuse) {
255         vq->resubmit_list = calloc(vq->inuse, sizeof(VduseVirtqInflightDesc));
256         if (!vq->resubmit_list) {
257             return -1;
258         }
259 
260         for (i = 0; i < vq->log->inflight.desc_num; i++) {
261             if (vq->log->inflight.desc[i].inflight) {
262                 vq->resubmit_list[vq->resubmit_num].index = i;
263                 vq->resubmit_list[vq->resubmit_num].counter =
264                                         vq->log->inflight.desc[i].counter;
265                 vq->resubmit_num++;
266             }
267         }
268 
269         if (vq->resubmit_num > 1) {
270             qsort(vq->resubmit_list, vq->resubmit_num,
271                   sizeof(VduseVirtqInflightDesc), inflight_desc_compare);
272         }
273         vq->counter = vq->resubmit_list[0].counter + 1;
274     }
275 
276     vduse_inject_irq(dev, vq->index);
277 
278     return 0;
279 }
280 
281 static int vduse_queue_inflight_get(VduseVirtq *vq, int desc_idx)
282 {
283     vq->log->inflight.desc[desc_idx].counter = vq->counter++;
284 
285     barrier();
286 
287     vq->log->inflight.desc[desc_idx].inflight = 1;
288 
289     return 0;
290 }
291 
292 static int vduse_queue_inflight_pre_put(VduseVirtq *vq, int desc_idx)
293 {
294     vq->log->inflight.last_batch_head = desc_idx;
295 
296     return 0;
297 }
298 
299 static int vduse_queue_inflight_post_put(VduseVirtq *vq, int desc_idx)
300 {
301     vq->log->inflight.desc[desc_idx].inflight = 0;
302 
303     barrier();
304 
305     vq->log->inflight.used_idx = vq->used_idx;
306 
307     return 0;
308 }
309 
310 static void vduse_iova_remove_region(VduseDev *dev, uint64_t start,
311                                      uint64_t last)
312 {
313     int i;
314 
315     if (last == start) {
316         return;
317     }
318 
319     for (i = 0; i < MAX_IOVA_REGIONS; i++) {
320         if (!dev->regions[i].mmap_addr) {
321             continue;
322         }
323 
324         if (start <= dev->regions[i].iova &&
325             last >= (dev->regions[i].iova + dev->regions[i].size - 1)) {
326             munmap((void *)(uintptr_t)dev->regions[i].mmap_addr,
327                    dev->regions[i].mmap_offset + dev->regions[i].size);
328             dev->regions[i].mmap_addr = 0;
329             dev->num_regions--;
330         }
331     }
332 }
333 
334 static int vduse_iova_add_region(VduseDev *dev, int fd,
335                                  uint64_t offset, uint64_t start,
336                                  uint64_t last, int prot)
337 {
338     int i;
339     uint64_t size = last - start + 1;
340     void *mmap_addr = mmap(0, size + offset, prot, MAP_SHARED, fd, 0);
341 
342     if (mmap_addr == MAP_FAILED) {
343         close(fd);
344         return -EINVAL;
345     }
346 
347     for (i = 0; i < MAX_IOVA_REGIONS; i++) {
348         if (!dev->regions[i].mmap_addr) {
349             dev->regions[i].mmap_addr = (uint64_t)(uintptr_t)mmap_addr;
350             dev->regions[i].mmap_offset = offset;
351             dev->regions[i].iova = start;
352             dev->regions[i].size = size;
353             dev->num_regions++;
354             break;
355         }
356     }
357     assert(i < MAX_IOVA_REGIONS);
358     close(fd);
359 
360     return 0;
361 }
362 
363 static int perm_to_prot(uint8_t perm)
364 {
365     int prot = 0;
366 
367     switch (perm) {
368     case VDUSE_ACCESS_WO:
369         prot |= PROT_WRITE;
370         break;
371     case VDUSE_ACCESS_RO:
372         prot |= PROT_READ;
373         break;
374     case VDUSE_ACCESS_RW:
375         prot |= PROT_READ | PROT_WRITE;
376         break;
377     default:
378         break;
379     }
380 
381     return prot;
382 }
383 
384 static inline void *iova_to_va(VduseDev *dev, uint64_t *plen, uint64_t iova)
385 {
386     int i, ret;
387     struct vduse_iotlb_entry entry;
388 
389     for (i = 0; i < MAX_IOVA_REGIONS; i++) {
390         VduseIovaRegion *r = &dev->regions[i];
391 
392         if (!r->mmap_addr) {
393             continue;
394         }
395 
396         if ((iova >= r->iova) && (iova < (r->iova + r->size))) {
397             if ((iova + *plen) > (r->iova + r->size)) {
398                 *plen = r->iova + r->size - iova;
399             }
400             return (void *)(uintptr_t)(iova - r->iova +
401                    r->mmap_addr + r->mmap_offset);
402         }
403     }
404 
405     entry.start = iova;
406     entry.last = iova + 1;
407     ret = ioctl(dev->fd, VDUSE_IOTLB_GET_FD, &entry);
408     if (ret < 0) {
409         return NULL;
410     }
411 
412     if (!vduse_iova_add_region(dev, ret, entry.offset, entry.start,
413                                entry.last, perm_to_prot(entry.perm))) {
414         return iova_to_va(dev, plen, iova);
415     }
416 
417     return NULL;
418 }
419 
420 static inline uint16_t vring_avail_flags(VduseVirtq *vq)
421 {
422     return le16toh(vq->vring.avail->flags);
423 }
424 
425 static inline uint16_t vring_avail_idx(VduseVirtq *vq)
426 {
427     vq->shadow_avail_idx = le16toh(vq->vring.avail->idx);
428 
429     return vq->shadow_avail_idx;
430 }
431 
432 static inline uint16_t vring_avail_ring(VduseVirtq *vq, int i)
433 {
434     return le16toh(vq->vring.avail->ring[i]);
435 }
436 
437 static inline uint16_t vring_get_used_event(VduseVirtq *vq)
438 {
439     return vring_avail_ring(vq, vq->vring.num);
440 }
441 
442 static bool vduse_queue_get_head(VduseVirtq *vq, unsigned int idx,
443                                  unsigned int *head)
444 {
445     /*
446      * Grab the next descriptor number they're advertising, and increment
447      * the index we've seen.
448      */
449     *head = vring_avail_ring(vq, idx % vq->vring.num);
450 
451     /* If their number is silly, that's a fatal mistake. */
452     if (*head >= vq->vring.num) {
453         fprintf(stderr, "Guest says index %u is available\n", *head);
454         return false;
455     }
456 
457     return true;
458 }
459 
460 static int
461 vduse_queue_read_indirect_desc(VduseDev *dev, struct vring_desc *desc,
462                                uint64_t addr, size_t len)
463 {
464     struct vring_desc *ori_desc;
465     uint64_t read_len;
466 
467     if (len > (VIRTQUEUE_MAX_SIZE * sizeof(struct vring_desc))) {
468         return -1;
469     }
470 
471     if (len == 0) {
472         return -1;
473     }
474 
475     while (len) {
476         read_len = len;
477         ori_desc = iova_to_va(dev, &read_len, addr);
478         if (!ori_desc) {
479             return -1;
480         }
481 
482         memcpy(desc, ori_desc, read_len);
483         len -= read_len;
484         addr += read_len;
485         desc += read_len;
486     }
487 
488     return 0;
489 }
490 
491 enum {
492     VIRTQUEUE_READ_DESC_ERROR = -1,
493     VIRTQUEUE_READ_DESC_DONE = 0,   /* end of chain */
494     VIRTQUEUE_READ_DESC_MORE = 1,   /* more buffers in chain */
495 };
496 
497 static int vduse_queue_read_next_desc(struct vring_desc *desc, int i,
498                                       unsigned int max, unsigned int *next)
499 {
500     /* If this descriptor says it doesn't chain, we're done. */
501     if (!(le16toh(desc[i].flags) & VRING_DESC_F_NEXT)) {
502         return VIRTQUEUE_READ_DESC_DONE;
503     }
504 
505     /* Check they're not leading us off end of descriptors. */
506     *next = desc[i].next;
507     /* Make sure compiler knows to grab that: we don't want it changing! */
508     smp_wmb();
509 
510     if (*next >= max) {
511         fprintf(stderr, "Desc next is %u\n", *next);
512         return VIRTQUEUE_READ_DESC_ERROR;
513     }
514 
515     return VIRTQUEUE_READ_DESC_MORE;
516 }
517 
518 /*
519  * Fetch avail_idx from VQ memory only when we really need to know if
520  * guest has added some buffers.
521  */
522 static bool vduse_queue_empty(VduseVirtq *vq)
523 {
524     if (unlikely(!vq->vring.avail)) {
525         return true;
526     }
527 
528     if (vq->shadow_avail_idx != vq->last_avail_idx) {
529         return false;
530     }
531 
532     return vring_avail_idx(vq) == vq->last_avail_idx;
533 }
534 
535 static bool vduse_queue_should_notify(VduseVirtq *vq)
536 {
537     VduseDev *dev = vq->dev;
538     uint16_t old, new;
539     bool v;
540 
541     /* We need to expose used array entries before checking used event. */
542     smp_mb();
543 
544     /* Always notify when queue is empty (when feature acknowledge) */
545     if (vduse_dev_has_feature(dev, VIRTIO_F_NOTIFY_ON_EMPTY) &&
546         !vq->inuse && vduse_queue_empty(vq)) {
547         return true;
548     }
549 
550     if (!vduse_dev_has_feature(dev, VIRTIO_RING_F_EVENT_IDX)) {
551         return !(vring_avail_flags(vq) & VRING_AVAIL_F_NO_INTERRUPT);
552     }
553 
554     v = vq->signalled_used_valid;
555     vq->signalled_used_valid = true;
556     old = vq->signalled_used;
557     new = vq->signalled_used = vq->used_idx;
558     return !v || vring_need_event(vring_get_used_event(vq), new, old);
559 }
560 
561 void vduse_queue_notify(VduseVirtq *vq)
562 {
563     VduseDev *dev = vq->dev;
564 
565     if (unlikely(!vq->vring.avail)) {
566         return;
567     }
568 
569     if (!vduse_queue_should_notify(vq)) {
570         return;
571     }
572 
573     if (vduse_inject_irq(dev, vq->index) < 0) {
574         fprintf(stderr, "Error inject irq for vq %d: %s\n",
575                 vq->index, strerror(errno));
576     }
577 }
578 
579 static inline void vring_set_avail_event(VduseVirtq *vq, uint16_t val)
580 {
581     *((uint16_t *)&vq->vring.used->ring[vq->vring.num]) = htole16(val);
582 }
583 
584 static bool vduse_queue_map_single_desc(VduseVirtq *vq, unsigned int *p_num_sg,
585                                    struct iovec *iov, unsigned int max_num_sg,
586                                    bool is_write, uint64_t pa, size_t sz)
587 {
588     unsigned num_sg = *p_num_sg;
589     VduseDev *dev = vq->dev;
590 
591     assert(num_sg <= max_num_sg);
592 
593     if (!sz) {
594         fprintf(stderr, "virtio: zero sized buffers are not allowed\n");
595         return false;
596     }
597 
598     while (sz) {
599         uint64_t len = sz;
600 
601         if (num_sg == max_num_sg) {
602             fprintf(stderr,
603                     "virtio: too many descriptors in indirect table\n");
604             return false;
605         }
606 
607         iov[num_sg].iov_base = iova_to_va(dev, &len, pa);
608         if (iov[num_sg].iov_base == NULL) {
609             fprintf(stderr, "virtio: invalid address for buffers\n");
610             return false;
611         }
612         iov[num_sg++].iov_len = len;
613         sz -= len;
614         pa += len;
615     }
616 
617     *p_num_sg = num_sg;
618     return true;
619 }
620 
621 static void *vduse_queue_alloc_element(size_t sz, unsigned out_num,
622                                        unsigned in_num)
623 {
624     VduseVirtqElement *elem;
625     size_t in_sg_ofs = ALIGN_UP(sz, __alignof__(elem->in_sg[0]));
626     size_t out_sg_ofs = in_sg_ofs + in_num * sizeof(elem->in_sg[0]);
627     size_t out_sg_end = out_sg_ofs + out_num * sizeof(elem->out_sg[0]);
628 
629     assert(sz >= sizeof(VduseVirtqElement));
630     elem = malloc(out_sg_end);
631     if (!elem) {
632         return NULL;
633     }
634     elem->out_num = out_num;
635     elem->in_num = in_num;
636     elem->in_sg = (void *)elem + in_sg_ofs;
637     elem->out_sg = (void *)elem + out_sg_ofs;
638     return elem;
639 }
640 
641 static void *vduse_queue_map_desc(VduseVirtq *vq, unsigned int idx, size_t sz)
642 {
643     struct vring_desc *desc = vq->vring.desc;
644     VduseDev *dev = vq->dev;
645     uint64_t desc_addr, read_len;
646     unsigned int desc_len;
647     unsigned int max = vq->vring.num;
648     unsigned int i = idx;
649     VduseVirtqElement *elem;
650     struct iovec iov[VIRTQUEUE_MAX_SIZE];
651     struct vring_desc desc_buf[VIRTQUEUE_MAX_SIZE];
652     unsigned int out_num = 0, in_num = 0;
653     int rc;
654 
655     if (le16toh(desc[i].flags) & VRING_DESC_F_INDIRECT) {
656         if (le32toh(desc[i].len) % sizeof(struct vring_desc)) {
657             fprintf(stderr, "Invalid size for indirect buffer table\n");
658             return NULL;
659         }
660 
661         /* loop over the indirect descriptor table */
662         desc_addr = le64toh(desc[i].addr);
663         desc_len = le32toh(desc[i].len);
664         max = desc_len / sizeof(struct vring_desc);
665         read_len = desc_len;
666         desc = iova_to_va(dev, &read_len, desc_addr);
667         if (unlikely(desc && read_len != desc_len)) {
668             /* Failed to use zero copy */
669             desc = NULL;
670             if (!vduse_queue_read_indirect_desc(dev, desc_buf,
671                                                 desc_addr,
672                                                 desc_len)) {
673                 desc = desc_buf;
674             }
675         }
676         if (!desc) {
677             fprintf(stderr, "Invalid indirect buffer table\n");
678             return NULL;
679         }
680         i = 0;
681     }
682 
683     /* Collect all the descriptors */
684     do {
685         if (le16toh(desc[i].flags) & VRING_DESC_F_WRITE) {
686             if (!vduse_queue_map_single_desc(vq, &in_num, iov + out_num,
687                                              VIRTQUEUE_MAX_SIZE - out_num,
688                                              true, le64toh(desc[i].addr),
689                                              le32toh(desc[i].len))) {
690                 return NULL;
691             }
692         } else {
693             if (in_num) {
694                 fprintf(stderr, "Incorrect order for descriptors\n");
695                 return NULL;
696             }
697             if (!vduse_queue_map_single_desc(vq, &out_num, iov,
698                                              VIRTQUEUE_MAX_SIZE, false,
699                                              le64toh(desc[i].addr),
700                                              le32toh(desc[i].len))) {
701                 return NULL;
702             }
703         }
704 
705         /* If we've got too many, that implies a descriptor loop. */
706         if ((in_num + out_num) > max) {
707             fprintf(stderr, "Looped descriptor\n");
708             return NULL;
709         }
710         rc = vduse_queue_read_next_desc(desc, i, max, &i);
711     } while (rc == VIRTQUEUE_READ_DESC_MORE);
712 
713     if (rc == VIRTQUEUE_READ_DESC_ERROR) {
714         fprintf(stderr, "read descriptor error\n");
715         return NULL;
716     }
717 
718     /* Now copy what we have collected and mapped */
719     elem = vduse_queue_alloc_element(sz, out_num, in_num);
720     if (!elem) {
721         fprintf(stderr, "read descriptor error\n");
722         return NULL;
723     }
724     elem->index = idx;
725     for (i = 0; i < out_num; i++) {
726         elem->out_sg[i] = iov[i];
727     }
728     for (i = 0; i < in_num; i++) {
729         elem->in_sg[i] = iov[out_num + i];
730     }
731 
732     return elem;
733 }
734 
735 void *vduse_queue_pop(VduseVirtq *vq, size_t sz)
736 {
737     unsigned int head;
738     VduseVirtqElement *elem;
739     VduseDev *dev = vq->dev;
740     int i;
741 
742     if (unlikely(!vq->vring.avail)) {
743         return NULL;
744     }
745 
746     if (unlikely(vq->resubmit_list && vq->resubmit_num > 0)) {
747         i = (--vq->resubmit_num);
748         elem = vduse_queue_map_desc(vq, vq->resubmit_list[i].index, sz);
749 
750         if (!vq->resubmit_num) {
751             free(vq->resubmit_list);
752             vq->resubmit_list = NULL;
753         }
754 
755         return elem;
756     }
757 
758     if (vduse_queue_empty(vq)) {
759         return NULL;
760     }
761     /* Needed after virtio_queue_empty() */
762     smp_rmb();
763 
764     if (vq->inuse >= vq->vring.num) {
765         fprintf(stderr, "Virtqueue size exceeded: %d\n", vq->inuse);
766         return NULL;
767     }
768 
769     if (!vduse_queue_get_head(vq, vq->last_avail_idx++, &head)) {
770         return NULL;
771     }
772 
773     if (vduse_dev_has_feature(dev, VIRTIO_RING_F_EVENT_IDX)) {
774         vring_set_avail_event(vq, vq->last_avail_idx);
775     }
776 
777     elem = vduse_queue_map_desc(vq, head, sz);
778 
779     if (!elem) {
780         return NULL;
781     }
782 
783     vq->inuse++;
784 
785     vduse_queue_inflight_get(vq, head);
786 
787     return elem;
788 }
789 
790 static inline void vring_used_write(VduseVirtq *vq,
791                                     struct vring_used_elem *uelem, int i)
792 {
793     struct vring_used *used = vq->vring.used;
794 
795     used->ring[i] = *uelem;
796 }
797 
798 static void vduse_queue_fill(VduseVirtq *vq, const VduseVirtqElement *elem,
799                              unsigned int len, unsigned int idx)
800 {
801     struct vring_used_elem uelem;
802 
803     if (unlikely(!vq->vring.used)) {
804         return;
805     }
806 
807     idx = (idx + vq->used_idx) % vq->vring.num;
808 
809     uelem.id = htole32(elem->index);
810     uelem.len = htole32(len);
811     vring_used_write(vq, &uelem, idx);
812 }
813 
814 static inline void vring_used_idx_set(VduseVirtq *vq, uint16_t val)
815 {
816     vq->vring.used->idx = htole16(val);
817     vq->used_idx = val;
818 }
819 
820 static void vduse_queue_flush(VduseVirtq *vq, unsigned int count)
821 {
822     uint16_t old, new;
823 
824     if (unlikely(!vq->vring.used)) {
825         return;
826     }
827 
828     /* Make sure buffer is written before we update index. */
829     smp_wmb();
830 
831     old = vq->used_idx;
832     new = old + count;
833     vring_used_idx_set(vq, new);
834     vq->inuse -= count;
835     if (unlikely((int16_t)(new - vq->signalled_used) < (uint16_t)(new - old))) {
836         vq->signalled_used_valid = false;
837     }
838 }
839 
840 void vduse_queue_push(VduseVirtq *vq, const VduseVirtqElement *elem,
841                       unsigned int len)
842 {
843     vduse_queue_fill(vq, elem, len, 0);
844     vduse_queue_inflight_pre_put(vq, elem->index);
845     vduse_queue_flush(vq, 1);
846     vduse_queue_inflight_post_put(vq, elem->index);
847 }
848 
849 static int vduse_queue_update_vring(VduseVirtq *vq, uint64_t desc_addr,
850                                     uint64_t avail_addr, uint64_t used_addr)
851 {
852     struct VduseDev *dev = vq->dev;
853     uint64_t len;
854 
855     len = sizeof(struct vring_desc);
856     vq->vring.desc = iova_to_va(dev, &len, desc_addr);
857     if (len != sizeof(struct vring_desc)) {
858         return -EINVAL;
859     }
860 
861     len = sizeof(struct vring_avail);
862     vq->vring.avail = iova_to_va(dev, &len, avail_addr);
863     if (len != sizeof(struct vring_avail)) {
864         return -EINVAL;
865     }
866 
867     len = sizeof(struct vring_used);
868     vq->vring.used = iova_to_va(dev, &len, used_addr);
869     if (len != sizeof(struct vring_used)) {
870         return -EINVAL;
871     }
872 
873     if (!vq->vring.desc || !vq->vring.avail || !vq->vring.used) {
874         fprintf(stderr, "Failed to get vq[%d] iova mapping\n", vq->index);
875         return -EINVAL;
876     }
877 
878     return 0;
879 }
880 
881 static void vduse_queue_enable(VduseVirtq *vq)
882 {
883     struct VduseDev *dev = vq->dev;
884     struct vduse_vq_info vq_info;
885     struct vduse_vq_eventfd vq_eventfd;
886     int fd;
887 
888     vq_info.index = vq->index;
889     if (ioctl(dev->fd, VDUSE_VQ_GET_INFO, &vq_info)) {
890         fprintf(stderr, "Failed to get vq[%d] info: %s\n",
891                 vq->index, strerror(errno));
892         return;
893     }
894 
895     if (!vq_info.ready) {
896         return;
897     }
898 
899     vq->vring.num = vq_info.num;
900     vq->vring.desc_addr = vq_info.desc_addr;
901     vq->vring.avail_addr = vq_info.driver_addr;
902     vq->vring.used_addr = vq_info.device_addr;
903 
904     if (vduse_queue_update_vring(vq, vq_info.desc_addr,
905                                  vq_info.driver_addr, vq_info.device_addr)) {
906         fprintf(stderr, "Failed to update vring for vq[%d]\n", vq->index);
907         return;
908     }
909 
910     fd = eventfd(0, EFD_NONBLOCK | EFD_CLOEXEC);
911     if (fd < 0) {
912         fprintf(stderr, "Failed to init eventfd for vq[%d]\n", vq->index);
913         return;
914     }
915 
916     vq_eventfd.index = vq->index;
917     vq_eventfd.fd = fd;
918     if (ioctl(dev->fd, VDUSE_VQ_SETUP_KICKFD, &vq_eventfd)) {
919         fprintf(stderr, "Failed to setup kick fd for vq[%d]\n", vq->index);
920         close(fd);
921         return;
922     }
923 
924     vq->fd = fd;
925     vq->signalled_used_valid = false;
926     vq->ready = true;
927 
928     if (vduse_queue_check_inflights(vq)) {
929         fprintf(stderr, "Failed to check inflights for vq[%d]\n", vq->index);
930         close(fd);
931         return;
932     }
933 
934     dev->ops->enable_queue(dev, vq);
935 }
936 
937 static void vduse_queue_disable(VduseVirtq *vq)
938 {
939     struct VduseDev *dev = vq->dev;
940     struct vduse_vq_eventfd eventfd;
941 
942     if (!vq->ready) {
943         return;
944     }
945 
946     dev->ops->disable_queue(dev, vq);
947 
948     eventfd.index = vq->index;
949     eventfd.fd = VDUSE_EVENTFD_DEASSIGN;
950     ioctl(dev->fd, VDUSE_VQ_SETUP_KICKFD, &eventfd);
951     close(vq->fd);
952 
953     assert(vq->inuse == 0);
954 
955     vq->vring.num = 0;
956     vq->vring.desc_addr = 0;
957     vq->vring.avail_addr = 0;
958     vq->vring.used_addr = 0;
959     vq->vring.desc = 0;
960     vq->vring.avail = 0;
961     vq->vring.used = 0;
962     vq->ready = false;
963     vq->fd = -1;
964 }
965 
966 static void vduse_dev_start_dataplane(VduseDev *dev)
967 {
968     int i;
969 
970     if (ioctl(dev->fd, VDUSE_DEV_GET_FEATURES, &dev->features)) {
971         fprintf(stderr, "Failed to get features: %s\n", strerror(errno));
972         return;
973     }
974     assert(vduse_dev_has_feature(dev, VIRTIO_F_VERSION_1));
975 
976     for (i = 0; i < dev->num_queues; i++) {
977         vduse_queue_enable(&dev->vqs[i]);
978     }
979 }
980 
981 static void vduse_dev_stop_dataplane(VduseDev *dev)
982 {
983     size_t log_size = dev->num_queues * vduse_vq_log_size(VIRTQUEUE_MAX_SIZE);
984     int i;
985 
986     for (i = 0; i < dev->num_queues; i++) {
987         vduse_queue_disable(&dev->vqs[i]);
988     }
989     if (dev->log) {
990         memset(dev->log, 0, log_size);
991     }
992     dev->features = 0;
993     vduse_iova_remove_region(dev, 0, ULONG_MAX);
994 }
995 
996 int vduse_dev_handler(VduseDev *dev)
997 {
998     struct vduse_dev_request req;
999     struct vduse_dev_response resp = { 0 };
1000     VduseVirtq *vq;
1001     int i, ret;
1002 
1003     ret = read(dev->fd, &req, sizeof(req));
1004     if (ret != sizeof(req)) {
1005         fprintf(stderr, "Read request error [%d]: %s\n",
1006                 ret, strerror(errno));
1007         return -errno;
1008     }
1009     resp.request_id = req.request_id;
1010 
1011     switch (req.type) {
1012     case VDUSE_GET_VQ_STATE:
1013         vq = &dev->vqs[req.vq_state.index];
1014         resp.vq_state.split.avail_index = vq->last_avail_idx;
1015         resp.result = VDUSE_REQ_RESULT_OK;
1016         break;
1017     case VDUSE_SET_STATUS:
1018         if (req.s.status & VIRTIO_CONFIG_S_DRIVER_OK) {
1019             vduse_dev_start_dataplane(dev);
1020         } else if (req.s.status == 0) {
1021             vduse_dev_stop_dataplane(dev);
1022         }
1023         resp.result = VDUSE_REQ_RESULT_OK;
1024         break;
1025     case VDUSE_UPDATE_IOTLB:
1026         /* The iova will be updated by iova_to_va() later, so just remove it */
1027         vduse_iova_remove_region(dev, req.iova.start, req.iova.last);
1028         for (i = 0; i < dev->num_queues; i++) {
1029             VduseVirtq *vq = &dev->vqs[i];
1030             if (vq->ready) {
1031                 if (vduse_queue_update_vring(vq, vq->vring.desc_addr,
1032                                              vq->vring.avail_addr,
1033                                              vq->vring.used_addr)) {
1034                     fprintf(stderr, "Failed to update vring for vq[%d]\n",
1035                             vq->index);
1036                 }
1037             }
1038         }
1039         resp.result = VDUSE_REQ_RESULT_OK;
1040         break;
1041     default:
1042         resp.result = VDUSE_REQ_RESULT_FAILED;
1043         break;
1044     }
1045 
1046     ret = write(dev->fd, &resp, sizeof(resp));
1047     if (ret != sizeof(resp)) {
1048         fprintf(stderr, "Write request %d error [%d]: %s\n",
1049                 req.type, ret, strerror(errno));
1050         return -errno;
1051     }
1052     return 0;
1053 }
1054 
1055 int vduse_dev_update_config(VduseDev *dev, uint32_t size,
1056                             uint32_t offset, char *buffer)
1057 {
1058     int ret;
1059     struct vduse_config_data *data;
1060 
1061     data = malloc(offsetof(struct vduse_config_data, buffer) + size);
1062     if (!data) {
1063         return -ENOMEM;
1064     }
1065 
1066     data->offset = offset;
1067     data->length = size;
1068     memcpy(data->buffer, buffer, size);
1069 
1070     ret = ioctl(dev->fd, VDUSE_DEV_SET_CONFIG, data);
1071     free(data);
1072 
1073     if (ret) {
1074         return -errno;
1075     }
1076 
1077     if (ioctl(dev->fd, VDUSE_DEV_INJECT_CONFIG_IRQ)) {
1078         return -errno;
1079     }
1080 
1081     return 0;
1082 }
1083 
1084 int vduse_dev_setup_queue(VduseDev *dev, int index, int max_size)
1085 {
1086     VduseVirtq *vq = &dev->vqs[index];
1087     struct vduse_vq_config vq_config = { 0 };
1088 
1089     if (max_size > VIRTQUEUE_MAX_SIZE) {
1090         return -EINVAL;
1091     }
1092 
1093     vq_config.index = vq->index;
1094     vq_config.max_size = max_size;
1095 
1096     if (ioctl(dev->fd, VDUSE_VQ_SETUP, &vq_config)) {
1097         return -errno;
1098     }
1099 
1100     vduse_queue_enable(vq);
1101 
1102     return 0;
1103 }
1104 
1105 int vduse_set_reconnect_log_file(VduseDev *dev, const char *filename)
1106 {
1107 
1108     size_t log_size = dev->num_queues * vduse_vq_log_size(VIRTQUEUE_MAX_SIZE);
1109     void *log;
1110     int i;
1111 
1112     dev->log = log = vduse_log_get(filename, log_size);
1113     if (log == MAP_FAILED) {
1114         fprintf(stderr, "Failed to get vduse log\n");
1115         return -EINVAL;
1116     }
1117 
1118     for (i = 0; i < dev->num_queues; i++) {
1119         dev->vqs[i].log = log;
1120         dev->vqs[i].log->inflight.desc_num = VIRTQUEUE_MAX_SIZE;
1121         log = (void *)((char *)log + vduse_vq_log_size(VIRTQUEUE_MAX_SIZE));
1122     }
1123 
1124     return 0;
1125 }
1126 
1127 static int vduse_dev_init_vqs(VduseDev *dev, uint16_t num_queues)
1128 {
1129     VduseVirtq *vqs;
1130     int i;
1131 
1132     vqs = calloc(sizeof(VduseVirtq), num_queues);
1133     if (!vqs) {
1134         return -ENOMEM;
1135     }
1136 
1137     for (i = 0; i < num_queues; i++) {
1138         vqs[i].index = i;
1139         vqs[i].dev = dev;
1140         vqs[i].fd = -1;
1141     }
1142     dev->vqs = vqs;
1143 
1144     return 0;
1145 }
1146 
1147 static int vduse_dev_init(VduseDev *dev, const char *name,
1148                           uint16_t num_queues, const VduseOps *ops,
1149                           void *priv)
1150 {
1151     char *dev_path, *dev_name;
1152     int ret, fd;
1153 
1154     dev_path = malloc(strlen(name) + strlen("/dev/vduse/") + 1);
1155     if (!dev_path) {
1156         return -ENOMEM;
1157     }
1158     sprintf(dev_path, "/dev/vduse/%s", name);
1159 
1160     fd = open(dev_path, O_RDWR);
1161     free(dev_path);
1162     if (fd < 0) {
1163         fprintf(stderr, "Failed to open vduse dev %s: %s\n",
1164                 name, strerror(errno));
1165         return -errno;
1166     }
1167 
1168     if (ioctl(fd, VDUSE_DEV_GET_FEATURES, &dev->features)) {
1169         fprintf(stderr, "Failed to get features: %s\n", strerror(errno));
1170         close(fd);
1171         return -errno;
1172     }
1173 
1174     dev_name = strdup(name);
1175     if (!dev_name) {
1176         close(fd);
1177         return -ENOMEM;
1178     }
1179 
1180     ret = vduse_dev_init_vqs(dev, num_queues);
1181     if (ret) {
1182         free(dev_name);
1183         close(fd);
1184         return ret;
1185     }
1186 
1187     dev->name = dev_name;
1188     dev->num_queues = num_queues;
1189     dev->fd = fd;
1190     dev->ops = ops;
1191     dev->priv = priv;
1192 
1193     return 0;
1194 }
1195 
1196 static inline bool vduse_name_is_invalid(const char *name)
1197 {
1198     return strlen(name) >= VDUSE_NAME_MAX || strstr(name, "..");
1199 }
1200 
1201 VduseDev *vduse_dev_create_by_fd(int fd, uint16_t num_queues,
1202                                  const VduseOps *ops, void *priv)
1203 {
1204     VduseDev *dev;
1205     int ret;
1206 
1207     if (!ops || !ops->enable_queue || !ops->disable_queue) {
1208         fprintf(stderr, "Invalid parameter for vduse\n");
1209         return NULL;
1210     }
1211 
1212     dev = calloc(sizeof(VduseDev), 1);
1213     if (!dev) {
1214         fprintf(stderr, "Failed to allocate vduse device\n");
1215         return NULL;
1216     }
1217 
1218     if (ioctl(fd, VDUSE_DEV_GET_FEATURES, &dev->features)) {
1219         fprintf(stderr, "Failed to get features: %s\n", strerror(errno));
1220         free(dev);
1221         return NULL;
1222     }
1223 
1224     ret = vduse_dev_init_vqs(dev, num_queues);
1225     if (ret) {
1226         fprintf(stderr, "Failed to init vqs\n");
1227         free(dev);
1228         return NULL;
1229     }
1230 
1231     dev->num_queues = num_queues;
1232     dev->fd = fd;
1233     dev->ops = ops;
1234     dev->priv = priv;
1235 
1236     return dev;
1237 }
1238 
1239 VduseDev *vduse_dev_create_by_name(const char *name, uint16_t num_queues,
1240                                    const VduseOps *ops, void *priv)
1241 {
1242     VduseDev *dev;
1243     int ret;
1244 
1245     if (!name || vduse_name_is_invalid(name) || !ops ||
1246         !ops->enable_queue || !ops->disable_queue) {
1247         fprintf(stderr, "Invalid parameter for vduse\n");
1248         return NULL;
1249     }
1250 
1251     dev = calloc(sizeof(VduseDev), 1);
1252     if (!dev) {
1253         fprintf(stderr, "Failed to allocate vduse device\n");
1254         return NULL;
1255     }
1256 
1257     ret = vduse_dev_init(dev, name, num_queues, ops, priv);
1258     if (ret < 0) {
1259         fprintf(stderr, "Failed to init vduse device %s: %s\n",
1260                 name, strerror(-ret));
1261         free(dev);
1262         return NULL;
1263     }
1264 
1265     return dev;
1266 }
1267 
1268 VduseDev *vduse_dev_create(const char *name, uint32_t device_id,
1269                            uint32_t vendor_id, uint64_t features,
1270                            uint16_t num_queues, uint32_t config_size,
1271                            char *config, const VduseOps *ops, void *priv)
1272 {
1273     VduseDev *dev;
1274     int ret, ctrl_fd;
1275     uint64_t version;
1276     struct vduse_dev_config *dev_config;
1277     size_t size = offsetof(struct vduse_dev_config, config);
1278 
1279     if (!name || vduse_name_is_invalid(name) ||
1280         !has_feature(features,  VIRTIO_F_VERSION_1) || !config ||
1281         !config_size || !ops || !ops->enable_queue || !ops->disable_queue) {
1282         fprintf(stderr, "Invalid parameter for vduse\n");
1283         return NULL;
1284     }
1285 
1286     dev = calloc(sizeof(VduseDev), 1);
1287     if (!dev) {
1288         fprintf(stderr, "Failed to allocate vduse device\n");
1289         return NULL;
1290     }
1291 
1292     ctrl_fd = open("/dev/vduse/control", O_RDWR);
1293     if (ctrl_fd < 0) {
1294         fprintf(stderr, "Failed to open /dev/vduse/control: %s\n",
1295                 strerror(errno));
1296         goto err_ctrl;
1297     }
1298 
1299     version = VDUSE_API_VERSION;
1300     if (ioctl(ctrl_fd, VDUSE_SET_API_VERSION, &version)) {
1301         fprintf(stderr, "Failed to set api version %" PRIu64 ": %s\n",
1302                 version, strerror(errno));
1303         goto err_dev;
1304     }
1305 
1306     dev_config = calloc(size + config_size, 1);
1307     if (!dev_config) {
1308         fprintf(stderr, "Failed to allocate config space\n");
1309         goto err_dev;
1310     }
1311 
1312     assert(!vduse_name_is_invalid(name));
1313     strcpy(dev_config->name, name);
1314     dev_config->device_id = device_id;
1315     dev_config->vendor_id = vendor_id;
1316     dev_config->features = features;
1317     dev_config->vq_num = num_queues;
1318     dev_config->vq_align = VDUSE_VQ_ALIGN;
1319     dev_config->config_size = config_size;
1320     memcpy(dev_config->config, config, config_size);
1321 
1322     ret = ioctl(ctrl_fd, VDUSE_CREATE_DEV, dev_config);
1323     free(dev_config);
1324     if (ret && errno != EEXIST) {
1325         fprintf(stderr, "Failed to create vduse device %s: %s\n",
1326                 name, strerror(errno));
1327         goto err_dev;
1328     }
1329     dev->ctrl_fd = ctrl_fd;
1330 
1331     ret = vduse_dev_init(dev, name, num_queues, ops, priv);
1332     if (ret < 0) {
1333         fprintf(stderr, "Failed to init vduse device %s: %s\n",
1334                 name, strerror(-ret));
1335         goto err;
1336     }
1337 
1338     return dev;
1339 err:
1340     ioctl(ctrl_fd, VDUSE_DESTROY_DEV, name);
1341 err_dev:
1342     close(ctrl_fd);
1343 err_ctrl:
1344     free(dev);
1345 
1346     return NULL;
1347 }
1348 
1349 int vduse_dev_destroy(VduseDev *dev)
1350 {
1351     size_t log_size = dev->num_queues * vduse_vq_log_size(VIRTQUEUE_MAX_SIZE);
1352     int i, ret = 0;
1353 
1354     if (dev->log) {
1355         munmap(dev->log, log_size);
1356     }
1357     for (i = 0; i < dev->num_queues; i++) {
1358         free(dev->vqs[i].resubmit_list);
1359     }
1360     free(dev->vqs);
1361     if (dev->fd >= 0) {
1362         close(dev->fd);
1363         dev->fd = -1;
1364     }
1365     if (dev->ctrl_fd >= 0) {
1366         if (ioctl(dev->ctrl_fd, VDUSE_DESTROY_DEV, dev->name)) {
1367             ret = -errno;
1368         }
1369         close(dev->ctrl_fd);
1370         dev->ctrl_fd = -1;
1371     }
1372     free(dev->name);
1373     free(dev);
1374 
1375     return ret;
1376 }
1377