xref: /openbmc/linux/drivers/iommu/iommu-sva.c (revision 37002bc6)
1 // SPDX-License-Identifier: GPL-2.0
2 /*
3  * Helpers for IOMMU drivers implementing SVA
4  */
5 #include <linux/mmu_context.h>
6 #include <linux/mutex.h>
7 #include <linux/sched/mm.h>
8 #include <linux/iommu.h>
9 
10 #include "iommu-sva.h"
11 
12 static DEFINE_MUTEX(iommu_sva_lock);
13 static DEFINE_IDA(iommu_global_pasid_ida);
14 
15 /* Allocate a PASID for the mm within range (inclusive) */
16 static int iommu_sva_alloc_pasid(struct mm_struct *mm, ioasid_t min, ioasid_t max)
17 {
18 	int ret = 0;
19 
20 	if (min == IOMMU_PASID_INVALID ||
21 	    max == IOMMU_PASID_INVALID ||
22 	    min == 0 || max < min)
23 		return -EINVAL;
24 
25 	if (!arch_pgtable_dma_compat(mm))
26 		return -EBUSY;
27 
28 	mutex_lock(&iommu_sva_lock);
29 	/* Is a PASID already associated with this mm? */
30 	if (mm_valid_pasid(mm)) {
31 		if (mm->pasid < min || mm->pasid > max)
32 			ret = -EOVERFLOW;
33 		goto out;
34 	}
35 
36 	ret = ida_alloc_range(&iommu_global_pasid_ida, min, max, GFP_KERNEL);
37 	if (ret < 0)
38 		goto out;
39 
40 	mm->pasid = ret;
41 	ret = 0;
42 out:
43 	mutex_unlock(&iommu_sva_lock);
44 	return ret;
45 }
46 
47 /**
48  * iommu_sva_bind_device() - Bind a process address space to a device
49  * @dev: the device
50  * @mm: the mm to bind, caller must hold a reference to mm_users
51  *
52  * Create a bond between device and address space, allowing the device to
53  * access the mm using the PASID returned by iommu_sva_get_pasid(). If a
54  * bond already exists between @device and @mm, an additional internal
55  * reference is taken. Caller must call iommu_sva_unbind_device()
56  * to release each reference.
57  *
58  * iommu_dev_enable_feature(dev, IOMMU_DEV_FEAT_SVA) must be called first, to
59  * initialize the required SVA features.
60  *
61  * On error, returns an ERR_PTR value.
62  */
63 struct iommu_sva *iommu_sva_bind_device(struct device *dev, struct mm_struct *mm)
64 {
65 	struct iommu_domain *domain;
66 	struct iommu_sva *handle;
67 	ioasid_t max_pasids;
68 	int ret;
69 
70 	max_pasids = dev->iommu->max_pasids;
71 	if (!max_pasids)
72 		return ERR_PTR(-EOPNOTSUPP);
73 
74 	/* Allocate mm->pasid if necessary. */
75 	ret = iommu_sva_alloc_pasid(mm, 1, max_pasids - 1);
76 	if (ret)
77 		return ERR_PTR(ret);
78 
79 	handle = kzalloc(sizeof(*handle), GFP_KERNEL);
80 	if (!handle)
81 		return ERR_PTR(-ENOMEM);
82 
83 	mutex_lock(&iommu_sva_lock);
84 	/* Search for an existing domain. */
85 	domain = iommu_get_domain_for_dev_pasid(dev, mm->pasid,
86 						IOMMU_DOMAIN_SVA);
87 	if (IS_ERR(domain)) {
88 		ret = PTR_ERR(domain);
89 		goto out_unlock;
90 	}
91 
92 	if (domain) {
93 		domain->users++;
94 		goto out;
95 	}
96 
97 	/* Allocate a new domain and set it on device pasid. */
98 	domain = iommu_sva_domain_alloc(dev, mm);
99 	if (!domain) {
100 		ret = -ENOMEM;
101 		goto out_unlock;
102 	}
103 
104 	ret = iommu_attach_device_pasid(domain, dev, mm->pasid);
105 	if (ret)
106 		goto out_free_domain;
107 	domain->users = 1;
108 out:
109 	mutex_unlock(&iommu_sva_lock);
110 	handle->dev = dev;
111 	handle->domain = domain;
112 
113 	return handle;
114 
115 out_free_domain:
116 	iommu_domain_free(domain);
117 out_unlock:
118 	mutex_unlock(&iommu_sva_lock);
119 	kfree(handle);
120 
121 	return ERR_PTR(ret);
122 }
123 EXPORT_SYMBOL_GPL(iommu_sva_bind_device);
124 
125 /**
126  * iommu_sva_unbind_device() - Remove a bond created with iommu_sva_bind_device
127  * @handle: the handle returned by iommu_sva_bind_device()
128  *
129  * Put reference to a bond between device and address space. The device should
130  * not be issuing any more transaction for this PASID. All outstanding page
131  * requests for this PASID must have been flushed to the IOMMU.
132  */
133 void iommu_sva_unbind_device(struct iommu_sva *handle)
134 {
135 	struct iommu_domain *domain = handle->domain;
136 	ioasid_t pasid = domain->mm->pasid;
137 	struct device *dev = handle->dev;
138 
139 	mutex_lock(&iommu_sva_lock);
140 	if (--domain->users == 0) {
141 		iommu_detach_device_pasid(domain, dev, pasid);
142 		iommu_domain_free(domain);
143 	}
144 	mutex_unlock(&iommu_sva_lock);
145 	kfree(handle);
146 }
147 EXPORT_SYMBOL_GPL(iommu_sva_unbind_device);
148 
149 u32 iommu_sva_get_pasid(struct iommu_sva *handle)
150 {
151 	struct iommu_domain *domain = handle->domain;
152 
153 	return domain->mm->pasid;
154 }
155 EXPORT_SYMBOL_GPL(iommu_sva_get_pasid);
156 
157 /*
158  * I/O page fault handler for SVA
159  */
160 enum iommu_page_response_code
161 iommu_sva_handle_iopf(struct iommu_fault *fault, void *data)
162 {
163 	vm_fault_t ret;
164 	struct vm_area_struct *vma;
165 	struct mm_struct *mm = data;
166 	unsigned int access_flags = 0;
167 	unsigned int fault_flags = FAULT_FLAG_REMOTE;
168 	struct iommu_fault_page_request *prm = &fault->prm;
169 	enum iommu_page_response_code status = IOMMU_PAGE_RESP_INVALID;
170 
171 	if (!(prm->flags & IOMMU_FAULT_PAGE_REQUEST_PASID_VALID))
172 		return status;
173 
174 	if (!mmget_not_zero(mm))
175 		return status;
176 
177 	mmap_read_lock(mm);
178 
179 	vma = vma_lookup(mm, prm->addr);
180 	if (!vma)
181 		/* Unmapped area */
182 		goto out_put_mm;
183 
184 	if (prm->perm & IOMMU_FAULT_PERM_READ)
185 		access_flags |= VM_READ;
186 
187 	if (prm->perm & IOMMU_FAULT_PERM_WRITE) {
188 		access_flags |= VM_WRITE;
189 		fault_flags |= FAULT_FLAG_WRITE;
190 	}
191 
192 	if (prm->perm & IOMMU_FAULT_PERM_EXEC) {
193 		access_flags |= VM_EXEC;
194 		fault_flags |= FAULT_FLAG_INSTRUCTION;
195 	}
196 
197 	if (!(prm->perm & IOMMU_FAULT_PERM_PRIV))
198 		fault_flags |= FAULT_FLAG_USER;
199 
200 	if (access_flags & ~vma->vm_flags)
201 		/* Access fault */
202 		goto out_put_mm;
203 
204 	ret = handle_mm_fault(vma, prm->addr, fault_flags, NULL);
205 	status = ret & VM_FAULT_ERROR ? IOMMU_PAGE_RESP_INVALID :
206 		IOMMU_PAGE_RESP_SUCCESS;
207 
208 out_put_mm:
209 	mmap_read_unlock(mm);
210 	mmput(mm);
211 
212 	return status;
213 }
214 
215 void mm_pasid_drop(struct mm_struct *mm)
216 {
217 	if (likely(!mm_valid_pasid(mm)))
218 		return;
219 
220 	ida_free(&iommu_global_pasid_ida, mm->pasid);
221 }
222