xref: /openbmc/linux/drivers/iommu/iommufd/pages.c (revision f4b20bb3)
1 // SPDX-License-Identifier: GPL-2.0
2 /* Copyright (c) 2021-2022, NVIDIA CORPORATION & AFFILIATES.
3  *
4  * The iopt_pages is the center of the storage and motion of PFNs. Each
5  * iopt_pages represents a logical linear array of full PFNs. The array is 0
6  * based and has npages in it. Accessors use 'index' to refer to the entry in
7  * this logical array, regardless of its storage location.
8  *
9  * PFNs are stored in a tiered scheme:
10  *  1) iopt_pages::pinned_pfns xarray
11  *  2) An iommu_domain
12  *  3) The origin of the PFNs, i.e. the userspace pointer
13  *
14  * PFN have to be copied between all combinations of tiers, depending on the
15  * configuration.
16  *
17  * When a PFN is taken out of the userspace pointer it is pinned exactly once.
18  * The storage locations of the PFN's index are tracked in the two interval
19  * trees. If no interval includes the index then it is not pinned.
20  *
21  * If access_itree includes the PFN's index then an in-kernel access has
22  * requested the page. The PFN is stored in the xarray so other requestors can
23  * continue to find it.
24  *
25  * If the domains_itree includes the PFN's index then an iommu_domain is storing
26  * the PFN and it can be read back using iommu_iova_to_phys(). To avoid
27  * duplicating storage the xarray is not used if only iommu_domains are using
28  * the PFN's index.
29  *
30  * As a general principle this is designed so that destroy never fails. This
31  * means removing an iommu_domain or releasing a in-kernel access will not fail
32  * due to insufficient memory. In practice this means some cases have to hold
33  * PFNs in the xarray even though they are also being stored in an iommu_domain.
34  *
35  * While the iopt_pages can use an iommu_domain as storage, it does not have an
36  * IOVA itself. Instead the iopt_area represents a range of IOVA and uses the
37  * iopt_pages as the PFN provider. Multiple iopt_areas can share the iopt_pages
38  * and reference their own slice of the PFN array, with sub page granularity.
39  *
40  * In this file the term 'last' indicates an inclusive and closed interval, eg
41  * [0,0] refers to a single PFN. 'end' means an open range, eg [0,0) refers to
42  * no PFNs.
43  *
44  * Be cautious of overflow. An IOVA can go all the way up to U64_MAX, so
45  * last_iova + 1 can overflow. An iopt_pages index will always be much less than
46  * ULONG_MAX so last_index + 1 cannot overflow.
47  */
48 #include <linux/overflow.h>
49 #include <linux/slab.h>
50 #include <linux/iommu.h>
51 #include <linux/sched/mm.h>
52 #include <linux/highmem.h>
53 #include <linux/kthread.h>
54 #include <linux/iommufd.h>
55 
56 #include "io_pagetable.h"
57 #include "double_span.h"
58 
59 #ifndef CONFIG_IOMMUFD_TEST
60 #define TEMP_MEMORY_LIMIT 65536
61 #else
62 #define TEMP_MEMORY_LIMIT iommufd_test_memory_limit
63 #endif
64 #define BATCH_BACKUP_SIZE 32
65 
66 /*
67  * More memory makes pin_user_pages() and the batching more efficient, but as
68  * this is only a performance optimization don't try too hard to get it. A 64k
69  * allocation can hold about 26M of 4k pages and 13G of 2M pages in an
70  * pfn_batch. Various destroy paths cannot fail and provide a small amount of
71  * stack memory as a backup contingency. If backup_len is given this cannot
72  * fail.
73  */
74 static void *temp_kmalloc(size_t *size, void *backup, size_t backup_len)
75 {
76 	void *res;
77 
78 	if (WARN_ON(*size == 0))
79 		return NULL;
80 
81 	if (*size < backup_len)
82 		return backup;
83 	*size = min_t(size_t, *size, TEMP_MEMORY_LIMIT);
84 	res = kmalloc(*size, GFP_KERNEL | __GFP_NOWARN | __GFP_NORETRY);
85 	if (res)
86 		return res;
87 	*size = PAGE_SIZE;
88 	if (backup_len) {
89 		res = kmalloc(*size, GFP_KERNEL | __GFP_NOWARN | __GFP_NORETRY);
90 		if (res)
91 			return res;
92 		*size = backup_len;
93 		return backup;
94 	}
95 	return kmalloc(*size, GFP_KERNEL);
96 }
97 
98 void interval_tree_double_span_iter_update(
99 	struct interval_tree_double_span_iter *iter)
100 {
101 	unsigned long last_hole = ULONG_MAX;
102 	unsigned int i;
103 
104 	for (i = 0; i != ARRAY_SIZE(iter->spans); i++) {
105 		if (interval_tree_span_iter_done(&iter->spans[i])) {
106 			iter->is_used = -1;
107 			return;
108 		}
109 
110 		if (iter->spans[i].is_hole) {
111 			last_hole = min(last_hole, iter->spans[i].last_hole);
112 			continue;
113 		}
114 
115 		iter->is_used = i + 1;
116 		iter->start_used = iter->spans[i].start_used;
117 		iter->last_used = min(iter->spans[i].last_used, last_hole);
118 		return;
119 	}
120 
121 	iter->is_used = 0;
122 	iter->start_hole = iter->spans[0].start_hole;
123 	iter->last_hole =
124 		min(iter->spans[0].last_hole, iter->spans[1].last_hole);
125 }
126 
127 void interval_tree_double_span_iter_first(
128 	struct interval_tree_double_span_iter *iter,
129 	struct rb_root_cached *itree1, struct rb_root_cached *itree2,
130 	unsigned long first_index, unsigned long last_index)
131 {
132 	unsigned int i;
133 
134 	iter->itrees[0] = itree1;
135 	iter->itrees[1] = itree2;
136 	for (i = 0; i != ARRAY_SIZE(iter->spans); i++)
137 		interval_tree_span_iter_first(&iter->spans[i], iter->itrees[i],
138 					      first_index, last_index);
139 	interval_tree_double_span_iter_update(iter);
140 }
141 
142 void interval_tree_double_span_iter_next(
143 	struct interval_tree_double_span_iter *iter)
144 {
145 	unsigned int i;
146 
147 	if (iter->is_used == -1 ||
148 	    iter->last_hole == iter->spans[0].last_index) {
149 		iter->is_used = -1;
150 		return;
151 	}
152 
153 	for (i = 0; i != ARRAY_SIZE(iter->spans); i++)
154 		interval_tree_span_iter_advance(
155 			&iter->spans[i], iter->itrees[i], iter->last_hole + 1);
156 	interval_tree_double_span_iter_update(iter);
157 }
158 
159 static void iopt_pages_add_npinned(struct iopt_pages *pages, size_t npages)
160 {
161 	pages->npinned += npages;
162 }
163 
164 static void iopt_pages_sub_npinned(struct iopt_pages *pages, size_t npages)
165 {
166 	pages->npinned -= npages;
167 }
168 
169 static void iopt_pages_err_unpin(struct iopt_pages *pages,
170 				 unsigned long start_index,
171 				 unsigned long last_index,
172 				 struct page **page_list)
173 {
174 	unsigned long npages = last_index - start_index + 1;
175 
176 	unpin_user_pages(page_list, npages);
177 	iopt_pages_sub_npinned(pages, npages);
178 }
179 
180 /*
181  * index is the number of PAGE_SIZE units from the start of the area's
182  * iopt_pages. If the iova is sub page-size then the area has an iova that
183  * covers a portion of the first and last pages in the range.
184  */
185 static unsigned long iopt_area_index_to_iova(struct iopt_area *area,
186 					     unsigned long index)
187 {
188 	index -= iopt_area_index(area);
189 	if (index == 0)
190 		return iopt_area_iova(area);
191 	return iopt_area_iova(area) - area->page_offset + index * PAGE_SIZE;
192 }
193 
194 static unsigned long iopt_area_index_to_iova_last(struct iopt_area *area,
195 						  unsigned long index)
196 {
197 	if (index == iopt_area_last_index(area))
198 		return iopt_area_last_iova(area);
199 	return iopt_area_iova(area) - area->page_offset +
200 	       (index - iopt_area_index(area) + 1) * PAGE_SIZE - 1;
201 }
202 
203 static void iommu_unmap_nofail(struct iommu_domain *domain, unsigned long iova,
204 			       size_t size)
205 {
206 	size_t ret;
207 
208 	ret = iommu_unmap(domain, iova, size);
209 	/*
210 	 * It is a logic error in this code or a driver bug if the IOMMU unmaps
211 	 * something other than exactly as requested. This implies that the
212 	 * iommu driver may not fail unmap for reasons beyond bad agruments.
213 	 * Particularly, the iommu driver may not do a memory allocation on the
214 	 * unmap path.
215 	 */
216 	WARN_ON(ret != size);
217 }
218 
219 static void iopt_area_unmap_domain_range(struct iopt_area *area,
220 					 struct iommu_domain *domain,
221 					 unsigned long start_index,
222 					 unsigned long last_index)
223 {
224 	unsigned long start_iova = iopt_area_index_to_iova(area, start_index);
225 
226 	iommu_unmap_nofail(domain, start_iova,
227 			   iopt_area_index_to_iova_last(area, last_index) -
228 				   start_iova + 1);
229 }
230 
231 static struct iopt_area *iopt_pages_find_domain_area(struct iopt_pages *pages,
232 						     unsigned long index)
233 {
234 	struct interval_tree_node *node;
235 
236 	node = interval_tree_iter_first(&pages->domains_itree, index, index);
237 	if (!node)
238 		return NULL;
239 	return container_of(node, struct iopt_area, pages_node);
240 }
241 
242 /*
243  * A simple datastructure to hold a vector of PFNs, optimized for contiguous
244  * PFNs. This is used as a temporary holding memory for shuttling pfns from one
245  * place to another. Generally everything is made more efficient if operations
246  * work on the largest possible grouping of pfns. eg fewer lock/unlock cycles,
247  * better cache locality, etc
248  */
249 struct pfn_batch {
250 	unsigned long *pfns;
251 	u32 *npfns;
252 	unsigned int array_size;
253 	unsigned int end;
254 	unsigned int total_pfns;
255 };
256 
257 static void batch_clear(struct pfn_batch *batch)
258 {
259 	batch->total_pfns = 0;
260 	batch->end = 0;
261 	batch->pfns[0] = 0;
262 	batch->npfns[0] = 0;
263 }
264 
265 /*
266  * Carry means we carry a portion of the final hugepage over to the front of the
267  * batch
268  */
269 static void batch_clear_carry(struct pfn_batch *batch, unsigned int keep_pfns)
270 {
271 	if (!keep_pfns)
272 		return batch_clear(batch);
273 
274 	batch->total_pfns = keep_pfns;
275 	batch->npfns[0] = keep_pfns;
276 	batch->pfns[0] = batch->pfns[batch->end - 1] +
277 			 (batch->npfns[batch->end - 1] - keep_pfns);
278 	batch->end = 0;
279 }
280 
281 static void batch_skip_carry(struct pfn_batch *batch, unsigned int skip_pfns)
282 {
283 	if (!batch->total_pfns)
284 		return;
285 	skip_pfns = min(batch->total_pfns, skip_pfns);
286 	batch->pfns[0] += skip_pfns;
287 	batch->npfns[0] -= skip_pfns;
288 	batch->total_pfns -= skip_pfns;
289 }
290 
291 static int __batch_init(struct pfn_batch *batch, size_t max_pages, void *backup,
292 			size_t backup_len)
293 {
294 	const size_t elmsz = sizeof(*batch->pfns) + sizeof(*batch->npfns);
295 	size_t size = max_pages * elmsz;
296 
297 	batch->pfns = temp_kmalloc(&size, backup, backup_len);
298 	if (!batch->pfns)
299 		return -ENOMEM;
300 	batch->array_size = size / elmsz;
301 	batch->npfns = (u32 *)(batch->pfns + batch->array_size);
302 	batch_clear(batch);
303 	return 0;
304 }
305 
306 static int batch_init(struct pfn_batch *batch, size_t max_pages)
307 {
308 	return __batch_init(batch, max_pages, NULL, 0);
309 }
310 
311 static void batch_init_backup(struct pfn_batch *batch, size_t max_pages,
312 			      void *backup, size_t backup_len)
313 {
314 	__batch_init(batch, max_pages, backup, backup_len);
315 }
316 
317 static void batch_destroy(struct pfn_batch *batch, void *backup)
318 {
319 	if (batch->pfns != backup)
320 		kfree(batch->pfns);
321 }
322 
323 /* true if the pfn could be added, false otherwise */
324 static bool batch_add_pfn(struct pfn_batch *batch, unsigned long pfn)
325 {
326 	const unsigned int MAX_NPFNS = type_max(typeof(*batch->npfns));
327 
328 	if (batch->end &&
329 	    pfn == batch->pfns[batch->end - 1] + batch->npfns[batch->end - 1] &&
330 	    batch->npfns[batch->end - 1] != MAX_NPFNS) {
331 		batch->npfns[batch->end - 1]++;
332 		batch->total_pfns++;
333 		return true;
334 	}
335 	if (batch->end == batch->array_size)
336 		return false;
337 	batch->total_pfns++;
338 	batch->pfns[batch->end] = pfn;
339 	batch->npfns[batch->end] = 1;
340 	batch->end++;
341 	return true;
342 }
343 
344 /*
345  * Fill the batch with pfns from the domain. When the batch is full, or it
346  * reaches last_index, the function will return. The caller should use
347  * batch->total_pfns to determine the starting point for the next iteration.
348  */
349 static void batch_from_domain(struct pfn_batch *batch,
350 			      struct iommu_domain *domain,
351 			      struct iopt_area *area, unsigned long start_index,
352 			      unsigned long last_index)
353 {
354 	unsigned int page_offset = 0;
355 	unsigned long iova;
356 	phys_addr_t phys;
357 
358 	iova = iopt_area_index_to_iova(area, start_index);
359 	if (start_index == iopt_area_index(area))
360 		page_offset = area->page_offset;
361 	while (start_index <= last_index) {
362 		/*
363 		 * This is pretty slow, it would be nice to get the page size
364 		 * back from the driver, or have the driver directly fill the
365 		 * batch.
366 		 */
367 		phys = iommu_iova_to_phys(domain, iova) - page_offset;
368 		if (!batch_add_pfn(batch, PHYS_PFN(phys)))
369 			return;
370 		iova += PAGE_SIZE - page_offset;
371 		page_offset = 0;
372 		start_index++;
373 	}
374 }
375 
376 static struct page **raw_pages_from_domain(struct iommu_domain *domain,
377 					   struct iopt_area *area,
378 					   unsigned long start_index,
379 					   unsigned long last_index,
380 					   struct page **out_pages)
381 {
382 	unsigned int page_offset = 0;
383 	unsigned long iova;
384 	phys_addr_t phys;
385 
386 	iova = iopt_area_index_to_iova(area, start_index);
387 	if (start_index == iopt_area_index(area))
388 		page_offset = area->page_offset;
389 	while (start_index <= last_index) {
390 		phys = iommu_iova_to_phys(domain, iova) - page_offset;
391 		*(out_pages++) = pfn_to_page(PHYS_PFN(phys));
392 		iova += PAGE_SIZE - page_offset;
393 		page_offset = 0;
394 		start_index++;
395 	}
396 	return out_pages;
397 }
398 
399 /* Continues reading a domain until we reach a discontiguity in the pfns. */
400 static void batch_from_domain_continue(struct pfn_batch *batch,
401 				       struct iommu_domain *domain,
402 				       struct iopt_area *area,
403 				       unsigned long start_index,
404 				       unsigned long last_index)
405 {
406 	unsigned int array_size = batch->array_size;
407 
408 	batch->array_size = batch->end;
409 	batch_from_domain(batch, domain, area, start_index, last_index);
410 	batch->array_size = array_size;
411 }
412 
413 /*
414  * This is part of the VFIO compatibility support for VFIO_TYPE1_IOMMU. That
415  * mode permits splitting a mapped area up, and then one of the splits is
416  * unmapped. Doing this normally would cause us to violate our invariant of
417  * pairing map/unmap. Thus, to support old VFIO compatibility disable support
418  * for batching consecutive PFNs. All PFNs mapped into the iommu are done in
419  * PAGE_SIZE units, not larger or smaller.
420  */
421 static int batch_iommu_map_small(struct iommu_domain *domain,
422 				 unsigned long iova, phys_addr_t paddr,
423 				 size_t size, int prot)
424 {
425 	unsigned long start_iova = iova;
426 	int rc;
427 
428 	while (size) {
429 		rc = iommu_map(domain, iova, paddr, PAGE_SIZE, prot);
430 		if (rc)
431 			goto err_unmap;
432 		iova += PAGE_SIZE;
433 		paddr += PAGE_SIZE;
434 		size -= PAGE_SIZE;
435 	}
436 	return 0;
437 
438 err_unmap:
439 	if (start_iova != iova)
440 		iommu_unmap_nofail(domain, start_iova, iova - start_iova);
441 	return rc;
442 }
443 
444 static int batch_to_domain(struct pfn_batch *batch, struct iommu_domain *domain,
445 			   struct iopt_area *area, unsigned long start_index)
446 {
447 	bool disable_large_pages = area->iopt->disable_large_pages;
448 	unsigned long last_iova = iopt_area_last_iova(area);
449 	unsigned int page_offset = 0;
450 	unsigned long start_iova;
451 	unsigned long next_iova;
452 	unsigned int cur = 0;
453 	unsigned long iova;
454 	int rc;
455 
456 	/* The first index might be a partial page */
457 	if (start_index == iopt_area_index(area))
458 		page_offset = area->page_offset;
459 	next_iova = iova = start_iova =
460 		iopt_area_index_to_iova(area, start_index);
461 	while (cur < batch->end) {
462 		next_iova = min(last_iova + 1,
463 				next_iova + batch->npfns[cur] * PAGE_SIZE -
464 					page_offset);
465 		if (disable_large_pages)
466 			rc = batch_iommu_map_small(
467 				domain, iova,
468 				PFN_PHYS(batch->pfns[cur]) + page_offset,
469 				next_iova - iova, area->iommu_prot);
470 		else
471 			rc = iommu_map(domain, iova,
472 				       PFN_PHYS(batch->pfns[cur]) + page_offset,
473 				       next_iova - iova, area->iommu_prot);
474 		if (rc)
475 			goto err_unmap;
476 		iova = next_iova;
477 		page_offset = 0;
478 		cur++;
479 	}
480 	return 0;
481 err_unmap:
482 	if (start_iova != iova)
483 		iommu_unmap_nofail(domain, start_iova, iova - start_iova);
484 	return rc;
485 }
486 
487 static void batch_from_xarray(struct pfn_batch *batch, struct xarray *xa,
488 			      unsigned long start_index,
489 			      unsigned long last_index)
490 {
491 	XA_STATE(xas, xa, start_index);
492 	void *entry;
493 
494 	rcu_read_lock();
495 	while (true) {
496 		entry = xas_next(&xas);
497 		if (xas_retry(&xas, entry))
498 			continue;
499 		WARN_ON(!xa_is_value(entry));
500 		if (!batch_add_pfn(batch, xa_to_value(entry)) ||
501 		    start_index == last_index)
502 			break;
503 		start_index++;
504 	}
505 	rcu_read_unlock();
506 }
507 
508 static void batch_from_xarray_clear(struct pfn_batch *batch, struct xarray *xa,
509 				    unsigned long start_index,
510 				    unsigned long last_index)
511 {
512 	XA_STATE(xas, xa, start_index);
513 	void *entry;
514 
515 	xas_lock(&xas);
516 	while (true) {
517 		entry = xas_next(&xas);
518 		if (xas_retry(&xas, entry))
519 			continue;
520 		WARN_ON(!xa_is_value(entry));
521 		if (!batch_add_pfn(batch, xa_to_value(entry)))
522 			break;
523 		xas_store(&xas, NULL);
524 		if (start_index == last_index)
525 			break;
526 		start_index++;
527 	}
528 	xas_unlock(&xas);
529 }
530 
531 static void clear_xarray(struct xarray *xa, unsigned long start_index,
532 			 unsigned long last_index)
533 {
534 	XA_STATE(xas, xa, start_index);
535 	void *entry;
536 
537 	xas_lock(&xas);
538 	xas_for_each(&xas, entry, last_index)
539 		xas_store(&xas, NULL);
540 	xas_unlock(&xas);
541 }
542 
543 static int pages_to_xarray(struct xarray *xa, unsigned long start_index,
544 			   unsigned long last_index, struct page **pages)
545 {
546 	struct page **end_pages = pages + (last_index - start_index) + 1;
547 	XA_STATE(xas, xa, start_index);
548 
549 	do {
550 		void *old;
551 
552 		xas_lock(&xas);
553 		while (pages != end_pages) {
554 			old = xas_store(&xas, xa_mk_value(page_to_pfn(*pages)));
555 			if (xas_error(&xas))
556 				break;
557 			WARN_ON(old);
558 			pages++;
559 			xas_next(&xas);
560 		}
561 		xas_unlock(&xas);
562 	} while (xas_nomem(&xas, GFP_KERNEL));
563 
564 	if (xas_error(&xas)) {
565 		if (xas.xa_index != start_index)
566 			clear_xarray(xa, start_index, xas.xa_index - 1);
567 		return xas_error(&xas);
568 	}
569 	return 0;
570 }
571 
572 static void batch_from_pages(struct pfn_batch *batch, struct page **pages,
573 			     size_t npages)
574 {
575 	struct page **end = pages + npages;
576 
577 	for (; pages != end; pages++)
578 		if (!batch_add_pfn(batch, page_to_pfn(*pages)))
579 			break;
580 }
581 
582 static void batch_unpin(struct pfn_batch *batch, struct iopt_pages *pages,
583 			unsigned int first_page_off, size_t npages)
584 {
585 	unsigned int cur = 0;
586 
587 	while (first_page_off) {
588 		if (batch->npfns[cur] > first_page_off)
589 			break;
590 		first_page_off -= batch->npfns[cur];
591 		cur++;
592 	}
593 
594 	while (npages) {
595 		size_t to_unpin = min_t(size_t, npages,
596 					batch->npfns[cur] - first_page_off);
597 
598 		unpin_user_page_range_dirty_lock(
599 			pfn_to_page(batch->pfns[cur] + first_page_off),
600 			to_unpin, pages->writable);
601 		iopt_pages_sub_npinned(pages, to_unpin);
602 		cur++;
603 		first_page_off = 0;
604 		npages -= to_unpin;
605 	}
606 }
607 
608 static void copy_data_page(struct page *page, void *data, unsigned long offset,
609 			   size_t length, unsigned int flags)
610 {
611 	void *mem;
612 
613 	mem = kmap_local_page(page);
614 	if (flags & IOMMUFD_ACCESS_RW_WRITE) {
615 		memcpy(mem + offset, data, length);
616 		set_page_dirty_lock(page);
617 	} else {
618 		memcpy(data, mem + offset, length);
619 	}
620 	kunmap_local(mem);
621 }
622 
623 static unsigned long batch_rw(struct pfn_batch *batch, void *data,
624 			      unsigned long offset, unsigned long length,
625 			      unsigned int flags)
626 {
627 	unsigned long copied = 0;
628 	unsigned int npage = 0;
629 	unsigned int cur = 0;
630 
631 	while (cur < batch->end) {
632 		unsigned long bytes = min(length, PAGE_SIZE - offset);
633 
634 		copy_data_page(pfn_to_page(batch->pfns[cur] + npage), data,
635 			       offset, bytes, flags);
636 		offset = 0;
637 		length -= bytes;
638 		data += bytes;
639 		copied += bytes;
640 		npage++;
641 		if (npage == batch->npfns[cur]) {
642 			npage = 0;
643 			cur++;
644 		}
645 		if (!length)
646 			break;
647 	}
648 	return copied;
649 }
650 
651 /* pfn_reader_user is just the pin_user_pages() path */
652 struct pfn_reader_user {
653 	struct page **upages;
654 	size_t upages_len;
655 	unsigned long upages_start;
656 	unsigned long upages_end;
657 	unsigned int gup_flags;
658 	/*
659 	 * 1 means mmget() and mmap_read_lock(), 0 means only mmget(), -1 is
660 	 * neither
661 	 */
662 	int locked;
663 };
664 
665 static void pfn_reader_user_init(struct pfn_reader_user *user,
666 				 struct iopt_pages *pages)
667 {
668 	user->upages = NULL;
669 	user->upages_start = 0;
670 	user->upages_end = 0;
671 	user->locked = -1;
672 
673 	if (pages->writable) {
674 		user->gup_flags = FOLL_LONGTERM | FOLL_WRITE;
675 	} else {
676 		/* Still need to break COWs on read */
677 		user->gup_flags = FOLL_LONGTERM | FOLL_FORCE | FOLL_WRITE;
678 	}
679 }
680 
681 static void pfn_reader_user_destroy(struct pfn_reader_user *user,
682 				    struct iopt_pages *pages)
683 {
684 	if (user->locked != -1) {
685 		if (user->locked)
686 			mmap_read_unlock(pages->source_mm);
687 		if (pages->source_mm != current->mm)
688 			mmput(pages->source_mm);
689 		user->locked = 0;
690 	}
691 
692 	kfree(user->upages);
693 	user->upages = NULL;
694 }
695 
696 static int pfn_reader_user_pin(struct pfn_reader_user *user,
697 			       struct iopt_pages *pages,
698 			       unsigned long start_index,
699 			       unsigned long last_index)
700 {
701 	bool remote_mm = pages->source_mm != current->mm;
702 	unsigned long npages;
703 	uintptr_t uptr;
704 	long rc;
705 
706 	if (!user->upages) {
707 		/* All undone in pfn_reader_destroy() */
708 		user->upages_len =
709 			(last_index - start_index + 1) * sizeof(*user->upages);
710 		user->upages = temp_kmalloc(&user->upages_len, NULL, 0);
711 		if (!user->upages)
712 			return -ENOMEM;
713 	}
714 
715 	if (user->locked == -1) {
716 		/*
717 		 * The majority of usages will run the map task within the mm
718 		 * providing the pages, so we can optimize into
719 		 * get_user_pages_fast()
720 		 */
721 		if (remote_mm) {
722 			if (!mmget_not_zero(pages->source_mm))
723 				return -EFAULT;
724 		}
725 		user->locked = 0;
726 	}
727 
728 	npages = min_t(unsigned long, last_index - start_index + 1,
729 		       user->upages_len / sizeof(*user->upages));
730 
731 	uptr = (uintptr_t)(pages->uptr + start_index * PAGE_SIZE);
732 	if (!remote_mm)
733 		rc = pin_user_pages_fast(uptr, npages, user->gup_flags,
734 					 user->upages);
735 	else {
736 		if (!user->locked) {
737 			mmap_read_lock(pages->source_mm);
738 			user->locked = 1;
739 		}
740 		/*
741 		 * FIXME: last NULL can be &pfns->locked once the GUP patch
742 		 * is merged.
743 		 */
744 		rc = pin_user_pages_remote(pages->source_mm, uptr, npages,
745 					   user->gup_flags, user->upages, NULL,
746 					   NULL);
747 	}
748 	if (rc <= 0) {
749 		if (WARN_ON(!rc))
750 			return -EFAULT;
751 		return rc;
752 	}
753 	iopt_pages_add_npinned(pages, rc);
754 	user->upages_start = start_index;
755 	user->upages_end = start_index + rc;
756 	return 0;
757 }
758 
759 /* This is the "modern" and faster accounting method used by io_uring */
760 static int incr_user_locked_vm(struct iopt_pages *pages, unsigned long npages)
761 {
762 	unsigned long lock_limit;
763 	unsigned long cur_pages;
764 	unsigned long new_pages;
765 
766 	lock_limit = task_rlimit(pages->source_task, RLIMIT_MEMLOCK) >>
767 		     PAGE_SHIFT;
768 	npages = pages->npinned - pages->last_npinned;
769 	do {
770 		cur_pages = atomic_long_read(&pages->source_user->locked_vm);
771 		new_pages = cur_pages + npages;
772 		if (new_pages > lock_limit)
773 			return -ENOMEM;
774 	} while (atomic_long_cmpxchg(&pages->source_user->locked_vm, cur_pages,
775 				     new_pages) != cur_pages);
776 	return 0;
777 }
778 
779 static void decr_user_locked_vm(struct iopt_pages *pages, unsigned long npages)
780 {
781 	if (WARN_ON(atomic_long_read(&pages->source_user->locked_vm) < npages))
782 		return;
783 	atomic_long_sub(npages, &pages->source_user->locked_vm);
784 }
785 
786 /* This is the accounting method used for compatibility with VFIO */
787 static int update_mm_locked_vm(struct iopt_pages *pages, unsigned long npages,
788 			       bool inc, struct pfn_reader_user *user)
789 {
790 	bool do_put = false;
791 	int rc;
792 
793 	if (user && user->locked) {
794 		mmap_read_unlock(pages->source_mm);
795 		user->locked = 0;
796 		/* If we had the lock then we also have a get */
797 	} else if ((!user || !user->upages) &&
798 		   pages->source_mm != current->mm) {
799 		if (!mmget_not_zero(pages->source_mm))
800 			return -EINVAL;
801 		do_put = true;
802 	}
803 
804 	mmap_write_lock(pages->source_mm);
805 	rc = __account_locked_vm(pages->source_mm, npages, inc,
806 				 pages->source_task, false);
807 	mmap_write_unlock(pages->source_mm);
808 
809 	if (do_put)
810 		mmput(pages->source_mm);
811 	return rc;
812 }
813 
814 static int do_update_pinned(struct iopt_pages *pages, unsigned long npages,
815 			    bool inc, struct pfn_reader_user *user)
816 {
817 	int rc = 0;
818 
819 	switch (pages->account_mode) {
820 	case IOPT_PAGES_ACCOUNT_NONE:
821 		break;
822 	case IOPT_PAGES_ACCOUNT_USER:
823 		if (inc)
824 			rc = incr_user_locked_vm(pages, npages);
825 		else
826 			decr_user_locked_vm(pages, npages);
827 		break;
828 	case IOPT_PAGES_ACCOUNT_MM:
829 		rc = update_mm_locked_vm(pages, npages, inc, user);
830 		break;
831 	}
832 	if (rc)
833 		return rc;
834 
835 	pages->last_npinned = pages->npinned;
836 	if (inc)
837 		atomic64_add(npages, &pages->source_mm->pinned_vm);
838 	else
839 		atomic64_sub(npages, &pages->source_mm->pinned_vm);
840 	return 0;
841 }
842 
843 static void update_unpinned(struct iopt_pages *pages)
844 {
845 	if (WARN_ON(pages->npinned > pages->last_npinned))
846 		return;
847 	if (pages->npinned == pages->last_npinned)
848 		return;
849 	do_update_pinned(pages, pages->last_npinned - pages->npinned, false,
850 			 NULL);
851 }
852 
853 /*
854  * Changes in the number of pages pinned is done after the pages have been read
855  * and processed. If the user lacked the limit then the error unwind will unpin
856  * everything that was just pinned. This is because it is expensive to calculate
857  * how many pages we have already pinned within a range to generate an accurate
858  * prediction in advance of doing the work to actually pin them.
859  */
860 static int pfn_reader_user_update_pinned(struct pfn_reader_user *user,
861 					 struct iopt_pages *pages)
862 {
863 	unsigned long npages;
864 	bool inc;
865 
866 	lockdep_assert_held(&pages->mutex);
867 
868 	if (pages->npinned == pages->last_npinned)
869 		return 0;
870 
871 	if (pages->npinned < pages->last_npinned) {
872 		npages = pages->last_npinned - pages->npinned;
873 		inc = false;
874 	} else {
875 		npages = pages->npinned - pages->last_npinned;
876 		inc = true;
877 	}
878 	return do_update_pinned(pages, npages, inc, user);
879 }
880 
881 /*
882  * PFNs are stored in three places, in order of preference:
883  * - The iopt_pages xarray. This is only populated if there is a
884  *   iopt_pages_access
885  * - The iommu_domain under an area
886  * - The original PFN source, ie pages->source_mm
887  *
888  * This iterator reads the pfns optimizing to load according to the
889  * above order.
890  */
891 struct pfn_reader {
892 	struct iopt_pages *pages;
893 	struct interval_tree_double_span_iter span;
894 	struct pfn_batch batch;
895 	unsigned long batch_start_index;
896 	unsigned long batch_end_index;
897 	unsigned long last_index;
898 
899 	struct pfn_reader_user user;
900 };
901 
902 static int pfn_reader_update_pinned(struct pfn_reader *pfns)
903 {
904 	return pfn_reader_user_update_pinned(&pfns->user, pfns->pages);
905 }
906 
907 /*
908  * The batch can contain a mixture of pages that are still in use and pages that
909  * need to be unpinned. Unpin only pages that are not held anywhere else.
910  */
911 static void pfn_reader_unpin(struct pfn_reader *pfns)
912 {
913 	unsigned long last = pfns->batch_end_index - 1;
914 	unsigned long start = pfns->batch_start_index;
915 	struct interval_tree_double_span_iter span;
916 	struct iopt_pages *pages = pfns->pages;
917 
918 	lockdep_assert_held(&pages->mutex);
919 
920 	interval_tree_for_each_double_span(&span, &pages->access_itree,
921 					   &pages->domains_itree, start, last) {
922 		if (span.is_used)
923 			continue;
924 
925 		batch_unpin(&pfns->batch, pages, span.start_hole - start,
926 			    span.last_hole - span.start_hole + 1);
927 	}
928 }
929 
930 /* Process a single span to load it from the proper storage */
931 static int pfn_reader_fill_span(struct pfn_reader *pfns)
932 {
933 	struct interval_tree_double_span_iter *span = &pfns->span;
934 	unsigned long start_index = pfns->batch_end_index;
935 	struct iopt_area *area;
936 	int rc;
937 
938 	if (span->is_used == 1) {
939 		batch_from_xarray(&pfns->batch, &pfns->pages->pinned_pfns,
940 				  start_index, span->last_used);
941 		return 0;
942 	}
943 
944 	if (span->is_used == 2) {
945 		/*
946 		 * Pull as many pages from the first domain we find in the
947 		 * target span. If it is too small then we will be called again
948 		 * and we'll find another area.
949 		 */
950 		area = iopt_pages_find_domain_area(pfns->pages, start_index);
951 		if (WARN_ON(!area))
952 			return -EINVAL;
953 
954 		/* The storage_domain cannot change without the pages mutex */
955 		batch_from_domain(
956 			&pfns->batch, area->storage_domain, area, start_index,
957 			min(iopt_area_last_index(area), span->last_used));
958 		return 0;
959 	}
960 
961 	if (start_index >= pfns->user.upages_end) {
962 		rc = pfn_reader_user_pin(&pfns->user, pfns->pages, start_index,
963 					 span->last_hole);
964 		if (rc)
965 			return rc;
966 	}
967 
968 	batch_from_pages(&pfns->batch,
969 			 pfns->user.upages +
970 				 (start_index - pfns->user.upages_start),
971 			 pfns->user.upages_end - start_index);
972 	return 0;
973 }
974 
975 static bool pfn_reader_done(struct pfn_reader *pfns)
976 {
977 	return pfns->batch_start_index == pfns->last_index + 1;
978 }
979 
980 static int pfn_reader_next(struct pfn_reader *pfns)
981 {
982 	int rc;
983 
984 	batch_clear(&pfns->batch);
985 	pfns->batch_start_index = pfns->batch_end_index;
986 
987 	while (pfns->batch_end_index != pfns->last_index + 1) {
988 		unsigned int npfns = pfns->batch.total_pfns;
989 
990 		rc = pfn_reader_fill_span(pfns);
991 		if (rc)
992 			return rc;
993 
994 		if (WARN_ON(!pfns->batch.total_pfns))
995 			return -EINVAL;
996 
997 		pfns->batch_end_index =
998 			pfns->batch_start_index + pfns->batch.total_pfns;
999 		if (pfns->batch_end_index == pfns->span.last_used + 1)
1000 			interval_tree_double_span_iter_next(&pfns->span);
1001 
1002 		/* Batch is full */
1003 		if (npfns == pfns->batch.total_pfns)
1004 			return 0;
1005 	}
1006 	return 0;
1007 }
1008 
1009 static int pfn_reader_init(struct pfn_reader *pfns, struct iopt_pages *pages,
1010 			   unsigned long start_index, unsigned long last_index)
1011 {
1012 	int rc;
1013 
1014 	lockdep_assert_held(&pages->mutex);
1015 
1016 	pfns->pages = pages;
1017 	pfns->batch_start_index = start_index;
1018 	pfns->batch_end_index = start_index;
1019 	pfns->last_index = last_index;
1020 	pfn_reader_user_init(&pfns->user, pages);
1021 	rc = batch_init(&pfns->batch, last_index - start_index + 1);
1022 	if (rc)
1023 		return rc;
1024 	interval_tree_double_span_iter_first(&pfns->span, &pages->access_itree,
1025 					     &pages->domains_itree, start_index,
1026 					     last_index);
1027 	return 0;
1028 }
1029 
1030 /*
1031  * There are many assertions regarding the state of pages->npinned vs
1032  * pages->last_pinned, for instance something like unmapping a domain must only
1033  * decrement the npinned, and pfn_reader_destroy() must be called only after all
1034  * the pins are updated. This is fine for success flows, but error flows
1035  * sometimes need to release the pins held inside the pfn_reader before going on
1036  * to complete unmapping and releasing pins held in domains.
1037  */
1038 static void pfn_reader_release_pins(struct pfn_reader *pfns)
1039 {
1040 	struct iopt_pages *pages = pfns->pages;
1041 
1042 	if (pfns->user.upages_end > pfns->batch_end_index) {
1043 		size_t npages = pfns->user.upages_end - pfns->batch_end_index;
1044 
1045 		/* Any pages not transferred to the batch are just unpinned */
1046 		unpin_user_pages(pfns->user.upages + (pfns->batch_end_index -
1047 						      pfns->user.upages_start),
1048 				 npages);
1049 		iopt_pages_sub_npinned(pages, npages);
1050 		pfns->user.upages_end = pfns->batch_end_index;
1051 	}
1052 	if (pfns->batch_start_index != pfns->batch_end_index) {
1053 		pfn_reader_unpin(pfns);
1054 		pfns->batch_start_index = pfns->batch_end_index;
1055 	}
1056 }
1057 
1058 static void pfn_reader_destroy(struct pfn_reader *pfns)
1059 {
1060 	struct iopt_pages *pages = pfns->pages;
1061 
1062 	pfn_reader_release_pins(pfns);
1063 	pfn_reader_user_destroy(&pfns->user, pfns->pages);
1064 	batch_destroy(&pfns->batch, NULL);
1065 	WARN_ON(pages->last_npinned != pages->npinned);
1066 }
1067 
1068 static int pfn_reader_first(struct pfn_reader *pfns, struct iopt_pages *pages,
1069 			    unsigned long start_index, unsigned long last_index)
1070 {
1071 	int rc;
1072 
1073 	rc = pfn_reader_init(pfns, pages, start_index, last_index);
1074 	if (rc)
1075 		return rc;
1076 	rc = pfn_reader_next(pfns);
1077 	if (rc) {
1078 		pfn_reader_destroy(pfns);
1079 		return rc;
1080 	}
1081 	return 0;
1082 }
1083 
1084 struct iopt_pages *iopt_alloc_pages(void __user *uptr, unsigned long length,
1085 				    bool writable)
1086 {
1087 	struct iopt_pages *pages;
1088 
1089 	/*
1090 	 * The iommu API uses size_t as the length, and protect the DIV_ROUND_UP
1091 	 * below from overflow
1092 	 */
1093 	if (length > SIZE_MAX - PAGE_SIZE || length == 0)
1094 		return ERR_PTR(-EINVAL);
1095 
1096 	pages = kzalloc(sizeof(*pages), GFP_KERNEL_ACCOUNT);
1097 	if (!pages)
1098 		return ERR_PTR(-ENOMEM);
1099 
1100 	kref_init(&pages->kref);
1101 	xa_init_flags(&pages->pinned_pfns, XA_FLAGS_ACCOUNT);
1102 	mutex_init(&pages->mutex);
1103 	pages->source_mm = current->mm;
1104 	mmgrab(pages->source_mm);
1105 	pages->uptr = (void __user *)ALIGN_DOWN((uintptr_t)uptr, PAGE_SIZE);
1106 	pages->npages = DIV_ROUND_UP(length + (uptr - pages->uptr), PAGE_SIZE);
1107 	pages->access_itree = RB_ROOT_CACHED;
1108 	pages->domains_itree = RB_ROOT_CACHED;
1109 	pages->writable = writable;
1110 	if (capable(CAP_IPC_LOCK))
1111 		pages->account_mode = IOPT_PAGES_ACCOUNT_NONE;
1112 	else
1113 		pages->account_mode = IOPT_PAGES_ACCOUNT_USER;
1114 	pages->source_task = current->group_leader;
1115 	get_task_struct(current->group_leader);
1116 	pages->source_user = get_uid(current_user());
1117 	return pages;
1118 }
1119 
1120 void iopt_release_pages(struct kref *kref)
1121 {
1122 	struct iopt_pages *pages = container_of(kref, struct iopt_pages, kref);
1123 
1124 	WARN_ON(!RB_EMPTY_ROOT(&pages->access_itree.rb_root));
1125 	WARN_ON(!RB_EMPTY_ROOT(&pages->domains_itree.rb_root));
1126 	WARN_ON(pages->npinned);
1127 	WARN_ON(!xa_empty(&pages->pinned_pfns));
1128 	mmdrop(pages->source_mm);
1129 	mutex_destroy(&pages->mutex);
1130 	put_task_struct(pages->source_task);
1131 	free_uid(pages->source_user);
1132 	kfree(pages);
1133 }
1134 
1135 static void
1136 iopt_area_unpin_domain(struct pfn_batch *batch, struct iopt_area *area,
1137 		       struct iopt_pages *pages, struct iommu_domain *domain,
1138 		       unsigned long start_index, unsigned long last_index,
1139 		       unsigned long *unmapped_end_index,
1140 		       unsigned long real_last_index)
1141 {
1142 	while (start_index <= last_index) {
1143 		unsigned long batch_last_index;
1144 
1145 		if (*unmapped_end_index <= last_index) {
1146 			unsigned long start =
1147 				max(start_index, *unmapped_end_index);
1148 
1149 			batch_from_domain(batch, domain, area, start,
1150 					  last_index);
1151 			batch_last_index = start + batch->total_pfns - 1;
1152 		} else {
1153 			batch_last_index = last_index;
1154 		}
1155 
1156 		/*
1157 		 * unmaps must always 'cut' at a place where the pfns are not
1158 		 * contiguous to pair with the maps that always install
1159 		 * contiguous pages. Thus, if we have to stop unpinning in the
1160 		 * middle of the domains we need to keep reading pfns until we
1161 		 * find a cut point to do the unmap. The pfns we read are
1162 		 * carried over and either skipped or integrated into the next
1163 		 * batch.
1164 		 */
1165 		if (batch_last_index == last_index &&
1166 		    last_index != real_last_index)
1167 			batch_from_domain_continue(batch, domain, area,
1168 						   last_index + 1,
1169 						   real_last_index);
1170 
1171 		if (*unmapped_end_index <= batch_last_index) {
1172 			iopt_area_unmap_domain_range(
1173 				area, domain, *unmapped_end_index,
1174 				start_index + batch->total_pfns - 1);
1175 			*unmapped_end_index = start_index + batch->total_pfns;
1176 		}
1177 
1178 		/* unpin must follow unmap */
1179 		batch_unpin(batch, pages, 0,
1180 			    batch_last_index - start_index + 1);
1181 		start_index = batch_last_index + 1;
1182 
1183 		batch_clear_carry(batch,
1184 				  *unmapped_end_index - batch_last_index - 1);
1185 	}
1186 }
1187 
1188 static void __iopt_area_unfill_domain(struct iopt_area *area,
1189 				      struct iopt_pages *pages,
1190 				      struct iommu_domain *domain,
1191 				      unsigned long last_index)
1192 {
1193 	struct interval_tree_double_span_iter span;
1194 	unsigned long start_index = iopt_area_index(area);
1195 	unsigned long unmapped_end_index = start_index;
1196 	u64 backup[BATCH_BACKUP_SIZE];
1197 	struct pfn_batch batch;
1198 
1199 	lockdep_assert_held(&pages->mutex);
1200 
1201 	/*
1202 	 * For security we must not unpin something that is still DMA mapped,
1203 	 * so this must unmap any IOVA before we go ahead and unpin the pages.
1204 	 * This creates a complexity where we need to skip over unpinning pages
1205 	 * held in the xarray, but continue to unmap from the domain.
1206 	 *
1207 	 * The domain unmap cannot stop in the middle of a contiguous range of
1208 	 * PFNs. To solve this problem the unpinning step will read ahead to the
1209 	 * end of any contiguous span, unmap that whole span, and then only
1210 	 * unpin the leading part that does not have any accesses. The residual
1211 	 * PFNs that were unmapped but not unpinned are called a "carry" in the
1212 	 * batch as they are moved to the front of the PFN list and continue on
1213 	 * to the next iteration(s).
1214 	 */
1215 	batch_init_backup(&batch, last_index + 1, backup, sizeof(backup));
1216 	interval_tree_for_each_double_span(&span, &pages->domains_itree,
1217 					   &pages->access_itree, start_index,
1218 					   last_index) {
1219 		if (span.is_used) {
1220 			batch_skip_carry(&batch,
1221 					 span.last_used - span.start_used + 1);
1222 			continue;
1223 		}
1224 		iopt_area_unpin_domain(&batch, area, pages, domain,
1225 				       span.start_hole, span.last_hole,
1226 				       &unmapped_end_index, last_index);
1227 	}
1228 	/*
1229 	 * If the range ends in a access then we do the residual unmap without
1230 	 * any unpins.
1231 	 */
1232 	if (unmapped_end_index != last_index + 1)
1233 		iopt_area_unmap_domain_range(area, domain, unmapped_end_index,
1234 					     last_index);
1235 	WARN_ON(batch.total_pfns);
1236 	batch_destroy(&batch, backup);
1237 	update_unpinned(pages);
1238 }
1239 
1240 static void iopt_area_unfill_partial_domain(struct iopt_area *area,
1241 					    struct iopt_pages *pages,
1242 					    struct iommu_domain *domain,
1243 					    unsigned long end_index)
1244 {
1245 	if (end_index != iopt_area_index(area))
1246 		__iopt_area_unfill_domain(area, pages, domain, end_index - 1);
1247 }
1248 
1249 /**
1250  * iopt_area_unmap_domain() - Unmap without unpinning PFNs in a domain
1251  * @area: The IOVA range to unmap
1252  * @domain: The domain to unmap
1253  *
1254  * The caller must know that unpinning is not required, usually because there
1255  * are other domains in the iopt.
1256  */
1257 void iopt_area_unmap_domain(struct iopt_area *area, struct iommu_domain *domain)
1258 {
1259 	iommu_unmap_nofail(domain, iopt_area_iova(area),
1260 			   iopt_area_length(area));
1261 }
1262 
1263 /**
1264  * iopt_area_unfill_domain() - Unmap and unpin PFNs in a domain
1265  * @area: IOVA area to use
1266  * @pages: page supplier for the area (area->pages is NULL)
1267  * @domain: Domain to unmap from
1268  *
1269  * The domain should be removed from the domains_itree before calling. The
1270  * domain will always be unmapped, but the PFNs may not be unpinned if there are
1271  * still accesses.
1272  */
1273 void iopt_area_unfill_domain(struct iopt_area *area, struct iopt_pages *pages,
1274 			     struct iommu_domain *domain)
1275 {
1276 	__iopt_area_unfill_domain(area, pages, domain,
1277 				  iopt_area_last_index(area));
1278 }
1279 
1280 /**
1281  * iopt_area_fill_domain() - Map PFNs from the area into a domain
1282  * @area: IOVA area to use
1283  * @domain: Domain to load PFNs into
1284  *
1285  * Read the pfns from the area's underlying iopt_pages and map them into the
1286  * given domain. Called when attaching a new domain to an io_pagetable.
1287  */
1288 int iopt_area_fill_domain(struct iopt_area *area, struct iommu_domain *domain)
1289 {
1290 	unsigned long done_end_index;
1291 	struct pfn_reader pfns;
1292 	int rc;
1293 
1294 	lockdep_assert_held(&area->pages->mutex);
1295 
1296 	rc = pfn_reader_first(&pfns, area->pages, iopt_area_index(area),
1297 			      iopt_area_last_index(area));
1298 	if (rc)
1299 		return rc;
1300 
1301 	while (!pfn_reader_done(&pfns)) {
1302 		done_end_index = pfns.batch_start_index;
1303 		rc = batch_to_domain(&pfns.batch, domain, area,
1304 				     pfns.batch_start_index);
1305 		if (rc)
1306 			goto out_unmap;
1307 		done_end_index = pfns.batch_end_index;
1308 
1309 		rc = pfn_reader_next(&pfns);
1310 		if (rc)
1311 			goto out_unmap;
1312 	}
1313 
1314 	rc = pfn_reader_update_pinned(&pfns);
1315 	if (rc)
1316 		goto out_unmap;
1317 	goto out_destroy;
1318 
1319 out_unmap:
1320 	pfn_reader_release_pins(&pfns);
1321 	iopt_area_unfill_partial_domain(area, area->pages, domain,
1322 					done_end_index);
1323 out_destroy:
1324 	pfn_reader_destroy(&pfns);
1325 	return rc;
1326 }
1327 
1328 /**
1329  * iopt_area_fill_domains() - Install PFNs into the area's domains
1330  * @area: The area to act on
1331  * @pages: The pages associated with the area (area->pages is NULL)
1332  *
1333  * Called during area creation. The area is freshly created and not inserted in
1334  * the domains_itree yet. PFNs are read and loaded into every domain held in the
1335  * area's io_pagetable and the area is installed in the domains_itree.
1336  *
1337  * On failure all domains are left unchanged.
1338  */
1339 int iopt_area_fill_domains(struct iopt_area *area, struct iopt_pages *pages)
1340 {
1341 	unsigned long done_first_end_index;
1342 	unsigned long done_all_end_index;
1343 	struct iommu_domain *domain;
1344 	unsigned long unmap_index;
1345 	struct pfn_reader pfns;
1346 	unsigned long index;
1347 	int rc;
1348 
1349 	lockdep_assert_held(&area->iopt->domains_rwsem);
1350 
1351 	if (xa_empty(&area->iopt->domains))
1352 		return 0;
1353 
1354 	mutex_lock(&pages->mutex);
1355 	rc = pfn_reader_first(&pfns, pages, iopt_area_index(area),
1356 			      iopt_area_last_index(area));
1357 	if (rc)
1358 		goto out_unlock;
1359 
1360 	while (!pfn_reader_done(&pfns)) {
1361 		done_first_end_index = pfns.batch_end_index;
1362 		done_all_end_index = pfns.batch_start_index;
1363 		xa_for_each(&area->iopt->domains, index, domain) {
1364 			rc = batch_to_domain(&pfns.batch, domain, area,
1365 					     pfns.batch_start_index);
1366 			if (rc)
1367 				goto out_unmap;
1368 		}
1369 		done_all_end_index = done_first_end_index;
1370 
1371 		rc = pfn_reader_next(&pfns);
1372 		if (rc)
1373 			goto out_unmap;
1374 	}
1375 	rc = pfn_reader_update_pinned(&pfns);
1376 	if (rc)
1377 		goto out_unmap;
1378 
1379 	area->storage_domain = xa_load(&area->iopt->domains, 0);
1380 	interval_tree_insert(&area->pages_node, &pages->domains_itree);
1381 	goto out_destroy;
1382 
1383 out_unmap:
1384 	pfn_reader_release_pins(&pfns);
1385 	xa_for_each(&area->iopt->domains, unmap_index, domain) {
1386 		unsigned long end_index;
1387 
1388 		if (unmap_index < index)
1389 			end_index = done_first_end_index;
1390 		else
1391 			end_index = done_all_end_index;
1392 
1393 		/*
1394 		 * The area is not yet part of the domains_itree so we have to
1395 		 * manage the unpinning specially. The last domain does the
1396 		 * unpin, every other domain is just unmapped.
1397 		 */
1398 		if (unmap_index != area->iopt->next_domain_id - 1) {
1399 			if (end_index != iopt_area_index(area))
1400 				iopt_area_unmap_domain_range(
1401 					area, domain, iopt_area_index(area),
1402 					end_index - 1);
1403 		} else {
1404 			iopt_area_unfill_partial_domain(area, pages, domain,
1405 							end_index);
1406 		}
1407 	}
1408 out_destroy:
1409 	pfn_reader_destroy(&pfns);
1410 out_unlock:
1411 	mutex_unlock(&pages->mutex);
1412 	return rc;
1413 }
1414 
1415 /**
1416  * iopt_area_unfill_domains() - unmap PFNs from the area's domains
1417  * @area: The area to act on
1418  * @pages: The pages associated with the area (area->pages is NULL)
1419  *
1420  * Called during area destruction. This unmaps the iova's covered by all the
1421  * area's domains and releases the PFNs.
1422  */
1423 void iopt_area_unfill_domains(struct iopt_area *area, struct iopt_pages *pages)
1424 {
1425 	struct io_pagetable *iopt = area->iopt;
1426 	struct iommu_domain *domain;
1427 	unsigned long index;
1428 
1429 	lockdep_assert_held(&iopt->domains_rwsem);
1430 
1431 	mutex_lock(&pages->mutex);
1432 	if (!area->storage_domain)
1433 		goto out_unlock;
1434 
1435 	xa_for_each(&iopt->domains, index, domain)
1436 		if (domain != area->storage_domain)
1437 			iopt_area_unmap_domain_range(
1438 				area, domain, iopt_area_index(area),
1439 				iopt_area_last_index(area));
1440 
1441 	interval_tree_remove(&area->pages_node, &pages->domains_itree);
1442 	iopt_area_unfill_domain(area, pages, area->storage_domain);
1443 	area->storage_domain = NULL;
1444 out_unlock:
1445 	mutex_unlock(&pages->mutex);
1446 }
1447 
1448 static void iopt_pages_unpin_xarray(struct pfn_batch *batch,
1449 				    struct iopt_pages *pages,
1450 				    unsigned long start_index,
1451 				    unsigned long end_index)
1452 {
1453 	while (start_index <= end_index) {
1454 		batch_from_xarray_clear(batch, &pages->pinned_pfns, start_index,
1455 					end_index);
1456 		batch_unpin(batch, pages, 0, batch->total_pfns);
1457 		start_index += batch->total_pfns;
1458 		batch_clear(batch);
1459 	}
1460 }
1461 
1462 /**
1463  * iopt_pages_unfill_xarray() - Update the xarry after removing an access
1464  * @pages: The pages to act on
1465  * @start_index: Starting PFN index
1466  * @last_index: Last PFN index
1467  *
1468  * Called when an iopt_pages_access is removed, removes pages from the itree.
1469  * The access should already be removed from the access_itree.
1470  */
1471 void iopt_pages_unfill_xarray(struct iopt_pages *pages,
1472 			      unsigned long start_index,
1473 			      unsigned long last_index)
1474 {
1475 	struct interval_tree_double_span_iter span;
1476 	u64 backup[BATCH_BACKUP_SIZE];
1477 	struct pfn_batch batch;
1478 	bool batch_inited = false;
1479 
1480 	lockdep_assert_held(&pages->mutex);
1481 
1482 	interval_tree_for_each_double_span(&span, &pages->access_itree,
1483 					   &pages->domains_itree, start_index,
1484 					   last_index) {
1485 		if (!span.is_used) {
1486 			if (!batch_inited) {
1487 				batch_init_backup(&batch,
1488 						  last_index - start_index + 1,
1489 						  backup, sizeof(backup));
1490 				batch_inited = true;
1491 			}
1492 			iopt_pages_unpin_xarray(&batch, pages, span.start_hole,
1493 						span.last_hole);
1494 		} else if (span.is_used == 2) {
1495 			/* Covered by a domain */
1496 			clear_xarray(&pages->pinned_pfns, span.start_used,
1497 				     span.last_used);
1498 		}
1499 		/* Otherwise covered by an existing access */
1500 	}
1501 	if (batch_inited)
1502 		batch_destroy(&batch, backup);
1503 	update_unpinned(pages);
1504 }
1505 
1506 /**
1507  * iopt_pages_fill_from_xarray() - Fast path for reading PFNs
1508  * @pages: The pages to act on
1509  * @start_index: The first page index in the range
1510  * @last_index: The last page index in the range
1511  * @out_pages: The output array to return the pages
1512  *
1513  * This can be called if the caller is holding a refcount on an
1514  * iopt_pages_access that is known to have already been filled. It quickly reads
1515  * the pages directly from the xarray.
1516  *
1517  * This is part of the SW iommu interface to read pages for in-kernel use.
1518  */
1519 void iopt_pages_fill_from_xarray(struct iopt_pages *pages,
1520 				 unsigned long start_index,
1521 				 unsigned long last_index,
1522 				 struct page **out_pages)
1523 {
1524 	XA_STATE(xas, &pages->pinned_pfns, start_index);
1525 	void *entry;
1526 
1527 	rcu_read_lock();
1528 	while (start_index <= last_index) {
1529 		entry = xas_next(&xas);
1530 		if (xas_retry(&xas, entry))
1531 			continue;
1532 		WARN_ON(!xa_is_value(entry));
1533 		*(out_pages++) = pfn_to_page(xa_to_value(entry));
1534 		start_index++;
1535 	}
1536 	rcu_read_unlock();
1537 }
1538 
1539 static int iopt_pages_fill_from_domain(struct iopt_pages *pages,
1540 				       unsigned long start_index,
1541 				       unsigned long last_index,
1542 				       struct page **out_pages)
1543 {
1544 	while (start_index != last_index + 1) {
1545 		unsigned long domain_last;
1546 		struct iopt_area *area;
1547 
1548 		area = iopt_pages_find_domain_area(pages, start_index);
1549 		if (WARN_ON(!area))
1550 			return -EINVAL;
1551 
1552 		domain_last = min(iopt_area_last_index(area), last_index);
1553 		out_pages = raw_pages_from_domain(area->storage_domain, area,
1554 						  start_index, domain_last,
1555 						  out_pages);
1556 		start_index = domain_last + 1;
1557 	}
1558 	return 0;
1559 }
1560 
1561 static int iopt_pages_fill_from_mm(struct iopt_pages *pages,
1562 				   struct pfn_reader_user *user,
1563 				   unsigned long start_index,
1564 				   unsigned long last_index,
1565 				   struct page **out_pages)
1566 {
1567 	unsigned long cur_index = start_index;
1568 	int rc;
1569 
1570 	while (cur_index != last_index + 1) {
1571 		user->upages = out_pages + (cur_index - start_index);
1572 		rc = pfn_reader_user_pin(user, pages, cur_index, last_index);
1573 		if (rc)
1574 			goto out_unpin;
1575 		cur_index = user->upages_end;
1576 	}
1577 	return 0;
1578 
1579 out_unpin:
1580 	if (start_index != cur_index)
1581 		iopt_pages_err_unpin(pages, start_index, cur_index - 1,
1582 				     out_pages);
1583 	return rc;
1584 }
1585 
1586 /**
1587  * iopt_pages_fill_xarray() - Read PFNs
1588  * @pages: The pages to act on
1589  * @start_index: The first page index in the range
1590  * @last_index: The last page index in the range
1591  * @out_pages: The output array to return the pages, may be NULL
1592  *
1593  * This populates the xarray and returns the pages in out_pages. As the slow
1594  * path this is able to copy pages from other storage tiers into the xarray.
1595  *
1596  * On failure the xarray is left unchanged.
1597  *
1598  * This is part of the SW iommu interface to read pages for in-kernel use.
1599  */
1600 int iopt_pages_fill_xarray(struct iopt_pages *pages, unsigned long start_index,
1601 			   unsigned long last_index, struct page **out_pages)
1602 {
1603 	struct interval_tree_double_span_iter span;
1604 	unsigned long xa_end = start_index;
1605 	struct pfn_reader_user user;
1606 	int rc;
1607 
1608 	lockdep_assert_held(&pages->mutex);
1609 
1610 	pfn_reader_user_init(&user, pages);
1611 	user.upages_len = (last_index - start_index + 1) * sizeof(*out_pages);
1612 	interval_tree_for_each_double_span(&span, &pages->access_itree,
1613 					   &pages->domains_itree, start_index,
1614 					   last_index) {
1615 		struct page **cur_pages;
1616 
1617 		if (span.is_used == 1) {
1618 			cur_pages = out_pages + (span.start_used - start_index);
1619 			iopt_pages_fill_from_xarray(pages, span.start_used,
1620 						    span.last_used, cur_pages);
1621 			continue;
1622 		}
1623 
1624 		if (span.is_used == 2) {
1625 			cur_pages = out_pages + (span.start_used - start_index);
1626 			iopt_pages_fill_from_domain(pages, span.start_used,
1627 						    span.last_used, cur_pages);
1628 			rc = pages_to_xarray(&pages->pinned_pfns,
1629 					     span.start_used, span.last_used,
1630 					     cur_pages);
1631 			if (rc)
1632 				goto out_clean_xa;
1633 			xa_end = span.last_used + 1;
1634 			continue;
1635 		}
1636 
1637 		/* hole */
1638 		cur_pages = out_pages + (span.start_hole - start_index);
1639 		rc = iopt_pages_fill_from_mm(pages, &user, span.start_hole,
1640 					     span.last_hole, cur_pages);
1641 		if (rc)
1642 			goto out_clean_xa;
1643 		rc = pages_to_xarray(&pages->pinned_pfns, span.start_hole,
1644 				     span.last_hole, cur_pages);
1645 		if (rc) {
1646 			iopt_pages_err_unpin(pages, span.start_hole,
1647 					     span.last_hole, cur_pages);
1648 			goto out_clean_xa;
1649 		}
1650 		xa_end = span.last_hole + 1;
1651 	}
1652 	rc = pfn_reader_user_update_pinned(&user, pages);
1653 	if (rc)
1654 		goto out_clean_xa;
1655 	user.upages = NULL;
1656 	pfn_reader_user_destroy(&user, pages);
1657 	return 0;
1658 
1659 out_clean_xa:
1660 	if (start_index != xa_end)
1661 		iopt_pages_unfill_xarray(pages, start_index, xa_end - 1);
1662 	user.upages = NULL;
1663 	pfn_reader_user_destroy(&user, pages);
1664 	return rc;
1665 }
1666 
1667 /*
1668  * This uses the pfn_reader instead of taking a shortcut by using the mm. It can
1669  * do every scenario and is fully consistent with what an iommu_domain would
1670  * see.
1671  */
1672 static int iopt_pages_rw_slow(struct iopt_pages *pages,
1673 			      unsigned long start_index,
1674 			      unsigned long last_index, unsigned long offset,
1675 			      void *data, unsigned long length,
1676 			      unsigned int flags)
1677 {
1678 	struct pfn_reader pfns;
1679 	int rc;
1680 
1681 	mutex_lock(&pages->mutex);
1682 
1683 	rc = pfn_reader_first(&pfns, pages, start_index, last_index);
1684 	if (rc)
1685 		goto out_unlock;
1686 
1687 	while (!pfn_reader_done(&pfns)) {
1688 		unsigned long done;
1689 
1690 		done = batch_rw(&pfns.batch, data, offset, length, flags);
1691 		data += done;
1692 		length -= done;
1693 		offset = 0;
1694 		pfn_reader_unpin(&pfns);
1695 
1696 		rc = pfn_reader_next(&pfns);
1697 		if (rc)
1698 			goto out_destroy;
1699 	}
1700 	if (WARN_ON(length != 0))
1701 		rc = -EINVAL;
1702 out_destroy:
1703 	pfn_reader_destroy(&pfns);
1704 out_unlock:
1705 	mutex_unlock(&pages->mutex);
1706 	return rc;
1707 }
1708 
1709 /*
1710  * A medium speed path that still allows DMA inconsistencies, but doesn't do any
1711  * memory allocations or interval tree searches.
1712  */
1713 static int iopt_pages_rw_page(struct iopt_pages *pages, unsigned long index,
1714 			      unsigned long offset, void *data,
1715 			      unsigned long length, unsigned int flags)
1716 {
1717 	struct page *page = NULL;
1718 	int rc;
1719 
1720 	if (!mmget_not_zero(pages->source_mm))
1721 		return iopt_pages_rw_slow(pages, index, index, offset, data,
1722 					  length, flags);
1723 
1724 	mmap_read_lock(pages->source_mm);
1725 	rc = pin_user_pages_remote(
1726 		pages->source_mm, (uintptr_t)(pages->uptr + index * PAGE_SIZE),
1727 		1, (flags & IOMMUFD_ACCESS_RW_WRITE) ? FOLL_WRITE : 0, &page,
1728 		NULL, NULL);
1729 	mmap_read_unlock(pages->source_mm);
1730 	if (rc != 1) {
1731 		if (WARN_ON(rc >= 0))
1732 			rc = -EINVAL;
1733 		goto out_mmput;
1734 	}
1735 	copy_data_page(page, data, offset, length, flags);
1736 	unpin_user_page(page);
1737 	rc = 0;
1738 
1739 out_mmput:
1740 	mmput(pages->source_mm);
1741 	return rc;
1742 }
1743 
1744 /**
1745  * iopt_pages_rw_access - Copy to/from a linear slice of the pages
1746  * @pages: pages to act on
1747  * @start_byte: First byte of pages to copy to/from
1748  * @data: Kernel buffer to get/put the data
1749  * @length: Number of bytes to copy
1750  * @flags: IOMMUFD_ACCESS_RW_* flags
1751  *
1752  * This will find each page in the range, kmap it and then memcpy to/from
1753  * the given kernel buffer.
1754  */
1755 int iopt_pages_rw_access(struct iopt_pages *pages, unsigned long start_byte,
1756 			 void *data, unsigned long length, unsigned int flags)
1757 {
1758 	unsigned long start_index = start_byte / PAGE_SIZE;
1759 	unsigned long last_index = (start_byte + length - 1) / PAGE_SIZE;
1760 	bool change_mm = current->mm != pages->source_mm;
1761 	int rc = 0;
1762 
1763 	if (IS_ENABLED(CONFIG_IOMMUFD_TEST) &&
1764 	    (flags & __IOMMUFD_ACCESS_RW_SLOW_PATH))
1765 		change_mm = true;
1766 
1767 	if ((flags & IOMMUFD_ACCESS_RW_WRITE) && !pages->writable)
1768 		return -EPERM;
1769 
1770 	if (!(flags & IOMMUFD_ACCESS_RW_KTHREAD) && change_mm) {
1771 		if (start_index == last_index)
1772 			return iopt_pages_rw_page(pages, start_index,
1773 						  start_byte % PAGE_SIZE, data,
1774 						  length, flags);
1775 		return iopt_pages_rw_slow(pages, start_index, last_index,
1776 					  start_byte % PAGE_SIZE, data, length,
1777 					  flags);
1778 	}
1779 
1780 	/*
1781 	 * Try to copy using copy_to_user(). We do this as a fast path and
1782 	 * ignore any pinning inconsistencies, unlike a real DMA path.
1783 	 */
1784 	if (change_mm) {
1785 		if (!mmget_not_zero(pages->source_mm))
1786 			return iopt_pages_rw_slow(pages, start_index,
1787 						  last_index,
1788 						  start_byte % PAGE_SIZE, data,
1789 						  length, flags);
1790 		kthread_use_mm(pages->source_mm);
1791 	}
1792 
1793 	if (flags & IOMMUFD_ACCESS_RW_WRITE) {
1794 		if (copy_to_user(pages->uptr + start_byte, data, length))
1795 			rc = -EFAULT;
1796 	} else {
1797 		if (copy_from_user(data, pages->uptr + start_byte, length))
1798 			rc = -EFAULT;
1799 	}
1800 
1801 	if (change_mm) {
1802 		kthread_unuse_mm(pages->source_mm);
1803 		mmput(pages->source_mm);
1804 	}
1805 
1806 	return rc;
1807 }
1808 
1809 static struct iopt_pages_access *
1810 iopt_pages_get_exact_access(struct iopt_pages *pages, unsigned long index,
1811 			    unsigned long last)
1812 {
1813 	struct interval_tree_node *node;
1814 
1815 	lockdep_assert_held(&pages->mutex);
1816 
1817 	/* There can be overlapping ranges in this interval tree */
1818 	for (node = interval_tree_iter_first(&pages->access_itree, index, last);
1819 	     node; node = interval_tree_iter_next(node, index, last))
1820 		if (node->start == index && node->last == last)
1821 			return container_of(node, struct iopt_pages_access,
1822 					    node);
1823 	return NULL;
1824 }
1825 
1826 /**
1827  * iopt_area_add_access() - Record an in-knerel access for PFNs
1828  * @area: The source of PFNs
1829  * @start_index: First page index
1830  * @last_index: Inclusive last page index
1831  * @out_pages: Output list of struct page's representing the PFNs
1832  * @flags: IOMMUFD_ACCESS_RW_* flags
1833  *
1834  * Record that an in-kernel access will be accessing the pages, ensure they are
1835  * pinned, and return the PFNs as a simple list of 'struct page *'.
1836  *
1837  * This should be undone through a matching call to iopt_area_remove_access()
1838  */
1839 int iopt_area_add_access(struct iopt_area *area, unsigned long start_index,
1840 			  unsigned long last_index, struct page **out_pages,
1841 			  unsigned int flags)
1842 {
1843 	struct iopt_pages *pages = area->pages;
1844 	struct iopt_pages_access *access;
1845 	int rc;
1846 
1847 	if ((flags & IOMMUFD_ACCESS_RW_WRITE) && !pages->writable)
1848 		return -EPERM;
1849 
1850 	mutex_lock(&pages->mutex);
1851 	access = iopt_pages_get_exact_access(pages, start_index, last_index);
1852 	if (access) {
1853 		area->num_accesses++;
1854 		access->users++;
1855 		iopt_pages_fill_from_xarray(pages, start_index, last_index,
1856 					    out_pages);
1857 		mutex_unlock(&pages->mutex);
1858 		return 0;
1859 	}
1860 
1861 	access = kzalloc(sizeof(*access), GFP_KERNEL_ACCOUNT);
1862 	if (!access) {
1863 		rc = -ENOMEM;
1864 		goto err_unlock;
1865 	}
1866 
1867 	rc = iopt_pages_fill_xarray(pages, start_index, last_index, out_pages);
1868 	if (rc)
1869 		goto err_free;
1870 
1871 	access->node.start = start_index;
1872 	access->node.last = last_index;
1873 	access->users = 1;
1874 	area->num_accesses++;
1875 	interval_tree_insert(&access->node, &pages->access_itree);
1876 	mutex_unlock(&pages->mutex);
1877 	return 0;
1878 
1879 err_free:
1880 	kfree(access);
1881 err_unlock:
1882 	mutex_unlock(&pages->mutex);
1883 	return rc;
1884 }
1885 
1886 /**
1887  * iopt_area_remove_access() - Release an in-kernel access for PFNs
1888  * @area: The source of PFNs
1889  * @start_index: First page index
1890  * @last_index: Inclusive last page index
1891  *
1892  * Undo iopt_area_add_access() and unpin the pages if necessary. The caller
1893  * must stop using the PFNs before calling this.
1894  */
1895 void iopt_area_remove_access(struct iopt_area *area, unsigned long start_index,
1896 			     unsigned long last_index)
1897 {
1898 	struct iopt_pages *pages = area->pages;
1899 	struct iopt_pages_access *access;
1900 
1901 	mutex_lock(&pages->mutex);
1902 	access = iopt_pages_get_exact_access(pages, start_index, last_index);
1903 	if (WARN_ON(!access))
1904 		goto out_unlock;
1905 
1906 	WARN_ON(area->num_accesses == 0 || access->users == 0);
1907 	area->num_accesses--;
1908 	access->users--;
1909 	if (access->users)
1910 		goto out_unlock;
1911 
1912 	interval_tree_remove(&access->node, &pages->access_itree);
1913 	iopt_pages_unfill_xarray(pages, start_index, last_index);
1914 	kfree(access);
1915 out_unlock:
1916 	mutex_unlock(&pages->mutex);
1917 }
1918