xref: /openbmc/linux/drivers/iommu/iommu-sva.c (revision 2a954832)
1 // SPDX-License-Identifier: GPL-2.0
2 /*
3  * Helpers for IOMMU drivers implementing SVA
4  */
5 #include <linux/mmu_context.h>
6 #include <linux/mutex.h>
7 #include <linux/sched/mm.h>
8 #include <linux/iommu.h>
9 
10 #include "iommu-sva.h"
11 
12 static DEFINE_MUTEX(iommu_sva_lock);
13 static DEFINE_IDA(iommu_global_pasid_ida);
14 
15 /* Allocate a PASID for the mm within range (inclusive) */
16 static int iommu_sva_alloc_pasid(struct mm_struct *mm, ioasid_t min, ioasid_t max)
17 {
18 	int ret = 0;
19 
20 	if (min == IOMMU_PASID_INVALID ||
21 	    max == IOMMU_PASID_INVALID ||
22 	    min == 0 || max < min)
23 		return -EINVAL;
24 
25 	if (!arch_pgtable_dma_compat(mm))
26 		return -EBUSY;
27 
28 	mutex_lock(&iommu_sva_lock);
29 	/* Is a PASID already associated with this mm? */
30 	if (mm_valid_pasid(mm)) {
31 		if (mm->pasid < min || mm->pasid > max)
32 			ret = -EOVERFLOW;
33 		goto out;
34 	}
35 
36 	ret = ida_alloc_range(&iommu_global_pasid_ida, min, max, GFP_KERNEL);
37 	if (ret < min)
38 		goto out;
39 	mm->pasid = ret;
40 	ret = 0;
41 out:
42 	mutex_unlock(&iommu_sva_lock);
43 	return ret;
44 }
45 
46 /**
47  * iommu_sva_bind_device() - Bind a process address space to a device
48  * @dev: the device
49  * @mm: the mm to bind, caller must hold a reference to mm_users
50  *
51  * Create a bond between device and address space, allowing the device to
52  * access the mm using the PASID returned by iommu_sva_get_pasid(). If a
53  * bond already exists between @device and @mm, an additional internal
54  * reference is taken. Caller must call iommu_sva_unbind_device()
55  * to release each reference.
56  *
57  * iommu_dev_enable_feature(dev, IOMMU_DEV_FEAT_SVA) must be called first, to
58  * initialize the required SVA features.
59  *
60  * On error, returns an ERR_PTR value.
61  */
62 struct iommu_sva *iommu_sva_bind_device(struct device *dev, struct mm_struct *mm)
63 {
64 	struct iommu_domain *domain;
65 	struct iommu_sva *handle;
66 	ioasid_t max_pasids;
67 	int ret;
68 
69 	max_pasids = dev->iommu->max_pasids;
70 	if (!max_pasids)
71 		return ERR_PTR(-EOPNOTSUPP);
72 
73 	/* Allocate mm->pasid if necessary. */
74 	ret = iommu_sva_alloc_pasid(mm, 1, max_pasids - 1);
75 	if (ret)
76 		return ERR_PTR(ret);
77 
78 	handle = kzalloc(sizeof(*handle), GFP_KERNEL);
79 	if (!handle)
80 		return ERR_PTR(-ENOMEM);
81 
82 	mutex_lock(&iommu_sva_lock);
83 	/* Search for an existing domain. */
84 	domain = iommu_get_domain_for_dev_pasid(dev, mm->pasid,
85 						IOMMU_DOMAIN_SVA);
86 	if (IS_ERR(domain)) {
87 		ret = PTR_ERR(domain);
88 		goto out_unlock;
89 	}
90 
91 	if (domain) {
92 		domain->users++;
93 		goto out;
94 	}
95 
96 	/* Allocate a new domain and set it on device pasid. */
97 	domain = iommu_sva_domain_alloc(dev, mm);
98 	if (!domain) {
99 		ret = -ENOMEM;
100 		goto out_unlock;
101 	}
102 
103 	ret = iommu_attach_device_pasid(domain, dev, mm->pasid);
104 	if (ret)
105 		goto out_free_domain;
106 	domain->users = 1;
107 out:
108 	mutex_unlock(&iommu_sva_lock);
109 	handle->dev = dev;
110 	handle->domain = domain;
111 
112 	return handle;
113 
114 out_free_domain:
115 	iommu_domain_free(domain);
116 out_unlock:
117 	mutex_unlock(&iommu_sva_lock);
118 	kfree(handle);
119 
120 	return ERR_PTR(ret);
121 }
122 EXPORT_SYMBOL_GPL(iommu_sva_bind_device);
123 
124 /**
125  * iommu_sva_unbind_device() - Remove a bond created with iommu_sva_bind_device
126  * @handle: the handle returned by iommu_sva_bind_device()
127  *
128  * Put reference to a bond between device and address space. The device should
129  * not be issuing any more transaction for this PASID. All outstanding page
130  * requests for this PASID must have been flushed to the IOMMU.
131  */
132 void iommu_sva_unbind_device(struct iommu_sva *handle)
133 {
134 	struct iommu_domain *domain = handle->domain;
135 	ioasid_t pasid = domain->mm->pasid;
136 	struct device *dev = handle->dev;
137 
138 	mutex_lock(&iommu_sva_lock);
139 	if (--domain->users == 0) {
140 		iommu_detach_device_pasid(domain, dev, pasid);
141 		iommu_domain_free(domain);
142 	}
143 	mutex_unlock(&iommu_sva_lock);
144 	kfree(handle);
145 }
146 EXPORT_SYMBOL_GPL(iommu_sva_unbind_device);
147 
148 u32 iommu_sva_get_pasid(struct iommu_sva *handle)
149 {
150 	struct iommu_domain *domain = handle->domain;
151 
152 	return domain->mm->pasid;
153 }
154 EXPORT_SYMBOL_GPL(iommu_sva_get_pasid);
155 
156 /*
157  * I/O page fault handler for SVA
158  */
159 enum iommu_page_response_code
160 iommu_sva_handle_iopf(struct iommu_fault *fault, void *data)
161 {
162 	vm_fault_t ret;
163 	struct vm_area_struct *vma;
164 	struct mm_struct *mm = data;
165 	unsigned int access_flags = 0;
166 	unsigned int fault_flags = FAULT_FLAG_REMOTE;
167 	struct iommu_fault_page_request *prm = &fault->prm;
168 	enum iommu_page_response_code status = IOMMU_PAGE_RESP_INVALID;
169 
170 	if (!(prm->flags & IOMMU_FAULT_PAGE_REQUEST_PASID_VALID))
171 		return status;
172 
173 	if (!mmget_not_zero(mm))
174 		return status;
175 
176 	mmap_read_lock(mm);
177 
178 	vma = find_extend_vma(mm, prm->addr);
179 	if (!vma)
180 		/* Unmapped area */
181 		goto out_put_mm;
182 
183 	if (prm->perm & IOMMU_FAULT_PERM_READ)
184 		access_flags |= VM_READ;
185 
186 	if (prm->perm & IOMMU_FAULT_PERM_WRITE) {
187 		access_flags |= VM_WRITE;
188 		fault_flags |= FAULT_FLAG_WRITE;
189 	}
190 
191 	if (prm->perm & IOMMU_FAULT_PERM_EXEC) {
192 		access_flags |= VM_EXEC;
193 		fault_flags |= FAULT_FLAG_INSTRUCTION;
194 	}
195 
196 	if (!(prm->perm & IOMMU_FAULT_PERM_PRIV))
197 		fault_flags |= FAULT_FLAG_USER;
198 
199 	if (access_flags & ~vma->vm_flags)
200 		/* Access fault */
201 		goto out_put_mm;
202 
203 	ret = handle_mm_fault(vma, prm->addr, fault_flags, NULL);
204 	status = ret & VM_FAULT_ERROR ? IOMMU_PAGE_RESP_INVALID :
205 		IOMMU_PAGE_RESP_SUCCESS;
206 
207 out_put_mm:
208 	mmap_read_unlock(mm);
209 	mmput(mm);
210 
211 	return status;
212 }
213 
214 void mm_pasid_drop(struct mm_struct *mm)
215 {
216 	if (likely(!mm_valid_pasid(mm)))
217 		return;
218 
219 	ida_free(&iommu_global_pasid_ida, mm->pasid);
220 }
221