1 /*
2  * Copyright (c) 2014 Mellanox Technologies. All rights reserved.
3  *
4  * This software is available to you under a choice of one of two
5  * licenses.  You may choose to be licensed under the terms of the GNU
6  * General Public License (GPL) Version 2, available from the file
7  * COPYING in the main directory of this source tree, or the
8  * OpenIB.org BSD license below:
9  *
10  *     Redistribution and use in source and binary forms, with or
11  *     without modification, are permitted provided that the following
12  *     conditions are met:
13  *
14  *      - Redistributions of source code must retain the above
15  *        copyright notice, this list of conditions and the following
16  *        disclaimer.
17  *
18  *      - Redistributions in binary form must reproduce the above
19  *        copyright notice, this list of conditions and the following
20  *        disclaimer in the documentation and/or other materials
21  *        provided with the distribution.
22  *
23  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
24  * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
25  * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
26  * NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS
27  * BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN
28  * ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
29  * CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
30  * SOFTWARE.
31  */
32 
33 #include <linux/types.h>
34 #include <linux/sched.h>
35 #include <linux/sched/mm.h>
36 #include <linux/sched/task.h>
37 #include <linux/pid.h>
38 #include <linux/slab.h>
39 #include <linux/export.h>
40 #include <linux/vmalloc.h>
41 #include <linux/hugetlb.h>
42 #include <linux/interval_tree.h>
43 #include <linux/pagemap.h>
44 
45 #include <rdma/ib_verbs.h>
46 #include <rdma/ib_umem.h>
47 #include <rdma/ib_umem_odp.h>
48 
49 #include "uverbs.h"
50 
51 static void ib_umem_notifier_start_account(struct ib_umem_odp *umem_odp)
52 {
53 	mutex_lock(&umem_odp->umem_mutex);
54 	if (umem_odp->notifiers_count++ == 0)
55 		/*
56 		 * Initialize the completion object for waiting on
57 		 * notifiers. Since notifier_count is zero, no one should be
58 		 * waiting right now.
59 		 */
60 		reinit_completion(&umem_odp->notifier_completion);
61 	mutex_unlock(&umem_odp->umem_mutex);
62 }
63 
64 static void ib_umem_notifier_end_account(struct ib_umem_odp *umem_odp)
65 {
66 	mutex_lock(&umem_odp->umem_mutex);
67 	/*
68 	 * This sequence increase will notify the QP page fault that the page
69 	 * that is going to be mapped in the spte could have been freed.
70 	 */
71 	++umem_odp->notifiers_seq;
72 	if (--umem_odp->notifiers_count == 0)
73 		complete_all(&umem_odp->notifier_completion);
74 	mutex_unlock(&umem_odp->umem_mutex);
75 }
76 
77 static void ib_umem_notifier_release(struct mmu_notifier *mn,
78 				     struct mm_struct *mm)
79 {
80 	struct ib_ucontext_per_mm *per_mm =
81 		container_of(mn, struct ib_ucontext_per_mm, mn);
82 	struct rb_node *node;
83 
84 	down_read(&per_mm->umem_rwsem);
85 	if (!per_mm->active)
86 		goto out;
87 
88 	for (node = rb_first_cached(&per_mm->umem_tree); node;
89 	     node = rb_next(node)) {
90 		struct ib_umem_odp *umem_odp =
91 			rb_entry(node, struct ib_umem_odp, interval_tree.rb);
92 
93 		/*
94 		 * Increase the number of notifiers running, to prevent any
95 		 * further fault handling on this MR.
96 		 */
97 		ib_umem_notifier_start_account(umem_odp);
98 		complete_all(&umem_odp->notifier_completion);
99 		umem_odp->umem.context->invalidate_range(
100 			umem_odp, ib_umem_start(umem_odp),
101 			ib_umem_end(umem_odp));
102 	}
103 
104 out:
105 	up_read(&per_mm->umem_rwsem);
106 }
107 
108 static int invalidate_range_start_trampoline(struct ib_umem_odp *item,
109 					     u64 start, u64 end, void *cookie)
110 {
111 	ib_umem_notifier_start_account(item);
112 	item->umem.context->invalidate_range(item, start, end);
113 	return 0;
114 }
115 
116 static int ib_umem_notifier_invalidate_range_start(struct mmu_notifier *mn,
117 				const struct mmu_notifier_range *range)
118 {
119 	struct ib_ucontext_per_mm *per_mm =
120 		container_of(mn, struct ib_ucontext_per_mm, mn);
121 	int rc;
122 
123 	if (mmu_notifier_range_blockable(range))
124 		down_read(&per_mm->umem_rwsem);
125 	else if (!down_read_trylock(&per_mm->umem_rwsem))
126 		return -EAGAIN;
127 
128 	if (!per_mm->active) {
129 		up_read(&per_mm->umem_rwsem);
130 		/*
131 		 * At this point active is permanently set and visible to this
132 		 * CPU without a lock, that fact is relied on to skip the unlock
133 		 * in range_end.
134 		 */
135 		return 0;
136 	}
137 
138 	rc = rbt_ib_umem_for_each_in_range(&per_mm->umem_tree, range->start,
139 					   range->end,
140 					   invalidate_range_start_trampoline,
141 					   mmu_notifier_range_blockable(range),
142 					   NULL);
143 	if (rc)
144 		up_read(&per_mm->umem_rwsem);
145 	return rc;
146 }
147 
148 static int invalidate_range_end_trampoline(struct ib_umem_odp *item, u64 start,
149 					   u64 end, void *cookie)
150 {
151 	ib_umem_notifier_end_account(item);
152 	return 0;
153 }
154 
155 static void ib_umem_notifier_invalidate_range_end(struct mmu_notifier *mn,
156 				const struct mmu_notifier_range *range)
157 {
158 	struct ib_ucontext_per_mm *per_mm =
159 		container_of(mn, struct ib_ucontext_per_mm, mn);
160 
161 	if (unlikely(!per_mm->active))
162 		return;
163 
164 	rbt_ib_umem_for_each_in_range(&per_mm->umem_tree, range->start,
165 				      range->end,
166 				      invalidate_range_end_trampoline, true, NULL);
167 	up_read(&per_mm->umem_rwsem);
168 }
169 
170 static const struct mmu_notifier_ops ib_umem_notifiers = {
171 	.release                    = ib_umem_notifier_release,
172 	.invalidate_range_start     = ib_umem_notifier_invalidate_range_start,
173 	.invalidate_range_end       = ib_umem_notifier_invalidate_range_end,
174 };
175 
176 static void remove_umem_from_per_mm(struct ib_umem_odp *umem_odp)
177 {
178 	struct ib_ucontext_per_mm *per_mm = umem_odp->per_mm;
179 
180 	down_write(&per_mm->umem_rwsem);
181 	interval_tree_remove(&umem_odp->interval_tree, &per_mm->umem_tree);
182 	complete_all(&umem_odp->notifier_completion);
183 	up_write(&per_mm->umem_rwsem);
184 }
185 
186 static struct ib_ucontext_per_mm *alloc_per_mm(struct ib_ucontext *ctx,
187 					       struct mm_struct *mm)
188 {
189 	struct ib_ucontext_per_mm *per_mm;
190 	int ret;
191 
192 	per_mm = kzalloc(sizeof(*per_mm), GFP_KERNEL);
193 	if (!per_mm)
194 		return ERR_PTR(-ENOMEM);
195 
196 	per_mm->context = ctx;
197 	per_mm->mm = mm;
198 	per_mm->umem_tree = RB_ROOT_CACHED;
199 	init_rwsem(&per_mm->umem_rwsem);
200 	per_mm->active = true;
201 
202 	rcu_read_lock();
203 	per_mm->tgid = get_task_pid(current->group_leader, PIDTYPE_PID);
204 	rcu_read_unlock();
205 
206 	WARN_ON(mm != current->mm);
207 
208 	per_mm->mn.ops = &ib_umem_notifiers;
209 	ret = mmu_notifier_register(&per_mm->mn, per_mm->mm);
210 	if (ret) {
211 		dev_err(&ctx->device->dev,
212 			"Failed to register mmu_notifier %d\n", ret);
213 		goto out_pid;
214 	}
215 
216 	list_add(&per_mm->ucontext_list, &ctx->per_mm_list);
217 	return per_mm;
218 
219 out_pid:
220 	put_pid(per_mm->tgid);
221 	kfree(per_mm);
222 	return ERR_PTR(ret);
223 }
224 
225 static struct ib_ucontext_per_mm *get_per_mm(struct ib_umem_odp *umem_odp)
226 {
227 	struct ib_ucontext *ctx = umem_odp->umem.context;
228 	struct ib_ucontext_per_mm *per_mm;
229 
230 	lockdep_assert_held(&ctx->per_mm_list_lock);
231 
232 	/*
233 	 * Generally speaking we expect only one or two per_mm in this list,
234 	 * so no reason to optimize this search today.
235 	 */
236 	list_for_each_entry(per_mm, &ctx->per_mm_list, ucontext_list) {
237 		if (per_mm->mm == umem_odp->umem.owning_mm)
238 			return per_mm;
239 	}
240 
241 	return alloc_per_mm(ctx, umem_odp->umem.owning_mm);
242 }
243 
244 static void free_per_mm(struct rcu_head *rcu)
245 {
246 	kfree(container_of(rcu, struct ib_ucontext_per_mm, rcu));
247 }
248 
249 static void put_per_mm(struct ib_umem_odp *umem_odp)
250 {
251 	struct ib_ucontext_per_mm *per_mm = umem_odp->per_mm;
252 	struct ib_ucontext *ctx = umem_odp->umem.context;
253 	bool need_free;
254 
255 	mutex_lock(&ctx->per_mm_list_lock);
256 	umem_odp->per_mm = NULL;
257 	per_mm->odp_mrs_count--;
258 	need_free = per_mm->odp_mrs_count == 0;
259 	if (need_free)
260 		list_del(&per_mm->ucontext_list);
261 	mutex_unlock(&ctx->per_mm_list_lock);
262 
263 	if (!need_free)
264 		return;
265 
266 	/*
267 	 * NOTE! mmu_notifier_unregister() can happen between a start/end
268 	 * callback, resulting in an start/end, and thus an unbalanced
269 	 * lock. This doesn't really matter to us since we are about to kfree
270 	 * the memory that holds the lock, however LOCKDEP doesn't like this.
271 	 */
272 	down_write(&per_mm->umem_rwsem);
273 	per_mm->active = false;
274 	up_write(&per_mm->umem_rwsem);
275 
276 	WARN_ON(!RB_EMPTY_ROOT(&per_mm->umem_tree.rb_root));
277 	mmu_notifier_unregister_no_release(&per_mm->mn, per_mm->mm);
278 	put_pid(per_mm->tgid);
279 	mmu_notifier_call_srcu(&per_mm->rcu, free_per_mm);
280 }
281 
282 static inline int ib_init_umem_odp(struct ib_umem_odp *umem_odp,
283 				   struct ib_ucontext_per_mm *per_mm)
284 {
285 	struct ib_ucontext *ctx = umem_odp->umem.context;
286 	int ret;
287 
288 	umem_odp->umem.is_odp = 1;
289 	if (!umem_odp->is_implicit_odp) {
290 		size_t pages = ib_umem_odp_num_pages(umem_odp);
291 
292 		if (!pages)
293 			return -EINVAL;
294 
295 		/*
296 		 * Note that the representation of the intervals in the
297 		 * interval tree considers the ending point as contained in
298 		 * the interval, while the function ib_umem_end returns the
299 		 * first address which is not contained in the umem.
300 		 */
301 		umem_odp->interval_tree.start = ib_umem_start(umem_odp);
302 		umem_odp->interval_tree.last = ib_umem_end(umem_odp) - 1;
303 
304 		umem_odp->page_list = vzalloc(
305 			array_size(sizeof(*umem_odp->page_list), pages));
306 		if (!umem_odp->page_list)
307 			return -ENOMEM;
308 
309 		umem_odp->dma_list =
310 			vzalloc(array_size(sizeof(*umem_odp->dma_list), pages));
311 		if (!umem_odp->dma_list) {
312 			ret = -ENOMEM;
313 			goto out_page_list;
314 		}
315 	}
316 
317 	mutex_lock(&ctx->per_mm_list_lock);
318 	if (!per_mm) {
319 		per_mm = get_per_mm(umem_odp);
320 		if (IS_ERR(per_mm)) {
321 			ret = PTR_ERR(per_mm);
322 			goto out_unlock;
323 		}
324 	}
325 	umem_odp->per_mm = per_mm;
326 	per_mm->odp_mrs_count++;
327 	mutex_unlock(&ctx->per_mm_list_lock);
328 
329 	mutex_init(&umem_odp->umem_mutex);
330 	init_completion(&umem_odp->notifier_completion);
331 
332 	if (!umem_odp->is_implicit_odp) {
333 		down_write(&per_mm->umem_rwsem);
334 		interval_tree_insert(&umem_odp->interval_tree,
335 				     &per_mm->umem_tree);
336 		up_write(&per_mm->umem_rwsem);
337 	}
338 	mmgrab(umem_odp->umem.owning_mm);
339 
340 	return 0;
341 
342 out_unlock:
343 	mutex_unlock(&ctx->per_mm_list_lock);
344 	vfree(umem_odp->dma_list);
345 out_page_list:
346 	vfree(umem_odp->page_list);
347 	return ret;
348 }
349 
350 /**
351  * ib_umem_odp_alloc_implicit - Allocate a parent implicit ODP umem
352  *
353  * Implicit ODP umems do not have a VA range and do not have any page lists.
354  * They exist only to hold the per_mm reference to help the driver create
355  * children umems.
356  *
357  * @udata: udata from the syscall being used to create the umem
358  * @access: ib_reg_mr access flags
359  */
360 struct ib_umem_odp *ib_umem_odp_alloc_implicit(struct ib_udata *udata,
361 					       int access)
362 {
363 	struct ib_ucontext *context =
364 		container_of(udata, struct uverbs_attr_bundle, driver_udata)
365 			->context;
366 	struct ib_umem *umem;
367 	struct ib_umem_odp *umem_odp;
368 	int ret;
369 
370 	if (access & IB_ACCESS_HUGETLB)
371 		return ERR_PTR(-EINVAL);
372 
373 	if (!context)
374 		return ERR_PTR(-EIO);
375 	if (WARN_ON_ONCE(!context->invalidate_range))
376 		return ERR_PTR(-EINVAL);
377 
378 	umem_odp = kzalloc(sizeof(*umem_odp), GFP_KERNEL);
379 	if (!umem_odp)
380 		return ERR_PTR(-ENOMEM);
381 	umem = &umem_odp->umem;
382 	umem->context = context;
383 	umem->writable = ib_access_writable(access);
384 	umem->owning_mm = current->mm;
385 	umem_odp->is_implicit_odp = 1;
386 	umem_odp->page_shift = PAGE_SHIFT;
387 
388 	ret = ib_init_umem_odp(umem_odp, NULL);
389 	if (ret) {
390 		kfree(umem_odp);
391 		return ERR_PTR(ret);
392 	}
393 	return umem_odp;
394 }
395 EXPORT_SYMBOL(ib_umem_odp_alloc_implicit);
396 
397 /**
398  * ib_umem_odp_alloc_child - Allocate a child ODP umem under an implicit
399  *                           parent ODP umem
400  *
401  * @root: The parent umem enclosing the child. This must be allocated using
402  *        ib_alloc_implicit_odp_umem()
403  * @addr: The starting userspace VA
404  * @size: The length of the userspace VA
405  */
406 struct ib_umem_odp *ib_umem_odp_alloc_child(struct ib_umem_odp *root,
407 					    unsigned long addr, size_t size)
408 {
409 	/*
410 	 * Caller must ensure that root cannot be freed during the call to
411 	 * ib_alloc_odp_umem.
412 	 */
413 	struct ib_umem_odp *odp_data;
414 	struct ib_umem *umem;
415 	int ret;
416 
417 	if (WARN_ON(!root->is_implicit_odp))
418 		return ERR_PTR(-EINVAL);
419 
420 	odp_data = kzalloc(sizeof(*odp_data), GFP_KERNEL);
421 	if (!odp_data)
422 		return ERR_PTR(-ENOMEM);
423 	umem = &odp_data->umem;
424 	umem->context    = root->umem.context;
425 	umem->length     = size;
426 	umem->address    = addr;
427 	umem->writable   = root->umem.writable;
428 	umem->owning_mm  = root->umem.owning_mm;
429 	odp_data->page_shift = PAGE_SHIFT;
430 
431 	ret = ib_init_umem_odp(odp_data, root->per_mm);
432 	if (ret) {
433 		kfree(odp_data);
434 		return ERR_PTR(ret);
435 	}
436 	return odp_data;
437 }
438 EXPORT_SYMBOL(ib_umem_odp_alloc_child);
439 
440 /**
441  * ib_umem_odp_get - Create a umem_odp for a userspace va
442  *
443  * @udata: userspace context to pin memory for
444  * @addr: userspace virtual address to start at
445  * @size: length of region to pin
446  * @access: IB_ACCESS_xxx flags for memory being pinned
447  *
448  * The driver should use when the access flags indicate ODP memory. It avoids
449  * pinning, instead, stores the mm for future page fault handling in
450  * conjunction with MMU notifiers.
451  */
452 struct ib_umem_odp *ib_umem_odp_get(struct ib_udata *udata, unsigned long addr,
453 				    size_t size, int access)
454 {
455 	struct ib_umem_odp *umem_odp;
456 	struct ib_ucontext *context;
457 	struct mm_struct *mm;
458 	int ret;
459 
460 	if (!udata)
461 		return ERR_PTR(-EIO);
462 
463 	context = container_of(udata, struct uverbs_attr_bundle, driver_udata)
464 			  ->context;
465 	if (!context)
466 		return ERR_PTR(-EIO);
467 
468 	if (WARN_ON_ONCE(!(access & IB_ACCESS_ON_DEMAND)) ||
469 	    WARN_ON_ONCE(!context->invalidate_range))
470 		return ERR_PTR(-EINVAL);
471 
472 	umem_odp = kzalloc(sizeof(struct ib_umem_odp), GFP_KERNEL);
473 	if (!umem_odp)
474 		return ERR_PTR(-ENOMEM);
475 
476 	umem_odp->umem.context = context;
477 	umem_odp->umem.length = size;
478 	umem_odp->umem.address = addr;
479 	umem_odp->umem.writable = ib_access_writable(access);
480 	umem_odp->umem.owning_mm = mm = current->mm;
481 
482 	umem_odp->page_shift = PAGE_SHIFT;
483 	if (access & IB_ACCESS_HUGETLB) {
484 		struct vm_area_struct *vma;
485 		struct hstate *h;
486 
487 		down_read(&mm->mmap_sem);
488 		vma = find_vma(mm, ib_umem_start(umem_odp));
489 		if (!vma || !is_vm_hugetlb_page(vma)) {
490 			up_read(&mm->mmap_sem);
491 			ret = -EINVAL;
492 			goto err_free;
493 		}
494 		h = hstate_vma(vma);
495 		umem_odp->page_shift = huge_page_shift(h);
496 		up_read(&mm->mmap_sem);
497 	}
498 
499 	ret = ib_init_umem_odp(umem_odp, NULL);
500 	if (ret)
501 		goto err_free;
502 	return umem_odp;
503 
504 err_free:
505 	kfree(umem_odp);
506 	return ERR_PTR(ret);
507 }
508 EXPORT_SYMBOL(ib_umem_odp_get);
509 
510 void ib_umem_odp_release(struct ib_umem_odp *umem_odp)
511 {
512 	/*
513 	 * Ensure that no more pages are mapped in the umem.
514 	 *
515 	 * It is the driver's responsibility to ensure, before calling us,
516 	 * that the hardware will not attempt to access the MR any more.
517 	 */
518 	if (!umem_odp->is_implicit_odp) {
519 		ib_umem_odp_unmap_dma_pages(umem_odp, ib_umem_start(umem_odp),
520 					    ib_umem_end(umem_odp));
521 		remove_umem_from_per_mm(umem_odp);
522 		vfree(umem_odp->dma_list);
523 		vfree(umem_odp->page_list);
524 	}
525 	put_per_mm(umem_odp);
526 }
527 
528 /*
529  * Map for DMA and insert a single page into the on-demand paging page tables.
530  *
531  * @umem: the umem to insert the page to.
532  * @page_index: index in the umem to add the page to.
533  * @page: the page struct to map and add.
534  * @access_mask: access permissions needed for this page.
535  * @current_seq: sequence number for synchronization with invalidations.
536  *               the sequence number is taken from
537  *               umem_odp->notifiers_seq.
538  *
539  * The function returns -EFAULT if the DMA mapping operation fails. It returns
540  * -EAGAIN if a concurrent invalidation prevents us from updating the page.
541  *
542  * The page is released via put_user_page even if the operation failed. For
543  * on-demand pinning, the page is released whenever it isn't stored in the
544  * umem.
545  */
546 static int ib_umem_odp_map_dma_single_page(
547 		struct ib_umem_odp *umem_odp,
548 		int page_index,
549 		struct page *page,
550 		u64 access_mask,
551 		unsigned long current_seq)
552 {
553 	struct ib_ucontext *context = umem_odp->umem.context;
554 	struct ib_device *dev = context->device;
555 	dma_addr_t dma_addr;
556 	int remove_existing_mapping = 0;
557 	int ret = 0;
558 
559 	/*
560 	 * Note: we avoid writing if seq is different from the initial seq, to
561 	 * handle case of a racing notifier. This check also allows us to bail
562 	 * early if we have a notifier running in parallel with us.
563 	 */
564 	if (ib_umem_mmu_notifier_retry(umem_odp, current_seq)) {
565 		ret = -EAGAIN;
566 		goto out;
567 	}
568 	if (!(umem_odp->dma_list[page_index])) {
569 		dma_addr =
570 			ib_dma_map_page(dev, page, 0, BIT(umem_odp->page_shift),
571 					DMA_BIDIRECTIONAL);
572 		if (ib_dma_mapping_error(dev, dma_addr)) {
573 			ret = -EFAULT;
574 			goto out;
575 		}
576 		umem_odp->dma_list[page_index] = dma_addr | access_mask;
577 		umem_odp->page_list[page_index] = page;
578 		umem_odp->npages++;
579 	} else if (umem_odp->page_list[page_index] == page) {
580 		umem_odp->dma_list[page_index] |= access_mask;
581 	} else {
582 		pr_err("error: got different pages in IB device and from get_user_pages. IB device page: %p, gup page: %p\n",
583 		       umem_odp->page_list[page_index], page);
584 		/* Better remove the mapping now, to prevent any further
585 		 * damage. */
586 		remove_existing_mapping = 1;
587 	}
588 
589 out:
590 	put_user_page(page);
591 
592 	if (remove_existing_mapping) {
593 		ib_umem_notifier_start_account(umem_odp);
594 		context->invalidate_range(
595 			umem_odp,
596 			ib_umem_start(umem_odp) +
597 				(page_index << umem_odp->page_shift),
598 			ib_umem_start(umem_odp) +
599 				((page_index + 1) << umem_odp->page_shift));
600 		ib_umem_notifier_end_account(umem_odp);
601 		ret = -EAGAIN;
602 	}
603 
604 	return ret;
605 }
606 
607 /**
608  * ib_umem_odp_map_dma_pages - Pin and DMA map userspace memory in an ODP MR.
609  *
610  * Pins the range of pages passed in the argument, and maps them to
611  * DMA addresses. The DMA addresses of the mapped pages is updated in
612  * umem_odp->dma_list.
613  *
614  * Returns the number of pages mapped in success, negative error code
615  * for failure.
616  * An -EAGAIN error code is returned when a concurrent mmu notifier prevents
617  * the function from completing its task.
618  * An -ENOENT error code indicates that userspace process is being terminated
619  * and mm was already destroyed.
620  * @umem_odp: the umem to map and pin
621  * @user_virt: the address from which we need to map.
622  * @bcnt: the minimal number of bytes to pin and map. The mapping might be
623  *        bigger due to alignment, and may also be smaller in case of an error
624  *        pinning or mapping a page. The actual pages mapped is returned in
625  *        the return value.
626  * @access_mask: bit mask of the requested access permissions for the given
627  *               range.
628  * @current_seq: the MMU notifiers sequance value for synchronization with
629  *               invalidations. the sequance number is read from
630  *               umem_odp->notifiers_seq before calling this function
631  */
632 int ib_umem_odp_map_dma_pages(struct ib_umem_odp *umem_odp, u64 user_virt,
633 			      u64 bcnt, u64 access_mask,
634 			      unsigned long current_seq)
635 {
636 	struct task_struct *owning_process  = NULL;
637 	struct mm_struct *owning_mm = umem_odp->umem.owning_mm;
638 	struct page       **local_page_list = NULL;
639 	u64 page_mask, off;
640 	int j, k, ret = 0, start_idx, npages = 0;
641 	unsigned int flags = 0, page_shift;
642 	phys_addr_t p = 0;
643 
644 	if (access_mask == 0)
645 		return -EINVAL;
646 
647 	if (user_virt < ib_umem_start(umem_odp) ||
648 	    user_virt + bcnt > ib_umem_end(umem_odp))
649 		return -EFAULT;
650 
651 	local_page_list = (struct page **)__get_free_page(GFP_KERNEL);
652 	if (!local_page_list)
653 		return -ENOMEM;
654 
655 	page_shift = umem_odp->page_shift;
656 	page_mask = ~(BIT(page_shift) - 1);
657 	off = user_virt & (~page_mask);
658 	user_virt = user_virt & page_mask;
659 	bcnt += off; /* Charge for the first page offset as well. */
660 
661 	/*
662 	 * owning_process is allowed to be NULL, this means somehow the mm is
663 	 * existing beyond the lifetime of the originating process.. Presumably
664 	 * mmget_not_zero will fail in this case.
665 	 */
666 	owning_process = get_pid_task(umem_odp->per_mm->tgid, PIDTYPE_PID);
667 	if (!owning_process || !mmget_not_zero(owning_mm)) {
668 		ret = -EINVAL;
669 		goto out_put_task;
670 	}
671 
672 	if (access_mask & ODP_WRITE_ALLOWED_BIT)
673 		flags |= FOLL_WRITE;
674 
675 	start_idx = (user_virt - ib_umem_start(umem_odp)) >> page_shift;
676 	k = start_idx;
677 
678 	while (bcnt > 0) {
679 		const size_t gup_num_pages = min_t(size_t,
680 				(bcnt + BIT(page_shift) - 1) >> page_shift,
681 				PAGE_SIZE / sizeof(struct page *));
682 
683 		down_read(&owning_mm->mmap_sem);
684 		/*
685 		 * Note: this might result in redundent page getting. We can
686 		 * avoid this by checking dma_list to be 0 before calling
687 		 * get_user_pages. However, this make the code much more
688 		 * complex (and doesn't gain us much performance in most use
689 		 * cases).
690 		 */
691 		npages = get_user_pages_remote(owning_process, owning_mm,
692 				user_virt, gup_num_pages,
693 				flags, local_page_list, NULL, NULL);
694 		up_read(&owning_mm->mmap_sem);
695 
696 		if (npages < 0) {
697 			if (npages != -EAGAIN)
698 				pr_warn("fail to get %zu user pages with error %d\n", gup_num_pages, npages);
699 			else
700 				pr_debug("fail to get %zu user pages with error %d\n", gup_num_pages, npages);
701 			break;
702 		}
703 
704 		bcnt -= min_t(size_t, npages << PAGE_SHIFT, bcnt);
705 		mutex_lock(&umem_odp->umem_mutex);
706 		for (j = 0; j < npages; j++, user_virt += PAGE_SIZE) {
707 			if (user_virt & ~page_mask) {
708 				p += PAGE_SIZE;
709 				if (page_to_phys(local_page_list[j]) != p) {
710 					ret = -EFAULT;
711 					break;
712 				}
713 				put_user_page(local_page_list[j]);
714 				continue;
715 			}
716 
717 			ret = ib_umem_odp_map_dma_single_page(
718 					umem_odp, k, local_page_list[j],
719 					access_mask, current_seq);
720 			if (ret < 0) {
721 				if (ret != -EAGAIN)
722 					pr_warn("ib_umem_odp_map_dma_single_page failed with error %d\n", ret);
723 				else
724 					pr_debug("ib_umem_odp_map_dma_single_page failed with error %d\n", ret);
725 				break;
726 			}
727 
728 			p = page_to_phys(local_page_list[j]);
729 			k++;
730 		}
731 		mutex_unlock(&umem_odp->umem_mutex);
732 
733 		if (ret < 0) {
734 			/*
735 			 * Release pages, remembering that the first page
736 			 * to hit an error was already released by
737 			 * ib_umem_odp_map_dma_single_page().
738 			 */
739 			if (npages - (j + 1) > 0)
740 				put_user_pages(&local_page_list[j+1],
741 					       npages - (j + 1));
742 			break;
743 		}
744 	}
745 
746 	if (ret >= 0) {
747 		if (npages < 0 && k == start_idx)
748 			ret = npages;
749 		else
750 			ret = k - start_idx;
751 	}
752 
753 	mmput(owning_mm);
754 out_put_task:
755 	if (owning_process)
756 		put_task_struct(owning_process);
757 	free_page((unsigned long)local_page_list);
758 	return ret;
759 }
760 EXPORT_SYMBOL(ib_umem_odp_map_dma_pages);
761 
762 void ib_umem_odp_unmap_dma_pages(struct ib_umem_odp *umem_odp, u64 virt,
763 				 u64 bound)
764 {
765 	int idx;
766 	u64 addr;
767 	struct ib_device *dev = umem_odp->umem.context->device;
768 
769 	virt = max_t(u64, virt, ib_umem_start(umem_odp));
770 	bound = min_t(u64, bound, ib_umem_end(umem_odp));
771 	/* Note that during the run of this function, the
772 	 * notifiers_count of the MR is > 0, preventing any racing
773 	 * faults from completion. We might be racing with other
774 	 * invalidations, so we must make sure we free each page only
775 	 * once. */
776 	mutex_lock(&umem_odp->umem_mutex);
777 	for (addr = virt; addr < bound; addr += BIT(umem_odp->page_shift)) {
778 		idx = (addr - ib_umem_start(umem_odp)) >> umem_odp->page_shift;
779 		if (umem_odp->page_list[idx]) {
780 			struct page *page = umem_odp->page_list[idx];
781 			dma_addr_t dma = umem_odp->dma_list[idx];
782 			dma_addr_t dma_addr = dma & ODP_DMA_ADDR_MASK;
783 
784 			WARN_ON(!dma_addr);
785 
786 			ib_dma_unmap_page(dev, dma_addr,
787 					  BIT(umem_odp->page_shift),
788 					  DMA_BIDIRECTIONAL);
789 			if (dma & ODP_WRITE_ALLOWED_BIT) {
790 				struct page *head_page = compound_head(page);
791 				/*
792 				 * set_page_dirty prefers being called with
793 				 * the page lock. However, MMU notifiers are
794 				 * called sometimes with and sometimes without
795 				 * the lock. We rely on the umem_mutex instead
796 				 * to prevent other mmu notifiers from
797 				 * continuing and allowing the page mapping to
798 				 * be removed.
799 				 */
800 				set_page_dirty(head_page);
801 			}
802 			umem_odp->page_list[idx] = NULL;
803 			umem_odp->dma_list[idx] = 0;
804 			umem_odp->npages--;
805 		}
806 	}
807 	mutex_unlock(&umem_odp->umem_mutex);
808 }
809 EXPORT_SYMBOL(ib_umem_odp_unmap_dma_pages);
810 
811 /* @last is not a part of the interval. See comment for function
812  * node_last.
813  */
814 int rbt_ib_umem_for_each_in_range(struct rb_root_cached *root,
815 				  u64 start, u64 last,
816 				  umem_call_back cb,
817 				  bool blockable,
818 				  void *cookie)
819 {
820 	int ret_val = 0;
821 	struct interval_tree_node *node, *next;
822 	struct ib_umem_odp *umem;
823 
824 	if (unlikely(start == last))
825 		return ret_val;
826 
827 	for (node = interval_tree_iter_first(root, start, last - 1);
828 			node; node = next) {
829 		/* TODO move the blockable decision up to the callback */
830 		if (!blockable)
831 			return -EAGAIN;
832 		next = interval_tree_iter_next(node, start, last - 1);
833 		umem = container_of(node, struct ib_umem_odp, interval_tree);
834 		ret_val = cb(umem, start, last, cookie) || ret_val;
835 	}
836 
837 	return ret_val;
838 }
839