xref: /openbmc/qemu/hw/virtio/vhost-user.c (revision 2055dbc1)
1 /*
2  * vhost-user
3  *
4  * Copyright (c) 2013 Virtual Open Systems Sarl.
5  *
6  * This work is licensed under the terms of the GNU GPL, version 2 or later.
7  * See the COPYING file in the top-level directory.
8  *
9  */
10 
11 #include "qemu/osdep.h"
12 #include "qapi/error.h"
13 #include "hw/virtio/vhost.h"
14 #include "hw/virtio/vhost-user.h"
15 #include "hw/virtio/vhost-backend.h"
16 #include "hw/virtio/virtio.h"
17 #include "hw/virtio/virtio-net.h"
18 #include "chardev/char-fe.h"
19 #include "sysemu/kvm.h"
20 #include "qemu/error-report.h"
21 #include "qemu/main-loop.h"
22 #include "qemu/sockets.h"
23 #include "sysemu/cryptodev.h"
24 #include "migration/migration.h"
25 #include "migration/postcopy-ram.h"
26 #include "trace.h"
27 
28 #include <sys/ioctl.h>
29 #include <sys/socket.h>
30 #include <sys/un.h>
31 
32 #include "standard-headers/linux/vhost_types.h"
33 
34 #ifdef CONFIG_LINUX
35 #include <linux/userfaultfd.h>
36 #endif
37 
38 #define VHOST_MEMORY_BASELINE_NREGIONS    8
39 #define VHOST_USER_F_PROTOCOL_FEATURES 30
40 #define VHOST_USER_SLAVE_MAX_FDS     8
41 
42 /*
43  * Set maximum number of RAM slots supported to
44  * the maximum number supported by the target
45  * hardware plaform.
46  */
47 #if defined(TARGET_X86) || defined(TARGET_X86_64) || \
48     defined(TARGET_ARM) || defined(TARGET_ARM_64)
49 #include "hw/acpi/acpi.h"
50 #define VHOST_USER_MAX_RAM_SLOTS ACPI_MAX_RAM_SLOTS
51 
52 #elif defined(TARGET_PPC) || defined(TARGET_PPC_64)
53 #include "hw/ppc/spapr.h"
54 #define VHOST_USER_MAX_RAM_SLOTS SPAPR_MAX_RAM_SLOTS
55 
56 #else
57 #define VHOST_USER_MAX_RAM_SLOTS 512
58 #endif
59 
60 /*
61  * Maximum size of virtio device config space
62  */
63 #define VHOST_USER_MAX_CONFIG_SIZE 256
64 
65 enum VhostUserProtocolFeature {
66     VHOST_USER_PROTOCOL_F_MQ = 0,
67     VHOST_USER_PROTOCOL_F_LOG_SHMFD = 1,
68     VHOST_USER_PROTOCOL_F_RARP = 2,
69     VHOST_USER_PROTOCOL_F_REPLY_ACK = 3,
70     VHOST_USER_PROTOCOL_F_NET_MTU = 4,
71     VHOST_USER_PROTOCOL_F_SLAVE_REQ = 5,
72     VHOST_USER_PROTOCOL_F_CROSS_ENDIAN = 6,
73     VHOST_USER_PROTOCOL_F_CRYPTO_SESSION = 7,
74     VHOST_USER_PROTOCOL_F_PAGEFAULT = 8,
75     VHOST_USER_PROTOCOL_F_CONFIG = 9,
76     VHOST_USER_PROTOCOL_F_SLAVE_SEND_FD = 10,
77     VHOST_USER_PROTOCOL_F_HOST_NOTIFIER = 11,
78     VHOST_USER_PROTOCOL_F_INFLIGHT_SHMFD = 12,
79     VHOST_USER_PROTOCOL_F_RESET_DEVICE = 13,
80     /* Feature 14 reserved for VHOST_USER_PROTOCOL_F_INBAND_NOTIFICATIONS. */
81     VHOST_USER_PROTOCOL_F_CONFIGURE_MEM_SLOTS = 15,
82     VHOST_USER_PROTOCOL_F_MAX
83 };
84 
85 #define VHOST_USER_PROTOCOL_FEATURE_MASK ((1 << VHOST_USER_PROTOCOL_F_MAX) - 1)
86 
87 typedef enum VhostUserRequest {
88     VHOST_USER_NONE = 0,
89     VHOST_USER_GET_FEATURES = 1,
90     VHOST_USER_SET_FEATURES = 2,
91     VHOST_USER_SET_OWNER = 3,
92     VHOST_USER_RESET_OWNER = 4,
93     VHOST_USER_SET_MEM_TABLE = 5,
94     VHOST_USER_SET_LOG_BASE = 6,
95     VHOST_USER_SET_LOG_FD = 7,
96     VHOST_USER_SET_VRING_NUM = 8,
97     VHOST_USER_SET_VRING_ADDR = 9,
98     VHOST_USER_SET_VRING_BASE = 10,
99     VHOST_USER_GET_VRING_BASE = 11,
100     VHOST_USER_SET_VRING_KICK = 12,
101     VHOST_USER_SET_VRING_CALL = 13,
102     VHOST_USER_SET_VRING_ERR = 14,
103     VHOST_USER_GET_PROTOCOL_FEATURES = 15,
104     VHOST_USER_SET_PROTOCOL_FEATURES = 16,
105     VHOST_USER_GET_QUEUE_NUM = 17,
106     VHOST_USER_SET_VRING_ENABLE = 18,
107     VHOST_USER_SEND_RARP = 19,
108     VHOST_USER_NET_SET_MTU = 20,
109     VHOST_USER_SET_SLAVE_REQ_FD = 21,
110     VHOST_USER_IOTLB_MSG = 22,
111     VHOST_USER_SET_VRING_ENDIAN = 23,
112     VHOST_USER_GET_CONFIG = 24,
113     VHOST_USER_SET_CONFIG = 25,
114     VHOST_USER_CREATE_CRYPTO_SESSION = 26,
115     VHOST_USER_CLOSE_CRYPTO_SESSION = 27,
116     VHOST_USER_POSTCOPY_ADVISE  = 28,
117     VHOST_USER_POSTCOPY_LISTEN  = 29,
118     VHOST_USER_POSTCOPY_END     = 30,
119     VHOST_USER_GET_INFLIGHT_FD = 31,
120     VHOST_USER_SET_INFLIGHT_FD = 32,
121     VHOST_USER_GPU_SET_SOCKET = 33,
122     VHOST_USER_RESET_DEVICE = 34,
123     /* Message number 35 reserved for VHOST_USER_VRING_KICK. */
124     VHOST_USER_GET_MAX_MEM_SLOTS = 36,
125     VHOST_USER_ADD_MEM_REG = 37,
126     VHOST_USER_REM_MEM_REG = 38,
127     VHOST_USER_MAX
128 } VhostUserRequest;
129 
130 typedef enum VhostUserSlaveRequest {
131     VHOST_USER_SLAVE_NONE = 0,
132     VHOST_USER_SLAVE_IOTLB_MSG = 1,
133     VHOST_USER_SLAVE_CONFIG_CHANGE_MSG = 2,
134     VHOST_USER_SLAVE_VRING_HOST_NOTIFIER_MSG = 3,
135     VHOST_USER_SLAVE_MAX
136 }  VhostUserSlaveRequest;
137 
138 typedef struct VhostUserMemoryRegion {
139     uint64_t guest_phys_addr;
140     uint64_t memory_size;
141     uint64_t userspace_addr;
142     uint64_t mmap_offset;
143 } VhostUserMemoryRegion;
144 
145 typedef struct VhostUserMemory {
146     uint32_t nregions;
147     uint32_t padding;
148     VhostUserMemoryRegion regions[VHOST_MEMORY_BASELINE_NREGIONS];
149 } VhostUserMemory;
150 
151 typedef struct VhostUserMemRegMsg {
152     uint32_t padding;
153     VhostUserMemoryRegion region;
154 } VhostUserMemRegMsg;
155 
156 typedef struct VhostUserLog {
157     uint64_t mmap_size;
158     uint64_t mmap_offset;
159 } VhostUserLog;
160 
161 typedef struct VhostUserConfig {
162     uint32_t offset;
163     uint32_t size;
164     uint32_t flags;
165     uint8_t region[VHOST_USER_MAX_CONFIG_SIZE];
166 } VhostUserConfig;
167 
168 #define VHOST_CRYPTO_SYM_HMAC_MAX_KEY_LEN    512
169 #define VHOST_CRYPTO_SYM_CIPHER_MAX_KEY_LEN  64
170 
171 typedef struct VhostUserCryptoSession {
172     /* session id for success, -1 on errors */
173     int64_t session_id;
174     CryptoDevBackendSymSessionInfo session_setup_data;
175     uint8_t key[VHOST_CRYPTO_SYM_CIPHER_MAX_KEY_LEN];
176     uint8_t auth_key[VHOST_CRYPTO_SYM_HMAC_MAX_KEY_LEN];
177 } VhostUserCryptoSession;
178 
179 static VhostUserConfig c __attribute__ ((unused));
180 #define VHOST_USER_CONFIG_HDR_SIZE (sizeof(c.offset) \
181                                    + sizeof(c.size) \
182                                    + sizeof(c.flags))
183 
184 typedef struct VhostUserVringArea {
185     uint64_t u64;
186     uint64_t size;
187     uint64_t offset;
188 } VhostUserVringArea;
189 
190 typedef struct VhostUserInflight {
191     uint64_t mmap_size;
192     uint64_t mmap_offset;
193     uint16_t num_queues;
194     uint16_t queue_size;
195 } VhostUserInflight;
196 
197 typedef struct {
198     VhostUserRequest request;
199 
200 #define VHOST_USER_VERSION_MASK     (0x3)
201 #define VHOST_USER_REPLY_MASK       (0x1<<2)
202 #define VHOST_USER_NEED_REPLY_MASK  (0x1 << 3)
203     uint32_t flags;
204     uint32_t size; /* the following payload size */
205 } QEMU_PACKED VhostUserHeader;
206 
207 typedef union {
208 #define VHOST_USER_VRING_IDX_MASK   (0xff)
209 #define VHOST_USER_VRING_NOFD_MASK  (0x1<<8)
210         uint64_t u64;
211         struct vhost_vring_state state;
212         struct vhost_vring_addr addr;
213         VhostUserMemory memory;
214         VhostUserMemRegMsg mem_reg;
215         VhostUserLog log;
216         struct vhost_iotlb_msg iotlb;
217         VhostUserConfig config;
218         VhostUserCryptoSession session;
219         VhostUserVringArea area;
220         VhostUserInflight inflight;
221 } VhostUserPayload;
222 
223 typedef struct VhostUserMsg {
224     VhostUserHeader hdr;
225     VhostUserPayload payload;
226 } QEMU_PACKED VhostUserMsg;
227 
228 static VhostUserMsg m __attribute__ ((unused));
229 #define VHOST_USER_HDR_SIZE (sizeof(VhostUserHeader))
230 
231 #define VHOST_USER_PAYLOAD_SIZE (sizeof(VhostUserPayload))
232 
233 /* The version of the protocol we support */
234 #define VHOST_USER_VERSION    (0x1)
235 
236 struct vhost_user {
237     struct vhost_dev *dev;
238     /* Shared between vhost devs of the same virtio device */
239     VhostUserState *user;
240     int slave_fd;
241     NotifierWithReturn postcopy_notifier;
242     struct PostCopyFD  postcopy_fd;
243     uint64_t           postcopy_client_bases[VHOST_USER_MAX_RAM_SLOTS];
244     /* Length of the region_rb and region_rb_offset arrays */
245     size_t             region_rb_len;
246     /* RAMBlock associated with a given region */
247     RAMBlock         **region_rb;
248     /* The offset from the start of the RAMBlock to the start of the
249      * vhost region.
250      */
251     ram_addr_t        *region_rb_offset;
252 
253     /* True once we've entered postcopy_listen */
254     bool               postcopy_listen;
255 
256     /* Our current regions */
257     int num_shadow_regions;
258     struct vhost_memory_region shadow_regions[VHOST_USER_MAX_RAM_SLOTS];
259 };
260 
261 struct scrub_regions {
262     struct vhost_memory_region *region;
263     int reg_idx;
264     int fd_idx;
265 };
266 
267 static bool ioeventfd_enabled(void)
268 {
269     return !kvm_enabled() || kvm_eventfds_enabled();
270 }
271 
272 static int vhost_user_read_header(struct vhost_dev *dev, VhostUserMsg *msg)
273 {
274     struct vhost_user *u = dev->opaque;
275     CharBackend *chr = u->user->chr;
276     uint8_t *p = (uint8_t *) msg;
277     int r, size = VHOST_USER_HDR_SIZE;
278 
279     r = qemu_chr_fe_read_all(chr, p, size);
280     if (r != size) {
281         error_report("Failed to read msg header. Read %d instead of %d."
282                      " Original request %d.", r, size, msg->hdr.request);
283         return -1;
284     }
285 
286     /* validate received flags */
287     if (msg->hdr.flags != (VHOST_USER_REPLY_MASK | VHOST_USER_VERSION)) {
288         error_report("Failed to read msg header."
289                 " Flags 0x%x instead of 0x%x.", msg->hdr.flags,
290                 VHOST_USER_REPLY_MASK | VHOST_USER_VERSION);
291         return -1;
292     }
293 
294     return 0;
295 }
296 
297 static int vhost_user_read(struct vhost_dev *dev, VhostUserMsg *msg)
298 {
299     struct vhost_user *u = dev->opaque;
300     CharBackend *chr = u->user->chr;
301     uint8_t *p = (uint8_t *) msg;
302     int r, size;
303 
304     if (vhost_user_read_header(dev, msg) < 0) {
305         return -1;
306     }
307 
308     /* validate message size is sane */
309     if (msg->hdr.size > VHOST_USER_PAYLOAD_SIZE) {
310         error_report("Failed to read msg header."
311                 " Size %d exceeds the maximum %zu.", msg->hdr.size,
312                 VHOST_USER_PAYLOAD_SIZE);
313         return -1;
314     }
315 
316     if (msg->hdr.size) {
317         p += VHOST_USER_HDR_SIZE;
318         size = msg->hdr.size;
319         r = qemu_chr_fe_read_all(chr, p, size);
320         if (r != size) {
321             error_report("Failed to read msg payload."
322                          " Read %d instead of %d.", r, msg->hdr.size);
323             return -1;
324         }
325     }
326 
327     return 0;
328 }
329 
330 static int process_message_reply(struct vhost_dev *dev,
331                                  const VhostUserMsg *msg)
332 {
333     VhostUserMsg msg_reply;
334 
335     if ((msg->hdr.flags & VHOST_USER_NEED_REPLY_MASK) == 0) {
336         return 0;
337     }
338 
339     if (vhost_user_read(dev, &msg_reply) < 0) {
340         return -1;
341     }
342 
343     if (msg_reply.hdr.request != msg->hdr.request) {
344         error_report("Received unexpected msg type."
345                      "Expected %d received %d",
346                      msg->hdr.request, msg_reply.hdr.request);
347         return -1;
348     }
349 
350     return msg_reply.payload.u64 ? -1 : 0;
351 }
352 
353 static bool vhost_user_one_time_request(VhostUserRequest request)
354 {
355     switch (request) {
356     case VHOST_USER_SET_OWNER:
357     case VHOST_USER_RESET_OWNER:
358     case VHOST_USER_SET_MEM_TABLE:
359     case VHOST_USER_GET_QUEUE_NUM:
360     case VHOST_USER_NET_SET_MTU:
361         return true;
362     default:
363         return false;
364     }
365 }
366 
367 /* most non-init callers ignore the error */
368 static int vhost_user_write(struct vhost_dev *dev, VhostUserMsg *msg,
369                             int *fds, int fd_num)
370 {
371     struct vhost_user *u = dev->opaque;
372     CharBackend *chr = u->user->chr;
373     int ret, size = VHOST_USER_HDR_SIZE + msg->hdr.size;
374 
375     /*
376      * For non-vring specific requests, like VHOST_USER_SET_MEM_TABLE,
377      * we just need send it once in the first time. For later such
378      * request, we just ignore it.
379      */
380     if (vhost_user_one_time_request(msg->hdr.request) && dev->vq_index != 0) {
381         msg->hdr.flags &= ~VHOST_USER_NEED_REPLY_MASK;
382         return 0;
383     }
384 
385     if (qemu_chr_fe_set_msgfds(chr, fds, fd_num) < 0) {
386         error_report("Failed to set msg fds.");
387         return -1;
388     }
389 
390     ret = qemu_chr_fe_write_all(chr, (const uint8_t *) msg, size);
391     if (ret != size) {
392         error_report("Failed to write msg."
393                      " Wrote %d instead of %d.", ret, size);
394         return -1;
395     }
396 
397     return 0;
398 }
399 
400 int vhost_user_gpu_set_socket(struct vhost_dev *dev, int fd)
401 {
402     VhostUserMsg msg = {
403         .hdr.request = VHOST_USER_GPU_SET_SOCKET,
404         .hdr.flags = VHOST_USER_VERSION,
405     };
406 
407     return vhost_user_write(dev, &msg, &fd, 1);
408 }
409 
410 static int vhost_user_set_log_base(struct vhost_dev *dev, uint64_t base,
411                                    struct vhost_log *log)
412 {
413     int fds[VHOST_USER_MAX_RAM_SLOTS];
414     size_t fd_num = 0;
415     bool shmfd = virtio_has_feature(dev->protocol_features,
416                                     VHOST_USER_PROTOCOL_F_LOG_SHMFD);
417     VhostUserMsg msg = {
418         .hdr.request = VHOST_USER_SET_LOG_BASE,
419         .hdr.flags = VHOST_USER_VERSION,
420         .payload.log.mmap_size = log->size * sizeof(*(log->log)),
421         .payload.log.mmap_offset = 0,
422         .hdr.size = sizeof(msg.payload.log),
423     };
424 
425     if (shmfd && log->fd != -1) {
426         fds[fd_num++] = log->fd;
427     }
428 
429     if (vhost_user_write(dev, &msg, fds, fd_num) < 0) {
430         return -1;
431     }
432 
433     if (shmfd) {
434         msg.hdr.size = 0;
435         if (vhost_user_read(dev, &msg) < 0) {
436             return -1;
437         }
438 
439         if (msg.hdr.request != VHOST_USER_SET_LOG_BASE) {
440             error_report("Received unexpected msg type. "
441                          "Expected %d received %d",
442                          VHOST_USER_SET_LOG_BASE, msg.hdr.request);
443             return -1;
444         }
445     }
446 
447     return 0;
448 }
449 
450 static MemoryRegion *vhost_user_get_mr_data(uint64_t addr, ram_addr_t *offset,
451                                             int *fd)
452 {
453     MemoryRegion *mr;
454 
455     assert((uintptr_t)addr == addr);
456     mr = memory_region_from_host((void *)(uintptr_t)addr, offset);
457     *fd = memory_region_get_fd(mr);
458 
459     return mr;
460 }
461 
462 static void vhost_user_fill_msg_region(VhostUserMemoryRegion *dst,
463                                        struct vhost_memory_region *src)
464 {
465     assert(src != NULL && dst != NULL);
466     dst->userspace_addr = src->userspace_addr;
467     dst->memory_size = src->memory_size;
468     dst->guest_phys_addr = src->guest_phys_addr;
469 }
470 
471 static int vhost_user_fill_set_mem_table_msg(struct vhost_user *u,
472                                              struct vhost_dev *dev,
473                                              VhostUserMsg *msg,
474                                              int *fds, size_t *fd_num,
475                                              bool track_ramblocks)
476 {
477     int i, fd;
478     ram_addr_t offset;
479     MemoryRegion *mr;
480     struct vhost_memory_region *reg;
481     VhostUserMemoryRegion region_buffer;
482 
483     msg->hdr.request = VHOST_USER_SET_MEM_TABLE;
484 
485     for (i = 0; i < dev->mem->nregions; ++i) {
486         reg = dev->mem->regions + i;
487 
488         mr = vhost_user_get_mr_data(reg->userspace_addr, &offset, &fd);
489         if (fd > 0) {
490             if (track_ramblocks) {
491                 assert(*fd_num < VHOST_MEMORY_BASELINE_NREGIONS);
492                 trace_vhost_user_set_mem_table_withfd(*fd_num, mr->name,
493                                                       reg->memory_size,
494                                                       reg->guest_phys_addr,
495                                                       reg->userspace_addr,
496                                                       offset);
497                 u->region_rb_offset[i] = offset;
498                 u->region_rb[i] = mr->ram_block;
499             } else if (*fd_num == VHOST_MEMORY_BASELINE_NREGIONS) {
500                 error_report("Failed preparing vhost-user memory table msg");
501                 return -1;
502             }
503             vhost_user_fill_msg_region(&region_buffer, reg);
504             msg->payload.memory.regions[*fd_num] = region_buffer;
505             msg->payload.memory.regions[*fd_num].mmap_offset = offset;
506             fds[(*fd_num)++] = fd;
507         } else if (track_ramblocks) {
508             u->region_rb_offset[i] = 0;
509             u->region_rb[i] = NULL;
510         }
511     }
512 
513     msg->payload.memory.nregions = *fd_num;
514 
515     if (!*fd_num) {
516         error_report("Failed initializing vhost-user memory map, "
517                      "consider using -object memory-backend-file share=on");
518         return -1;
519     }
520 
521     msg->hdr.size = sizeof(msg->payload.memory.nregions);
522     msg->hdr.size += sizeof(msg->payload.memory.padding);
523     msg->hdr.size += *fd_num * sizeof(VhostUserMemoryRegion);
524 
525     return 1;
526 }
527 
528 static inline bool reg_equal(struct vhost_memory_region *shadow_reg,
529                              struct vhost_memory_region *vdev_reg)
530 {
531     return shadow_reg->guest_phys_addr == vdev_reg->guest_phys_addr &&
532         shadow_reg->userspace_addr == vdev_reg->userspace_addr &&
533         shadow_reg->memory_size == vdev_reg->memory_size;
534 }
535 
536 static void scrub_shadow_regions(struct vhost_dev *dev,
537                                  struct scrub_regions *add_reg,
538                                  int *nr_add_reg,
539                                  struct scrub_regions *rem_reg,
540                                  int *nr_rem_reg, uint64_t *shadow_pcb,
541                                  bool track_ramblocks)
542 {
543     struct vhost_user *u = dev->opaque;
544     bool found[VHOST_USER_MAX_RAM_SLOTS] = {};
545     struct vhost_memory_region *reg, *shadow_reg;
546     int i, j, fd, add_idx = 0, rm_idx = 0, fd_num = 0;
547     ram_addr_t offset;
548     MemoryRegion *mr;
549     bool matching;
550 
551     /*
552      * Find memory regions present in our shadow state which are not in
553      * the device's current memory state.
554      *
555      * Mark regions in both the shadow and device state as "found".
556      */
557     for (i = 0; i < u->num_shadow_regions; i++) {
558         shadow_reg = &u->shadow_regions[i];
559         matching = false;
560 
561         for (j = 0; j < dev->mem->nregions; j++) {
562             reg = &dev->mem->regions[j];
563 
564             mr = vhost_user_get_mr_data(reg->userspace_addr, &offset, &fd);
565 
566             if (reg_equal(shadow_reg, reg)) {
567                 matching = true;
568                 found[j] = true;
569                 if (track_ramblocks) {
570                     /*
571                      * Reset postcopy client bases, region_rb, and
572                      * region_rb_offset in case regions are removed.
573                      */
574                     if (fd > 0) {
575                         u->region_rb_offset[j] = offset;
576                         u->region_rb[j] = mr->ram_block;
577                         shadow_pcb[j] = u->postcopy_client_bases[i];
578                     } else {
579                         u->region_rb_offset[j] = 0;
580                         u->region_rb[j] = NULL;
581                     }
582                 }
583                 break;
584             }
585         }
586 
587         /*
588          * If the region was not found in the current device memory state
589          * create an entry for it in the removed list.
590          */
591         if (!matching) {
592             rem_reg[rm_idx].region = shadow_reg;
593             rem_reg[rm_idx++].reg_idx = i;
594         }
595     }
596 
597     /*
598      * For regions not marked "found", create entries in the added list.
599      *
600      * Note their indexes in the device memory state and the indexes of their
601      * file descriptors.
602      */
603     for (i = 0; i < dev->mem->nregions; i++) {
604         reg = &dev->mem->regions[i];
605         mr = vhost_user_get_mr_data(reg->userspace_addr, &offset, &fd);
606         if (fd > 0) {
607             ++fd_num;
608         }
609 
610         /*
611          * If the region was in both the shadow and device state we don't
612          * need to send a VHOST_USER_ADD_MEM_REG message for it.
613          */
614         if (found[i]) {
615             continue;
616         }
617 
618         add_reg[add_idx].region = reg;
619         add_reg[add_idx].reg_idx = i;
620         add_reg[add_idx++].fd_idx = fd_num;
621     }
622     *nr_rem_reg = rm_idx;
623     *nr_add_reg = add_idx;
624 
625     return;
626 }
627 
628 static int send_remove_regions(struct vhost_dev *dev,
629                                struct scrub_regions *remove_reg,
630                                int nr_rem_reg, VhostUserMsg *msg,
631                                bool reply_supported)
632 {
633     struct vhost_user *u = dev->opaque;
634     struct vhost_memory_region *shadow_reg;
635     int i, fd, shadow_reg_idx, ret;
636     ram_addr_t offset;
637     VhostUserMemoryRegion region_buffer;
638 
639     /*
640      * The regions in remove_reg appear in the same order they do in the
641      * shadow table. Therefore we can minimize memory copies by iterating
642      * through remove_reg backwards.
643      */
644     for (i = nr_rem_reg - 1; i >= 0; i--) {
645         shadow_reg = remove_reg[i].region;
646         shadow_reg_idx = remove_reg[i].reg_idx;
647 
648         vhost_user_get_mr_data(shadow_reg->userspace_addr, &offset, &fd);
649 
650         if (fd > 0) {
651             msg->hdr.request = VHOST_USER_REM_MEM_REG;
652             vhost_user_fill_msg_region(&region_buffer, shadow_reg);
653             msg->payload.mem_reg.region = region_buffer;
654 
655             if (vhost_user_write(dev, msg, &fd, 1) < 0) {
656                 return -1;
657             }
658 
659             if (reply_supported) {
660                 ret = process_message_reply(dev, msg);
661                 if (ret) {
662                     return ret;
663                 }
664             }
665         }
666 
667         /*
668          * At this point we know the backend has unmapped the region. It is now
669          * safe to remove it from the shadow table.
670          */
671         memmove(&u->shadow_regions[shadow_reg_idx],
672                 &u->shadow_regions[shadow_reg_idx + 1],
673                 sizeof(struct vhost_memory_region) *
674                 (u->num_shadow_regions - shadow_reg_idx));
675         u->num_shadow_regions--;
676     }
677 
678     return 0;
679 }
680 
681 static int send_add_regions(struct vhost_dev *dev,
682                             struct scrub_regions *add_reg, int nr_add_reg,
683                             VhostUserMsg *msg, uint64_t *shadow_pcb,
684                             bool reply_supported, bool track_ramblocks)
685 {
686     struct vhost_user *u = dev->opaque;
687     int i, fd, ret, reg_idx, reg_fd_idx;
688     struct vhost_memory_region *reg;
689     MemoryRegion *mr;
690     ram_addr_t offset;
691     VhostUserMsg msg_reply;
692     VhostUserMemoryRegion region_buffer;
693 
694     for (i = 0; i < nr_add_reg; i++) {
695         reg = add_reg[i].region;
696         reg_idx = add_reg[i].reg_idx;
697         reg_fd_idx = add_reg[i].fd_idx;
698 
699         mr = vhost_user_get_mr_data(reg->userspace_addr, &offset, &fd);
700 
701         if (fd > 0) {
702             if (track_ramblocks) {
703                 trace_vhost_user_set_mem_table_withfd(reg_fd_idx, mr->name,
704                                                       reg->memory_size,
705                                                       reg->guest_phys_addr,
706                                                       reg->userspace_addr,
707                                                       offset);
708                 u->region_rb_offset[reg_idx] = offset;
709                 u->region_rb[reg_idx] = mr->ram_block;
710             }
711             msg->hdr.request = VHOST_USER_ADD_MEM_REG;
712             vhost_user_fill_msg_region(&region_buffer, reg);
713             msg->payload.mem_reg.region = region_buffer;
714             msg->payload.mem_reg.region.mmap_offset = offset;
715 
716             if (vhost_user_write(dev, msg, &fd, 1) < 0) {
717                 return -1;
718             }
719 
720             if (track_ramblocks) {
721                 uint64_t reply_gpa;
722 
723                 if (vhost_user_read(dev, &msg_reply) < 0) {
724                     return -1;
725                 }
726 
727                 reply_gpa = msg_reply.payload.mem_reg.region.guest_phys_addr;
728 
729                 if (msg_reply.hdr.request != VHOST_USER_ADD_MEM_REG) {
730                     error_report("%s: Received unexpected msg type."
731                                  "Expected %d received %d", __func__,
732                                  VHOST_USER_ADD_MEM_REG,
733                                  msg_reply.hdr.request);
734                     return -1;
735                 }
736 
737                 /*
738                  * We're using the same structure, just reusing one of the
739                  * fields, so it should be the same size.
740                  */
741                 if (msg_reply.hdr.size != msg->hdr.size) {
742                     error_report("%s: Unexpected size for postcopy reply "
743                                  "%d vs %d", __func__, msg_reply.hdr.size,
744                                  msg->hdr.size);
745                     return -1;
746                 }
747 
748                 /* Get the postcopy client base from the backend's reply. */
749                 if (reply_gpa == dev->mem->regions[reg_idx].guest_phys_addr) {
750                     shadow_pcb[reg_idx] =
751                         msg_reply.payload.mem_reg.region.userspace_addr;
752                     trace_vhost_user_set_mem_table_postcopy(
753                         msg_reply.payload.mem_reg.region.userspace_addr,
754                         msg->payload.mem_reg.region.userspace_addr,
755                         reg_fd_idx, reg_idx);
756                 } else {
757                     error_report("%s: invalid postcopy reply for region. "
758                                  "Got guest physical address %" PRIX64 ", expected "
759                                  "%" PRIX64, __func__, reply_gpa,
760                                  dev->mem->regions[reg_idx].guest_phys_addr);
761                     return -1;
762                 }
763             } else if (reply_supported) {
764                 ret = process_message_reply(dev, msg);
765                 if (ret) {
766                     return ret;
767                 }
768             }
769         } else if (track_ramblocks) {
770             u->region_rb_offset[reg_idx] = 0;
771             u->region_rb[reg_idx] = NULL;
772         }
773 
774         /*
775          * At this point, we know the backend has mapped in the new
776          * region, if the region has a valid file descriptor.
777          *
778          * The region should now be added to the shadow table.
779          */
780         u->shadow_regions[u->num_shadow_regions].guest_phys_addr =
781             reg->guest_phys_addr;
782         u->shadow_regions[u->num_shadow_regions].userspace_addr =
783             reg->userspace_addr;
784         u->shadow_regions[u->num_shadow_regions].memory_size =
785             reg->memory_size;
786         u->num_shadow_regions++;
787     }
788 
789     return 0;
790 }
791 
792 static int vhost_user_add_remove_regions(struct vhost_dev *dev,
793                                          VhostUserMsg *msg,
794                                          bool reply_supported,
795                                          bool track_ramblocks)
796 {
797     struct vhost_user *u = dev->opaque;
798     struct scrub_regions add_reg[VHOST_USER_MAX_RAM_SLOTS];
799     struct scrub_regions rem_reg[VHOST_USER_MAX_RAM_SLOTS];
800     uint64_t shadow_pcb[VHOST_USER_MAX_RAM_SLOTS] = {};
801     int nr_add_reg, nr_rem_reg;
802 
803     msg->hdr.size = sizeof(msg->payload.mem_reg.padding) +
804         sizeof(VhostUserMemoryRegion);
805 
806     /* Find the regions which need to be removed or added. */
807     scrub_shadow_regions(dev, add_reg, &nr_add_reg, rem_reg, &nr_rem_reg,
808                          shadow_pcb, track_ramblocks);
809 
810     if (nr_rem_reg && send_remove_regions(dev, rem_reg, nr_rem_reg, msg,
811                 reply_supported) < 0)
812     {
813         goto err;
814     }
815 
816     if (nr_add_reg && send_add_regions(dev, add_reg, nr_add_reg, msg,
817                 shadow_pcb, reply_supported, track_ramblocks) < 0)
818     {
819         goto err;
820     }
821 
822     if (track_ramblocks) {
823         memcpy(u->postcopy_client_bases, shadow_pcb,
824                sizeof(uint64_t) * VHOST_USER_MAX_RAM_SLOTS);
825         /*
826          * Now we've registered this with the postcopy code, we ack to the
827          * client, because now we're in the position to be able to deal with
828          * any faults it generates.
829          */
830         /* TODO: Use this for failure cases as well with a bad value. */
831         msg->hdr.size = sizeof(msg->payload.u64);
832         msg->payload.u64 = 0; /* OK */
833 
834         if (vhost_user_write(dev, msg, NULL, 0) < 0) {
835             return -1;
836         }
837     }
838 
839     return 0;
840 
841 err:
842     if (track_ramblocks) {
843         memcpy(u->postcopy_client_bases, shadow_pcb,
844                sizeof(uint64_t) * VHOST_USER_MAX_RAM_SLOTS);
845     }
846 
847     return -1;
848 }
849 
850 static int vhost_user_set_mem_table_postcopy(struct vhost_dev *dev,
851                                              struct vhost_memory *mem,
852                                              bool reply_supported,
853                                              bool config_mem_slots)
854 {
855     struct vhost_user *u = dev->opaque;
856     int fds[VHOST_MEMORY_BASELINE_NREGIONS];
857     size_t fd_num = 0;
858     VhostUserMsg msg_reply;
859     int region_i, msg_i;
860 
861     VhostUserMsg msg = {
862         .hdr.flags = VHOST_USER_VERSION,
863     };
864 
865     if (u->region_rb_len < dev->mem->nregions) {
866         u->region_rb = g_renew(RAMBlock*, u->region_rb, dev->mem->nregions);
867         u->region_rb_offset = g_renew(ram_addr_t, u->region_rb_offset,
868                                       dev->mem->nregions);
869         memset(&(u->region_rb[u->region_rb_len]), '\0',
870                sizeof(RAMBlock *) * (dev->mem->nregions - u->region_rb_len));
871         memset(&(u->region_rb_offset[u->region_rb_len]), '\0',
872                sizeof(ram_addr_t) * (dev->mem->nregions - u->region_rb_len));
873         u->region_rb_len = dev->mem->nregions;
874     }
875 
876     if (config_mem_slots) {
877         if (vhost_user_add_remove_regions(dev, &msg, reply_supported,
878                                           true) < 0) {
879             return -1;
880         }
881     } else {
882         if (vhost_user_fill_set_mem_table_msg(u, dev, &msg, fds, &fd_num,
883                                               true) < 0) {
884             return -1;
885         }
886 
887         if (vhost_user_write(dev, &msg, fds, fd_num) < 0) {
888             return -1;
889         }
890 
891         if (vhost_user_read(dev, &msg_reply) < 0) {
892             return -1;
893         }
894 
895         if (msg_reply.hdr.request != VHOST_USER_SET_MEM_TABLE) {
896             error_report("%s: Received unexpected msg type."
897                          "Expected %d received %d", __func__,
898                          VHOST_USER_SET_MEM_TABLE, msg_reply.hdr.request);
899             return -1;
900         }
901 
902         /*
903          * We're using the same structure, just reusing one of the
904          * fields, so it should be the same size.
905          */
906         if (msg_reply.hdr.size != msg.hdr.size) {
907             error_report("%s: Unexpected size for postcopy reply "
908                          "%d vs %d", __func__, msg_reply.hdr.size,
909                          msg.hdr.size);
910             return -1;
911         }
912 
913         memset(u->postcopy_client_bases, 0,
914                sizeof(uint64_t) * VHOST_USER_MAX_RAM_SLOTS);
915 
916         /*
917          * They're in the same order as the regions that were sent
918          * but some of the regions were skipped (above) if they
919          * didn't have fd's
920          */
921         for (msg_i = 0, region_i = 0;
922              region_i < dev->mem->nregions;
923              region_i++) {
924             if (msg_i < fd_num &&
925                 msg_reply.payload.memory.regions[msg_i].guest_phys_addr ==
926                 dev->mem->regions[region_i].guest_phys_addr) {
927                 u->postcopy_client_bases[region_i] =
928                     msg_reply.payload.memory.regions[msg_i].userspace_addr;
929                 trace_vhost_user_set_mem_table_postcopy(
930                     msg_reply.payload.memory.regions[msg_i].userspace_addr,
931                     msg.payload.memory.regions[msg_i].userspace_addr,
932                     msg_i, region_i);
933                 msg_i++;
934             }
935         }
936         if (msg_i != fd_num) {
937             error_report("%s: postcopy reply not fully consumed "
938                          "%d vs %zd",
939                          __func__, msg_i, fd_num);
940             return -1;
941         }
942 
943         /*
944          * Now we've registered this with the postcopy code, we ack to the
945          * client, because now we're in the position to be able to deal
946          * with any faults it generates.
947          */
948         /* TODO: Use this for failure cases as well with a bad value. */
949         msg.hdr.size = sizeof(msg.payload.u64);
950         msg.payload.u64 = 0; /* OK */
951         if (vhost_user_write(dev, &msg, NULL, 0) < 0) {
952             return -1;
953         }
954     }
955 
956     return 0;
957 }
958 
959 static int vhost_user_set_mem_table(struct vhost_dev *dev,
960                                     struct vhost_memory *mem)
961 {
962     struct vhost_user *u = dev->opaque;
963     int fds[VHOST_MEMORY_BASELINE_NREGIONS];
964     size_t fd_num = 0;
965     bool do_postcopy = u->postcopy_listen && u->postcopy_fd.handler;
966     bool reply_supported = virtio_has_feature(dev->protocol_features,
967                                               VHOST_USER_PROTOCOL_F_REPLY_ACK);
968     bool config_mem_slots =
969         virtio_has_feature(dev->protocol_features,
970                            VHOST_USER_PROTOCOL_F_CONFIGURE_MEM_SLOTS);
971 
972     if (do_postcopy) {
973         /*
974          * Postcopy has enough differences that it's best done in it's own
975          * version
976          */
977         return vhost_user_set_mem_table_postcopy(dev, mem, reply_supported,
978                                                  config_mem_slots);
979     }
980 
981     VhostUserMsg msg = {
982         .hdr.flags = VHOST_USER_VERSION,
983     };
984 
985     if (reply_supported) {
986         msg.hdr.flags |= VHOST_USER_NEED_REPLY_MASK;
987     }
988 
989     if (config_mem_slots) {
990         if (vhost_user_add_remove_regions(dev, &msg, reply_supported,
991                                           false) < 0) {
992             return -1;
993         }
994     } else {
995         if (vhost_user_fill_set_mem_table_msg(u, dev, &msg, fds, &fd_num,
996                                               false) < 0) {
997             return -1;
998         }
999         if (vhost_user_write(dev, &msg, fds, fd_num) < 0) {
1000             return -1;
1001         }
1002 
1003         if (reply_supported) {
1004             return process_message_reply(dev, &msg);
1005         }
1006     }
1007 
1008     return 0;
1009 }
1010 
1011 static int vhost_user_set_vring_addr(struct vhost_dev *dev,
1012                                      struct vhost_vring_addr *addr)
1013 {
1014     VhostUserMsg msg = {
1015         .hdr.request = VHOST_USER_SET_VRING_ADDR,
1016         .hdr.flags = VHOST_USER_VERSION,
1017         .payload.addr = *addr,
1018         .hdr.size = sizeof(msg.payload.addr),
1019     };
1020 
1021     if (vhost_user_write(dev, &msg, NULL, 0) < 0) {
1022         return -1;
1023     }
1024 
1025     return 0;
1026 }
1027 
1028 static int vhost_user_set_vring_endian(struct vhost_dev *dev,
1029                                        struct vhost_vring_state *ring)
1030 {
1031     bool cross_endian = virtio_has_feature(dev->protocol_features,
1032                                            VHOST_USER_PROTOCOL_F_CROSS_ENDIAN);
1033     VhostUserMsg msg = {
1034         .hdr.request = VHOST_USER_SET_VRING_ENDIAN,
1035         .hdr.flags = VHOST_USER_VERSION,
1036         .payload.state = *ring,
1037         .hdr.size = sizeof(msg.payload.state),
1038     };
1039 
1040     if (!cross_endian) {
1041         error_report("vhost-user trying to send unhandled ioctl");
1042         return -1;
1043     }
1044 
1045     if (vhost_user_write(dev, &msg, NULL, 0) < 0) {
1046         return -1;
1047     }
1048 
1049     return 0;
1050 }
1051 
1052 static int vhost_set_vring(struct vhost_dev *dev,
1053                            unsigned long int request,
1054                            struct vhost_vring_state *ring)
1055 {
1056     VhostUserMsg msg = {
1057         .hdr.request = request,
1058         .hdr.flags = VHOST_USER_VERSION,
1059         .payload.state = *ring,
1060         .hdr.size = sizeof(msg.payload.state),
1061     };
1062 
1063     if (vhost_user_write(dev, &msg, NULL, 0) < 0) {
1064         return -1;
1065     }
1066 
1067     return 0;
1068 }
1069 
1070 static int vhost_user_set_vring_num(struct vhost_dev *dev,
1071                                     struct vhost_vring_state *ring)
1072 {
1073     return vhost_set_vring(dev, VHOST_USER_SET_VRING_NUM, ring);
1074 }
1075 
1076 static void vhost_user_host_notifier_restore(struct vhost_dev *dev,
1077                                              int queue_idx)
1078 {
1079     struct vhost_user *u = dev->opaque;
1080     VhostUserHostNotifier *n = &u->user->notifier[queue_idx];
1081     VirtIODevice *vdev = dev->vdev;
1082 
1083     if (n->addr && !n->set) {
1084         virtio_queue_set_host_notifier_mr(vdev, queue_idx, &n->mr, true);
1085         n->set = true;
1086     }
1087 }
1088 
1089 static void vhost_user_host_notifier_remove(struct vhost_dev *dev,
1090                                             int queue_idx)
1091 {
1092     struct vhost_user *u = dev->opaque;
1093     VhostUserHostNotifier *n = &u->user->notifier[queue_idx];
1094     VirtIODevice *vdev = dev->vdev;
1095 
1096     if (n->addr && n->set) {
1097         virtio_queue_set_host_notifier_mr(vdev, queue_idx, &n->mr, false);
1098         n->set = false;
1099     }
1100 }
1101 
1102 static int vhost_user_set_vring_base(struct vhost_dev *dev,
1103                                      struct vhost_vring_state *ring)
1104 {
1105     vhost_user_host_notifier_restore(dev, ring->index);
1106 
1107     return vhost_set_vring(dev, VHOST_USER_SET_VRING_BASE, ring);
1108 }
1109 
1110 static int vhost_user_set_vring_enable(struct vhost_dev *dev, int enable)
1111 {
1112     int i;
1113 
1114     if (!virtio_has_feature(dev->features, VHOST_USER_F_PROTOCOL_FEATURES)) {
1115         return -1;
1116     }
1117 
1118     for (i = 0; i < dev->nvqs; ++i) {
1119         struct vhost_vring_state state = {
1120             .index = dev->vq_index + i,
1121             .num   = enable,
1122         };
1123 
1124         vhost_set_vring(dev, VHOST_USER_SET_VRING_ENABLE, &state);
1125     }
1126 
1127     return 0;
1128 }
1129 
1130 static int vhost_user_get_vring_base(struct vhost_dev *dev,
1131                                      struct vhost_vring_state *ring)
1132 {
1133     VhostUserMsg msg = {
1134         .hdr.request = VHOST_USER_GET_VRING_BASE,
1135         .hdr.flags = VHOST_USER_VERSION,
1136         .payload.state = *ring,
1137         .hdr.size = sizeof(msg.payload.state),
1138     };
1139 
1140     vhost_user_host_notifier_remove(dev, ring->index);
1141 
1142     if (vhost_user_write(dev, &msg, NULL, 0) < 0) {
1143         return -1;
1144     }
1145 
1146     if (vhost_user_read(dev, &msg) < 0) {
1147         return -1;
1148     }
1149 
1150     if (msg.hdr.request != VHOST_USER_GET_VRING_BASE) {
1151         error_report("Received unexpected msg type. Expected %d received %d",
1152                      VHOST_USER_GET_VRING_BASE, msg.hdr.request);
1153         return -1;
1154     }
1155 
1156     if (msg.hdr.size != sizeof(msg.payload.state)) {
1157         error_report("Received bad msg size.");
1158         return -1;
1159     }
1160 
1161     *ring = msg.payload.state;
1162 
1163     return 0;
1164 }
1165 
1166 static int vhost_set_vring_file(struct vhost_dev *dev,
1167                                 VhostUserRequest request,
1168                                 struct vhost_vring_file *file)
1169 {
1170     int fds[VHOST_USER_MAX_RAM_SLOTS];
1171     size_t fd_num = 0;
1172     VhostUserMsg msg = {
1173         .hdr.request = request,
1174         .hdr.flags = VHOST_USER_VERSION,
1175         .payload.u64 = file->index & VHOST_USER_VRING_IDX_MASK,
1176         .hdr.size = sizeof(msg.payload.u64),
1177     };
1178 
1179     if (ioeventfd_enabled() && file->fd > 0) {
1180         fds[fd_num++] = file->fd;
1181     } else {
1182         msg.payload.u64 |= VHOST_USER_VRING_NOFD_MASK;
1183     }
1184 
1185     if (vhost_user_write(dev, &msg, fds, fd_num) < 0) {
1186         return -1;
1187     }
1188 
1189     return 0;
1190 }
1191 
1192 static int vhost_user_set_vring_kick(struct vhost_dev *dev,
1193                                      struct vhost_vring_file *file)
1194 {
1195     return vhost_set_vring_file(dev, VHOST_USER_SET_VRING_KICK, file);
1196 }
1197 
1198 static int vhost_user_set_vring_call(struct vhost_dev *dev,
1199                                      struct vhost_vring_file *file)
1200 {
1201     return vhost_set_vring_file(dev, VHOST_USER_SET_VRING_CALL, file);
1202 }
1203 
1204 static int vhost_user_set_u64(struct vhost_dev *dev, int request, uint64_t u64)
1205 {
1206     VhostUserMsg msg = {
1207         .hdr.request = request,
1208         .hdr.flags = VHOST_USER_VERSION,
1209         .payload.u64 = u64,
1210         .hdr.size = sizeof(msg.payload.u64),
1211     };
1212 
1213     if (vhost_user_write(dev, &msg, NULL, 0) < 0) {
1214         return -1;
1215     }
1216 
1217     return 0;
1218 }
1219 
1220 static int vhost_user_set_features(struct vhost_dev *dev,
1221                                    uint64_t features)
1222 {
1223     return vhost_user_set_u64(dev, VHOST_USER_SET_FEATURES, features);
1224 }
1225 
1226 static int vhost_user_set_protocol_features(struct vhost_dev *dev,
1227                                             uint64_t features)
1228 {
1229     return vhost_user_set_u64(dev, VHOST_USER_SET_PROTOCOL_FEATURES, features);
1230 }
1231 
1232 static int vhost_user_get_u64(struct vhost_dev *dev, int request, uint64_t *u64)
1233 {
1234     VhostUserMsg msg = {
1235         .hdr.request = request,
1236         .hdr.flags = VHOST_USER_VERSION,
1237     };
1238 
1239     if (vhost_user_one_time_request(request) && dev->vq_index != 0) {
1240         return 0;
1241     }
1242 
1243     if (vhost_user_write(dev, &msg, NULL, 0) < 0) {
1244         return -1;
1245     }
1246 
1247     if (vhost_user_read(dev, &msg) < 0) {
1248         return -1;
1249     }
1250 
1251     if (msg.hdr.request != request) {
1252         error_report("Received unexpected msg type. Expected %d received %d",
1253                      request, msg.hdr.request);
1254         return -1;
1255     }
1256 
1257     if (msg.hdr.size != sizeof(msg.payload.u64)) {
1258         error_report("Received bad msg size.");
1259         return -1;
1260     }
1261 
1262     *u64 = msg.payload.u64;
1263 
1264     return 0;
1265 }
1266 
1267 static int vhost_user_get_features(struct vhost_dev *dev, uint64_t *features)
1268 {
1269     return vhost_user_get_u64(dev, VHOST_USER_GET_FEATURES, features);
1270 }
1271 
1272 static int vhost_user_set_owner(struct vhost_dev *dev)
1273 {
1274     VhostUserMsg msg = {
1275         .hdr.request = VHOST_USER_SET_OWNER,
1276         .hdr.flags = VHOST_USER_VERSION,
1277     };
1278 
1279     if (vhost_user_write(dev, &msg, NULL, 0) < 0) {
1280         return -1;
1281     }
1282 
1283     return 0;
1284 }
1285 
1286 static int vhost_user_get_max_memslots(struct vhost_dev *dev,
1287                                        uint64_t *max_memslots)
1288 {
1289     uint64_t backend_max_memslots;
1290     int err;
1291 
1292     err = vhost_user_get_u64(dev, VHOST_USER_GET_MAX_MEM_SLOTS,
1293                              &backend_max_memslots);
1294     if (err < 0) {
1295         return err;
1296     }
1297 
1298     *max_memslots = backend_max_memslots;
1299 
1300     return 0;
1301 }
1302 
1303 static int vhost_user_reset_device(struct vhost_dev *dev)
1304 {
1305     VhostUserMsg msg = {
1306         .hdr.flags = VHOST_USER_VERSION,
1307     };
1308 
1309     msg.hdr.request = virtio_has_feature(dev->protocol_features,
1310                                          VHOST_USER_PROTOCOL_F_RESET_DEVICE)
1311         ? VHOST_USER_RESET_DEVICE
1312         : VHOST_USER_RESET_OWNER;
1313 
1314     if (vhost_user_write(dev, &msg, NULL, 0) < 0) {
1315         return -1;
1316     }
1317 
1318     return 0;
1319 }
1320 
1321 static int vhost_user_slave_handle_config_change(struct vhost_dev *dev)
1322 {
1323     int ret = -1;
1324 
1325     if (!dev->config_ops) {
1326         return -1;
1327     }
1328 
1329     if (dev->config_ops->vhost_dev_config_notifier) {
1330         ret = dev->config_ops->vhost_dev_config_notifier(dev);
1331     }
1332 
1333     return ret;
1334 }
1335 
1336 static int vhost_user_slave_handle_vring_host_notifier(struct vhost_dev *dev,
1337                                                        VhostUserVringArea *area,
1338                                                        int fd)
1339 {
1340     int queue_idx = area->u64 & VHOST_USER_VRING_IDX_MASK;
1341     size_t page_size = qemu_real_host_page_size;
1342     struct vhost_user *u = dev->opaque;
1343     VhostUserState *user = u->user;
1344     VirtIODevice *vdev = dev->vdev;
1345     VhostUserHostNotifier *n;
1346     void *addr;
1347     char *name;
1348 
1349     if (!virtio_has_feature(dev->protocol_features,
1350                             VHOST_USER_PROTOCOL_F_HOST_NOTIFIER) ||
1351         vdev == NULL || queue_idx >= virtio_get_num_queues(vdev)) {
1352         return -1;
1353     }
1354 
1355     n = &user->notifier[queue_idx];
1356 
1357     if (n->addr) {
1358         virtio_queue_set_host_notifier_mr(vdev, queue_idx, &n->mr, false);
1359         object_unparent(OBJECT(&n->mr));
1360         munmap(n->addr, page_size);
1361         n->addr = NULL;
1362     }
1363 
1364     if (area->u64 & VHOST_USER_VRING_NOFD_MASK) {
1365         return 0;
1366     }
1367 
1368     /* Sanity check. */
1369     if (area->size != page_size) {
1370         return -1;
1371     }
1372 
1373     addr = mmap(NULL, page_size, PROT_READ | PROT_WRITE, MAP_SHARED,
1374                 fd, area->offset);
1375     if (addr == MAP_FAILED) {
1376         return -1;
1377     }
1378 
1379     name = g_strdup_printf("vhost-user/host-notifier@%p mmaps[%d]",
1380                            user, queue_idx);
1381     memory_region_init_ram_device_ptr(&n->mr, OBJECT(vdev), name,
1382                                       page_size, addr);
1383     g_free(name);
1384 
1385     if (virtio_queue_set_host_notifier_mr(vdev, queue_idx, &n->mr, true)) {
1386         munmap(addr, page_size);
1387         return -1;
1388     }
1389 
1390     n->addr = addr;
1391     n->set = true;
1392 
1393     return 0;
1394 }
1395 
1396 static void slave_read(void *opaque)
1397 {
1398     struct vhost_dev *dev = opaque;
1399     struct vhost_user *u = dev->opaque;
1400     VhostUserHeader hdr = { 0, };
1401     VhostUserPayload payload = { 0, };
1402     int size, ret = 0;
1403     struct iovec iov;
1404     struct msghdr msgh;
1405     int fd[VHOST_USER_SLAVE_MAX_FDS];
1406     char control[CMSG_SPACE(sizeof(fd))];
1407     struct cmsghdr *cmsg;
1408     int i, fdsize = 0;
1409 
1410     memset(&msgh, 0, sizeof(msgh));
1411     msgh.msg_iov = &iov;
1412     msgh.msg_iovlen = 1;
1413     msgh.msg_control = control;
1414     msgh.msg_controllen = sizeof(control);
1415 
1416     memset(fd, -1, sizeof(fd));
1417 
1418     /* Read header */
1419     iov.iov_base = &hdr;
1420     iov.iov_len = VHOST_USER_HDR_SIZE;
1421 
1422     do {
1423         size = recvmsg(u->slave_fd, &msgh, 0);
1424     } while (size < 0 && (errno == EINTR || errno == EAGAIN));
1425 
1426     if (size != VHOST_USER_HDR_SIZE) {
1427         error_report("Failed to read from slave.");
1428         goto err;
1429     }
1430 
1431     if (msgh.msg_flags & MSG_CTRUNC) {
1432         error_report("Truncated message.");
1433         goto err;
1434     }
1435 
1436     for (cmsg = CMSG_FIRSTHDR(&msgh); cmsg != NULL;
1437          cmsg = CMSG_NXTHDR(&msgh, cmsg)) {
1438             if (cmsg->cmsg_level == SOL_SOCKET &&
1439                 cmsg->cmsg_type == SCM_RIGHTS) {
1440                     fdsize = cmsg->cmsg_len - CMSG_LEN(0);
1441                     memcpy(fd, CMSG_DATA(cmsg), fdsize);
1442                     break;
1443             }
1444     }
1445 
1446     if (hdr.size > VHOST_USER_PAYLOAD_SIZE) {
1447         error_report("Failed to read msg header."
1448                 " Size %d exceeds the maximum %zu.", hdr.size,
1449                 VHOST_USER_PAYLOAD_SIZE);
1450         goto err;
1451     }
1452 
1453     /* Read payload */
1454     do {
1455         size = read(u->slave_fd, &payload, hdr.size);
1456     } while (size < 0 && (errno == EINTR || errno == EAGAIN));
1457 
1458     if (size != hdr.size) {
1459         error_report("Failed to read payload from slave.");
1460         goto err;
1461     }
1462 
1463     switch (hdr.request) {
1464     case VHOST_USER_SLAVE_IOTLB_MSG:
1465         ret = vhost_backend_handle_iotlb_msg(dev, &payload.iotlb);
1466         break;
1467     case VHOST_USER_SLAVE_CONFIG_CHANGE_MSG :
1468         ret = vhost_user_slave_handle_config_change(dev);
1469         break;
1470     case VHOST_USER_SLAVE_VRING_HOST_NOTIFIER_MSG:
1471         ret = vhost_user_slave_handle_vring_host_notifier(dev, &payload.area,
1472                                                           fd[0]);
1473         break;
1474     default:
1475         error_report("Received unexpected msg type: %d.", hdr.request);
1476         ret = -EINVAL;
1477     }
1478 
1479     /* Close the remaining file descriptors. */
1480     for (i = 0; i < fdsize; i++) {
1481         if (fd[i] != -1) {
1482             close(fd[i]);
1483         }
1484     }
1485 
1486     /*
1487      * REPLY_ACK feature handling. Other reply types has to be managed
1488      * directly in their request handlers.
1489      */
1490     if (hdr.flags & VHOST_USER_NEED_REPLY_MASK) {
1491         struct iovec iovec[2];
1492 
1493 
1494         hdr.flags &= ~VHOST_USER_NEED_REPLY_MASK;
1495         hdr.flags |= VHOST_USER_REPLY_MASK;
1496 
1497         payload.u64 = !!ret;
1498         hdr.size = sizeof(payload.u64);
1499 
1500         iovec[0].iov_base = &hdr;
1501         iovec[0].iov_len = VHOST_USER_HDR_SIZE;
1502         iovec[1].iov_base = &payload;
1503         iovec[1].iov_len = hdr.size;
1504 
1505         do {
1506             size = writev(u->slave_fd, iovec, ARRAY_SIZE(iovec));
1507         } while (size < 0 && (errno == EINTR || errno == EAGAIN));
1508 
1509         if (size != VHOST_USER_HDR_SIZE + hdr.size) {
1510             error_report("Failed to send msg reply to slave.");
1511             goto err;
1512         }
1513     }
1514 
1515     return;
1516 
1517 err:
1518     qemu_set_fd_handler(u->slave_fd, NULL, NULL, NULL);
1519     close(u->slave_fd);
1520     u->slave_fd = -1;
1521     for (i = 0; i < fdsize; i++) {
1522         if (fd[i] != -1) {
1523             close(fd[i]);
1524         }
1525     }
1526     return;
1527 }
1528 
1529 static int vhost_setup_slave_channel(struct vhost_dev *dev)
1530 {
1531     VhostUserMsg msg = {
1532         .hdr.request = VHOST_USER_SET_SLAVE_REQ_FD,
1533         .hdr.flags = VHOST_USER_VERSION,
1534     };
1535     struct vhost_user *u = dev->opaque;
1536     int sv[2], ret = 0;
1537     bool reply_supported = virtio_has_feature(dev->protocol_features,
1538                                               VHOST_USER_PROTOCOL_F_REPLY_ACK);
1539 
1540     if (!virtio_has_feature(dev->protocol_features,
1541                             VHOST_USER_PROTOCOL_F_SLAVE_REQ)) {
1542         return 0;
1543     }
1544 
1545     if (socketpair(PF_UNIX, SOCK_STREAM, 0, sv) == -1) {
1546         error_report("socketpair() failed");
1547         return -1;
1548     }
1549 
1550     u->slave_fd = sv[0];
1551     qemu_set_fd_handler(u->slave_fd, slave_read, NULL, dev);
1552 
1553     if (reply_supported) {
1554         msg.hdr.flags |= VHOST_USER_NEED_REPLY_MASK;
1555     }
1556 
1557     ret = vhost_user_write(dev, &msg, &sv[1], 1);
1558     if (ret) {
1559         goto out;
1560     }
1561 
1562     if (reply_supported) {
1563         ret = process_message_reply(dev, &msg);
1564     }
1565 
1566 out:
1567     close(sv[1]);
1568     if (ret) {
1569         qemu_set_fd_handler(u->slave_fd, NULL, NULL, NULL);
1570         close(u->slave_fd);
1571         u->slave_fd = -1;
1572     }
1573 
1574     return ret;
1575 }
1576 
1577 #ifdef CONFIG_LINUX
1578 /*
1579  * Called back from the postcopy fault thread when a fault is received on our
1580  * ufd.
1581  * TODO: This is Linux specific
1582  */
1583 static int vhost_user_postcopy_fault_handler(struct PostCopyFD *pcfd,
1584                                              void *ufd)
1585 {
1586     struct vhost_dev *dev = pcfd->data;
1587     struct vhost_user *u = dev->opaque;
1588     struct uffd_msg *msg = ufd;
1589     uint64_t faultaddr = msg->arg.pagefault.address;
1590     RAMBlock *rb = NULL;
1591     uint64_t rb_offset;
1592     int i;
1593 
1594     trace_vhost_user_postcopy_fault_handler(pcfd->idstr, faultaddr,
1595                                             dev->mem->nregions);
1596     for (i = 0; i < MIN(dev->mem->nregions, u->region_rb_len); i++) {
1597         trace_vhost_user_postcopy_fault_handler_loop(i,
1598                 u->postcopy_client_bases[i], dev->mem->regions[i].memory_size);
1599         if (faultaddr >= u->postcopy_client_bases[i]) {
1600             /* Ofset of the fault address in the vhost region */
1601             uint64_t region_offset = faultaddr - u->postcopy_client_bases[i];
1602             if (region_offset < dev->mem->regions[i].memory_size) {
1603                 rb_offset = region_offset + u->region_rb_offset[i];
1604                 trace_vhost_user_postcopy_fault_handler_found(i,
1605                         region_offset, rb_offset);
1606                 rb = u->region_rb[i];
1607                 return postcopy_request_shared_page(pcfd, rb, faultaddr,
1608                                                     rb_offset);
1609             }
1610         }
1611     }
1612     error_report("%s: Failed to find region for fault %" PRIx64,
1613                  __func__, faultaddr);
1614     return -1;
1615 }
1616 
1617 static int vhost_user_postcopy_waker(struct PostCopyFD *pcfd, RAMBlock *rb,
1618                                      uint64_t offset)
1619 {
1620     struct vhost_dev *dev = pcfd->data;
1621     struct vhost_user *u = dev->opaque;
1622     int i;
1623 
1624     trace_vhost_user_postcopy_waker(qemu_ram_get_idstr(rb), offset);
1625 
1626     if (!u) {
1627         return 0;
1628     }
1629     /* Translate the offset into an address in the clients address space */
1630     for (i = 0; i < MIN(dev->mem->nregions, u->region_rb_len); i++) {
1631         if (u->region_rb[i] == rb &&
1632             offset >= u->region_rb_offset[i] &&
1633             offset < (u->region_rb_offset[i] +
1634                       dev->mem->regions[i].memory_size)) {
1635             uint64_t client_addr = (offset - u->region_rb_offset[i]) +
1636                                    u->postcopy_client_bases[i];
1637             trace_vhost_user_postcopy_waker_found(client_addr);
1638             return postcopy_wake_shared(pcfd, client_addr, rb);
1639         }
1640     }
1641 
1642     trace_vhost_user_postcopy_waker_nomatch(qemu_ram_get_idstr(rb), offset);
1643     return 0;
1644 }
1645 #endif
1646 
1647 /*
1648  * Called at the start of an inbound postcopy on reception of the
1649  * 'advise' command.
1650  */
1651 static int vhost_user_postcopy_advise(struct vhost_dev *dev, Error **errp)
1652 {
1653 #ifdef CONFIG_LINUX
1654     struct vhost_user *u = dev->opaque;
1655     CharBackend *chr = u->user->chr;
1656     int ufd;
1657     VhostUserMsg msg = {
1658         .hdr.request = VHOST_USER_POSTCOPY_ADVISE,
1659         .hdr.flags = VHOST_USER_VERSION,
1660     };
1661 
1662     if (vhost_user_write(dev, &msg, NULL, 0) < 0) {
1663         error_setg(errp, "Failed to send postcopy_advise to vhost");
1664         return -1;
1665     }
1666 
1667     if (vhost_user_read(dev, &msg) < 0) {
1668         error_setg(errp, "Failed to get postcopy_advise reply from vhost");
1669         return -1;
1670     }
1671 
1672     if (msg.hdr.request != VHOST_USER_POSTCOPY_ADVISE) {
1673         error_setg(errp, "Unexpected msg type. Expected %d received %d",
1674                      VHOST_USER_POSTCOPY_ADVISE, msg.hdr.request);
1675         return -1;
1676     }
1677 
1678     if (msg.hdr.size) {
1679         error_setg(errp, "Received bad msg size.");
1680         return -1;
1681     }
1682     ufd = qemu_chr_fe_get_msgfd(chr);
1683     if (ufd < 0) {
1684         error_setg(errp, "%s: Failed to get ufd", __func__);
1685         return -1;
1686     }
1687     qemu_set_nonblock(ufd);
1688 
1689     /* register ufd with userfault thread */
1690     u->postcopy_fd.fd = ufd;
1691     u->postcopy_fd.data = dev;
1692     u->postcopy_fd.handler = vhost_user_postcopy_fault_handler;
1693     u->postcopy_fd.waker = vhost_user_postcopy_waker;
1694     u->postcopy_fd.idstr = "vhost-user"; /* Need to find unique name */
1695     postcopy_register_shared_ufd(&u->postcopy_fd);
1696     return 0;
1697 #else
1698     error_setg(errp, "Postcopy not supported on non-Linux systems");
1699     return -1;
1700 #endif
1701 }
1702 
1703 /*
1704  * Called at the switch to postcopy on reception of the 'listen' command.
1705  */
1706 static int vhost_user_postcopy_listen(struct vhost_dev *dev, Error **errp)
1707 {
1708     struct vhost_user *u = dev->opaque;
1709     int ret;
1710     VhostUserMsg msg = {
1711         .hdr.request = VHOST_USER_POSTCOPY_LISTEN,
1712         .hdr.flags = VHOST_USER_VERSION | VHOST_USER_NEED_REPLY_MASK,
1713     };
1714     u->postcopy_listen = true;
1715     trace_vhost_user_postcopy_listen();
1716     if (vhost_user_write(dev, &msg, NULL, 0) < 0) {
1717         error_setg(errp, "Failed to send postcopy_listen to vhost");
1718         return -1;
1719     }
1720 
1721     ret = process_message_reply(dev, &msg);
1722     if (ret) {
1723         error_setg(errp, "Failed to receive reply to postcopy_listen");
1724         return ret;
1725     }
1726 
1727     return 0;
1728 }
1729 
1730 /*
1731  * Called at the end of postcopy
1732  */
1733 static int vhost_user_postcopy_end(struct vhost_dev *dev, Error **errp)
1734 {
1735     VhostUserMsg msg = {
1736         .hdr.request = VHOST_USER_POSTCOPY_END,
1737         .hdr.flags = VHOST_USER_VERSION | VHOST_USER_NEED_REPLY_MASK,
1738     };
1739     int ret;
1740     struct vhost_user *u = dev->opaque;
1741 
1742     trace_vhost_user_postcopy_end_entry();
1743     if (vhost_user_write(dev, &msg, NULL, 0) < 0) {
1744         error_setg(errp, "Failed to send postcopy_end to vhost");
1745         return -1;
1746     }
1747 
1748     ret = process_message_reply(dev, &msg);
1749     if (ret) {
1750         error_setg(errp, "Failed to receive reply to postcopy_end");
1751         return ret;
1752     }
1753     postcopy_unregister_shared_ufd(&u->postcopy_fd);
1754     close(u->postcopy_fd.fd);
1755     u->postcopy_fd.handler = NULL;
1756 
1757     trace_vhost_user_postcopy_end_exit();
1758 
1759     return 0;
1760 }
1761 
1762 static int vhost_user_postcopy_notifier(NotifierWithReturn *notifier,
1763                                         void *opaque)
1764 {
1765     struct PostcopyNotifyData *pnd = opaque;
1766     struct vhost_user *u = container_of(notifier, struct vhost_user,
1767                                          postcopy_notifier);
1768     struct vhost_dev *dev = u->dev;
1769 
1770     switch (pnd->reason) {
1771     case POSTCOPY_NOTIFY_PROBE:
1772         if (!virtio_has_feature(dev->protocol_features,
1773                                 VHOST_USER_PROTOCOL_F_PAGEFAULT)) {
1774             /* TODO: Get the device name into this error somehow */
1775             error_setg(pnd->errp,
1776                        "vhost-user backend not capable of postcopy");
1777             return -ENOENT;
1778         }
1779         break;
1780 
1781     case POSTCOPY_NOTIFY_INBOUND_ADVISE:
1782         return vhost_user_postcopy_advise(dev, pnd->errp);
1783 
1784     case POSTCOPY_NOTIFY_INBOUND_LISTEN:
1785         return vhost_user_postcopy_listen(dev, pnd->errp);
1786 
1787     case POSTCOPY_NOTIFY_INBOUND_END:
1788         return vhost_user_postcopy_end(dev, pnd->errp);
1789 
1790     default:
1791         /* We ignore notifications we don't know */
1792         break;
1793     }
1794 
1795     return 0;
1796 }
1797 
1798 static int vhost_user_backend_init(struct vhost_dev *dev, void *opaque)
1799 {
1800     uint64_t features, protocol_features, ram_slots;
1801     struct vhost_user *u;
1802     int err;
1803 
1804     assert(dev->vhost_ops->backend_type == VHOST_BACKEND_TYPE_USER);
1805 
1806     u = g_new0(struct vhost_user, 1);
1807     u->user = opaque;
1808     u->slave_fd = -1;
1809     u->dev = dev;
1810     dev->opaque = u;
1811 
1812     err = vhost_user_get_features(dev, &features);
1813     if (err < 0) {
1814         return err;
1815     }
1816 
1817     if (virtio_has_feature(features, VHOST_USER_F_PROTOCOL_FEATURES)) {
1818         dev->backend_features |= 1ULL << VHOST_USER_F_PROTOCOL_FEATURES;
1819 
1820         err = vhost_user_get_u64(dev, VHOST_USER_GET_PROTOCOL_FEATURES,
1821                                  &protocol_features);
1822         if (err < 0) {
1823             return err;
1824         }
1825 
1826         dev->protocol_features =
1827             protocol_features & VHOST_USER_PROTOCOL_FEATURE_MASK;
1828 
1829         if (!dev->config_ops || !dev->config_ops->vhost_dev_config_notifier) {
1830             /* Don't acknowledge CONFIG feature if device doesn't support it */
1831             dev->protocol_features &= ~(1ULL << VHOST_USER_PROTOCOL_F_CONFIG);
1832         } else if (!(protocol_features &
1833                     (1ULL << VHOST_USER_PROTOCOL_F_CONFIG))) {
1834             error_report("Device expects VHOST_USER_PROTOCOL_F_CONFIG "
1835                     "but backend does not support it.");
1836             return -1;
1837         }
1838 
1839         err = vhost_user_set_protocol_features(dev, dev->protocol_features);
1840         if (err < 0) {
1841             return err;
1842         }
1843 
1844         /* query the max queues we support if backend supports Multiple Queue */
1845         if (dev->protocol_features & (1ULL << VHOST_USER_PROTOCOL_F_MQ)) {
1846             err = vhost_user_get_u64(dev, VHOST_USER_GET_QUEUE_NUM,
1847                                      &dev->max_queues);
1848             if (err < 0) {
1849                 return err;
1850             }
1851         }
1852 
1853         if (virtio_has_feature(features, VIRTIO_F_IOMMU_PLATFORM) &&
1854                 !(virtio_has_feature(dev->protocol_features,
1855                     VHOST_USER_PROTOCOL_F_SLAVE_REQ) &&
1856                  virtio_has_feature(dev->protocol_features,
1857                     VHOST_USER_PROTOCOL_F_REPLY_ACK))) {
1858             error_report("IOMMU support requires reply-ack and "
1859                          "slave-req protocol features.");
1860             return -1;
1861         }
1862 
1863         /* get max memory regions if backend supports configurable RAM slots */
1864         if (!virtio_has_feature(dev->protocol_features,
1865                                 VHOST_USER_PROTOCOL_F_CONFIGURE_MEM_SLOTS)) {
1866             u->user->memory_slots = VHOST_MEMORY_BASELINE_NREGIONS;
1867         } else {
1868             err = vhost_user_get_max_memslots(dev, &ram_slots);
1869             if (err < 0) {
1870                 return err;
1871             }
1872 
1873             if (ram_slots < u->user->memory_slots) {
1874                 error_report("The backend specified a max ram slots limit "
1875                              "of %" PRIu64", when the prior validated limit was %d. "
1876                              "This limit should never decrease.", ram_slots,
1877                              u->user->memory_slots);
1878                 return -1;
1879             }
1880 
1881             u->user->memory_slots = MIN(ram_slots, VHOST_USER_MAX_RAM_SLOTS);
1882         }
1883     }
1884 
1885     if (dev->migration_blocker == NULL &&
1886         !virtio_has_feature(dev->protocol_features,
1887                             VHOST_USER_PROTOCOL_F_LOG_SHMFD)) {
1888         error_setg(&dev->migration_blocker,
1889                    "Migration disabled: vhost-user backend lacks "
1890                    "VHOST_USER_PROTOCOL_F_LOG_SHMFD feature.");
1891     }
1892 
1893     if (dev->vq_index == 0) {
1894         err = vhost_setup_slave_channel(dev);
1895         if (err < 0) {
1896             return err;
1897         }
1898     }
1899 
1900     u->postcopy_notifier.notify = vhost_user_postcopy_notifier;
1901     postcopy_add_notifier(&u->postcopy_notifier);
1902 
1903     return 0;
1904 }
1905 
1906 static int vhost_user_backend_cleanup(struct vhost_dev *dev)
1907 {
1908     struct vhost_user *u;
1909 
1910     assert(dev->vhost_ops->backend_type == VHOST_BACKEND_TYPE_USER);
1911 
1912     u = dev->opaque;
1913     if (u->postcopy_notifier.notify) {
1914         postcopy_remove_notifier(&u->postcopy_notifier);
1915         u->postcopy_notifier.notify = NULL;
1916     }
1917     u->postcopy_listen = false;
1918     if (u->postcopy_fd.handler) {
1919         postcopy_unregister_shared_ufd(&u->postcopy_fd);
1920         close(u->postcopy_fd.fd);
1921         u->postcopy_fd.handler = NULL;
1922     }
1923     if (u->slave_fd >= 0) {
1924         qemu_set_fd_handler(u->slave_fd, NULL, NULL, NULL);
1925         close(u->slave_fd);
1926         u->slave_fd = -1;
1927     }
1928     g_free(u->region_rb);
1929     u->region_rb = NULL;
1930     g_free(u->region_rb_offset);
1931     u->region_rb_offset = NULL;
1932     u->region_rb_len = 0;
1933     g_free(u);
1934     dev->opaque = 0;
1935 
1936     return 0;
1937 }
1938 
1939 static int vhost_user_get_vq_index(struct vhost_dev *dev, int idx)
1940 {
1941     assert(idx >= dev->vq_index && idx < dev->vq_index + dev->nvqs);
1942 
1943     return idx;
1944 }
1945 
1946 static int vhost_user_memslots_limit(struct vhost_dev *dev)
1947 {
1948     struct vhost_user *u = dev->opaque;
1949 
1950     return u->user->memory_slots;
1951 }
1952 
1953 static bool vhost_user_requires_shm_log(struct vhost_dev *dev)
1954 {
1955     assert(dev->vhost_ops->backend_type == VHOST_BACKEND_TYPE_USER);
1956 
1957     return virtio_has_feature(dev->protocol_features,
1958                               VHOST_USER_PROTOCOL_F_LOG_SHMFD);
1959 }
1960 
1961 static int vhost_user_migration_done(struct vhost_dev *dev, char* mac_addr)
1962 {
1963     VhostUserMsg msg = { };
1964 
1965     assert(dev->vhost_ops->backend_type == VHOST_BACKEND_TYPE_USER);
1966 
1967     /* If guest supports GUEST_ANNOUNCE do nothing */
1968     if (virtio_has_feature(dev->acked_features, VIRTIO_NET_F_GUEST_ANNOUNCE)) {
1969         return 0;
1970     }
1971 
1972     /* if backend supports VHOST_USER_PROTOCOL_F_RARP ask it to send the RARP */
1973     if (virtio_has_feature(dev->protocol_features,
1974                            VHOST_USER_PROTOCOL_F_RARP)) {
1975         msg.hdr.request = VHOST_USER_SEND_RARP;
1976         msg.hdr.flags = VHOST_USER_VERSION;
1977         memcpy((char *)&msg.payload.u64, mac_addr, 6);
1978         msg.hdr.size = sizeof(msg.payload.u64);
1979 
1980         return vhost_user_write(dev, &msg, NULL, 0);
1981     }
1982     return -1;
1983 }
1984 
1985 static bool vhost_user_can_merge(struct vhost_dev *dev,
1986                                  uint64_t start1, uint64_t size1,
1987                                  uint64_t start2, uint64_t size2)
1988 {
1989     ram_addr_t offset;
1990     int mfd, rfd;
1991 
1992     (void)vhost_user_get_mr_data(start1, &offset, &mfd);
1993     (void)vhost_user_get_mr_data(start2, &offset, &rfd);
1994 
1995     return mfd == rfd;
1996 }
1997 
1998 static int vhost_user_net_set_mtu(struct vhost_dev *dev, uint16_t mtu)
1999 {
2000     VhostUserMsg msg;
2001     bool reply_supported = virtio_has_feature(dev->protocol_features,
2002                                               VHOST_USER_PROTOCOL_F_REPLY_ACK);
2003 
2004     if (!(dev->protocol_features & (1ULL << VHOST_USER_PROTOCOL_F_NET_MTU))) {
2005         return 0;
2006     }
2007 
2008     msg.hdr.request = VHOST_USER_NET_SET_MTU;
2009     msg.payload.u64 = mtu;
2010     msg.hdr.size = sizeof(msg.payload.u64);
2011     msg.hdr.flags = VHOST_USER_VERSION;
2012     if (reply_supported) {
2013         msg.hdr.flags |= VHOST_USER_NEED_REPLY_MASK;
2014     }
2015 
2016     if (vhost_user_write(dev, &msg, NULL, 0) < 0) {
2017         return -1;
2018     }
2019 
2020     /* If reply_ack supported, slave has to ack specified MTU is valid */
2021     if (reply_supported) {
2022         return process_message_reply(dev, &msg);
2023     }
2024 
2025     return 0;
2026 }
2027 
2028 static int vhost_user_send_device_iotlb_msg(struct vhost_dev *dev,
2029                                             struct vhost_iotlb_msg *imsg)
2030 {
2031     VhostUserMsg msg = {
2032         .hdr.request = VHOST_USER_IOTLB_MSG,
2033         .hdr.size = sizeof(msg.payload.iotlb),
2034         .hdr.flags = VHOST_USER_VERSION | VHOST_USER_NEED_REPLY_MASK,
2035         .payload.iotlb = *imsg,
2036     };
2037 
2038     if (vhost_user_write(dev, &msg, NULL, 0) < 0) {
2039         return -EFAULT;
2040     }
2041 
2042     return process_message_reply(dev, &msg);
2043 }
2044 
2045 
2046 static void vhost_user_set_iotlb_callback(struct vhost_dev *dev, int enabled)
2047 {
2048     /* No-op as the receive channel is not dedicated to IOTLB messages. */
2049 }
2050 
2051 static int vhost_user_get_config(struct vhost_dev *dev, uint8_t *config,
2052                                  uint32_t config_len)
2053 {
2054     VhostUserMsg msg = {
2055         .hdr.request = VHOST_USER_GET_CONFIG,
2056         .hdr.flags = VHOST_USER_VERSION,
2057         .hdr.size = VHOST_USER_CONFIG_HDR_SIZE + config_len,
2058     };
2059 
2060     if (!virtio_has_feature(dev->protocol_features,
2061                 VHOST_USER_PROTOCOL_F_CONFIG)) {
2062         return -1;
2063     }
2064 
2065     if (config_len > VHOST_USER_MAX_CONFIG_SIZE) {
2066         return -1;
2067     }
2068 
2069     msg.payload.config.offset = 0;
2070     msg.payload.config.size = config_len;
2071     if (vhost_user_write(dev, &msg, NULL, 0) < 0) {
2072         return -1;
2073     }
2074 
2075     if (vhost_user_read(dev, &msg) < 0) {
2076         return -1;
2077     }
2078 
2079     if (msg.hdr.request != VHOST_USER_GET_CONFIG) {
2080         error_report("Received unexpected msg type. Expected %d received %d",
2081                      VHOST_USER_GET_CONFIG, msg.hdr.request);
2082         return -1;
2083     }
2084 
2085     if (msg.hdr.size != VHOST_USER_CONFIG_HDR_SIZE + config_len) {
2086         error_report("Received bad msg size.");
2087         return -1;
2088     }
2089 
2090     memcpy(config, msg.payload.config.region, config_len);
2091 
2092     return 0;
2093 }
2094 
2095 static int vhost_user_set_config(struct vhost_dev *dev, const uint8_t *data,
2096                                  uint32_t offset, uint32_t size, uint32_t flags)
2097 {
2098     uint8_t *p;
2099     bool reply_supported = virtio_has_feature(dev->protocol_features,
2100                                               VHOST_USER_PROTOCOL_F_REPLY_ACK);
2101 
2102     VhostUserMsg msg = {
2103         .hdr.request = VHOST_USER_SET_CONFIG,
2104         .hdr.flags = VHOST_USER_VERSION,
2105         .hdr.size = VHOST_USER_CONFIG_HDR_SIZE + size,
2106     };
2107 
2108     if (!virtio_has_feature(dev->protocol_features,
2109                 VHOST_USER_PROTOCOL_F_CONFIG)) {
2110         return -1;
2111     }
2112 
2113     if (reply_supported) {
2114         msg.hdr.flags |= VHOST_USER_NEED_REPLY_MASK;
2115     }
2116 
2117     if (size > VHOST_USER_MAX_CONFIG_SIZE) {
2118         return -1;
2119     }
2120 
2121     msg.payload.config.offset = offset,
2122     msg.payload.config.size = size,
2123     msg.payload.config.flags = flags,
2124     p = msg.payload.config.region;
2125     memcpy(p, data, size);
2126 
2127     if (vhost_user_write(dev, &msg, NULL, 0) < 0) {
2128         return -1;
2129     }
2130 
2131     if (reply_supported) {
2132         return process_message_reply(dev, &msg);
2133     }
2134 
2135     return 0;
2136 }
2137 
2138 static int vhost_user_crypto_create_session(struct vhost_dev *dev,
2139                                             void *session_info,
2140                                             uint64_t *session_id)
2141 {
2142     bool crypto_session = virtio_has_feature(dev->protocol_features,
2143                                        VHOST_USER_PROTOCOL_F_CRYPTO_SESSION);
2144     CryptoDevBackendSymSessionInfo *sess_info = session_info;
2145     VhostUserMsg msg = {
2146         .hdr.request = VHOST_USER_CREATE_CRYPTO_SESSION,
2147         .hdr.flags = VHOST_USER_VERSION,
2148         .hdr.size = sizeof(msg.payload.session),
2149     };
2150 
2151     assert(dev->vhost_ops->backend_type == VHOST_BACKEND_TYPE_USER);
2152 
2153     if (!crypto_session) {
2154         error_report("vhost-user trying to send unhandled ioctl");
2155         return -1;
2156     }
2157 
2158     memcpy(&msg.payload.session.session_setup_data, sess_info,
2159               sizeof(CryptoDevBackendSymSessionInfo));
2160     if (sess_info->key_len) {
2161         memcpy(&msg.payload.session.key, sess_info->cipher_key,
2162                sess_info->key_len);
2163     }
2164     if (sess_info->auth_key_len > 0) {
2165         memcpy(&msg.payload.session.auth_key, sess_info->auth_key,
2166                sess_info->auth_key_len);
2167     }
2168     if (vhost_user_write(dev, &msg, NULL, 0) < 0) {
2169         error_report("vhost_user_write() return -1, create session failed");
2170         return -1;
2171     }
2172 
2173     if (vhost_user_read(dev, &msg) < 0) {
2174         error_report("vhost_user_read() return -1, create session failed");
2175         return -1;
2176     }
2177 
2178     if (msg.hdr.request != VHOST_USER_CREATE_CRYPTO_SESSION) {
2179         error_report("Received unexpected msg type. Expected %d received %d",
2180                      VHOST_USER_CREATE_CRYPTO_SESSION, msg.hdr.request);
2181         return -1;
2182     }
2183 
2184     if (msg.hdr.size != sizeof(msg.payload.session)) {
2185         error_report("Received bad msg size.");
2186         return -1;
2187     }
2188 
2189     if (msg.payload.session.session_id < 0) {
2190         error_report("Bad session id: %" PRId64 "",
2191                               msg.payload.session.session_id);
2192         return -1;
2193     }
2194     *session_id = msg.payload.session.session_id;
2195 
2196     return 0;
2197 }
2198 
2199 static int
2200 vhost_user_crypto_close_session(struct vhost_dev *dev, uint64_t session_id)
2201 {
2202     bool crypto_session = virtio_has_feature(dev->protocol_features,
2203                                        VHOST_USER_PROTOCOL_F_CRYPTO_SESSION);
2204     VhostUserMsg msg = {
2205         .hdr.request = VHOST_USER_CLOSE_CRYPTO_SESSION,
2206         .hdr.flags = VHOST_USER_VERSION,
2207         .hdr.size = sizeof(msg.payload.u64),
2208     };
2209     msg.payload.u64 = session_id;
2210 
2211     if (!crypto_session) {
2212         error_report("vhost-user trying to send unhandled ioctl");
2213         return -1;
2214     }
2215 
2216     if (vhost_user_write(dev, &msg, NULL, 0) < 0) {
2217         error_report("vhost_user_write() return -1, close session failed");
2218         return -1;
2219     }
2220 
2221     return 0;
2222 }
2223 
2224 static bool vhost_user_mem_section_filter(struct vhost_dev *dev,
2225                                           MemoryRegionSection *section)
2226 {
2227     bool result;
2228 
2229     result = memory_region_get_fd(section->mr) >= 0;
2230 
2231     return result;
2232 }
2233 
2234 static int vhost_user_get_inflight_fd(struct vhost_dev *dev,
2235                                       uint16_t queue_size,
2236                                       struct vhost_inflight *inflight)
2237 {
2238     void *addr;
2239     int fd;
2240     struct vhost_user *u = dev->opaque;
2241     CharBackend *chr = u->user->chr;
2242     VhostUserMsg msg = {
2243         .hdr.request = VHOST_USER_GET_INFLIGHT_FD,
2244         .hdr.flags = VHOST_USER_VERSION,
2245         .payload.inflight.num_queues = dev->nvqs,
2246         .payload.inflight.queue_size = queue_size,
2247         .hdr.size = sizeof(msg.payload.inflight),
2248     };
2249 
2250     if (!virtio_has_feature(dev->protocol_features,
2251                             VHOST_USER_PROTOCOL_F_INFLIGHT_SHMFD)) {
2252         return 0;
2253     }
2254 
2255     if (vhost_user_write(dev, &msg, NULL, 0) < 0) {
2256         return -1;
2257     }
2258 
2259     if (vhost_user_read(dev, &msg) < 0) {
2260         return -1;
2261     }
2262 
2263     if (msg.hdr.request != VHOST_USER_GET_INFLIGHT_FD) {
2264         error_report("Received unexpected msg type. "
2265                      "Expected %d received %d",
2266                      VHOST_USER_GET_INFLIGHT_FD, msg.hdr.request);
2267         return -1;
2268     }
2269 
2270     if (msg.hdr.size != sizeof(msg.payload.inflight)) {
2271         error_report("Received bad msg size.");
2272         return -1;
2273     }
2274 
2275     if (!msg.payload.inflight.mmap_size) {
2276         return 0;
2277     }
2278 
2279     fd = qemu_chr_fe_get_msgfd(chr);
2280     if (fd < 0) {
2281         error_report("Failed to get mem fd");
2282         return -1;
2283     }
2284 
2285     addr = mmap(0, msg.payload.inflight.mmap_size, PROT_READ | PROT_WRITE,
2286                 MAP_SHARED, fd, msg.payload.inflight.mmap_offset);
2287 
2288     if (addr == MAP_FAILED) {
2289         error_report("Failed to mmap mem fd");
2290         close(fd);
2291         return -1;
2292     }
2293 
2294     inflight->addr = addr;
2295     inflight->fd = fd;
2296     inflight->size = msg.payload.inflight.mmap_size;
2297     inflight->offset = msg.payload.inflight.mmap_offset;
2298     inflight->queue_size = queue_size;
2299 
2300     return 0;
2301 }
2302 
2303 static int vhost_user_set_inflight_fd(struct vhost_dev *dev,
2304                                       struct vhost_inflight *inflight)
2305 {
2306     VhostUserMsg msg = {
2307         .hdr.request = VHOST_USER_SET_INFLIGHT_FD,
2308         .hdr.flags = VHOST_USER_VERSION,
2309         .payload.inflight.mmap_size = inflight->size,
2310         .payload.inflight.mmap_offset = inflight->offset,
2311         .payload.inflight.num_queues = dev->nvqs,
2312         .payload.inflight.queue_size = inflight->queue_size,
2313         .hdr.size = sizeof(msg.payload.inflight),
2314     };
2315 
2316     if (!virtio_has_feature(dev->protocol_features,
2317                             VHOST_USER_PROTOCOL_F_INFLIGHT_SHMFD)) {
2318         return 0;
2319     }
2320 
2321     if (vhost_user_write(dev, &msg, &inflight->fd, 1) < 0) {
2322         return -1;
2323     }
2324 
2325     return 0;
2326 }
2327 
2328 bool vhost_user_init(VhostUserState *user, CharBackend *chr, Error **errp)
2329 {
2330     if (user->chr) {
2331         error_setg(errp, "Cannot initialize vhost-user state");
2332         return false;
2333     }
2334     user->chr = chr;
2335     user->memory_slots = 0;
2336     return true;
2337 }
2338 
2339 void vhost_user_cleanup(VhostUserState *user)
2340 {
2341     int i;
2342 
2343     if (!user->chr) {
2344         return;
2345     }
2346 
2347     for (i = 0; i < VIRTIO_QUEUE_MAX; i++) {
2348         if (user->notifier[i].addr) {
2349             object_unparent(OBJECT(&user->notifier[i].mr));
2350             munmap(user->notifier[i].addr, qemu_real_host_page_size);
2351             user->notifier[i].addr = NULL;
2352         }
2353     }
2354     user->chr = NULL;
2355 }
2356 
2357 const VhostOps user_ops = {
2358         .backend_type = VHOST_BACKEND_TYPE_USER,
2359         .vhost_backend_init = vhost_user_backend_init,
2360         .vhost_backend_cleanup = vhost_user_backend_cleanup,
2361         .vhost_backend_memslots_limit = vhost_user_memslots_limit,
2362         .vhost_set_log_base = vhost_user_set_log_base,
2363         .vhost_set_mem_table = vhost_user_set_mem_table,
2364         .vhost_set_vring_addr = vhost_user_set_vring_addr,
2365         .vhost_set_vring_endian = vhost_user_set_vring_endian,
2366         .vhost_set_vring_num = vhost_user_set_vring_num,
2367         .vhost_set_vring_base = vhost_user_set_vring_base,
2368         .vhost_get_vring_base = vhost_user_get_vring_base,
2369         .vhost_set_vring_kick = vhost_user_set_vring_kick,
2370         .vhost_set_vring_call = vhost_user_set_vring_call,
2371         .vhost_set_features = vhost_user_set_features,
2372         .vhost_get_features = vhost_user_get_features,
2373         .vhost_set_owner = vhost_user_set_owner,
2374         .vhost_reset_device = vhost_user_reset_device,
2375         .vhost_get_vq_index = vhost_user_get_vq_index,
2376         .vhost_set_vring_enable = vhost_user_set_vring_enable,
2377         .vhost_requires_shm_log = vhost_user_requires_shm_log,
2378         .vhost_migration_done = vhost_user_migration_done,
2379         .vhost_backend_can_merge = vhost_user_can_merge,
2380         .vhost_net_set_mtu = vhost_user_net_set_mtu,
2381         .vhost_set_iotlb_callback = vhost_user_set_iotlb_callback,
2382         .vhost_send_device_iotlb_msg = vhost_user_send_device_iotlb_msg,
2383         .vhost_get_config = vhost_user_get_config,
2384         .vhost_set_config = vhost_user_set_config,
2385         .vhost_crypto_create_session = vhost_user_crypto_create_session,
2386         .vhost_crypto_close_session = vhost_user_crypto_close_session,
2387         .vhost_backend_mem_section_filter = vhost_user_mem_section_filter,
2388         .vhost_get_inflight_fd = vhost_user_get_inflight_fd,
2389         .vhost_set_inflight_fd = vhost_user_set_inflight_fd,
2390 };
2391