1 /*
2  * VFIO: IOMMU DMA mapping support for Type1 IOMMU
3  *
4  * Copyright (C) 2012 Red Hat, Inc.  All rights reserved.
5  *     Author: Alex Williamson <alex.williamson@redhat.com>
6  *
7  * This program is free software; you can redistribute it and/or modify
8  * it under the terms of the GNU General Public License version 2 as
9  * published by the Free Software Foundation.
10  *
11  * Derived from original vfio:
12  * Copyright 2010 Cisco Systems, Inc.  All rights reserved.
13  * Author: Tom Lyon, pugs@cisco.com
14  *
15  * We arbitrarily define a Type1 IOMMU as one matching the below code.
16  * It could be called the x86 IOMMU as it's designed for AMD-Vi & Intel
17  * VT-d, but that makes it harder to re-use as theoretically anyone
18  * implementing a similar IOMMU could make use of this.  We expect the
19  * IOMMU to support the IOMMU API and have few to no restrictions around
20  * the IOVA range that can be mapped.  The Type1 IOMMU is currently
21  * optimized for relatively static mappings of a userspace process with
22  * userpsace pages pinned into memory.  We also assume devices and IOMMU
23  * domains are PCI based as the IOMMU API is still centered around a
24  * device/bus interface rather than a group interface.
25  */
26 
27 #include <linux/compat.h>
28 #include <linux/device.h>
29 #include <linux/fs.h>
30 #include <linux/iommu.h>
31 #include <linux/module.h>
32 #include <linux/mm.h>
33 #include <linux/rbtree.h>
34 #include <linux/sched/signal.h>
35 #include <linux/sched/mm.h>
36 #include <linux/slab.h>
37 #include <linux/uaccess.h>
38 #include <linux/vfio.h>
39 #include <linux/workqueue.h>
40 #include <linux/mdev.h>
41 #include <linux/notifier.h>
42 #include <linux/dma-iommu.h>
43 #include <linux/irqdomain.h>
44 
45 #define DRIVER_VERSION  "0.2"
46 #define DRIVER_AUTHOR   "Alex Williamson <alex.williamson@redhat.com>"
47 #define DRIVER_DESC     "Type1 IOMMU driver for VFIO"
48 
49 static bool allow_unsafe_interrupts;
50 module_param_named(allow_unsafe_interrupts,
51 		   allow_unsafe_interrupts, bool, S_IRUGO | S_IWUSR);
52 MODULE_PARM_DESC(allow_unsafe_interrupts,
53 		 "Enable VFIO IOMMU support for on platforms without interrupt remapping support.");
54 
55 static bool disable_hugepages;
56 module_param_named(disable_hugepages,
57 		   disable_hugepages, bool, S_IRUGO | S_IWUSR);
58 MODULE_PARM_DESC(disable_hugepages,
59 		 "Disable VFIO IOMMU support for IOMMU hugepages.");
60 
61 struct vfio_iommu {
62 	struct list_head	domain_list;
63 	struct vfio_domain	*external_domain; /* domain for external user */
64 	struct mutex		lock;
65 	struct rb_root		dma_list;
66 	struct blocking_notifier_head notifier;
67 	bool			v2;
68 	bool			nesting;
69 };
70 
71 struct vfio_domain {
72 	struct iommu_domain	*domain;
73 	struct list_head	next;
74 	struct list_head	group_list;
75 	int			prot;		/* IOMMU_CACHE */
76 	bool			fgsp;		/* Fine-grained super pages */
77 };
78 
79 struct vfio_dma {
80 	struct rb_node		node;
81 	dma_addr_t		iova;		/* Device address */
82 	unsigned long		vaddr;		/* Process virtual addr */
83 	size_t			size;		/* Map size (bytes) */
84 	int			prot;		/* IOMMU_READ/WRITE */
85 	bool			iommu_mapped;
86 	bool			lock_cap;	/* capable(CAP_IPC_LOCK) */
87 	struct task_struct	*task;
88 	struct rb_root		pfn_list;	/* Ex-user pinned pfn list */
89 };
90 
91 struct vfio_group {
92 	struct iommu_group	*iommu_group;
93 	struct list_head	next;
94 };
95 
96 /*
97  * Guest RAM pinning working set or DMA target
98  */
99 struct vfio_pfn {
100 	struct rb_node		node;
101 	dma_addr_t		iova;		/* Device address */
102 	unsigned long		pfn;		/* Host pfn */
103 	atomic_t		ref_count;
104 };
105 
106 struct vfio_regions {
107 	struct list_head list;
108 	dma_addr_t iova;
109 	phys_addr_t phys;
110 	size_t len;
111 };
112 
113 #define IS_IOMMU_CAP_DOMAIN_IN_CONTAINER(iommu)	\
114 					(!list_empty(&iommu->domain_list))
115 
116 static int put_pfn(unsigned long pfn, int prot);
117 
118 /*
119  * This code handles mapping and unmapping of user data buffers
120  * into DMA'ble space using the IOMMU
121  */
122 
123 static struct vfio_dma *vfio_find_dma(struct vfio_iommu *iommu,
124 				      dma_addr_t start, size_t size)
125 {
126 	struct rb_node *node = iommu->dma_list.rb_node;
127 
128 	while (node) {
129 		struct vfio_dma *dma = rb_entry(node, struct vfio_dma, node);
130 
131 		if (start + size <= dma->iova)
132 			node = node->rb_left;
133 		else if (start >= dma->iova + dma->size)
134 			node = node->rb_right;
135 		else
136 			return dma;
137 	}
138 
139 	return NULL;
140 }
141 
142 static void vfio_link_dma(struct vfio_iommu *iommu, struct vfio_dma *new)
143 {
144 	struct rb_node **link = &iommu->dma_list.rb_node, *parent = NULL;
145 	struct vfio_dma *dma;
146 
147 	while (*link) {
148 		parent = *link;
149 		dma = rb_entry(parent, struct vfio_dma, node);
150 
151 		if (new->iova + new->size <= dma->iova)
152 			link = &(*link)->rb_left;
153 		else
154 			link = &(*link)->rb_right;
155 	}
156 
157 	rb_link_node(&new->node, parent, link);
158 	rb_insert_color(&new->node, &iommu->dma_list);
159 }
160 
161 static void vfio_unlink_dma(struct vfio_iommu *iommu, struct vfio_dma *old)
162 {
163 	rb_erase(&old->node, &iommu->dma_list);
164 }
165 
166 /*
167  * Helper Functions for host iova-pfn list
168  */
169 static struct vfio_pfn *vfio_find_vpfn(struct vfio_dma *dma, dma_addr_t iova)
170 {
171 	struct vfio_pfn *vpfn;
172 	struct rb_node *node = dma->pfn_list.rb_node;
173 
174 	while (node) {
175 		vpfn = rb_entry(node, struct vfio_pfn, node);
176 
177 		if (iova < vpfn->iova)
178 			node = node->rb_left;
179 		else if (iova > vpfn->iova)
180 			node = node->rb_right;
181 		else
182 			return vpfn;
183 	}
184 	return NULL;
185 }
186 
187 static void vfio_link_pfn(struct vfio_dma *dma,
188 			  struct vfio_pfn *new)
189 {
190 	struct rb_node **link, *parent = NULL;
191 	struct vfio_pfn *vpfn;
192 
193 	link = &dma->pfn_list.rb_node;
194 	while (*link) {
195 		parent = *link;
196 		vpfn = rb_entry(parent, struct vfio_pfn, node);
197 
198 		if (new->iova < vpfn->iova)
199 			link = &(*link)->rb_left;
200 		else
201 			link = &(*link)->rb_right;
202 	}
203 
204 	rb_link_node(&new->node, parent, link);
205 	rb_insert_color(&new->node, &dma->pfn_list);
206 }
207 
208 static void vfio_unlink_pfn(struct vfio_dma *dma, struct vfio_pfn *old)
209 {
210 	rb_erase(&old->node, &dma->pfn_list);
211 }
212 
213 static int vfio_add_to_pfn_list(struct vfio_dma *dma, dma_addr_t iova,
214 				unsigned long pfn)
215 {
216 	struct vfio_pfn *vpfn;
217 
218 	vpfn = kzalloc(sizeof(*vpfn), GFP_KERNEL);
219 	if (!vpfn)
220 		return -ENOMEM;
221 
222 	vpfn->iova = iova;
223 	vpfn->pfn = pfn;
224 	atomic_set(&vpfn->ref_count, 1);
225 	vfio_link_pfn(dma, vpfn);
226 	return 0;
227 }
228 
229 static void vfio_remove_from_pfn_list(struct vfio_dma *dma,
230 				      struct vfio_pfn *vpfn)
231 {
232 	vfio_unlink_pfn(dma, vpfn);
233 	kfree(vpfn);
234 }
235 
236 static struct vfio_pfn *vfio_iova_get_vfio_pfn(struct vfio_dma *dma,
237 					       unsigned long iova)
238 {
239 	struct vfio_pfn *vpfn = vfio_find_vpfn(dma, iova);
240 
241 	if (vpfn)
242 		atomic_inc(&vpfn->ref_count);
243 	return vpfn;
244 }
245 
246 static int vfio_iova_put_vfio_pfn(struct vfio_dma *dma, struct vfio_pfn *vpfn)
247 {
248 	int ret = 0;
249 
250 	if (atomic_dec_and_test(&vpfn->ref_count)) {
251 		ret = put_pfn(vpfn->pfn, dma->prot);
252 		vfio_remove_from_pfn_list(dma, vpfn);
253 	}
254 	return ret;
255 }
256 
257 static int vfio_lock_acct(struct vfio_dma *dma, long npage, bool async)
258 {
259 	struct mm_struct *mm;
260 	int ret;
261 
262 	if (!npage)
263 		return 0;
264 
265 	mm = async ? get_task_mm(dma->task) : dma->task->mm;
266 	if (!mm)
267 		return -ESRCH; /* process exited */
268 
269 	ret = down_write_killable(&mm->mmap_sem);
270 	if (!ret) {
271 		if (npage > 0) {
272 			if (!dma->lock_cap) {
273 				unsigned long limit;
274 
275 				limit = task_rlimit(dma->task,
276 						RLIMIT_MEMLOCK) >> PAGE_SHIFT;
277 
278 				if (mm->locked_vm + npage > limit)
279 					ret = -ENOMEM;
280 			}
281 		}
282 
283 		if (!ret)
284 			mm->locked_vm += npage;
285 
286 		up_write(&mm->mmap_sem);
287 	}
288 
289 	if (async)
290 		mmput(mm);
291 
292 	return ret;
293 }
294 
295 /*
296  * Some mappings aren't backed by a struct page, for example an mmap'd
297  * MMIO range for our own or another device.  These use a different
298  * pfn conversion and shouldn't be tracked as locked pages.
299  */
300 static bool is_invalid_reserved_pfn(unsigned long pfn)
301 {
302 	if (pfn_valid(pfn)) {
303 		bool reserved;
304 		struct page *tail = pfn_to_page(pfn);
305 		struct page *head = compound_head(tail);
306 		reserved = !!(PageReserved(head));
307 		if (head != tail) {
308 			/*
309 			 * "head" is not a dangling pointer
310 			 * (compound_head takes care of that)
311 			 * but the hugepage may have been split
312 			 * from under us (and we may not hold a
313 			 * reference count on the head page so it can
314 			 * be reused before we run PageReferenced), so
315 			 * we've to check PageTail before returning
316 			 * what we just read.
317 			 */
318 			smp_rmb();
319 			if (PageTail(tail))
320 				return reserved;
321 		}
322 		return PageReserved(tail);
323 	}
324 
325 	return true;
326 }
327 
328 static int put_pfn(unsigned long pfn, int prot)
329 {
330 	if (!is_invalid_reserved_pfn(pfn)) {
331 		struct page *page = pfn_to_page(pfn);
332 		if (prot & IOMMU_WRITE)
333 			SetPageDirty(page);
334 		put_page(page);
335 		return 1;
336 	}
337 	return 0;
338 }
339 
340 static int vaddr_get_pfn(struct mm_struct *mm, unsigned long vaddr,
341 			 int prot, unsigned long *pfn)
342 {
343 	struct page *page[1];
344 	struct vm_area_struct *vma;
345 	struct vm_area_struct *vmas[1];
346 	int ret;
347 
348 	if (mm == current->mm) {
349 		ret = get_user_pages_longterm(vaddr, 1, !!(prot & IOMMU_WRITE),
350 					      page, vmas);
351 	} else {
352 		unsigned int flags = 0;
353 
354 		if (prot & IOMMU_WRITE)
355 			flags |= FOLL_WRITE;
356 
357 		down_read(&mm->mmap_sem);
358 		ret = get_user_pages_remote(NULL, mm, vaddr, 1, flags, page,
359 					    vmas, NULL);
360 		/*
361 		 * The lifetime of a vaddr_get_pfn() page pin is
362 		 * userspace-controlled. In the fs-dax case this could
363 		 * lead to indefinite stalls in filesystem operations.
364 		 * Disallow attempts to pin fs-dax pages via this
365 		 * interface.
366 		 */
367 		if (ret > 0 && vma_is_fsdax(vmas[0])) {
368 			ret = -EOPNOTSUPP;
369 			put_page(page[0]);
370 		}
371 		up_read(&mm->mmap_sem);
372 	}
373 
374 	if (ret == 1) {
375 		*pfn = page_to_pfn(page[0]);
376 		return 0;
377 	}
378 
379 	down_read(&mm->mmap_sem);
380 
381 	vma = find_vma_intersection(mm, vaddr, vaddr + 1);
382 
383 	if (vma && vma->vm_flags & VM_PFNMAP) {
384 		*pfn = ((vaddr - vma->vm_start) >> PAGE_SHIFT) + vma->vm_pgoff;
385 		if (is_invalid_reserved_pfn(*pfn))
386 			ret = 0;
387 	}
388 
389 	up_read(&mm->mmap_sem);
390 	return ret;
391 }
392 
393 /*
394  * Attempt to pin pages.  We really don't want to track all the pfns and
395  * the iommu can only map chunks of consecutive pfns anyway, so get the
396  * first page and all consecutive pages with the same locking.
397  */
398 static long vfio_pin_pages_remote(struct vfio_dma *dma, unsigned long vaddr,
399 				  long npage, unsigned long *pfn_base,
400 				  unsigned long limit)
401 {
402 	unsigned long pfn = 0;
403 	long ret, pinned = 0, lock_acct = 0;
404 	bool rsvd;
405 	dma_addr_t iova = vaddr - dma->vaddr + dma->iova;
406 
407 	/* This code path is only user initiated */
408 	if (!current->mm)
409 		return -ENODEV;
410 
411 	ret = vaddr_get_pfn(current->mm, vaddr, dma->prot, pfn_base);
412 	if (ret)
413 		return ret;
414 
415 	pinned++;
416 	rsvd = is_invalid_reserved_pfn(*pfn_base);
417 
418 	/*
419 	 * Reserved pages aren't counted against the user, externally pinned
420 	 * pages are already counted against the user.
421 	 */
422 	if (!rsvd && !vfio_find_vpfn(dma, iova)) {
423 		if (!dma->lock_cap && current->mm->locked_vm + 1 > limit) {
424 			put_pfn(*pfn_base, dma->prot);
425 			pr_warn("%s: RLIMIT_MEMLOCK (%ld) exceeded\n", __func__,
426 					limit << PAGE_SHIFT);
427 			return -ENOMEM;
428 		}
429 		lock_acct++;
430 	}
431 
432 	if (unlikely(disable_hugepages))
433 		goto out;
434 
435 	/* Lock all the consecutive pages from pfn_base */
436 	for (vaddr += PAGE_SIZE, iova += PAGE_SIZE; pinned < npage;
437 	     pinned++, vaddr += PAGE_SIZE, iova += PAGE_SIZE) {
438 		ret = vaddr_get_pfn(current->mm, vaddr, dma->prot, &pfn);
439 		if (ret)
440 			break;
441 
442 		if (pfn != *pfn_base + pinned ||
443 		    rsvd != is_invalid_reserved_pfn(pfn)) {
444 			put_pfn(pfn, dma->prot);
445 			break;
446 		}
447 
448 		if (!rsvd && !vfio_find_vpfn(dma, iova)) {
449 			if (!dma->lock_cap &&
450 			    current->mm->locked_vm + lock_acct + 1 > limit) {
451 				put_pfn(pfn, dma->prot);
452 				pr_warn("%s: RLIMIT_MEMLOCK (%ld) exceeded\n",
453 					__func__, limit << PAGE_SHIFT);
454 				ret = -ENOMEM;
455 				goto unpin_out;
456 			}
457 			lock_acct++;
458 		}
459 	}
460 
461 out:
462 	ret = vfio_lock_acct(dma, lock_acct, false);
463 
464 unpin_out:
465 	if (ret) {
466 		if (!rsvd) {
467 			for (pfn = *pfn_base ; pinned ; pfn++, pinned--)
468 				put_pfn(pfn, dma->prot);
469 		}
470 
471 		return ret;
472 	}
473 
474 	return pinned;
475 }
476 
477 static long vfio_unpin_pages_remote(struct vfio_dma *dma, dma_addr_t iova,
478 				    unsigned long pfn, long npage,
479 				    bool do_accounting)
480 {
481 	long unlocked = 0, locked = 0;
482 	long i;
483 
484 	for (i = 0; i < npage; i++, iova += PAGE_SIZE) {
485 		if (put_pfn(pfn++, dma->prot)) {
486 			unlocked++;
487 			if (vfio_find_vpfn(dma, iova))
488 				locked++;
489 		}
490 	}
491 
492 	if (do_accounting)
493 		vfio_lock_acct(dma, locked - unlocked, true);
494 
495 	return unlocked;
496 }
497 
498 static int vfio_pin_page_external(struct vfio_dma *dma, unsigned long vaddr,
499 				  unsigned long *pfn_base, bool do_accounting)
500 {
501 	struct mm_struct *mm;
502 	int ret;
503 
504 	mm = get_task_mm(dma->task);
505 	if (!mm)
506 		return -ENODEV;
507 
508 	ret = vaddr_get_pfn(mm, vaddr, dma->prot, pfn_base);
509 	if (!ret && do_accounting && !is_invalid_reserved_pfn(*pfn_base)) {
510 		ret = vfio_lock_acct(dma, 1, true);
511 		if (ret) {
512 			put_pfn(*pfn_base, dma->prot);
513 			if (ret == -ENOMEM)
514 				pr_warn("%s: Task %s (%d) RLIMIT_MEMLOCK "
515 					"(%ld) exceeded\n", __func__,
516 					dma->task->comm, task_pid_nr(dma->task),
517 					task_rlimit(dma->task, RLIMIT_MEMLOCK));
518 		}
519 	}
520 
521 	mmput(mm);
522 	return ret;
523 }
524 
525 static int vfio_unpin_page_external(struct vfio_dma *dma, dma_addr_t iova,
526 				    bool do_accounting)
527 {
528 	int unlocked;
529 	struct vfio_pfn *vpfn = vfio_find_vpfn(dma, iova);
530 
531 	if (!vpfn)
532 		return 0;
533 
534 	unlocked = vfio_iova_put_vfio_pfn(dma, vpfn);
535 
536 	if (do_accounting)
537 		vfio_lock_acct(dma, -unlocked, true);
538 
539 	return unlocked;
540 }
541 
542 static int vfio_iommu_type1_pin_pages(void *iommu_data,
543 				      unsigned long *user_pfn,
544 				      int npage, int prot,
545 				      unsigned long *phys_pfn)
546 {
547 	struct vfio_iommu *iommu = iommu_data;
548 	int i, j, ret;
549 	unsigned long remote_vaddr;
550 	struct vfio_dma *dma;
551 	bool do_accounting;
552 
553 	if (!iommu || !user_pfn || !phys_pfn)
554 		return -EINVAL;
555 
556 	/* Supported for v2 version only */
557 	if (!iommu->v2)
558 		return -EACCES;
559 
560 	mutex_lock(&iommu->lock);
561 
562 	/* Fail if notifier list is empty */
563 	if ((!iommu->external_domain) || (!iommu->notifier.head)) {
564 		ret = -EINVAL;
565 		goto pin_done;
566 	}
567 
568 	/*
569 	 * If iommu capable domain exist in the container then all pages are
570 	 * already pinned and accounted. Accouting should be done if there is no
571 	 * iommu capable domain in the container.
572 	 */
573 	do_accounting = !IS_IOMMU_CAP_DOMAIN_IN_CONTAINER(iommu);
574 
575 	for (i = 0; i < npage; i++) {
576 		dma_addr_t iova;
577 		struct vfio_pfn *vpfn;
578 
579 		iova = user_pfn[i] << PAGE_SHIFT;
580 		dma = vfio_find_dma(iommu, iova, PAGE_SIZE);
581 		if (!dma) {
582 			ret = -EINVAL;
583 			goto pin_unwind;
584 		}
585 
586 		if ((dma->prot & prot) != prot) {
587 			ret = -EPERM;
588 			goto pin_unwind;
589 		}
590 
591 		vpfn = vfio_iova_get_vfio_pfn(dma, iova);
592 		if (vpfn) {
593 			phys_pfn[i] = vpfn->pfn;
594 			continue;
595 		}
596 
597 		remote_vaddr = dma->vaddr + iova - dma->iova;
598 		ret = vfio_pin_page_external(dma, remote_vaddr, &phys_pfn[i],
599 					     do_accounting);
600 		if (ret)
601 			goto pin_unwind;
602 
603 		ret = vfio_add_to_pfn_list(dma, iova, phys_pfn[i]);
604 		if (ret) {
605 			vfio_unpin_page_external(dma, iova, do_accounting);
606 			goto pin_unwind;
607 		}
608 	}
609 
610 	ret = i;
611 	goto pin_done;
612 
613 pin_unwind:
614 	phys_pfn[i] = 0;
615 	for (j = 0; j < i; j++) {
616 		dma_addr_t iova;
617 
618 		iova = user_pfn[j] << PAGE_SHIFT;
619 		dma = vfio_find_dma(iommu, iova, PAGE_SIZE);
620 		vfio_unpin_page_external(dma, iova, do_accounting);
621 		phys_pfn[j] = 0;
622 	}
623 pin_done:
624 	mutex_unlock(&iommu->lock);
625 	return ret;
626 }
627 
628 static int vfio_iommu_type1_unpin_pages(void *iommu_data,
629 					unsigned long *user_pfn,
630 					int npage)
631 {
632 	struct vfio_iommu *iommu = iommu_data;
633 	bool do_accounting;
634 	int i;
635 
636 	if (!iommu || !user_pfn)
637 		return -EINVAL;
638 
639 	/* Supported for v2 version only */
640 	if (!iommu->v2)
641 		return -EACCES;
642 
643 	mutex_lock(&iommu->lock);
644 
645 	if (!iommu->external_domain) {
646 		mutex_unlock(&iommu->lock);
647 		return -EINVAL;
648 	}
649 
650 	do_accounting = !IS_IOMMU_CAP_DOMAIN_IN_CONTAINER(iommu);
651 	for (i = 0; i < npage; i++) {
652 		struct vfio_dma *dma;
653 		dma_addr_t iova;
654 
655 		iova = user_pfn[i] << PAGE_SHIFT;
656 		dma = vfio_find_dma(iommu, iova, PAGE_SIZE);
657 		if (!dma)
658 			goto unpin_exit;
659 		vfio_unpin_page_external(dma, iova, do_accounting);
660 	}
661 
662 unpin_exit:
663 	mutex_unlock(&iommu->lock);
664 	return i > npage ? npage : (i > 0 ? i : -EINVAL);
665 }
666 
667 static long vfio_sync_unpin(struct vfio_dma *dma, struct vfio_domain *domain,
668 				struct list_head *regions)
669 {
670 	long unlocked = 0;
671 	struct vfio_regions *entry, *next;
672 
673 	iommu_tlb_sync(domain->domain);
674 
675 	list_for_each_entry_safe(entry, next, regions, list) {
676 		unlocked += vfio_unpin_pages_remote(dma,
677 						    entry->iova,
678 						    entry->phys >> PAGE_SHIFT,
679 						    entry->len >> PAGE_SHIFT,
680 						    false);
681 		list_del(&entry->list);
682 		kfree(entry);
683 	}
684 
685 	cond_resched();
686 
687 	return unlocked;
688 }
689 
690 /*
691  * Generally, VFIO needs to unpin remote pages after each IOTLB flush.
692  * Therefore, when using IOTLB flush sync interface, VFIO need to keep track
693  * of these regions (currently using a list).
694  *
695  * This value specifies maximum number of regions for each IOTLB flush sync.
696  */
697 #define VFIO_IOMMU_TLB_SYNC_MAX		512
698 
699 static size_t unmap_unpin_fast(struct vfio_domain *domain,
700 			       struct vfio_dma *dma, dma_addr_t *iova,
701 			       size_t len, phys_addr_t phys, long *unlocked,
702 			       struct list_head *unmapped_list,
703 			       int *unmapped_cnt)
704 {
705 	size_t unmapped = 0;
706 	struct vfio_regions *entry = kzalloc(sizeof(*entry), GFP_KERNEL);
707 
708 	if (entry) {
709 		unmapped = iommu_unmap_fast(domain->domain, *iova, len);
710 
711 		if (!unmapped) {
712 			kfree(entry);
713 		} else {
714 			iommu_tlb_range_add(domain->domain, *iova, unmapped);
715 			entry->iova = *iova;
716 			entry->phys = phys;
717 			entry->len  = unmapped;
718 			list_add_tail(&entry->list, unmapped_list);
719 
720 			*iova += unmapped;
721 			(*unmapped_cnt)++;
722 		}
723 	}
724 
725 	/*
726 	 * Sync if the number of fast-unmap regions hits the limit
727 	 * or in case of errors.
728 	 */
729 	if (*unmapped_cnt >= VFIO_IOMMU_TLB_SYNC_MAX || !unmapped) {
730 		*unlocked += vfio_sync_unpin(dma, domain,
731 					     unmapped_list);
732 		*unmapped_cnt = 0;
733 	}
734 
735 	return unmapped;
736 }
737 
738 static size_t unmap_unpin_slow(struct vfio_domain *domain,
739 			       struct vfio_dma *dma, dma_addr_t *iova,
740 			       size_t len, phys_addr_t phys,
741 			       long *unlocked)
742 {
743 	size_t unmapped = iommu_unmap(domain->domain, *iova, len);
744 
745 	if (unmapped) {
746 		*unlocked += vfio_unpin_pages_remote(dma, *iova,
747 						     phys >> PAGE_SHIFT,
748 						     unmapped >> PAGE_SHIFT,
749 						     false);
750 		*iova += unmapped;
751 		cond_resched();
752 	}
753 	return unmapped;
754 }
755 
756 static long vfio_unmap_unpin(struct vfio_iommu *iommu, struct vfio_dma *dma,
757 			     bool do_accounting)
758 {
759 	dma_addr_t iova = dma->iova, end = dma->iova + dma->size;
760 	struct vfio_domain *domain, *d;
761 	LIST_HEAD(unmapped_region_list);
762 	int unmapped_region_cnt = 0;
763 	long unlocked = 0;
764 
765 	if (!dma->size)
766 		return 0;
767 
768 	if (!IS_IOMMU_CAP_DOMAIN_IN_CONTAINER(iommu))
769 		return 0;
770 
771 	/*
772 	 * We use the IOMMU to track the physical addresses, otherwise we'd
773 	 * need a much more complicated tracking system.  Unfortunately that
774 	 * means we need to use one of the iommu domains to figure out the
775 	 * pfns to unpin.  The rest need to be unmapped in advance so we have
776 	 * no iommu translations remaining when the pages are unpinned.
777 	 */
778 	domain = d = list_first_entry(&iommu->domain_list,
779 				      struct vfio_domain, next);
780 
781 	list_for_each_entry_continue(d, &iommu->domain_list, next) {
782 		iommu_unmap(d->domain, dma->iova, dma->size);
783 		cond_resched();
784 	}
785 
786 	while (iova < end) {
787 		size_t unmapped, len;
788 		phys_addr_t phys, next;
789 
790 		phys = iommu_iova_to_phys(domain->domain, iova);
791 		if (WARN_ON(!phys)) {
792 			iova += PAGE_SIZE;
793 			continue;
794 		}
795 
796 		/*
797 		 * To optimize for fewer iommu_unmap() calls, each of which
798 		 * may require hardware cache flushing, try to find the
799 		 * largest contiguous physical memory chunk to unmap.
800 		 */
801 		for (len = PAGE_SIZE;
802 		     !domain->fgsp && iova + len < end; len += PAGE_SIZE) {
803 			next = iommu_iova_to_phys(domain->domain, iova + len);
804 			if (next != phys + len)
805 				break;
806 		}
807 
808 		/*
809 		 * First, try to use fast unmap/unpin. In case of failure,
810 		 * switch to slow unmap/unpin path.
811 		 */
812 		unmapped = unmap_unpin_fast(domain, dma, &iova, len, phys,
813 					    &unlocked, &unmapped_region_list,
814 					    &unmapped_region_cnt);
815 		if (!unmapped) {
816 			unmapped = unmap_unpin_slow(domain, dma, &iova, len,
817 						    phys, &unlocked);
818 			if (WARN_ON(!unmapped))
819 				break;
820 		}
821 	}
822 
823 	dma->iommu_mapped = false;
824 
825 	if (unmapped_region_cnt)
826 		unlocked += vfio_sync_unpin(dma, domain, &unmapped_region_list);
827 
828 	if (do_accounting) {
829 		vfio_lock_acct(dma, -unlocked, true);
830 		return 0;
831 	}
832 	return unlocked;
833 }
834 
835 static void vfio_remove_dma(struct vfio_iommu *iommu, struct vfio_dma *dma)
836 {
837 	vfio_unmap_unpin(iommu, dma, true);
838 	vfio_unlink_dma(iommu, dma);
839 	put_task_struct(dma->task);
840 	kfree(dma);
841 }
842 
843 static unsigned long vfio_pgsize_bitmap(struct vfio_iommu *iommu)
844 {
845 	struct vfio_domain *domain;
846 	unsigned long bitmap = ULONG_MAX;
847 
848 	mutex_lock(&iommu->lock);
849 	list_for_each_entry(domain, &iommu->domain_list, next)
850 		bitmap &= domain->domain->pgsize_bitmap;
851 	mutex_unlock(&iommu->lock);
852 
853 	/*
854 	 * In case the IOMMU supports page sizes smaller than PAGE_SIZE
855 	 * we pretend PAGE_SIZE is supported and hide sub-PAGE_SIZE sizes.
856 	 * That way the user will be able to map/unmap buffers whose size/
857 	 * start address is aligned with PAGE_SIZE. Pinning code uses that
858 	 * granularity while iommu driver can use the sub-PAGE_SIZE size
859 	 * to map the buffer.
860 	 */
861 	if (bitmap & ~PAGE_MASK) {
862 		bitmap &= PAGE_MASK;
863 		bitmap |= PAGE_SIZE;
864 	}
865 
866 	return bitmap;
867 }
868 
869 static int vfio_dma_do_unmap(struct vfio_iommu *iommu,
870 			     struct vfio_iommu_type1_dma_unmap *unmap)
871 {
872 	uint64_t mask;
873 	struct vfio_dma *dma, *dma_last = NULL;
874 	size_t unmapped = 0;
875 	int ret = 0, retries = 0;
876 
877 	mask = ((uint64_t)1 << __ffs(vfio_pgsize_bitmap(iommu))) - 1;
878 
879 	if (unmap->iova & mask)
880 		return -EINVAL;
881 	if (!unmap->size || unmap->size & mask)
882 		return -EINVAL;
883 	if (unmap->iova + unmap->size < unmap->iova ||
884 	    unmap->size > SIZE_MAX)
885 		return -EINVAL;
886 
887 	WARN_ON(mask & PAGE_MASK);
888 again:
889 	mutex_lock(&iommu->lock);
890 
891 	/*
892 	 * vfio-iommu-type1 (v1) - User mappings were coalesced together to
893 	 * avoid tracking individual mappings.  This means that the granularity
894 	 * of the original mapping was lost and the user was allowed to attempt
895 	 * to unmap any range.  Depending on the contiguousness of physical
896 	 * memory and page sizes supported by the IOMMU, arbitrary unmaps may
897 	 * or may not have worked.  We only guaranteed unmap granularity
898 	 * matching the original mapping; even though it was untracked here,
899 	 * the original mappings are reflected in IOMMU mappings.  This
900 	 * resulted in a couple unusual behaviors.  First, if a range is not
901 	 * able to be unmapped, ex. a set of 4k pages that was mapped as a
902 	 * 2M hugepage into the IOMMU, the unmap ioctl returns success but with
903 	 * a zero sized unmap.  Also, if an unmap request overlaps the first
904 	 * address of a hugepage, the IOMMU will unmap the entire hugepage.
905 	 * This also returns success and the returned unmap size reflects the
906 	 * actual size unmapped.
907 	 *
908 	 * We attempt to maintain compatibility with this "v1" interface, but
909 	 * we take control out of the hands of the IOMMU.  Therefore, an unmap
910 	 * request offset from the beginning of the original mapping will
911 	 * return success with zero sized unmap.  And an unmap request covering
912 	 * the first iova of mapping will unmap the entire range.
913 	 *
914 	 * The v2 version of this interface intends to be more deterministic.
915 	 * Unmap requests must fully cover previous mappings.  Multiple
916 	 * mappings may still be unmaped by specifying large ranges, but there
917 	 * must not be any previous mappings bisected by the range.  An error
918 	 * will be returned if these conditions are not met.  The v2 interface
919 	 * will only return success and a size of zero if there were no
920 	 * mappings within the range.
921 	 */
922 	if (iommu->v2) {
923 		dma = vfio_find_dma(iommu, unmap->iova, 1);
924 		if (dma && dma->iova != unmap->iova) {
925 			ret = -EINVAL;
926 			goto unlock;
927 		}
928 		dma = vfio_find_dma(iommu, unmap->iova + unmap->size - 1, 0);
929 		if (dma && dma->iova + dma->size != unmap->iova + unmap->size) {
930 			ret = -EINVAL;
931 			goto unlock;
932 		}
933 	}
934 
935 	while ((dma = vfio_find_dma(iommu, unmap->iova, unmap->size))) {
936 		if (!iommu->v2 && unmap->iova > dma->iova)
937 			break;
938 		/*
939 		 * Task with same address space who mapped this iova range is
940 		 * allowed to unmap the iova range.
941 		 */
942 		if (dma->task->mm != current->mm)
943 			break;
944 
945 		if (!RB_EMPTY_ROOT(&dma->pfn_list)) {
946 			struct vfio_iommu_type1_dma_unmap nb_unmap;
947 
948 			if (dma_last == dma) {
949 				BUG_ON(++retries > 10);
950 			} else {
951 				dma_last = dma;
952 				retries = 0;
953 			}
954 
955 			nb_unmap.iova = dma->iova;
956 			nb_unmap.size = dma->size;
957 
958 			/*
959 			 * Notify anyone (mdev vendor drivers) to invalidate and
960 			 * unmap iovas within the range we're about to unmap.
961 			 * Vendor drivers MUST unpin pages in response to an
962 			 * invalidation.
963 			 */
964 			mutex_unlock(&iommu->lock);
965 			blocking_notifier_call_chain(&iommu->notifier,
966 						    VFIO_IOMMU_NOTIFY_DMA_UNMAP,
967 						    &nb_unmap);
968 			goto again;
969 		}
970 		unmapped += dma->size;
971 		vfio_remove_dma(iommu, dma);
972 	}
973 
974 unlock:
975 	mutex_unlock(&iommu->lock);
976 
977 	/* Report how much was unmapped */
978 	unmap->size = unmapped;
979 
980 	return ret;
981 }
982 
983 /*
984  * Turns out AMD IOMMU has a page table bug where it won't map large pages
985  * to a region that previously mapped smaller pages.  This should be fixed
986  * soon, so this is just a temporary workaround to break mappings down into
987  * PAGE_SIZE.  Better to map smaller pages than nothing.
988  */
989 static int map_try_harder(struct vfio_domain *domain, dma_addr_t iova,
990 			  unsigned long pfn, long npage, int prot)
991 {
992 	long i;
993 	int ret = 0;
994 
995 	for (i = 0; i < npage; i++, pfn++, iova += PAGE_SIZE) {
996 		ret = iommu_map(domain->domain, iova,
997 				(phys_addr_t)pfn << PAGE_SHIFT,
998 				PAGE_SIZE, prot | domain->prot);
999 		if (ret)
1000 			break;
1001 	}
1002 
1003 	for (; i < npage && i > 0; i--, iova -= PAGE_SIZE)
1004 		iommu_unmap(domain->domain, iova, PAGE_SIZE);
1005 
1006 	return ret;
1007 }
1008 
1009 static int vfio_iommu_map(struct vfio_iommu *iommu, dma_addr_t iova,
1010 			  unsigned long pfn, long npage, int prot)
1011 {
1012 	struct vfio_domain *d;
1013 	int ret;
1014 
1015 	list_for_each_entry(d, &iommu->domain_list, next) {
1016 		ret = iommu_map(d->domain, iova, (phys_addr_t)pfn << PAGE_SHIFT,
1017 				npage << PAGE_SHIFT, prot | d->prot);
1018 		if (ret) {
1019 			if (ret != -EBUSY ||
1020 			    map_try_harder(d, iova, pfn, npage, prot))
1021 				goto unwind;
1022 		}
1023 
1024 		cond_resched();
1025 	}
1026 
1027 	return 0;
1028 
1029 unwind:
1030 	list_for_each_entry_continue_reverse(d, &iommu->domain_list, next)
1031 		iommu_unmap(d->domain, iova, npage << PAGE_SHIFT);
1032 
1033 	return ret;
1034 }
1035 
1036 static int vfio_pin_map_dma(struct vfio_iommu *iommu, struct vfio_dma *dma,
1037 			    size_t map_size)
1038 {
1039 	dma_addr_t iova = dma->iova;
1040 	unsigned long vaddr = dma->vaddr;
1041 	size_t size = map_size;
1042 	long npage;
1043 	unsigned long pfn, limit = rlimit(RLIMIT_MEMLOCK) >> PAGE_SHIFT;
1044 	int ret = 0;
1045 
1046 	while (size) {
1047 		/* Pin a contiguous chunk of memory */
1048 		npage = vfio_pin_pages_remote(dma, vaddr + dma->size,
1049 					      size >> PAGE_SHIFT, &pfn, limit);
1050 		if (npage <= 0) {
1051 			WARN_ON(!npage);
1052 			ret = (int)npage;
1053 			break;
1054 		}
1055 
1056 		/* Map it! */
1057 		ret = vfio_iommu_map(iommu, iova + dma->size, pfn, npage,
1058 				     dma->prot);
1059 		if (ret) {
1060 			vfio_unpin_pages_remote(dma, iova + dma->size, pfn,
1061 						npage, true);
1062 			break;
1063 		}
1064 
1065 		size -= npage << PAGE_SHIFT;
1066 		dma->size += npage << PAGE_SHIFT;
1067 	}
1068 
1069 	dma->iommu_mapped = true;
1070 
1071 	if (ret)
1072 		vfio_remove_dma(iommu, dma);
1073 
1074 	return ret;
1075 }
1076 
1077 static int vfio_dma_do_map(struct vfio_iommu *iommu,
1078 			   struct vfio_iommu_type1_dma_map *map)
1079 {
1080 	dma_addr_t iova = map->iova;
1081 	unsigned long vaddr = map->vaddr;
1082 	size_t size = map->size;
1083 	int ret = 0, prot = 0;
1084 	uint64_t mask;
1085 	struct vfio_dma *dma;
1086 
1087 	/* Verify that none of our __u64 fields overflow */
1088 	if (map->size != size || map->vaddr != vaddr || map->iova != iova)
1089 		return -EINVAL;
1090 
1091 	mask = ((uint64_t)1 << __ffs(vfio_pgsize_bitmap(iommu))) - 1;
1092 
1093 	WARN_ON(mask & PAGE_MASK);
1094 
1095 	/* READ/WRITE from device perspective */
1096 	if (map->flags & VFIO_DMA_MAP_FLAG_WRITE)
1097 		prot |= IOMMU_WRITE;
1098 	if (map->flags & VFIO_DMA_MAP_FLAG_READ)
1099 		prot |= IOMMU_READ;
1100 
1101 	if (!prot || !size || (size | iova | vaddr) & mask)
1102 		return -EINVAL;
1103 
1104 	/* Don't allow IOVA or virtual address wrap */
1105 	if (iova + size - 1 < iova || vaddr + size - 1 < vaddr)
1106 		return -EINVAL;
1107 
1108 	mutex_lock(&iommu->lock);
1109 
1110 	if (vfio_find_dma(iommu, iova, size)) {
1111 		ret = -EEXIST;
1112 		goto out_unlock;
1113 	}
1114 
1115 	dma = kzalloc(sizeof(*dma), GFP_KERNEL);
1116 	if (!dma) {
1117 		ret = -ENOMEM;
1118 		goto out_unlock;
1119 	}
1120 
1121 	dma->iova = iova;
1122 	dma->vaddr = vaddr;
1123 	dma->prot = prot;
1124 
1125 	/*
1126 	 * We need to be able to both add to a task's locked memory and test
1127 	 * against the locked memory limit and we need to be able to do both
1128 	 * outside of this call path as pinning can be asynchronous via the
1129 	 * external interfaces for mdev devices.  RLIMIT_MEMLOCK requires a
1130 	 * task_struct and VM locked pages requires an mm_struct, however
1131 	 * holding an indefinite mm reference is not recommended, therefore we
1132 	 * only hold a reference to a task.  We could hold a reference to
1133 	 * current, however QEMU uses this call path through vCPU threads,
1134 	 * which can be killed resulting in a NULL mm and failure in the unmap
1135 	 * path when called via a different thread.  Avoid this problem by
1136 	 * using the group_leader as threads within the same group require
1137 	 * both CLONE_THREAD and CLONE_VM and will therefore use the same
1138 	 * mm_struct.
1139 	 *
1140 	 * Previously we also used the task for testing CAP_IPC_LOCK at the
1141 	 * time of pinning and accounting, however has_capability() makes use
1142 	 * of real_cred, a copy-on-write field, so we can't guarantee that it
1143 	 * matches group_leader, or in fact that it might not change by the
1144 	 * time it's evaluated.  If a process were to call MAP_DMA with
1145 	 * CAP_IPC_LOCK but later drop it, it doesn't make sense that they
1146 	 * possibly see different results for an iommu_mapped vfio_dma vs
1147 	 * externally mapped.  Therefore track CAP_IPC_LOCK in vfio_dma at the
1148 	 * time of calling MAP_DMA.
1149 	 */
1150 	get_task_struct(current->group_leader);
1151 	dma->task = current->group_leader;
1152 	dma->lock_cap = capable(CAP_IPC_LOCK);
1153 
1154 	dma->pfn_list = RB_ROOT;
1155 
1156 	/* Insert zero-sized and grow as we map chunks of it */
1157 	vfio_link_dma(iommu, dma);
1158 
1159 	/* Don't pin and map if container doesn't contain IOMMU capable domain*/
1160 	if (!IS_IOMMU_CAP_DOMAIN_IN_CONTAINER(iommu))
1161 		dma->size = size;
1162 	else
1163 		ret = vfio_pin_map_dma(iommu, dma, size);
1164 
1165 out_unlock:
1166 	mutex_unlock(&iommu->lock);
1167 	return ret;
1168 }
1169 
1170 static int vfio_bus_type(struct device *dev, void *data)
1171 {
1172 	struct bus_type **bus = data;
1173 
1174 	if (*bus && *bus != dev->bus)
1175 		return -EINVAL;
1176 
1177 	*bus = dev->bus;
1178 
1179 	return 0;
1180 }
1181 
1182 static int vfio_iommu_replay(struct vfio_iommu *iommu,
1183 			     struct vfio_domain *domain)
1184 {
1185 	struct vfio_domain *d;
1186 	struct rb_node *n;
1187 	unsigned long limit = rlimit(RLIMIT_MEMLOCK) >> PAGE_SHIFT;
1188 	int ret;
1189 
1190 	/* Arbitrarily pick the first domain in the list for lookups */
1191 	d = list_first_entry(&iommu->domain_list, struct vfio_domain, next);
1192 	n = rb_first(&iommu->dma_list);
1193 
1194 	for (; n; n = rb_next(n)) {
1195 		struct vfio_dma *dma;
1196 		dma_addr_t iova;
1197 
1198 		dma = rb_entry(n, struct vfio_dma, node);
1199 		iova = dma->iova;
1200 
1201 		while (iova < dma->iova + dma->size) {
1202 			phys_addr_t phys;
1203 			size_t size;
1204 
1205 			if (dma->iommu_mapped) {
1206 				phys_addr_t p;
1207 				dma_addr_t i;
1208 
1209 				phys = iommu_iova_to_phys(d->domain, iova);
1210 
1211 				if (WARN_ON(!phys)) {
1212 					iova += PAGE_SIZE;
1213 					continue;
1214 				}
1215 
1216 				size = PAGE_SIZE;
1217 				p = phys + size;
1218 				i = iova + size;
1219 				while (i < dma->iova + dma->size &&
1220 				       p == iommu_iova_to_phys(d->domain, i)) {
1221 					size += PAGE_SIZE;
1222 					p += PAGE_SIZE;
1223 					i += PAGE_SIZE;
1224 				}
1225 			} else {
1226 				unsigned long pfn;
1227 				unsigned long vaddr = dma->vaddr +
1228 						     (iova - dma->iova);
1229 				size_t n = dma->iova + dma->size - iova;
1230 				long npage;
1231 
1232 				npage = vfio_pin_pages_remote(dma, vaddr,
1233 							      n >> PAGE_SHIFT,
1234 							      &pfn, limit);
1235 				if (npage <= 0) {
1236 					WARN_ON(!npage);
1237 					ret = (int)npage;
1238 					return ret;
1239 				}
1240 
1241 				phys = pfn << PAGE_SHIFT;
1242 				size = npage << PAGE_SHIFT;
1243 			}
1244 
1245 			ret = iommu_map(domain->domain, iova, phys,
1246 					size, dma->prot | domain->prot);
1247 			if (ret)
1248 				return ret;
1249 
1250 			iova += size;
1251 		}
1252 		dma->iommu_mapped = true;
1253 	}
1254 	return 0;
1255 }
1256 
1257 /*
1258  * We change our unmap behavior slightly depending on whether the IOMMU
1259  * supports fine-grained superpages.  IOMMUs like AMD-Vi will use a superpage
1260  * for practically any contiguous power-of-two mapping we give it.  This means
1261  * we don't need to look for contiguous chunks ourselves to make unmapping
1262  * more efficient.  On IOMMUs with coarse-grained super pages, like Intel VT-d
1263  * with discrete 2M/1G/512G/1T superpages, identifying contiguous chunks
1264  * significantly boosts non-hugetlbfs mappings and doesn't seem to hurt when
1265  * hugetlbfs is in use.
1266  */
1267 static void vfio_test_domain_fgsp(struct vfio_domain *domain)
1268 {
1269 	struct page *pages;
1270 	int ret, order = get_order(PAGE_SIZE * 2);
1271 
1272 	pages = alloc_pages(GFP_KERNEL | __GFP_ZERO, order);
1273 	if (!pages)
1274 		return;
1275 
1276 	ret = iommu_map(domain->domain, 0, page_to_phys(pages), PAGE_SIZE * 2,
1277 			IOMMU_READ | IOMMU_WRITE | domain->prot);
1278 	if (!ret) {
1279 		size_t unmapped = iommu_unmap(domain->domain, 0, PAGE_SIZE);
1280 
1281 		if (unmapped == PAGE_SIZE)
1282 			iommu_unmap(domain->domain, PAGE_SIZE, PAGE_SIZE);
1283 		else
1284 			domain->fgsp = true;
1285 	}
1286 
1287 	__free_pages(pages, order);
1288 }
1289 
1290 static struct vfio_group *find_iommu_group(struct vfio_domain *domain,
1291 					   struct iommu_group *iommu_group)
1292 {
1293 	struct vfio_group *g;
1294 
1295 	list_for_each_entry(g, &domain->group_list, next) {
1296 		if (g->iommu_group == iommu_group)
1297 			return g;
1298 	}
1299 
1300 	return NULL;
1301 }
1302 
1303 static bool vfio_iommu_has_sw_msi(struct iommu_group *group, phys_addr_t *base)
1304 {
1305 	struct list_head group_resv_regions;
1306 	struct iommu_resv_region *region, *next;
1307 	bool ret = false;
1308 
1309 	INIT_LIST_HEAD(&group_resv_regions);
1310 	iommu_get_group_resv_regions(group, &group_resv_regions);
1311 	list_for_each_entry(region, &group_resv_regions, list) {
1312 		/*
1313 		 * The presence of any 'real' MSI regions should take
1314 		 * precedence over the software-managed one if the
1315 		 * IOMMU driver happens to advertise both types.
1316 		 */
1317 		if (region->type == IOMMU_RESV_MSI) {
1318 			ret = false;
1319 			break;
1320 		}
1321 
1322 		if (region->type == IOMMU_RESV_SW_MSI) {
1323 			*base = region->start;
1324 			ret = true;
1325 		}
1326 	}
1327 	list_for_each_entry_safe(region, next, &group_resv_regions, list)
1328 		kfree(region);
1329 	return ret;
1330 }
1331 
1332 static int vfio_iommu_type1_attach_group(void *iommu_data,
1333 					 struct iommu_group *iommu_group)
1334 {
1335 	struct vfio_iommu *iommu = iommu_data;
1336 	struct vfio_group *group;
1337 	struct vfio_domain *domain, *d;
1338 	struct bus_type *bus = NULL, *mdev_bus;
1339 	int ret;
1340 	bool resv_msi, msi_remap;
1341 	phys_addr_t resv_msi_base;
1342 
1343 	mutex_lock(&iommu->lock);
1344 
1345 	list_for_each_entry(d, &iommu->domain_list, next) {
1346 		if (find_iommu_group(d, iommu_group)) {
1347 			mutex_unlock(&iommu->lock);
1348 			return -EINVAL;
1349 		}
1350 	}
1351 
1352 	if (iommu->external_domain) {
1353 		if (find_iommu_group(iommu->external_domain, iommu_group)) {
1354 			mutex_unlock(&iommu->lock);
1355 			return -EINVAL;
1356 		}
1357 	}
1358 
1359 	group = kzalloc(sizeof(*group), GFP_KERNEL);
1360 	domain = kzalloc(sizeof(*domain), GFP_KERNEL);
1361 	if (!group || !domain) {
1362 		ret = -ENOMEM;
1363 		goto out_free;
1364 	}
1365 
1366 	group->iommu_group = iommu_group;
1367 
1368 	/* Determine bus_type in order to allocate a domain */
1369 	ret = iommu_group_for_each_dev(iommu_group, &bus, vfio_bus_type);
1370 	if (ret)
1371 		goto out_free;
1372 
1373 	mdev_bus = symbol_get(mdev_bus_type);
1374 
1375 	if (mdev_bus) {
1376 		if ((bus == mdev_bus) && !iommu_present(bus)) {
1377 			symbol_put(mdev_bus_type);
1378 			if (!iommu->external_domain) {
1379 				INIT_LIST_HEAD(&domain->group_list);
1380 				iommu->external_domain = domain;
1381 			} else
1382 				kfree(domain);
1383 
1384 			list_add(&group->next,
1385 				 &iommu->external_domain->group_list);
1386 			mutex_unlock(&iommu->lock);
1387 			return 0;
1388 		}
1389 		symbol_put(mdev_bus_type);
1390 	}
1391 
1392 	domain->domain = iommu_domain_alloc(bus);
1393 	if (!domain->domain) {
1394 		ret = -EIO;
1395 		goto out_free;
1396 	}
1397 
1398 	if (iommu->nesting) {
1399 		int attr = 1;
1400 
1401 		ret = iommu_domain_set_attr(domain->domain, DOMAIN_ATTR_NESTING,
1402 					    &attr);
1403 		if (ret)
1404 			goto out_domain;
1405 	}
1406 
1407 	ret = iommu_attach_group(domain->domain, iommu_group);
1408 	if (ret)
1409 		goto out_domain;
1410 
1411 	resv_msi = vfio_iommu_has_sw_msi(iommu_group, &resv_msi_base);
1412 
1413 	INIT_LIST_HEAD(&domain->group_list);
1414 	list_add(&group->next, &domain->group_list);
1415 
1416 	msi_remap = irq_domain_check_msi_remap() ||
1417 		    iommu_capable(bus, IOMMU_CAP_INTR_REMAP);
1418 
1419 	if (!allow_unsafe_interrupts && !msi_remap) {
1420 		pr_warn("%s: No interrupt remapping support.  Use the module param \"allow_unsafe_interrupts\" to enable VFIO IOMMU support on this platform\n",
1421 		       __func__);
1422 		ret = -EPERM;
1423 		goto out_detach;
1424 	}
1425 
1426 	if (iommu_capable(bus, IOMMU_CAP_CACHE_COHERENCY))
1427 		domain->prot |= IOMMU_CACHE;
1428 
1429 	/*
1430 	 * Try to match an existing compatible domain.  We don't want to
1431 	 * preclude an IOMMU driver supporting multiple bus_types and being
1432 	 * able to include different bus_types in the same IOMMU domain, so
1433 	 * we test whether the domains use the same iommu_ops rather than
1434 	 * testing if they're on the same bus_type.
1435 	 */
1436 	list_for_each_entry(d, &iommu->domain_list, next) {
1437 		if (d->domain->ops == domain->domain->ops &&
1438 		    d->prot == domain->prot) {
1439 			iommu_detach_group(domain->domain, iommu_group);
1440 			if (!iommu_attach_group(d->domain, iommu_group)) {
1441 				list_add(&group->next, &d->group_list);
1442 				iommu_domain_free(domain->domain);
1443 				kfree(domain);
1444 				mutex_unlock(&iommu->lock);
1445 				return 0;
1446 			}
1447 
1448 			ret = iommu_attach_group(domain->domain, iommu_group);
1449 			if (ret)
1450 				goto out_domain;
1451 		}
1452 	}
1453 
1454 	vfio_test_domain_fgsp(domain);
1455 
1456 	/* replay mappings on new domains */
1457 	ret = vfio_iommu_replay(iommu, domain);
1458 	if (ret)
1459 		goto out_detach;
1460 
1461 	if (resv_msi) {
1462 		ret = iommu_get_msi_cookie(domain->domain, resv_msi_base);
1463 		if (ret)
1464 			goto out_detach;
1465 	}
1466 
1467 	list_add(&domain->next, &iommu->domain_list);
1468 
1469 	mutex_unlock(&iommu->lock);
1470 
1471 	return 0;
1472 
1473 out_detach:
1474 	iommu_detach_group(domain->domain, iommu_group);
1475 out_domain:
1476 	iommu_domain_free(domain->domain);
1477 out_free:
1478 	kfree(domain);
1479 	kfree(group);
1480 	mutex_unlock(&iommu->lock);
1481 	return ret;
1482 }
1483 
1484 static void vfio_iommu_unmap_unpin_all(struct vfio_iommu *iommu)
1485 {
1486 	struct rb_node *node;
1487 
1488 	while ((node = rb_first(&iommu->dma_list)))
1489 		vfio_remove_dma(iommu, rb_entry(node, struct vfio_dma, node));
1490 }
1491 
1492 static void vfio_iommu_unmap_unpin_reaccount(struct vfio_iommu *iommu)
1493 {
1494 	struct rb_node *n, *p;
1495 
1496 	n = rb_first(&iommu->dma_list);
1497 	for (; n; n = rb_next(n)) {
1498 		struct vfio_dma *dma;
1499 		long locked = 0, unlocked = 0;
1500 
1501 		dma = rb_entry(n, struct vfio_dma, node);
1502 		unlocked += vfio_unmap_unpin(iommu, dma, false);
1503 		p = rb_first(&dma->pfn_list);
1504 		for (; p; p = rb_next(p)) {
1505 			struct vfio_pfn *vpfn = rb_entry(p, struct vfio_pfn,
1506 							 node);
1507 
1508 			if (!is_invalid_reserved_pfn(vpfn->pfn))
1509 				locked++;
1510 		}
1511 		vfio_lock_acct(dma, locked - unlocked, true);
1512 	}
1513 }
1514 
1515 static void vfio_sanity_check_pfn_list(struct vfio_iommu *iommu)
1516 {
1517 	struct rb_node *n;
1518 
1519 	n = rb_first(&iommu->dma_list);
1520 	for (; n; n = rb_next(n)) {
1521 		struct vfio_dma *dma;
1522 
1523 		dma = rb_entry(n, struct vfio_dma, node);
1524 
1525 		if (WARN_ON(!RB_EMPTY_ROOT(&dma->pfn_list)))
1526 			break;
1527 	}
1528 	/* mdev vendor driver must unregister notifier */
1529 	WARN_ON(iommu->notifier.head);
1530 }
1531 
1532 static void vfio_iommu_type1_detach_group(void *iommu_data,
1533 					  struct iommu_group *iommu_group)
1534 {
1535 	struct vfio_iommu *iommu = iommu_data;
1536 	struct vfio_domain *domain;
1537 	struct vfio_group *group;
1538 
1539 	mutex_lock(&iommu->lock);
1540 
1541 	if (iommu->external_domain) {
1542 		group = find_iommu_group(iommu->external_domain, iommu_group);
1543 		if (group) {
1544 			list_del(&group->next);
1545 			kfree(group);
1546 
1547 			if (list_empty(&iommu->external_domain->group_list)) {
1548 				vfio_sanity_check_pfn_list(iommu);
1549 
1550 				if (!IS_IOMMU_CAP_DOMAIN_IN_CONTAINER(iommu))
1551 					vfio_iommu_unmap_unpin_all(iommu);
1552 
1553 				kfree(iommu->external_domain);
1554 				iommu->external_domain = NULL;
1555 			}
1556 			goto detach_group_done;
1557 		}
1558 	}
1559 
1560 	list_for_each_entry(domain, &iommu->domain_list, next) {
1561 		group = find_iommu_group(domain, iommu_group);
1562 		if (!group)
1563 			continue;
1564 
1565 		iommu_detach_group(domain->domain, iommu_group);
1566 		list_del(&group->next);
1567 		kfree(group);
1568 		/*
1569 		 * Group ownership provides privilege, if the group list is
1570 		 * empty, the domain goes away. If it's the last domain with
1571 		 * iommu and external domain doesn't exist, then all the
1572 		 * mappings go away too. If it's the last domain with iommu and
1573 		 * external domain exist, update accounting
1574 		 */
1575 		if (list_empty(&domain->group_list)) {
1576 			if (list_is_singular(&iommu->domain_list)) {
1577 				if (!iommu->external_domain)
1578 					vfio_iommu_unmap_unpin_all(iommu);
1579 				else
1580 					vfio_iommu_unmap_unpin_reaccount(iommu);
1581 			}
1582 			iommu_domain_free(domain->domain);
1583 			list_del(&domain->next);
1584 			kfree(domain);
1585 		}
1586 		break;
1587 	}
1588 
1589 detach_group_done:
1590 	mutex_unlock(&iommu->lock);
1591 }
1592 
1593 static void *vfio_iommu_type1_open(unsigned long arg)
1594 {
1595 	struct vfio_iommu *iommu;
1596 
1597 	iommu = kzalloc(sizeof(*iommu), GFP_KERNEL);
1598 	if (!iommu)
1599 		return ERR_PTR(-ENOMEM);
1600 
1601 	switch (arg) {
1602 	case VFIO_TYPE1_IOMMU:
1603 		break;
1604 	case VFIO_TYPE1_NESTING_IOMMU:
1605 		iommu->nesting = true;
1606 	case VFIO_TYPE1v2_IOMMU:
1607 		iommu->v2 = true;
1608 		break;
1609 	default:
1610 		kfree(iommu);
1611 		return ERR_PTR(-EINVAL);
1612 	}
1613 
1614 	INIT_LIST_HEAD(&iommu->domain_list);
1615 	iommu->dma_list = RB_ROOT;
1616 	mutex_init(&iommu->lock);
1617 	BLOCKING_INIT_NOTIFIER_HEAD(&iommu->notifier);
1618 
1619 	return iommu;
1620 }
1621 
1622 static void vfio_release_domain(struct vfio_domain *domain, bool external)
1623 {
1624 	struct vfio_group *group, *group_tmp;
1625 
1626 	list_for_each_entry_safe(group, group_tmp,
1627 				 &domain->group_list, next) {
1628 		if (!external)
1629 			iommu_detach_group(domain->domain, group->iommu_group);
1630 		list_del(&group->next);
1631 		kfree(group);
1632 	}
1633 
1634 	if (!external)
1635 		iommu_domain_free(domain->domain);
1636 }
1637 
1638 static void vfio_iommu_type1_release(void *iommu_data)
1639 {
1640 	struct vfio_iommu *iommu = iommu_data;
1641 	struct vfio_domain *domain, *domain_tmp;
1642 
1643 	if (iommu->external_domain) {
1644 		vfio_release_domain(iommu->external_domain, true);
1645 		vfio_sanity_check_pfn_list(iommu);
1646 		kfree(iommu->external_domain);
1647 	}
1648 
1649 	vfio_iommu_unmap_unpin_all(iommu);
1650 
1651 	list_for_each_entry_safe(domain, domain_tmp,
1652 				 &iommu->domain_list, next) {
1653 		vfio_release_domain(domain, false);
1654 		list_del(&domain->next);
1655 		kfree(domain);
1656 	}
1657 	kfree(iommu);
1658 }
1659 
1660 static int vfio_domains_have_iommu_cache(struct vfio_iommu *iommu)
1661 {
1662 	struct vfio_domain *domain;
1663 	int ret = 1;
1664 
1665 	mutex_lock(&iommu->lock);
1666 	list_for_each_entry(domain, &iommu->domain_list, next) {
1667 		if (!(domain->prot & IOMMU_CACHE)) {
1668 			ret = 0;
1669 			break;
1670 		}
1671 	}
1672 	mutex_unlock(&iommu->lock);
1673 
1674 	return ret;
1675 }
1676 
1677 static long vfio_iommu_type1_ioctl(void *iommu_data,
1678 				   unsigned int cmd, unsigned long arg)
1679 {
1680 	struct vfio_iommu *iommu = iommu_data;
1681 	unsigned long minsz;
1682 
1683 	if (cmd == VFIO_CHECK_EXTENSION) {
1684 		switch (arg) {
1685 		case VFIO_TYPE1_IOMMU:
1686 		case VFIO_TYPE1v2_IOMMU:
1687 		case VFIO_TYPE1_NESTING_IOMMU:
1688 			return 1;
1689 		case VFIO_DMA_CC_IOMMU:
1690 			if (!iommu)
1691 				return 0;
1692 			return vfio_domains_have_iommu_cache(iommu);
1693 		default:
1694 			return 0;
1695 		}
1696 	} else if (cmd == VFIO_IOMMU_GET_INFO) {
1697 		struct vfio_iommu_type1_info info;
1698 
1699 		minsz = offsetofend(struct vfio_iommu_type1_info, iova_pgsizes);
1700 
1701 		if (copy_from_user(&info, (void __user *)arg, minsz))
1702 			return -EFAULT;
1703 
1704 		if (info.argsz < minsz)
1705 			return -EINVAL;
1706 
1707 		info.flags = VFIO_IOMMU_INFO_PGSIZES;
1708 
1709 		info.iova_pgsizes = vfio_pgsize_bitmap(iommu);
1710 
1711 		return copy_to_user((void __user *)arg, &info, minsz) ?
1712 			-EFAULT : 0;
1713 
1714 	} else if (cmd == VFIO_IOMMU_MAP_DMA) {
1715 		struct vfio_iommu_type1_dma_map map;
1716 		uint32_t mask = VFIO_DMA_MAP_FLAG_READ |
1717 				VFIO_DMA_MAP_FLAG_WRITE;
1718 
1719 		minsz = offsetofend(struct vfio_iommu_type1_dma_map, size);
1720 
1721 		if (copy_from_user(&map, (void __user *)arg, minsz))
1722 			return -EFAULT;
1723 
1724 		if (map.argsz < minsz || map.flags & ~mask)
1725 			return -EINVAL;
1726 
1727 		return vfio_dma_do_map(iommu, &map);
1728 
1729 	} else if (cmd == VFIO_IOMMU_UNMAP_DMA) {
1730 		struct vfio_iommu_type1_dma_unmap unmap;
1731 		long ret;
1732 
1733 		minsz = offsetofend(struct vfio_iommu_type1_dma_unmap, size);
1734 
1735 		if (copy_from_user(&unmap, (void __user *)arg, minsz))
1736 			return -EFAULT;
1737 
1738 		if (unmap.argsz < minsz || unmap.flags)
1739 			return -EINVAL;
1740 
1741 		ret = vfio_dma_do_unmap(iommu, &unmap);
1742 		if (ret)
1743 			return ret;
1744 
1745 		return copy_to_user((void __user *)arg, &unmap, minsz) ?
1746 			-EFAULT : 0;
1747 	}
1748 
1749 	return -ENOTTY;
1750 }
1751 
1752 static int vfio_iommu_type1_register_notifier(void *iommu_data,
1753 					      unsigned long *events,
1754 					      struct notifier_block *nb)
1755 {
1756 	struct vfio_iommu *iommu = iommu_data;
1757 
1758 	/* clear known events */
1759 	*events &= ~VFIO_IOMMU_NOTIFY_DMA_UNMAP;
1760 
1761 	/* refuse to register if still events remaining */
1762 	if (*events)
1763 		return -EINVAL;
1764 
1765 	return blocking_notifier_chain_register(&iommu->notifier, nb);
1766 }
1767 
1768 static int vfio_iommu_type1_unregister_notifier(void *iommu_data,
1769 						struct notifier_block *nb)
1770 {
1771 	struct vfio_iommu *iommu = iommu_data;
1772 
1773 	return blocking_notifier_chain_unregister(&iommu->notifier, nb);
1774 }
1775 
1776 static const struct vfio_iommu_driver_ops vfio_iommu_driver_ops_type1 = {
1777 	.name			= "vfio-iommu-type1",
1778 	.owner			= THIS_MODULE,
1779 	.open			= vfio_iommu_type1_open,
1780 	.release		= vfio_iommu_type1_release,
1781 	.ioctl			= vfio_iommu_type1_ioctl,
1782 	.attach_group		= vfio_iommu_type1_attach_group,
1783 	.detach_group		= vfio_iommu_type1_detach_group,
1784 	.pin_pages		= vfio_iommu_type1_pin_pages,
1785 	.unpin_pages		= vfio_iommu_type1_unpin_pages,
1786 	.register_notifier	= vfio_iommu_type1_register_notifier,
1787 	.unregister_notifier	= vfio_iommu_type1_unregister_notifier,
1788 };
1789 
1790 static int __init vfio_iommu_type1_init(void)
1791 {
1792 	return vfio_register_iommu_driver(&vfio_iommu_driver_ops_type1);
1793 }
1794 
1795 static void __exit vfio_iommu_type1_cleanup(void)
1796 {
1797 	vfio_unregister_iommu_driver(&vfio_iommu_driver_ops_type1);
1798 }
1799 
1800 module_init(vfio_iommu_type1_init);
1801 module_exit(vfio_iommu_type1_cleanup);
1802 
1803 MODULE_VERSION(DRIVER_VERSION);
1804 MODULE_LICENSE("GPL v2");
1805 MODULE_AUTHOR(DRIVER_AUTHOR);
1806 MODULE_DESCRIPTION(DRIVER_DESC);
1807