xref: /openbmc/qemu/hw/virtio/virtio-iommu.c (revision 4c4465ff)
1 /*
2  * virtio-iommu device
3  *
4  * Copyright (c) 2020 Red Hat, Inc.
5  *
6  * This program is free software; you can redistribute it and/or modify it
7  * under the terms and conditions of the GNU General Public License,
8  * version 2 or later, as published by the Free Software Foundation.
9  *
10  * This program is distributed in the hope it will be useful, but WITHOUT
11  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
12  * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License for
13  * more details.
14  *
15  * You should have received a copy of the GNU General Public License along with
16  * this program.  If not, see <http://www.gnu.org/licenses/>.
17  *
18  */
19 
20 #include "qemu/osdep.h"
21 #include "qemu/log.h"
22 #include "qemu/iov.h"
23 #include "qemu-common.h"
24 #include "hw/qdev-properties.h"
25 #include "hw/virtio/virtio.h"
26 #include "sysemu/kvm.h"
27 #include "qapi/error.h"
28 #include "qemu/error-report.h"
29 #include "trace.h"
30 
31 #include "standard-headers/linux/virtio_ids.h"
32 
33 #include "hw/virtio/virtio-bus.h"
34 #include "hw/virtio/virtio-access.h"
35 #include "hw/virtio/virtio-iommu.h"
36 #include "hw/pci/pci_bus.h"
37 #include "hw/pci/pci.h"
38 
39 /* Max size */
40 #define VIOMMU_DEFAULT_QUEUE_SIZE 256
41 #define VIOMMU_PROBE_SIZE 512
42 
43 typedef struct VirtIOIOMMUDomain {
44     uint32_t id;
45     GTree *mappings;
46     QLIST_HEAD(, VirtIOIOMMUEndpoint) endpoint_list;
47 } VirtIOIOMMUDomain;
48 
49 typedef struct VirtIOIOMMUEndpoint {
50     uint32_t id;
51     VirtIOIOMMUDomain *domain;
52     IOMMUMemoryRegion *iommu_mr;
53     QLIST_ENTRY(VirtIOIOMMUEndpoint) next;
54 } VirtIOIOMMUEndpoint;
55 
56 typedef struct VirtIOIOMMUInterval {
57     uint64_t low;
58     uint64_t high;
59 } VirtIOIOMMUInterval;
60 
61 typedef struct VirtIOIOMMUMapping {
62     uint64_t phys_addr;
63     uint32_t flags;
64 } VirtIOIOMMUMapping;
65 
66 static inline uint16_t virtio_iommu_get_bdf(IOMMUDevice *dev)
67 {
68     return PCI_BUILD_BDF(pci_bus_num(dev->bus), dev->devfn);
69 }
70 
71 /**
72  * The bus number is used for lookup when SID based operations occur.
73  * In that case we lazily populate the IOMMUPciBus array from the bus hash
74  * table. At the time the IOMMUPciBus is created (iommu_find_add_as), the bus
75  * numbers may not be always initialized yet.
76  */
77 static IOMMUPciBus *iommu_find_iommu_pcibus(VirtIOIOMMU *s, uint8_t bus_num)
78 {
79     IOMMUPciBus *iommu_pci_bus = s->iommu_pcibus_by_bus_num[bus_num];
80 
81     if (!iommu_pci_bus) {
82         GHashTableIter iter;
83 
84         g_hash_table_iter_init(&iter, s->as_by_busptr);
85         while (g_hash_table_iter_next(&iter, NULL, (void **)&iommu_pci_bus)) {
86             if (pci_bus_num(iommu_pci_bus->bus) == bus_num) {
87                 s->iommu_pcibus_by_bus_num[bus_num] = iommu_pci_bus;
88                 return iommu_pci_bus;
89             }
90         }
91         return NULL;
92     }
93     return iommu_pci_bus;
94 }
95 
96 static IOMMUMemoryRegion *virtio_iommu_mr(VirtIOIOMMU *s, uint32_t sid)
97 {
98     uint8_t bus_n, devfn;
99     IOMMUPciBus *iommu_pci_bus;
100     IOMMUDevice *dev;
101 
102     bus_n = PCI_BUS_NUM(sid);
103     iommu_pci_bus = iommu_find_iommu_pcibus(s, bus_n);
104     if (iommu_pci_bus) {
105         devfn = sid & (PCI_DEVFN_MAX - 1);
106         dev = iommu_pci_bus->pbdev[devfn];
107         if (dev) {
108             return &dev->iommu_mr;
109         }
110     }
111     return NULL;
112 }
113 
114 static gint interval_cmp(gconstpointer a, gconstpointer b, gpointer user_data)
115 {
116     VirtIOIOMMUInterval *inta = (VirtIOIOMMUInterval *)a;
117     VirtIOIOMMUInterval *intb = (VirtIOIOMMUInterval *)b;
118 
119     if (inta->high < intb->low) {
120         return -1;
121     } else if (intb->high < inta->low) {
122         return 1;
123     } else {
124         return 0;
125     }
126 }
127 
128 static void virtio_iommu_notify_map(IOMMUMemoryRegion *mr, hwaddr virt_start,
129                                     hwaddr virt_end, hwaddr paddr,
130                                     uint32_t flags)
131 {
132     IOMMUTLBEvent event;
133     IOMMUAccessFlags perm = IOMMU_ACCESS_FLAG(flags & VIRTIO_IOMMU_MAP_F_READ,
134                                               flags & VIRTIO_IOMMU_MAP_F_WRITE);
135 
136     if (!(mr->iommu_notify_flags & IOMMU_NOTIFIER_MAP) ||
137         (flags & VIRTIO_IOMMU_MAP_F_MMIO) || !perm) {
138         return;
139     }
140 
141     trace_virtio_iommu_notify_map(mr->parent_obj.name, virt_start, virt_end,
142                                   paddr, perm);
143 
144     event.type = IOMMU_NOTIFIER_MAP;
145     event.entry.target_as = &address_space_memory;
146     event.entry.addr_mask = virt_end - virt_start;
147     event.entry.iova = virt_start;
148     event.entry.perm = perm;
149     event.entry.translated_addr = paddr;
150 
151     memory_region_notify_iommu(mr, 0, event);
152 }
153 
154 static void virtio_iommu_notify_unmap(IOMMUMemoryRegion *mr, hwaddr virt_start,
155                                       hwaddr virt_end)
156 {
157     IOMMUTLBEvent event;
158 
159     if (!(mr->iommu_notify_flags & IOMMU_NOTIFIER_UNMAP)) {
160         return;
161     }
162 
163     trace_virtio_iommu_notify_unmap(mr->parent_obj.name, virt_start, virt_end);
164 
165     event.type = IOMMU_NOTIFIER_UNMAP;
166     event.entry.target_as = &address_space_memory;
167     event.entry.addr_mask = virt_end - virt_start;
168     event.entry.iova = virt_start;
169     event.entry.perm = IOMMU_NONE;
170     event.entry.translated_addr = 0;
171 
172     memory_region_notify_iommu(mr, 0, event);
173 }
174 
175 static gboolean virtio_iommu_notify_unmap_cb(gpointer key, gpointer value,
176                                              gpointer data)
177 {
178     VirtIOIOMMUInterval *interval = (VirtIOIOMMUInterval *) key;
179     IOMMUMemoryRegion *mr = (IOMMUMemoryRegion *) data;
180 
181     virtio_iommu_notify_unmap(mr, interval->low, interval->high);
182 
183     return false;
184 }
185 
186 static gboolean virtio_iommu_notify_map_cb(gpointer key, gpointer value,
187                                            gpointer data)
188 {
189     VirtIOIOMMUMapping *mapping = (VirtIOIOMMUMapping *) value;
190     VirtIOIOMMUInterval *interval = (VirtIOIOMMUInterval *) key;
191     IOMMUMemoryRegion *mr = (IOMMUMemoryRegion *) data;
192 
193     virtio_iommu_notify_map(mr, interval->low, interval->high,
194                             mapping->phys_addr, mapping->flags);
195 
196     return false;
197 }
198 
199 static void virtio_iommu_detach_endpoint_from_domain(VirtIOIOMMUEndpoint *ep)
200 {
201     VirtIOIOMMUDomain *domain = ep->domain;
202 
203     if (!ep->domain) {
204         return;
205     }
206     g_tree_foreach(domain->mappings, virtio_iommu_notify_unmap_cb,
207                    ep->iommu_mr);
208     QLIST_REMOVE(ep, next);
209     ep->domain = NULL;
210 }
211 
212 static VirtIOIOMMUEndpoint *virtio_iommu_get_endpoint(VirtIOIOMMU *s,
213                                                       uint32_t ep_id)
214 {
215     VirtIOIOMMUEndpoint *ep;
216     IOMMUMemoryRegion *mr;
217 
218     ep = g_tree_lookup(s->endpoints, GUINT_TO_POINTER(ep_id));
219     if (ep) {
220         return ep;
221     }
222     mr = virtio_iommu_mr(s, ep_id);
223     if (!mr) {
224         return NULL;
225     }
226     ep = g_malloc0(sizeof(*ep));
227     ep->id = ep_id;
228     ep->iommu_mr = mr;
229     trace_virtio_iommu_get_endpoint(ep_id);
230     g_tree_insert(s->endpoints, GUINT_TO_POINTER(ep_id), ep);
231     return ep;
232 }
233 
234 static void virtio_iommu_put_endpoint(gpointer data)
235 {
236     VirtIOIOMMUEndpoint *ep = (VirtIOIOMMUEndpoint *)data;
237 
238     if (ep->domain) {
239         virtio_iommu_detach_endpoint_from_domain(ep);
240     }
241 
242     trace_virtio_iommu_put_endpoint(ep->id);
243     g_free(ep);
244 }
245 
246 static VirtIOIOMMUDomain *virtio_iommu_get_domain(VirtIOIOMMU *s,
247                                                   uint32_t domain_id)
248 {
249     VirtIOIOMMUDomain *domain;
250 
251     domain = g_tree_lookup(s->domains, GUINT_TO_POINTER(domain_id));
252     if (domain) {
253         return domain;
254     }
255     domain = g_malloc0(sizeof(*domain));
256     domain->id = domain_id;
257     domain->mappings = g_tree_new_full((GCompareDataFunc)interval_cmp,
258                                    NULL, (GDestroyNotify)g_free,
259                                    (GDestroyNotify)g_free);
260     g_tree_insert(s->domains, GUINT_TO_POINTER(domain_id), domain);
261     QLIST_INIT(&domain->endpoint_list);
262     trace_virtio_iommu_get_domain(domain_id);
263     return domain;
264 }
265 
266 static void virtio_iommu_put_domain(gpointer data)
267 {
268     VirtIOIOMMUDomain *domain = (VirtIOIOMMUDomain *)data;
269     VirtIOIOMMUEndpoint *iter, *tmp;
270 
271     QLIST_FOREACH_SAFE(iter, &domain->endpoint_list, next, tmp) {
272         virtio_iommu_detach_endpoint_from_domain(iter);
273     }
274     g_tree_destroy(domain->mappings);
275     trace_virtio_iommu_put_domain(domain->id);
276     g_free(domain);
277 }
278 
279 static AddressSpace *virtio_iommu_find_add_as(PCIBus *bus, void *opaque,
280                                               int devfn)
281 {
282     VirtIOIOMMU *s = opaque;
283     IOMMUPciBus *sbus = g_hash_table_lookup(s->as_by_busptr, bus);
284     static uint32_t mr_index;
285     IOMMUDevice *sdev;
286 
287     if (!sbus) {
288         sbus = g_malloc0(sizeof(IOMMUPciBus) +
289                          sizeof(IOMMUDevice *) * PCI_DEVFN_MAX);
290         sbus->bus = bus;
291         g_hash_table_insert(s->as_by_busptr, bus, sbus);
292     }
293 
294     sdev = sbus->pbdev[devfn];
295     if (!sdev) {
296         char *name = g_strdup_printf("%s-%d-%d",
297                                      TYPE_VIRTIO_IOMMU_MEMORY_REGION,
298                                      mr_index++, devfn);
299         sdev = sbus->pbdev[devfn] = g_malloc0(sizeof(IOMMUDevice));
300 
301         sdev->viommu = s;
302         sdev->bus = bus;
303         sdev->devfn = devfn;
304 
305         trace_virtio_iommu_init_iommu_mr(name);
306 
307         memory_region_init_iommu(&sdev->iommu_mr, sizeof(sdev->iommu_mr),
308                                  TYPE_VIRTIO_IOMMU_MEMORY_REGION,
309                                  OBJECT(s), name,
310                                  UINT64_MAX);
311         address_space_init(&sdev->as,
312                            MEMORY_REGION(&sdev->iommu_mr), TYPE_VIRTIO_IOMMU);
313         g_free(name);
314     }
315     return &sdev->as;
316 }
317 
318 static int virtio_iommu_attach(VirtIOIOMMU *s,
319                                struct virtio_iommu_req_attach *req)
320 {
321     uint32_t domain_id = le32_to_cpu(req->domain);
322     uint32_t ep_id = le32_to_cpu(req->endpoint);
323     VirtIOIOMMUDomain *domain;
324     VirtIOIOMMUEndpoint *ep;
325 
326     trace_virtio_iommu_attach(domain_id, ep_id);
327 
328     ep = virtio_iommu_get_endpoint(s, ep_id);
329     if (!ep) {
330         return VIRTIO_IOMMU_S_NOENT;
331     }
332 
333     if (ep->domain) {
334         VirtIOIOMMUDomain *previous_domain = ep->domain;
335         /*
336          * the device is already attached to a domain,
337          * detach it first
338          */
339         virtio_iommu_detach_endpoint_from_domain(ep);
340         if (QLIST_EMPTY(&previous_domain->endpoint_list)) {
341             g_tree_remove(s->domains, GUINT_TO_POINTER(previous_domain->id));
342         }
343     }
344 
345     domain = virtio_iommu_get_domain(s, domain_id);
346     QLIST_INSERT_HEAD(&domain->endpoint_list, ep, next);
347 
348     ep->domain = domain;
349 
350     /* Replay domain mappings on the associated memory region */
351     g_tree_foreach(domain->mappings, virtio_iommu_notify_map_cb,
352                    ep->iommu_mr);
353 
354     return VIRTIO_IOMMU_S_OK;
355 }
356 
357 static int virtio_iommu_detach(VirtIOIOMMU *s,
358                                struct virtio_iommu_req_detach *req)
359 {
360     uint32_t domain_id = le32_to_cpu(req->domain);
361     uint32_t ep_id = le32_to_cpu(req->endpoint);
362     VirtIOIOMMUDomain *domain;
363     VirtIOIOMMUEndpoint *ep;
364 
365     trace_virtio_iommu_detach(domain_id, ep_id);
366 
367     ep = g_tree_lookup(s->endpoints, GUINT_TO_POINTER(ep_id));
368     if (!ep) {
369         return VIRTIO_IOMMU_S_NOENT;
370     }
371 
372     domain = ep->domain;
373 
374     if (!domain || domain->id != domain_id) {
375         return VIRTIO_IOMMU_S_INVAL;
376     }
377 
378     virtio_iommu_detach_endpoint_from_domain(ep);
379 
380     if (QLIST_EMPTY(&domain->endpoint_list)) {
381         g_tree_remove(s->domains, GUINT_TO_POINTER(domain->id));
382     }
383     return VIRTIO_IOMMU_S_OK;
384 }
385 
386 static int virtio_iommu_map(VirtIOIOMMU *s,
387                             struct virtio_iommu_req_map *req)
388 {
389     uint32_t domain_id = le32_to_cpu(req->domain);
390     uint64_t phys_start = le64_to_cpu(req->phys_start);
391     uint64_t virt_start = le64_to_cpu(req->virt_start);
392     uint64_t virt_end = le64_to_cpu(req->virt_end);
393     uint32_t flags = le32_to_cpu(req->flags);
394     VirtIOIOMMUDomain *domain;
395     VirtIOIOMMUInterval *interval;
396     VirtIOIOMMUMapping *mapping;
397     VirtIOIOMMUEndpoint *ep;
398 
399     if (flags & ~VIRTIO_IOMMU_MAP_F_MASK) {
400         return VIRTIO_IOMMU_S_INVAL;
401     }
402 
403     domain = g_tree_lookup(s->domains, GUINT_TO_POINTER(domain_id));
404     if (!domain) {
405         return VIRTIO_IOMMU_S_NOENT;
406     }
407 
408     interval = g_malloc0(sizeof(*interval));
409 
410     interval->low = virt_start;
411     interval->high = virt_end;
412 
413     mapping = g_tree_lookup(domain->mappings, (gpointer)interval);
414     if (mapping) {
415         g_free(interval);
416         return VIRTIO_IOMMU_S_INVAL;
417     }
418 
419     trace_virtio_iommu_map(domain_id, virt_start, virt_end, phys_start, flags);
420 
421     mapping = g_malloc0(sizeof(*mapping));
422     mapping->phys_addr = phys_start;
423     mapping->flags = flags;
424 
425     g_tree_insert(domain->mappings, interval, mapping);
426 
427     QLIST_FOREACH(ep, &domain->endpoint_list, next) {
428         virtio_iommu_notify_map(ep->iommu_mr, virt_start, virt_end, phys_start,
429                                 flags);
430     }
431 
432     return VIRTIO_IOMMU_S_OK;
433 }
434 
435 static int virtio_iommu_unmap(VirtIOIOMMU *s,
436                               struct virtio_iommu_req_unmap *req)
437 {
438     uint32_t domain_id = le32_to_cpu(req->domain);
439     uint64_t virt_start = le64_to_cpu(req->virt_start);
440     uint64_t virt_end = le64_to_cpu(req->virt_end);
441     VirtIOIOMMUMapping *iter_val;
442     VirtIOIOMMUInterval interval, *iter_key;
443     VirtIOIOMMUDomain *domain;
444     VirtIOIOMMUEndpoint *ep;
445     int ret = VIRTIO_IOMMU_S_OK;
446 
447     trace_virtio_iommu_unmap(domain_id, virt_start, virt_end);
448 
449     domain = g_tree_lookup(s->domains, GUINT_TO_POINTER(domain_id));
450     if (!domain) {
451         return VIRTIO_IOMMU_S_NOENT;
452     }
453     interval.low = virt_start;
454     interval.high = virt_end;
455 
456     while (g_tree_lookup_extended(domain->mappings, &interval,
457                                   (void **)&iter_key, (void**)&iter_val)) {
458         uint64_t current_low = iter_key->low;
459         uint64_t current_high = iter_key->high;
460 
461         if (interval.low <= current_low && interval.high >= current_high) {
462             QLIST_FOREACH(ep, &domain->endpoint_list, next) {
463                 virtio_iommu_notify_unmap(ep->iommu_mr, current_low,
464                                           current_high);
465             }
466             g_tree_remove(domain->mappings, iter_key);
467             trace_virtio_iommu_unmap_done(domain_id, current_low, current_high);
468         } else {
469             ret = VIRTIO_IOMMU_S_RANGE;
470             break;
471         }
472     }
473     return ret;
474 }
475 
476 static ssize_t virtio_iommu_fill_resv_mem_prop(VirtIOIOMMU *s, uint32_t ep,
477                                                uint8_t *buf, size_t free)
478 {
479     struct virtio_iommu_probe_resv_mem prop = {};
480     size_t size = sizeof(prop), length = size - sizeof(prop.head), total;
481     int i;
482 
483     total = size * s->nb_reserved_regions;
484 
485     if (total > free) {
486         return -ENOSPC;
487     }
488 
489     for (i = 0; i < s->nb_reserved_regions; i++) {
490         unsigned subtype = s->reserved_regions[i].type;
491 
492         assert(subtype == VIRTIO_IOMMU_RESV_MEM_T_RESERVED ||
493                subtype == VIRTIO_IOMMU_RESV_MEM_T_MSI);
494         prop.head.type = cpu_to_le16(VIRTIO_IOMMU_PROBE_T_RESV_MEM);
495         prop.head.length = cpu_to_le16(length);
496         prop.subtype = subtype;
497         prop.start = cpu_to_le64(s->reserved_regions[i].low);
498         prop.end = cpu_to_le64(s->reserved_regions[i].high);
499 
500         memcpy(buf, &prop, size);
501 
502         trace_virtio_iommu_fill_resv_property(ep, prop.subtype,
503                                               prop.start, prop.end);
504         buf += size;
505     }
506     return total;
507 }
508 
509 /**
510  * virtio_iommu_probe - Fill the probe request buffer with
511  * the properties the device is able to return
512  */
513 static int virtio_iommu_probe(VirtIOIOMMU *s,
514                               struct virtio_iommu_req_probe *req,
515                               uint8_t *buf)
516 {
517     uint32_t ep_id = le32_to_cpu(req->endpoint);
518     size_t free = VIOMMU_PROBE_SIZE;
519     ssize_t count;
520 
521     if (!virtio_iommu_mr(s, ep_id)) {
522         return VIRTIO_IOMMU_S_NOENT;
523     }
524 
525     count = virtio_iommu_fill_resv_mem_prop(s, ep_id, buf, free);
526     if (count < 0) {
527         return VIRTIO_IOMMU_S_INVAL;
528     }
529     buf += count;
530     free -= count;
531 
532     return VIRTIO_IOMMU_S_OK;
533 }
534 
535 static int virtio_iommu_iov_to_req(struct iovec *iov,
536                                    unsigned int iov_cnt,
537                                    void *req, size_t req_sz)
538 {
539     size_t sz, payload_sz = req_sz - sizeof(struct virtio_iommu_req_tail);
540 
541     sz = iov_to_buf(iov, iov_cnt, 0, req, payload_sz);
542     if (unlikely(sz != payload_sz)) {
543         return VIRTIO_IOMMU_S_INVAL;
544     }
545     return 0;
546 }
547 
548 #define virtio_iommu_handle_req(__req)                                  \
549 static int virtio_iommu_handle_ ## __req(VirtIOIOMMU *s,                \
550                                          struct iovec *iov,             \
551                                          unsigned int iov_cnt)          \
552 {                                                                       \
553     struct virtio_iommu_req_ ## __req req;                              \
554     int ret = virtio_iommu_iov_to_req(iov, iov_cnt, &req, sizeof(req)); \
555                                                                         \
556     return ret ? ret : virtio_iommu_ ## __req(s, &req);                 \
557 }
558 
559 virtio_iommu_handle_req(attach)
560 virtio_iommu_handle_req(detach)
561 virtio_iommu_handle_req(map)
562 virtio_iommu_handle_req(unmap)
563 
564 static int virtio_iommu_handle_probe(VirtIOIOMMU *s,
565                                      struct iovec *iov,
566                                      unsigned int iov_cnt,
567                                      uint8_t *buf)
568 {
569     struct virtio_iommu_req_probe req;
570     int ret = virtio_iommu_iov_to_req(iov, iov_cnt, &req, sizeof(req));
571 
572     return ret ? ret : virtio_iommu_probe(s, &req, buf);
573 }
574 
575 static void virtio_iommu_handle_command(VirtIODevice *vdev, VirtQueue *vq)
576 {
577     VirtIOIOMMU *s = VIRTIO_IOMMU(vdev);
578     struct virtio_iommu_req_head head;
579     struct virtio_iommu_req_tail tail = {};
580     size_t output_size = sizeof(tail), sz;
581     VirtQueueElement *elem;
582     unsigned int iov_cnt;
583     struct iovec *iov;
584     void *buf = NULL;
585 
586     for (;;) {
587         elem = virtqueue_pop(vq, sizeof(VirtQueueElement));
588         if (!elem) {
589             return;
590         }
591 
592         if (iov_size(elem->in_sg, elem->in_num) < sizeof(tail) ||
593             iov_size(elem->out_sg, elem->out_num) < sizeof(head)) {
594             virtio_error(vdev, "virtio-iommu bad head/tail size");
595             virtqueue_detach_element(vq, elem, 0);
596             g_free(elem);
597             break;
598         }
599 
600         iov_cnt = elem->out_num;
601         iov = elem->out_sg;
602         sz = iov_to_buf(iov, iov_cnt, 0, &head, sizeof(head));
603         if (unlikely(sz != sizeof(head))) {
604             tail.status = VIRTIO_IOMMU_S_DEVERR;
605             goto out;
606         }
607         qemu_mutex_lock(&s->mutex);
608         switch (head.type) {
609         case VIRTIO_IOMMU_T_ATTACH:
610             tail.status = virtio_iommu_handle_attach(s, iov, iov_cnt);
611             break;
612         case VIRTIO_IOMMU_T_DETACH:
613             tail.status = virtio_iommu_handle_detach(s, iov, iov_cnt);
614             break;
615         case VIRTIO_IOMMU_T_MAP:
616             tail.status = virtio_iommu_handle_map(s, iov, iov_cnt);
617             break;
618         case VIRTIO_IOMMU_T_UNMAP:
619             tail.status = virtio_iommu_handle_unmap(s, iov, iov_cnt);
620             break;
621         case VIRTIO_IOMMU_T_PROBE:
622         {
623             struct virtio_iommu_req_tail *ptail;
624 
625             output_size = s->config.probe_size + sizeof(tail);
626             buf = g_malloc0(output_size);
627 
628             ptail = (struct virtio_iommu_req_tail *)
629                         (buf + s->config.probe_size);
630             ptail->status = virtio_iommu_handle_probe(s, iov, iov_cnt, buf);
631             break;
632         }
633         default:
634             tail.status = VIRTIO_IOMMU_S_UNSUPP;
635         }
636         qemu_mutex_unlock(&s->mutex);
637 
638 out:
639         sz = iov_from_buf(elem->in_sg, elem->in_num, 0,
640                           buf ? buf : &tail, output_size);
641         assert(sz == output_size);
642 
643         virtqueue_push(vq, elem, sz);
644         virtio_notify(vdev, vq);
645         g_free(elem);
646         g_free(buf);
647     }
648 }
649 
650 static void virtio_iommu_report_fault(VirtIOIOMMU *viommu, uint8_t reason,
651                                       int flags, uint32_t endpoint,
652                                       uint64_t address)
653 {
654     VirtIODevice *vdev = &viommu->parent_obj;
655     VirtQueue *vq = viommu->event_vq;
656     struct virtio_iommu_fault fault;
657     VirtQueueElement *elem;
658     size_t sz;
659 
660     memset(&fault, 0, sizeof(fault));
661     fault.reason = reason;
662     fault.flags = cpu_to_le32(flags);
663     fault.endpoint = cpu_to_le32(endpoint);
664     fault.address = cpu_to_le64(address);
665 
666     elem = virtqueue_pop(vq, sizeof(VirtQueueElement));
667 
668     if (!elem) {
669         error_report_once(
670             "no buffer available in event queue to report event");
671         return;
672     }
673 
674     if (iov_size(elem->in_sg, elem->in_num) < sizeof(fault)) {
675         virtio_error(vdev, "error buffer of wrong size");
676         virtqueue_detach_element(vq, elem, 0);
677         g_free(elem);
678         return;
679     }
680 
681     sz = iov_from_buf(elem->in_sg, elem->in_num, 0,
682                       &fault, sizeof(fault));
683     assert(sz == sizeof(fault));
684 
685     trace_virtio_iommu_report_fault(reason, flags, endpoint, address);
686     virtqueue_push(vq, elem, sz);
687     virtio_notify(vdev, vq);
688     g_free(elem);
689 
690 }
691 
692 static IOMMUTLBEntry virtio_iommu_translate(IOMMUMemoryRegion *mr, hwaddr addr,
693                                             IOMMUAccessFlags flag,
694                                             int iommu_idx)
695 {
696     IOMMUDevice *sdev = container_of(mr, IOMMUDevice, iommu_mr);
697     VirtIOIOMMUInterval interval, *mapping_key;
698     VirtIOIOMMUMapping *mapping_value;
699     VirtIOIOMMU *s = sdev->viommu;
700     bool read_fault, write_fault;
701     VirtIOIOMMUEndpoint *ep;
702     uint32_t sid, flags;
703     bool bypass_allowed;
704     bool found;
705     int i;
706 
707     interval.low = addr;
708     interval.high = addr + 1;
709 
710     IOMMUTLBEntry entry = {
711         .target_as = &address_space_memory,
712         .iova = addr,
713         .translated_addr = addr,
714         .addr_mask = (1 << ctz32(s->config.page_size_mask)) - 1,
715         .perm = IOMMU_NONE,
716     };
717 
718     bypass_allowed = virtio_vdev_has_feature(&s->parent_obj,
719                                              VIRTIO_IOMMU_F_BYPASS);
720 
721     sid = virtio_iommu_get_bdf(sdev);
722 
723     trace_virtio_iommu_translate(mr->parent_obj.name, sid, addr, flag);
724     qemu_mutex_lock(&s->mutex);
725 
726     ep = g_tree_lookup(s->endpoints, GUINT_TO_POINTER(sid));
727     if (!ep) {
728         if (!bypass_allowed) {
729             error_report_once("%s sid=%d is not known!!", __func__, sid);
730             virtio_iommu_report_fault(s, VIRTIO_IOMMU_FAULT_R_UNKNOWN,
731                                       VIRTIO_IOMMU_FAULT_F_ADDRESS,
732                                       sid, addr);
733         } else {
734             entry.perm = flag;
735         }
736         goto unlock;
737     }
738 
739     for (i = 0; i < s->nb_reserved_regions; i++) {
740         ReservedRegion *reg = &s->reserved_regions[i];
741 
742         if (addr >= reg->low && addr <= reg->high) {
743             switch (reg->type) {
744             case VIRTIO_IOMMU_RESV_MEM_T_MSI:
745                 entry.perm = flag;
746                 break;
747             case VIRTIO_IOMMU_RESV_MEM_T_RESERVED:
748             default:
749                 virtio_iommu_report_fault(s, VIRTIO_IOMMU_FAULT_R_MAPPING,
750                                           VIRTIO_IOMMU_FAULT_F_ADDRESS,
751                                           sid, addr);
752                 break;
753             }
754             goto unlock;
755         }
756     }
757 
758     if (!ep->domain) {
759         if (!bypass_allowed) {
760             error_report_once("%s %02x:%02x.%01x not attached to any domain",
761                               __func__, PCI_BUS_NUM(sid),
762                               PCI_SLOT(sid), PCI_FUNC(sid));
763             virtio_iommu_report_fault(s, VIRTIO_IOMMU_FAULT_R_DOMAIN,
764                                       VIRTIO_IOMMU_FAULT_F_ADDRESS,
765                                       sid, addr);
766         } else {
767             entry.perm = flag;
768         }
769         goto unlock;
770     }
771 
772     found = g_tree_lookup_extended(ep->domain->mappings, (gpointer)(&interval),
773                                    (void **)&mapping_key,
774                                    (void **)&mapping_value);
775     if (!found) {
776         error_report_once("%s no mapping for 0x%"PRIx64" for sid=%d",
777                           __func__, addr, sid);
778         virtio_iommu_report_fault(s, VIRTIO_IOMMU_FAULT_R_MAPPING,
779                                   VIRTIO_IOMMU_FAULT_F_ADDRESS,
780                                   sid, addr);
781         goto unlock;
782     }
783 
784     read_fault = (flag & IOMMU_RO) &&
785                     !(mapping_value->flags & VIRTIO_IOMMU_MAP_F_READ);
786     write_fault = (flag & IOMMU_WO) &&
787                     !(mapping_value->flags & VIRTIO_IOMMU_MAP_F_WRITE);
788 
789     flags = read_fault ? VIRTIO_IOMMU_FAULT_F_READ : 0;
790     flags |= write_fault ? VIRTIO_IOMMU_FAULT_F_WRITE : 0;
791     if (flags) {
792         error_report_once("%s permission error on 0x%"PRIx64"(%d): allowed=%d",
793                           __func__, addr, flag, mapping_value->flags);
794         flags |= VIRTIO_IOMMU_FAULT_F_ADDRESS;
795         virtio_iommu_report_fault(s, VIRTIO_IOMMU_FAULT_R_MAPPING,
796                                   flags | VIRTIO_IOMMU_FAULT_F_ADDRESS,
797                                   sid, addr);
798         goto unlock;
799     }
800     entry.translated_addr = addr - mapping_key->low + mapping_value->phys_addr;
801     entry.perm = flag;
802     trace_virtio_iommu_translate_out(addr, entry.translated_addr, sid);
803 
804 unlock:
805     qemu_mutex_unlock(&s->mutex);
806     return entry;
807 }
808 
809 static void virtio_iommu_get_config(VirtIODevice *vdev, uint8_t *config_data)
810 {
811     VirtIOIOMMU *dev = VIRTIO_IOMMU(vdev);
812     struct virtio_iommu_config *config = &dev->config;
813 
814     trace_virtio_iommu_get_config(config->page_size_mask,
815                                   config->input_range.start,
816                                   config->input_range.end,
817                                   config->domain_range.end,
818                                   config->probe_size);
819     memcpy(config_data, &dev->config, sizeof(struct virtio_iommu_config));
820 }
821 
822 static void virtio_iommu_set_config(VirtIODevice *vdev,
823                                       const uint8_t *config_data)
824 {
825     struct virtio_iommu_config config;
826 
827     memcpy(&config, config_data, sizeof(struct virtio_iommu_config));
828     trace_virtio_iommu_set_config(config.page_size_mask,
829                                   config.input_range.start,
830                                   config.input_range.end,
831                                   config.domain_range.end,
832                                   config.probe_size);
833 }
834 
835 static uint64_t virtio_iommu_get_features(VirtIODevice *vdev, uint64_t f,
836                                           Error **errp)
837 {
838     VirtIOIOMMU *dev = VIRTIO_IOMMU(vdev);
839 
840     f |= dev->features;
841     trace_virtio_iommu_get_features(f);
842     return f;
843 }
844 
845 static gint int_cmp(gconstpointer a, gconstpointer b, gpointer user_data)
846 {
847     guint ua = GPOINTER_TO_UINT(a);
848     guint ub = GPOINTER_TO_UINT(b);
849     return (ua > ub) - (ua < ub);
850 }
851 
852 static gboolean virtio_iommu_remap(gpointer key, gpointer value, gpointer data)
853 {
854     VirtIOIOMMUMapping *mapping = (VirtIOIOMMUMapping *) value;
855     VirtIOIOMMUInterval *interval = (VirtIOIOMMUInterval *) key;
856     IOMMUMemoryRegion *mr = (IOMMUMemoryRegion *) data;
857 
858     trace_virtio_iommu_remap(mr->parent_obj.name, interval->low, interval->high,
859                              mapping->phys_addr);
860     virtio_iommu_notify_map(mr, interval->low, interval->high,
861                             mapping->phys_addr, mapping->flags);
862     return false;
863 }
864 
865 static void virtio_iommu_replay(IOMMUMemoryRegion *mr, IOMMUNotifier *n)
866 {
867     IOMMUDevice *sdev = container_of(mr, IOMMUDevice, iommu_mr);
868     VirtIOIOMMU *s = sdev->viommu;
869     uint32_t sid;
870     VirtIOIOMMUEndpoint *ep;
871 
872     sid = virtio_iommu_get_bdf(sdev);
873 
874     qemu_mutex_lock(&s->mutex);
875 
876     if (!s->endpoints) {
877         goto unlock;
878     }
879 
880     ep = g_tree_lookup(s->endpoints, GUINT_TO_POINTER(sid));
881     if (!ep || !ep->domain) {
882         goto unlock;
883     }
884 
885     g_tree_foreach(ep->domain->mappings, virtio_iommu_remap, mr);
886 
887 unlock:
888     qemu_mutex_unlock(&s->mutex);
889 }
890 
891 static int virtio_iommu_notify_flag_changed(IOMMUMemoryRegion *iommu_mr,
892                                             IOMMUNotifierFlag old,
893                                             IOMMUNotifierFlag new,
894                                             Error **errp)
895 {
896     if (old == IOMMU_NOTIFIER_NONE) {
897         trace_virtio_iommu_notify_flag_add(iommu_mr->parent_obj.name);
898     } else if (new == IOMMU_NOTIFIER_NONE) {
899         trace_virtio_iommu_notify_flag_del(iommu_mr->parent_obj.name);
900     }
901     return 0;
902 }
903 
904 /*
905  * The default mask (TARGET_PAGE_MASK) is the smallest supported guest granule,
906  * for example 0xfffffffffffff000. When an assigned device has page size
907  * restrictions due to the hardware IOMMU configuration, apply this restriction
908  * to the mask.
909  */
910 static int virtio_iommu_set_page_size_mask(IOMMUMemoryRegion *mr,
911                                            uint64_t new_mask,
912                                            Error **errp)
913 {
914     IOMMUDevice *sdev = container_of(mr, IOMMUDevice, iommu_mr);
915     VirtIOIOMMU *s = sdev->viommu;
916     uint64_t cur_mask = s->config.page_size_mask;
917 
918     trace_virtio_iommu_set_page_size_mask(mr->parent_obj.name, cur_mask,
919                                           new_mask);
920 
921     if ((cur_mask & new_mask) == 0) {
922         error_setg(errp, "virtio-iommu page mask 0x%"PRIx64
923                    " is incompatible with mask 0x%"PRIx64, cur_mask, new_mask);
924         return -1;
925     }
926 
927     /*
928      * After the machine is finalized, we can't change the mask anymore. If by
929      * chance the hotplugged device supports the same granule, we can still
930      * accept it. Having a different masks is possible but the guest will use
931      * sub-optimal block sizes, so warn about it.
932      */
933     if (phase_check(PHASE_MACHINE_READY)) {
934         int new_granule = ctz64(new_mask);
935         int cur_granule = ctz64(cur_mask);
936 
937         if (new_granule != cur_granule) {
938             error_setg(errp, "virtio-iommu page mask 0x%"PRIx64
939                        " is incompatible with mask 0x%"PRIx64, cur_mask,
940                        new_mask);
941             return -1;
942         } else if (new_mask != cur_mask) {
943             warn_report("virtio-iommu page mask 0x%"PRIx64
944                         " does not match 0x%"PRIx64, cur_mask, new_mask);
945         }
946         return 0;
947     }
948 
949     s->config.page_size_mask &= new_mask;
950     return 0;
951 }
952 
953 static void virtio_iommu_device_realize(DeviceState *dev, Error **errp)
954 {
955     VirtIODevice *vdev = VIRTIO_DEVICE(dev);
956     VirtIOIOMMU *s = VIRTIO_IOMMU(dev);
957 
958     virtio_init(vdev, "virtio-iommu", VIRTIO_ID_IOMMU,
959                 sizeof(struct virtio_iommu_config));
960 
961     memset(s->iommu_pcibus_by_bus_num, 0, sizeof(s->iommu_pcibus_by_bus_num));
962 
963     s->req_vq = virtio_add_queue(vdev, VIOMMU_DEFAULT_QUEUE_SIZE,
964                              virtio_iommu_handle_command);
965     s->event_vq = virtio_add_queue(vdev, VIOMMU_DEFAULT_QUEUE_SIZE, NULL);
966 
967     s->config.page_size_mask = TARGET_PAGE_MASK;
968     s->config.input_range.end = -1UL;
969     s->config.domain_range.end = 32;
970     s->config.probe_size = VIOMMU_PROBE_SIZE;
971 
972     virtio_add_feature(&s->features, VIRTIO_RING_F_EVENT_IDX);
973     virtio_add_feature(&s->features, VIRTIO_RING_F_INDIRECT_DESC);
974     virtio_add_feature(&s->features, VIRTIO_F_VERSION_1);
975     virtio_add_feature(&s->features, VIRTIO_IOMMU_F_INPUT_RANGE);
976     virtio_add_feature(&s->features, VIRTIO_IOMMU_F_DOMAIN_RANGE);
977     virtio_add_feature(&s->features, VIRTIO_IOMMU_F_MAP_UNMAP);
978     virtio_add_feature(&s->features, VIRTIO_IOMMU_F_BYPASS);
979     virtio_add_feature(&s->features, VIRTIO_IOMMU_F_MMIO);
980     virtio_add_feature(&s->features, VIRTIO_IOMMU_F_PROBE);
981 
982     qemu_mutex_init(&s->mutex);
983 
984     s->as_by_busptr = g_hash_table_new_full(NULL, NULL, NULL, g_free);
985 
986     if (s->primary_bus) {
987         pci_setup_iommu(s->primary_bus, virtio_iommu_find_add_as, s);
988     } else {
989         error_setg(errp, "VIRTIO-IOMMU is not attached to any PCI bus!");
990     }
991 }
992 
993 static void virtio_iommu_device_unrealize(DeviceState *dev)
994 {
995     VirtIODevice *vdev = VIRTIO_DEVICE(dev);
996     VirtIOIOMMU *s = VIRTIO_IOMMU(dev);
997 
998     g_hash_table_destroy(s->as_by_busptr);
999     if (s->domains) {
1000         g_tree_destroy(s->domains);
1001     }
1002     if (s->endpoints) {
1003         g_tree_destroy(s->endpoints);
1004     }
1005 
1006     virtio_delete_queue(s->req_vq);
1007     virtio_delete_queue(s->event_vq);
1008     virtio_cleanup(vdev);
1009 }
1010 
1011 static void virtio_iommu_device_reset(VirtIODevice *vdev)
1012 {
1013     VirtIOIOMMU *s = VIRTIO_IOMMU(vdev);
1014 
1015     trace_virtio_iommu_device_reset();
1016 
1017     if (s->domains) {
1018         g_tree_destroy(s->domains);
1019     }
1020     if (s->endpoints) {
1021         g_tree_destroy(s->endpoints);
1022     }
1023     s->domains = g_tree_new_full((GCompareDataFunc)int_cmp,
1024                                  NULL, NULL, virtio_iommu_put_domain);
1025     s->endpoints = g_tree_new_full((GCompareDataFunc)int_cmp,
1026                                    NULL, NULL, virtio_iommu_put_endpoint);
1027 }
1028 
1029 static void virtio_iommu_set_status(VirtIODevice *vdev, uint8_t status)
1030 {
1031     trace_virtio_iommu_device_status(status);
1032 }
1033 
1034 static void virtio_iommu_instance_init(Object *obj)
1035 {
1036 }
1037 
1038 #define VMSTATE_INTERVAL                               \
1039 {                                                      \
1040     .name = "interval",                                \
1041     .version_id = 1,                                   \
1042     .minimum_version_id = 1,                           \
1043     .fields = (VMStateField[]) {                       \
1044         VMSTATE_UINT64(low, VirtIOIOMMUInterval),      \
1045         VMSTATE_UINT64(high, VirtIOIOMMUInterval),     \
1046         VMSTATE_END_OF_LIST()                          \
1047     }                                                  \
1048 }
1049 
1050 #define VMSTATE_MAPPING                               \
1051 {                                                     \
1052     .name = "mapping",                                \
1053     .version_id = 1,                                  \
1054     .minimum_version_id = 1,                          \
1055     .fields = (VMStateField[]) {                      \
1056         VMSTATE_UINT64(phys_addr, VirtIOIOMMUMapping),\
1057         VMSTATE_UINT32(flags, VirtIOIOMMUMapping),    \
1058         VMSTATE_END_OF_LIST()                         \
1059     },                                                \
1060 }
1061 
1062 static const VMStateDescription vmstate_interval_mapping[2] = {
1063     VMSTATE_MAPPING,   /* value */
1064     VMSTATE_INTERVAL   /* key   */
1065 };
1066 
1067 static int domain_preload(void *opaque)
1068 {
1069     VirtIOIOMMUDomain *domain = opaque;
1070 
1071     domain->mappings = g_tree_new_full((GCompareDataFunc)interval_cmp,
1072                                        NULL, g_free, g_free);
1073     return 0;
1074 }
1075 
1076 static const VMStateDescription vmstate_endpoint = {
1077     .name = "endpoint",
1078     .version_id = 1,
1079     .minimum_version_id = 1,
1080     .fields = (VMStateField[]) {
1081         VMSTATE_UINT32(id, VirtIOIOMMUEndpoint),
1082         VMSTATE_END_OF_LIST()
1083     }
1084 };
1085 
1086 static const VMStateDescription vmstate_domain = {
1087     .name = "domain",
1088     .version_id = 1,
1089     .minimum_version_id = 1,
1090     .pre_load = domain_preload,
1091     .fields = (VMStateField[]) {
1092         VMSTATE_UINT32(id, VirtIOIOMMUDomain),
1093         VMSTATE_GTREE_V(mappings, VirtIOIOMMUDomain, 1,
1094                         vmstate_interval_mapping,
1095                         VirtIOIOMMUInterval, VirtIOIOMMUMapping),
1096         VMSTATE_QLIST_V(endpoint_list, VirtIOIOMMUDomain, 1,
1097                         vmstate_endpoint, VirtIOIOMMUEndpoint, next),
1098         VMSTATE_END_OF_LIST()
1099     }
1100 };
1101 
1102 static gboolean reconstruct_endpoints(gpointer key, gpointer value,
1103                                       gpointer data)
1104 {
1105     VirtIOIOMMU *s = (VirtIOIOMMU *)data;
1106     VirtIOIOMMUDomain *d = (VirtIOIOMMUDomain *)value;
1107     VirtIOIOMMUEndpoint *iter;
1108     IOMMUMemoryRegion *mr;
1109 
1110     QLIST_FOREACH(iter, &d->endpoint_list, next) {
1111         mr = virtio_iommu_mr(s, iter->id);
1112         assert(mr);
1113 
1114         iter->domain = d;
1115         iter->iommu_mr = mr;
1116         g_tree_insert(s->endpoints, GUINT_TO_POINTER(iter->id), iter);
1117     }
1118     return false; /* continue the domain traversal */
1119 }
1120 
1121 static int iommu_post_load(void *opaque, int version_id)
1122 {
1123     VirtIOIOMMU *s = opaque;
1124 
1125     g_tree_foreach(s->domains, reconstruct_endpoints, s);
1126     return 0;
1127 }
1128 
1129 static const VMStateDescription vmstate_virtio_iommu_device = {
1130     .name = "virtio-iommu-device",
1131     .minimum_version_id = 1,
1132     .version_id = 1,
1133     .post_load = iommu_post_load,
1134     .fields = (VMStateField[]) {
1135         VMSTATE_GTREE_DIRECT_KEY_V(domains, VirtIOIOMMU, 1,
1136                                    &vmstate_domain, VirtIOIOMMUDomain),
1137         VMSTATE_END_OF_LIST()
1138     },
1139 };
1140 
1141 static const VMStateDescription vmstate_virtio_iommu = {
1142     .name = "virtio-iommu",
1143     .minimum_version_id = 1,
1144     .priority = MIG_PRI_IOMMU,
1145     .version_id = 1,
1146     .fields = (VMStateField[]) {
1147         VMSTATE_VIRTIO_DEVICE,
1148         VMSTATE_END_OF_LIST()
1149     },
1150 };
1151 
1152 static Property virtio_iommu_properties[] = {
1153     DEFINE_PROP_LINK("primary-bus", VirtIOIOMMU, primary_bus, "PCI", PCIBus *),
1154     DEFINE_PROP_END_OF_LIST(),
1155 };
1156 
1157 static void virtio_iommu_class_init(ObjectClass *klass, void *data)
1158 {
1159     DeviceClass *dc = DEVICE_CLASS(klass);
1160     VirtioDeviceClass *vdc = VIRTIO_DEVICE_CLASS(klass);
1161 
1162     device_class_set_props(dc, virtio_iommu_properties);
1163     dc->vmsd = &vmstate_virtio_iommu;
1164 
1165     set_bit(DEVICE_CATEGORY_MISC, dc->categories);
1166     vdc->realize = virtio_iommu_device_realize;
1167     vdc->unrealize = virtio_iommu_device_unrealize;
1168     vdc->reset = virtio_iommu_device_reset;
1169     vdc->get_config = virtio_iommu_get_config;
1170     vdc->set_config = virtio_iommu_set_config;
1171     vdc->get_features = virtio_iommu_get_features;
1172     vdc->set_status = virtio_iommu_set_status;
1173     vdc->vmsd = &vmstate_virtio_iommu_device;
1174 }
1175 
1176 static void virtio_iommu_memory_region_class_init(ObjectClass *klass,
1177                                                   void *data)
1178 {
1179     IOMMUMemoryRegionClass *imrc = IOMMU_MEMORY_REGION_CLASS(klass);
1180 
1181     imrc->translate = virtio_iommu_translate;
1182     imrc->replay = virtio_iommu_replay;
1183     imrc->notify_flag_changed = virtio_iommu_notify_flag_changed;
1184     imrc->iommu_set_page_size_mask = virtio_iommu_set_page_size_mask;
1185 }
1186 
1187 static const TypeInfo virtio_iommu_info = {
1188     .name = TYPE_VIRTIO_IOMMU,
1189     .parent = TYPE_VIRTIO_DEVICE,
1190     .instance_size = sizeof(VirtIOIOMMU),
1191     .instance_init = virtio_iommu_instance_init,
1192     .class_init = virtio_iommu_class_init,
1193 };
1194 
1195 static const TypeInfo virtio_iommu_memory_region_info = {
1196     .parent = TYPE_IOMMU_MEMORY_REGION,
1197     .name = TYPE_VIRTIO_IOMMU_MEMORY_REGION,
1198     .class_init = virtio_iommu_memory_region_class_init,
1199 };
1200 
1201 static void virtio_register_types(void)
1202 {
1203     type_register_static(&virtio_iommu_info);
1204     type_register_static(&virtio_iommu_memory_region_info);
1205 }
1206 
1207 type_init(virtio_register_types)
1208