xref: /openbmc/qemu/hw/virtio/virtio-iommu.c (revision 38472890)
1 /*
2  * virtio-iommu device
3  *
4  * Copyright (c) 2020 Red Hat, Inc.
5  *
6  * This program is free software; you can redistribute it and/or modify it
7  * under the terms and conditions of the GNU General Public License,
8  * version 2 or later, as published by the Free Software Foundation.
9  *
10  * This program is distributed in the hope it will be useful, but WITHOUT
11  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
12  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License for
13  * more details.
14  *
15  * You should have received a copy of the GNU General Public License along with
16  * this program.  If not, see <http://www.gnu.org/licenses/>.
17  *
18  */
19 
20 #include "qemu/osdep.h"
21 #include "qemu/log.h"
22 #include "qemu/iov.h"
23 #include "qemu-common.h"
24 #include "hw/qdev-properties.h"
25 #include "hw/virtio/virtio.h"
26 #include "sysemu/kvm.h"
27 #include "qapi/error.h"
28 #include "qemu/error-report.h"
29 #include "trace.h"
30 
31 #include "standard-headers/linux/virtio_ids.h"
32 
33 #include "hw/virtio/virtio-bus.h"
34 #include "hw/virtio/virtio-access.h"
35 #include "hw/virtio/virtio-iommu.h"
36 #include "hw/pci/pci_bus.h"
37 #include "hw/pci/pci.h"
38 
39 /* Max size */
40 #define VIOMMU_DEFAULT_QUEUE_SIZE 256
41 
42 typedef struct VirtIOIOMMUDomain {
43     uint32_t id;
44     GTree *mappings;
45     QLIST_HEAD(, VirtIOIOMMUEndpoint) endpoint_list;
46 } VirtIOIOMMUDomain;
47 
48 typedef struct VirtIOIOMMUEndpoint {
49     uint32_t id;
50     VirtIOIOMMUDomain *domain;
51     QLIST_ENTRY(VirtIOIOMMUEndpoint) next;
52 } VirtIOIOMMUEndpoint;
53 
54 typedef struct VirtIOIOMMUInterval {
55     uint64_t low;
56     uint64_t high;
57 } VirtIOIOMMUInterval;
58 
59 typedef struct VirtIOIOMMUMapping {
60     uint64_t phys_addr;
61     uint32_t flags;
62 } VirtIOIOMMUMapping;
63 
64 static inline uint16_t virtio_iommu_get_bdf(IOMMUDevice *dev)
65 {
66     return PCI_BUILD_BDF(pci_bus_num(dev->bus), dev->devfn);
67 }
68 
69 /**
70  * The bus number is used for lookup when SID based operations occur.
71  * In that case we lazily populate the IOMMUPciBus array from the bus hash
72  * table. At the time the IOMMUPciBus is created (iommu_find_add_as), the bus
73  * numbers may not be always initialized yet.
74  */
75 static IOMMUPciBus *iommu_find_iommu_pcibus(VirtIOIOMMU *s, uint8_t bus_num)
76 {
77     IOMMUPciBus *iommu_pci_bus = s->iommu_pcibus_by_bus_num[bus_num];
78 
79     if (!iommu_pci_bus) {
80         GHashTableIter iter;
81 
82         g_hash_table_iter_init(&iter, s->as_by_busptr);
83         while (g_hash_table_iter_next(&iter, NULL, (void **)&iommu_pci_bus)) {
84             if (pci_bus_num(iommu_pci_bus->bus) == bus_num) {
85                 s->iommu_pcibus_by_bus_num[bus_num] = iommu_pci_bus;
86                 return iommu_pci_bus;
87             }
88         }
89         return NULL;
90     }
91     return iommu_pci_bus;
92 }
93 
94 static IOMMUMemoryRegion *virtio_iommu_mr(VirtIOIOMMU *s, uint32_t sid)
95 {
96     uint8_t bus_n, devfn;
97     IOMMUPciBus *iommu_pci_bus;
98     IOMMUDevice *dev;
99 
100     bus_n = PCI_BUS_NUM(sid);
101     iommu_pci_bus = iommu_find_iommu_pcibus(s, bus_n);
102     if (iommu_pci_bus) {
103         devfn = sid & PCI_DEVFN_MAX;
104         dev = iommu_pci_bus->pbdev[devfn];
105         if (dev) {
106             return &dev->iommu_mr;
107         }
108     }
109     return NULL;
110 }
111 
112 static gint interval_cmp(gconstpointer a, gconstpointer b, gpointer user_data)
113 {
114     VirtIOIOMMUInterval *inta = (VirtIOIOMMUInterval *)a;
115     VirtIOIOMMUInterval *intb = (VirtIOIOMMUInterval *)b;
116 
117     if (inta->high < intb->low) {
118         return -1;
119     } else if (intb->high < inta->low) {
120         return 1;
121     } else {
122         return 0;
123     }
124 }
125 
126 static void virtio_iommu_detach_endpoint_from_domain(VirtIOIOMMUEndpoint *ep)
127 {
128     if (!ep->domain) {
129         return;
130     }
131     QLIST_REMOVE(ep, next);
132     ep->domain = NULL;
133 }
134 
135 static VirtIOIOMMUEndpoint *virtio_iommu_get_endpoint(VirtIOIOMMU *s,
136                                                       uint32_t ep_id)
137 {
138     VirtIOIOMMUEndpoint *ep;
139 
140     ep = g_tree_lookup(s->endpoints, GUINT_TO_POINTER(ep_id));
141     if (ep) {
142         return ep;
143     }
144     if (!virtio_iommu_mr(s, ep_id)) {
145         return NULL;
146     }
147     ep = g_malloc0(sizeof(*ep));
148     ep->id = ep_id;
149     trace_virtio_iommu_get_endpoint(ep_id);
150     g_tree_insert(s->endpoints, GUINT_TO_POINTER(ep_id), ep);
151     return ep;
152 }
153 
154 static void virtio_iommu_put_endpoint(gpointer data)
155 {
156     VirtIOIOMMUEndpoint *ep = (VirtIOIOMMUEndpoint *)data;
157 
158     if (ep->domain) {
159         virtio_iommu_detach_endpoint_from_domain(ep);
160     }
161 
162     trace_virtio_iommu_put_endpoint(ep->id);
163     g_free(ep);
164 }
165 
166 static VirtIOIOMMUDomain *virtio_iommu_get_domain(VirtIOIOMMU *s,
167                                                   uint32_t domain_id)
168 {
169     VirtIOIOMMUDomain *domain;
170 
171     domain = g_tree_lookup(s->domains, GUINT_TO_POINTER(domain_id));
172     if (domain) {
173         return domain;
174     }
175     domain = g_malloc0(sizeof(*domain));
176     domain->id = domain_id;
177     domain->mappings = g_tree_new_full((GCompareDataFunc)interval_cmp,
178                                    NULL, (GDestroyNotify)g_free,
179                                    (GDestroyNotify)g_free);
180     g_tree_insert(s->domains, GUINT_TO_POINTER(domain_id), domain);
181     QLIST_INIT(&domain->endpoint_list);
182     trace_virtio_iommu_get_domain(domain_id);
183     return domain;
184 }
185 
186 static void virtio_iommu_put_domain(gpointer data)
187 {
188     VirtIOIOMMUDomain *domain = (VirtIOIOMMUDomain *)data;
189     VirtIOIOMMUEndpoint *iter, *tmp;
190 
191     QLIST_FOREACH_SAFE(iter, &domain->endpoint_list, next, tmp) {
192         virtio_iommu_detach_endpoint_from_domain(iter);
193     }
194     g_tree_destroy(domain->mappings);
195     trace_virtio_iommu_put_domain(domain->id);
196     g_free(domain);
197 }
198 
199 static AddressSpace *virtio_iommu_find_add_as(PCIBus *bus, void *opaque,
200                                               int devfn)
201 {
202     VirtIOIOMMU *s = opaque;
203     IOMMUPciBus *sbus = g_hash_table_lookup(s->as_by_busptr, bus);
204     static uint32_t mr_index;
205     IOMMUDevice *sdev;
206 
207     if (!sbus) {
208         sbus = g_malloc0(sizeof(IOMMUPciBus) +
209                          sizeof(IOMMUDevice *) * PCI_DEVFN_MAX);
210         sbus->bus = bus;
211         g_hash_table_insert(s->as_by_busptr, bus, sbus);
212     }
213 
214     sdev = sbus->pbdev[devfn];
215     if (!sdev) {
216         char *name = g_strdup_printf("%s-%d-%d",
217                                      TYPE_VIRTIO_IOMMU_MEMORY_REGION,
218                                      mr_index++, devfn);
219         sdev = sbus->pbdev[devfn] = g_malloc0(sizeof(IOMMUDevice));
220 
221         sdev->viommu = s;
222         sdev->bus = bus;
223         sdev->devfn = devfn;
224 
225         trace_virtio_iommu_init_iommu_mr(name);
226 
227         memory_region_init_iommu(&sdev->iommu_mr, sizeof(sdev->iommu_mr),
228                                  TYPE_VIRTIO_IOMMU_MEMORY_REGION,
229                                  OBJECT(s), name,
230                                  UINT64_MAX);
231         address_space_init(&sdev->as,
232                            MEMORY_REGION(&sdev->iommu_mr), TYPE_VIRTIO_IOMMU);
233         g_free(name);
234     }
235     return &sdev->as;
236 }
237 
238 static int virtio_iommu_attach(VirtIOIOMMU *s,
239                                struct virtio_iommu_req_attach *req)
240 {
241     uint32_t domain_id = le32_to_cpu(req->domain);
242     uint32_t ep_id = le32_to_cpu(req->endpoint);
243     VirtIOIOMMUDomain *domain;
244     VirtIOIOMMUEndpoint *ep;
245 
246     trace_virtio_iommu_attach(domain_id, ep_id);
247 
248     ep = virtio_iommu_get_endpoint(s, ep_id);
249     if (!ep) {
250         return VIRTIO_IOMMU_S_NOENT;
251     }
252 
253     if (ep->domain) {
254         VirtIOIOMMUDomain *previous_domain = ep->domain;
255         /*
256          * the device is already attached to a domain,
257          * detach it first
258          */
259         virtio_iommu_detach_endpoint_from_domain(ep);
260         if (QLIST_EMPTY(&previous_domain->endpoint_list)) {
261             g_tree_remove(s->domains, GUINT_TO_POINTER(previous_domain->id));
262         }
263     }
264 
265     domain = virtio_iommu_get_domain(s, domain_id);
266     QLIST_INSERT_HEAD(&domain->endpoint_list, ep, next);
267 
268     ep->domain = domain;
269 
270     return VIRTIO_IOMMU_S_OK;
271 }
272 
273 static int virtio_iommu_detach(VirtIOIOMMU *s,
274                                struct virtio_iommu_req_detach *req)
275 {
276     uint32_t domain_id = le32_to_cpu(req->domain);
277     uint32_t ep_id = le32_to_cpu(req->endpoint);
278     VirtIOIOMMUDomain *domain;
279     VirtIOIOMMUEndpoint *ep;
280 
281     trace_virtio_iommu_detach(domain_id, ep_id);
282 
283     ep = g_tree_lookup(s->endpoints, GUINT_TO_POINTER(ep_id));
284     if (!ep) {
285         return VIRTIO_IOMMU_S_NOENT;
286     }
287 
288     domain = ep->domain;
289 
290     if (!domain || domain->id != domain_id) {
291         return VIRTIO_IOMMU_S_INVAL;
292     }
293 
294     virtio_iommu_detach_endpoint_from_domain(ep);
295 
296     if (QLIST_EMPTY(&domain->endpoint_list)) {
297         g_tree_remove(s->domains, GUINT_TO_POINTER(domain->id));
298     }
299     return VIRTIO_IOMMU_S_OK;
300 }
301 
302 static int virtio_iommu_map(VirtIOIOMMU *s,
303                             struct virtio_iommu_req_map *req)
304 {
305     uint32_t domain_id = le32_to_cpu(req->domain);
306     uint64_t phys_start = le64_to_cpu(req->phys_start);
307     uint64_t virt_start = le64_to_cpu(req->virt_start);
308     uint64_t virt_end = le64_to_cpu(req->virt_end);
309     uint32_t flags = le32_to_cpu(req->flags);
310     VirtIOIOMMUDomain *domain;
311     VirtIOIOMMUInterval *interval;
312     VirtIOIOMMUMapping *mapping;
313 
314     if (flags & ~VIRTIO_IOMMU_MAP_F_MASK) {
315         return VIRTIO_IOMMU_S_INVAL;
316     }
317 
318     domain = g_tree_lookup(s->domains, GUINT_TO_POINTER(domain_id));
319     if (!domain) {
320         return VIRTIO_IOMMU_S_NOENT;
321     }
322 
323     interval = g_malloc0(sizeof(*interval));
324 
325     interval->low = virt_start;
326     interval->high = virt_end;
327 
328     mapping = g_tree_lookup(domain->mappings, (gpointer)interval);
329     if (mapping) {
330         g_free(interval);
331         return VIRTIO_IOMMU_S_INVAL;
332     }
333 
334     trace_virtio_iommu_map(domain_id, virt_start, virt_end, phys_start, flags);
335 
336     mapping = g_malloc0(sizeof(*mapping));
337     mapping->phys_addr = phys_start;
338     mapping->flags = flags;
339 
340     g_tree_insert(domain->mappings, interval, mapping);
341 
342     return VIRTIO_IOMMU_S_OK;
343 }
344 
345 static int virtio_iommu_unmap(VirtIOIOMMU *s,
346                               struct virtio_iommu_req_unmap *req)
347 {
348     uint32_t domain_id = le32_to_cpu(req->domain);
349     uint64_t virt_start = le64_to_cpu(req->virt_start);
350     uint64_t virt_end = le64_to_cpu(req->virt_end);
351     VirtIOIOMMUMapping *iter_val;
352     VirtIOIOMMUInterval interval, *iter_key;
353     VirtIOIOMMUDomain *domain;
354     int ret = VIRTIO_IOMMU_S_OK;
355 
356     trace_virtio_iommu_unmap(domain_id, virt_start, virt_end);
357 
358     domain = g_tree_lookup(s->domains, GUINT_TO_POINTER(domain_id));
359     if (!domain) {
360         return VIRTIO_IOMMU_S_NOENT;
361     }
362     interval.low = virt_start;
363     interval.high = virt_end;
364 
365     while (g_tree_lookup_extended(domain->mappings, &interval,
366                                   (void **)&iter_key, (void**)&iter_val)) {
367         uint64_t current_low = iter_key->low;
368         uint64_t current_high = iter_key->high;
369 
370         if (interval.low <= current_low && interval.high >= current_high) {
371             g_tree_remove(domain->mappings, iter_key);
372             trace_virtio_iommu_unmap_done(domain_id, current_low, current_high);
373         } else {
374             ret = VIRTIO_IOMMU_S_RANGE;
375             break;
376         }
377     }
378     return ret;
379 }
380 
381 static int virtio_iommu_iov_to_req(struct iovec *iov,
382                                    unsigned int iov_cnt,
383                                    void *req, size_t req_sz)
384 {
385     size_t sz, payload_sz = req_sz - sizeof(struct virtio_iommu_req_tail);
386 
387     sz = iov_to_buf(iov, iov_cnt, 0, req, payload_sz);
388     if (unlikely(sz != payload_sz)) {
389         return VIRTIO_IOMMU_S_INVAL;
390     }
391     return 0;
392 }
393 
394 #define virtio_iommu_handle_req(__req)                                  \
395 static int virtio_iommu_handle_ ## __req(VirtIOIOMMU *s,                \
396                                          struct iovec *iov,             \
397                                          unsigned int iov_cnt)          \
398 {                                                                       \
399     struct virtio_iommu_req_ ## __req req;                              \
400     int ret = virtio_iommu_iov_to_req(iov, iov_cnt, &req, sizeof(req)); \
401                                                                         \
402     return ret ? ret : virtio_iommu_ ## __req(s, &req);                 \
403 }
404 
405 virtio_iommu_handle_req(attach)
406 virtio_iommu_handle_req(detach)
407 virtio_iommu_handle_req(map)
408 virtio_iommu_handle_req(unmap)
409 
410 static void virtio_iommu_handle_command(VirtIODevice *vdev, VirtQueue *vq)
411 {
412     VirtIOIOMMU *s = VIRTIO_IOMMU(vdev);
413     struct virtio_iommu_req_head head;
414     struct virtio_iommu_req_tail tail = {};
415     VirtQueueElement *elem;
416     unsigned int iov_cnt;
417     struct iovec *iov;
418     size_t sz;
419 
420     for (;;) {
421         elem = virtqueue_pop(vq, sizeof(VirtQueueElement));
422         if (!elem) {
423             return;
424         }
425 
426         if (iov_size(elem->in_sg, elem->in_num) < sizeof(tail) ||
427             iov_size(elem->out_sg, elem->out_num) < sizeof(head)) {
428             virtio_error(vdev, "virtio-iommu bad head/tail size");
429             virtqueue_detach_element(vq, elem, 0);
430             g_free(elem);
431             break;
432         }
433 
434         iov_cnt = elem->out_num;
435         iov = elem->out_sg;
436         sz = iov_to_buf(iov, iov_cnt, 0, &head, sizeof(head));
437         if (unlikely(sz != sizeof(head))) {
438             tail.status = VIRTIO_IOMMU_S_DEVERR;
439             goto out;
440         }
441         qemu_mutex_lock(&s->mutex);
442         switch (head.type) {
443         case VIRTIO_IOMMU_T_ATTACH:
444             tail.status = virtio_iommu_handle_attach(s, iov, iov_cnt);
445             break;
446         case VIRTIO_IOMMU_T_DETACH:
447             tail.status = virtio_iommu_handle_detach(s, iov, iov_cnt);
448             break;
449         case VIRTIO_IOMMU_T_MAP:
450             tail.status = virtio_iommu_handle_map(s, iov, iov_cnt);
451             break;
452         case VIRTIO_IOMMU_T_UNMAP:
453             tail.status = virtio_iommu_handle_unmap(s, iov, iov_cnt);
454             break;
455         default:
456             tail.status = VIRTIO_IOMMU_S_UNSUPP;
457         }
458         qemu_mutex_unlock(&s->mutex);
459 
460 out:
461         sz = iov_from_buf(elem->in_sg, elem->in_num, 0,
462                           &tail, sizeof(tail));
463         assert(sz == sizeof(tail));
464 
465         virtqueue_push(vq, elem, sizeof(tail));
466         virtio_notify(vdev, vq);
467         g_free(elem);
468     }
469 }
470 
471 static void virtio_iommu_report_fault(VirtIOIOMMU *viommu, uint8_t reason,
472                                       int flags, uint32_t endpoint,
473                                       uint64_t address)
474 {
475     VirtIODevice *vdev = &viommu->parent_obj;
476     VirtQueue *vq = viommu->event_vq;
477     struct virtio_iommu_fault fault;
478     VirtQueueElement *elem;
479     size_t sz;
480 
481     memset(&fault, 0, sizeof(fault));
482     fault.reason = reason;
483     fault.flags = cpu_to_le32(flags);
484     fault.endpoint = cpu_to_le32(endpoint);
485     fault.address = cpu_to_le64(address);
486 
487     elem = virtqueue_pop(vq, sizeof(VirtQueueElement));
488 
489     if (!elem) {
490         error_report_once(
491             "no buffer available in event queue to report event");
492         return;
493     }
494 
495     if (iov_size(elem->in_sg, elem->in_num) < sizeof(fault)) {
496         virtio_error(vdev, "error buffer of wrong size");
497         virtqueue_detach_element(vq, elem, 0);
498         g_free(elem);
499         return;
500     }
501 
502     sz = iov_from_buf(elem->in_sg, elem->in_num, 0,
503                       &fault, sizeof(fault));
504     assert(sz == sizeof(fault));
505 
506     trace_virtio_iommu_report_fault(reason, flags, endpoint, address);
507     virtqueue_push(vq, elem, sz);
508     virtio_notify(vdev, vq);
509     g_free(elem);
510 
511 }
512 
513 static IOMMUTLBEntry virtio_iommu_translate(IOMMUMemoryRegion *mr, hwaddr addr,
514                                             IOMMUAccessFlags flag,
515                                             int iommu_idx)
516 {
517     IOMMUDevice *sdev = container_of(mr, IOMMUDevice, iommu_mr);
518     VirtIOIOMMUInterval interval, *mapping_key;
519     VirtIOIOMMUMapping *mapping_value;
520     VirtIOIOMMU *s = sdev->viommu;
521     bool read_fault, write_fault;
522     VirtIOIOMMUEndpoint *ep;
523     uint32_t sid, flags;
524     bool bypass_allowed;
525     bool found;
526 
527     interval.low = addr;
528     interval.high = addr + 1;
529 
530     IOMMUTLBEntry entry = {
531         .target_as = &address_space_memory,
532         .iova = addr,
533         .translated_addr = addr,
534         .addr_mask = (1 << ctz32(s->config.page_size_mask)) - 1,
535         .perm = IOMMU_NONE,
536     };
537 
538     bypass_allowed = virtio_vdev_has_feature(&s->parent_obj,
539                                              VIRTIO_IOMMU_F_BYPASS);
540 
541     sid = virtio_iommu_get_bdf(sdev);
542 
543     trace_virtio_iommu_translate(mr->parent_obj.name, sid, addr, flag);
544     qemu_mutex_lock(&s->mutex);
545 
546     ep = g_tree_lookup(s->endpoints, GUINT_TO_POINTER(sid));
547     if (!ep) {
548         if (!bypass_allowed) {
549             error_report_once("%s sid=%d is not known!!", __func__, sid);
550             virtio_iommu_report_fault(s, VIRTIO_IOMMU_FAULT_R_UNKNOWN,
551                                       VIRTIO_IOMMU_FAULT_F_ADDRESS,
552                                       sid, addr);
553         } else {
554             entry.perm = flag;
555         }
556         goto unlock;
557     }
558 
559     if (!ep->domain) {
560         if (!bypass_allowed) {
561             error_report_once("%s %02x:%02x.%01x not attached to any domain",
562                               __func__, PCI_BUS_NUM(sid),
563                               PCI_SLOT(sid), PCI_FUNC(sid));
564             virtio_iommu_report_fault(s, VIRTIO_IOMMU_FAULT_R_DOMAIN,
565                                       VIRTIO_IOMMU_FAULT_F_ADDRESS,
566                                       sid, addr);
567         } else {
568             entry.perm = flag;
569         }
570         goto unlock;
571     }
572 
573     found = g_tree_lookup_extended(ep->domain->mappings, (gpointer)(&interval),
574                                    (void **)&mapping_key,
575                                    (void **)&mapping_value);
576     if (!found) {
577         error_report_once("%s no mapping for 0x%"PRIx64" for sid=%d",
578                           __func__, addr, sid);
579         virtio_iommu_report_fault(s, VIRTIO_IOMMU_FAULT_R_MAPPING,
580                                   VIRTIO_IOMMU_FAULT_F_ADDRESS,
581                                   sid, addr);
582         goto unlock;
583     }
584 
585     read_fault = (flag & IOMMU_RO) &&
586                     !(mapping_value->flags & VIRTIO_IOMMU_MAP_F_READ);
587     write_fault = (flag & IOMMU_WO) &&
588                     !(mapping_value->flags & VIRTIO_IOMMU_MAP_F_WRITE);
589 
590     flags = read_fault ? VIRTIO_IOMMU_FAULT_F_READ : 0;
591     flags |= write_fault ? VIRTIO_IOMMU_FAULT_F_WRITE : 0;
592     if (flags) {
593         error_report_once("%s permission error on 0x%"PRIx64"(%d): allowed=%d",
594                           __func__, addr, flag, mapping_value->flags);
595         flags |= VIRTIO_IOMMU_FAULT_F_ADDRESS;
596         virtio_iommu_report_fault(s, VIRTIO_IOMMU_FAULT_R_MAPPING,
597                                   flags | VIRTIO_IOMMU_FAULT_F_ADDRESS,
598                                   sid, addr);
599         goto unlock;
600     }
601     entry.translated_addr = addr - mapping_key->low + mapping_value->phys_addr;
602     entry.perm = flag;
603     trace_virtio_iommu_translate_out(addr, entry.translated_addr, sid);
604 
605 unlock:
606     qemu_mutex_unlock(&s->mutex);
607     return entry;
608 }
609 
610 static void virtio_iommu_get_config(VirtIODevice *vdev, uint8_t *config_data)
611 {
612     VirtIOIOMMU *dev = VIRTIO_IOMMU(vdev);
613     struct virtio_iommu_config *config = &dev->config;
614 
615     trace_virtio_iommu_get_config(config->page_size_mask,
616                                   config->input_range.start,
617                                   config->input_range.end,
618                                   config->domain_range.end,
619                                   config->probe_size);
620     memcpy(config_data, &dev->config, sizeof(struct virtio_iommu_config));
621 }
622 
623 static void virtio_iommu_set_config(VirtIODevice *vdev,
624                                       const uint8_t *config_data)
625 {
626     struct virtio_iommu_config config;
627 
628     memcpy(&config, config_data, sizeof(struct virtio_iommu_config));
629     trace_virtio_iommu_set_config(config.page_size_mask,
630                                   config.input_range.start,
631                                   config.input_range.end,
632                                   config.domain_range.end,
633                                   config.probe_size);
634 }
635 
636 static uint64_t virtio_iommu_get_features(VirtIODevice *vdev, uint64_t f,
637                                           Error **errp)
638 {
639     VirtIOIOMMU *dev = VIRTIO_IOMMU(vdev);
640 
641     f |= dev->features;
642     trace_virtio_iommu_get_features(f);
643     return f;
644 }
645 
646 static gint int_cmp(gconstpointer a, gconstpointer b, gpointer user_data)
647 {
648     guint ua = GPOINTER_TO_UINT(a);
649     guint ub = GPOINTER_TO_UINT(b);
650     return (ua > ub) - (ua < ub);
651 }
652 
653 static void virtio_iommu_device_realize(DeviceState *dev, Error **errp)
654 {
655     VirtIODevice *vdev = VIRTIO_DEVICE(dev);
656     VirtIOIOMMU *s = VIRTIO_IOMMU(dev);
657 
658     virtio_init(vdev, "virtio-iommu", VIRTIO_ID_IOMMU,
659                 sizeof(struct virtio_iommu_config));
660 
661     memset(s->iommu_pcibus_by_bus_num, 0, sizeof(s->iommu_pcibus_by_bus_num));
662 
663     s->req_vq = virtio_add_queue(vdev, VIOMMU_DEFAULT_QUEUE_SIZE,
664                              virtio_iommu_handle_command);
665     s->event_vq = virtio_add_queue(vdev, VIOMMU_DEFAULT_QUEUE_SIZE, NULL);
666 
667     s->config.page_size_mask = TARGET_PAGE_MASK;
668     s->config.input_range.end = -1UL;
669     s->config.domain_range.end = 32;
670 
671     virtio_add_feature(&s->features, VIRTIO_RING_F_EVENT_IDX);
672     virtio_add_feature(&s->features, VIRTIO_RING_F_INDIRECT_DESC);
673     virtio_add_feature(&s->features, VIRTIO_F_VERSION_1);
674     virtio_add_feature(&s->features, VIRTIO_IOMMU_F_INPUT_RANGE);
675     virtio_add_feature(&s->features, VIRTIO_IOMMU_F_DOMAIN_RANGE);
676     virtio_add_feature(&s->features, VIRTIO_IOMMU_F_MAP_UNMAP);
677     virtio_add_feature(&s->features, VIRTIO_IOMMU_F_BYPASS);
678     virtio_add_feature(&s->features, VIRTIO_IOMMU_F_MMIO);
679 
680     qemu_mutex_init(&s->mutex);
681 
682     s->as_by_busptr = g_hash_table_new_full(NULL, NULL, NULL, g_free);
683 
684     if (s->primary_bus) {
685         pci_setup_iommu(s->primary_bus, virtio_iommu_find_add_as, s);
686     } else {
687         error_setg(errp, "VIRTIO-IOMMU is not attached to any PCI bus!");
688     }
689 }
690 
691 static void virtio_iommu_device_unrealize(DeviceState *dev, Error **errp)
692 {
693     VirtIODevice *vdev = VIRTIO_DEVICE(dev);
694     VirtIOIOMMU *s = VIRTIO_IOMMU(dev);
695 
696     g_hash_table_destroy(s->as_by_busptr);
697     g_tree_destroy(s->domains);
698     g_tree_destroy(s->endpoints);
699 
700     virtio_delete_queue(s->req_vq);
701     virtio_delete_queue(s->event_vq);
702     virtio_cleanup(vdev);
703 }
704 
705 static void virtio_iommu_device_reset(VirtIODevice *vdev)
706 {
707     VirtIOIOMMU *s = VIRTIO_IOMMU(vdev);
708 
709     trace_virtio_iommu_device_reset();
710 
711     if (s->domains) {
712         g_tree_destroy(s->domains);
713     }
714     if (s->endpoints) {
715         g_tree_destroy(s->endpoints);
716     }
717     s->domains = g_tree_new_full((GCompareDataFunc)int_cmp,
718                                  NULL, NULL, virtio_iommu_put_domain);
719     s->endpoints = g_tree_new_full((GCompareDataFunc)int_cmp,
720                                    NULL, NULL, virtio_iommu_put_endpoint);
721 }
722 
723 static void virtio_iommu_set_status(VirtIODevice *vdev, uint8_t status)
724 {
725     trace_virtio_iommu_device_status(status);
726 }
727 
728 static void virtio_iommu_instance_init(Object *obj)
729 {
730 }
731 
732 #define VMSTATE_INTERVAL                               \
733 {                                                      \
734     .name = "interval",                                \
735     .version_id = 1,                                   \
736     .minimum_version_id = 1,                           \
737     .fields = (VMStateField[]) {                       \
738         VMSTATE_UINT64(low, VirtIOIOMMUInterval),      \
739         VMSTATE_UINT64(high, VirtIOIOMMUInterval),     \
740         VMSTATE_END_OF_LIST()                          \
741     }                                                  \
742 }
743 
744 #define VMSTATE_MAPPING                               \
745 {                                                     \
746     .name = "mapping",                                \
747     .version_id = 1,                                  \
748     .minimum_version_id = 1,                          \
749     .fields = (VMStateField[]) {                      \
750         VMSTATE_UINT64(phys_addr, VirtIOIOMMUMapping),\
751         VMSTATE_UINT32(flags, VirtIOIOMMUMapping),    \
752         VMSTATE_END_OF_LIST()                         \
753     },                                                \
754 }
755 
756 static const VMStateDescription vmstate_interval_mapping[2] = {
757     VMSTATE_MAPPING,   /* value */
758     VMSTATE_INTERVAL   /* key   */
759 };
760 
761 static int domain_preload(void *opaque)
762 {
763     VirtIOIOMMUDomain *domain = opaque;
764 
765     domain->mappings = g_tree_new_full((GCompareDataFunc)interval_cmp,
766                                        NULL, g_free, g_free);
767     return 0;
768 }
769 
770 static const VMStateDescription vmstate_endpoint = {
771     .name = "endpoint",
772     .version_id = 1,
773     .minimum_version_id = 1,
774     .fields = (VMStateField[]) {
775         VMSTATE_UINT32(id, VirtIOIOMMUEndpoint),
776         VMSTATE_END_OF_LIST()
777     }
778 };
779 
780 static const VMStateDescription vmstate_domain = {
781     .name = "domain",
782     .version_id = 1,
783     .minimum_version_id = 1,
784     .pre_load = domain_preload,
785     .fields = (VMStateField[]) {
786         VMSTATE_UINT32(id, VirtIOIOMMUDomain),
787         VMSTATE_GTREE_V(mappings, VirtIOIOMMUDomain, 1,
788                         vmstate_interval_mapping,
789                         VirtIOIOMMUInterval, VirtIOIOMMUMapping),
790         VMSTATE_QLIST_V(endpoint_list, VirtIOIOMMUDomain, 1,
791                         vmstate_endpoint, VirtIOIOMMUEndpoint, next),
792         VMSTATE_END_OF_LIST()
793     }
794 };
795 
796 static gboolean reconstruct_endpoints(gpointer key, gpointer value,
797                                       gpointer data)
798 {
799     VirtIOIOMMU *s = (VirtIOIOMMU *)data;
800     VirtIOIOMMUDomain *d = (VirtIOIOMMUDomain *)value;
801     VirtIOIOMMUEndpoint *iter;
802 
803     QLIST_FOREACH(iter, &d->endpoint_list, next) {
804         iter->domain = d;
805         g_tree_insert(s->endpoints, GUINT_TO_POINTER(iter->id), iter);
806     }
807     return false; /* continue the domain traversal */
808 }
809 
810 static int iommu_post_load(void *opaque, int version_id)
811 {
812     VirtIOIOMMU *s = opaque;
813 
814     g_tree_foreach(s->domains, reconstruct_endpoints, s);
815     return 0;
816 }
817 
818 static const VMStateDescription vmstate_virtio_iommu_device = {
819     .name = "virtio-iommu-device",
820     .minimum_version_id = 1,
821     .version_id = 1,
822     .post_load = iommu_post_load,
823     .fields = (VMStateField[]) {
824         VMSTATE_GTREE_DIRECT_KEY_V(domains, VirtIOIOMMU, 1,
825                                    &vmstate_domain, VirtIOIOMMUDomain),
826         VMSTATE_END_OF_LIST()
827     },
828 };
829 
830 static const VMStateDescription vmstate_virtio_iommu = {
831     .name = "virtio-iommu",
832     .minimum_version_id = 1,
833     .priority = MIG_PRI_IOMMU,
834     .version_id = 1,
835     .fields = (VMStateField[]) {
836         VMSTATE_VIRTIO_DEVICE,
837         VMSTATE_END_OF_LIST()
838     },
839 };
840 
841 static Property virtio_iommu_properties[] = {
842     DEFINE_PROP_LINK("primary-bus", VirtIOIOMMU, primary_bus, "PCI", PCIBus *),
843     DEFINE_PROP_END_OF_LIST(),
844 };
845 
846 static void virtio_iommu_class_init(ObjectClass *klass, void *data)
847 {
848     DeviceClass *dc = DEVICE_CLASS(klass);
849     VirtioDeviceClass *vdc = VIRTIO_DEVICE_CLASS(klass);
850 
851     device_class_set_props(dc, virtio_iommu_properties);
852     dc->vmsd = &vmstate_virtio_iommu;
853 
854     set_bit(DEVICE_CATEGORY_MISC, dc->categories);
855     vdc->realize = virtio_iommu_device_realize;
856     vdc->unrealize = virtio_iommu_device_unrealize;
857     vdc->reset = virtio_iommu_device_reset;
858     vdc->get_config = virtio_iommu_get_config;
859     vdc->set_config = virtio_iommu_set_config;
860     vdc->get_features = virtio_iommu_get_features;
861     vdc->set_status = virtio_iommu_set_status;
862     vdc->vmsd = &vmstate_virtio_iommu_device;
863 }
864 
865 static void virtio_iommu_memory_region_class_init(ObjectClass *klass,
866                                                   void *data)
867 {
868     IOMMUMemoryRegionClass *imrc = IOMMU_MEMORY_REGION_CLASS(klass);
869 
870     imrc->translate = virtio_iommu_translate;
871 }
872 
873 static const TypeInfo virtio_iommu_info = {
874     .name = TYPE_VIRTIO_IOMMU,
875     .parent = TYPE_VIRTIO_DEVICE,
876     .instance_size = sizeof(VirtIOIOMMU),
877     .instance_init = virtio_iommu_instance_init,
878     .class_init = virtio_iommu_class_init,
879 };
880 
881 static const TypeInfo virtio_iommu_memory_region_info = {
882     .parent = TYPE_IOMMU_MEMORY_REGION,
883     .name = TYPE_VIRTIO_IOMMU_MEMORY_REGION,
884     .class_init = virtio_iommu_memory_region_class_init,
885 };
886 
887 static void virtio_register_types(void)
888 {
889     type_register_static(&virtio_iommu_info);
890     type_register_static(&virtio_iommu_memory_region_info);
891 }
892 
893 type_init(virtio_register_types)
894