xref: /openbmc/linux/drivers/iommu/iommu-sva.c (revision 2dcebc7d)
1 // SPDX-License-Identifier: GPL-2.0
2 /*
3  * Helpers for IOMMU drivers implementing SVA
4  */
5 #include <linux/mmu_context.h>
6 #include <linux/mutex.h>
7 #include <linux/sched/mm.h>
8 #include <linux/iommu.h>
9 
10 #include "iommu-sva.h"
11 
12 static DEFINE_MUTEX(iommu_sva_lock);
13 
14 /* Allocate a PASID for the mm within range (inclusive) */
iommu_sva_alloc_pasid(struct mm_struct * mm,struct device * dev)15 static int iommu_sva_alloc_pasid(struct mm_struct *mm, struct device *dev)
16 {
17 	ioasid_t pasid;
18 	int ret = 0;
19 
20 	if (!arch_pgtable_dma_compat(mm))
21 		return -EBUSY;
22 
23 	mutex_lock(&iommu_sva_lock);
24 	/* Is a PASID already associated with this mm? */
25 	if (mm_valid_pasid(mm)) {
26 		if (mm->pasid >= dev->iommu->max_pasids)
27 			ret = -EOVERFLOW;
28 		goto out;
29 	}
30 
31 	pasid = iommu_alloc_global_pasid(dev);
32 	if (pasid == IOMMU_PASID_INVALID) {
33 		ret = -ENOSPC;
34 		goto out;
35 	}
36 	mm->pasid = pasid;
37 	ret = 0;
38 out:
39 	mutex_unlock(&iommu_sva_lock);
40 	return ret;
41 }
42 
43 /**
44  * iommu_sva_bind_device() - Bind a process address space to a device
45  * @dev: the device
46  * @mm: the mm to bind, caller must hold a reference to mm_users
47  *
48  * Create a bond between device and address space, allowing the device to
49  * access the mm using the PASID returned by iommu_sva_get_pasid(). If a
50  * bond already exists between @device and @mm, an additional internal
51  * reference is taken. Caller must call iommu_sva_unbind_device()
52  * to release each reference.
53  *
54  * iommu_dev_enable_feature(dev, IOMMU_DEV_FEAT_SVA) must be called first, to
55  * initialize the required SVA features.
56  *
57  * On error, returns an ERR_PTR value.
58  */
iommu_sva_bind_device(struct device * dev,struct mm_struct * mm)59 struct iommu_sva *iommu_sva_bind_device(struct device *dev, struct mm_struct *mm)
60 {
61 	struct iommu_domain *domain;
62 	struct iommu_sva *handle;
63 	int ret;
64 
65 	/* Allocate mm->pasid if necessary. */
66 	ret = iommu_sva_alloc_pasid(mm, dev);
67 	if (ret)
68 		return ERR_PTR(ret);
69 
70 	handle = kzalloc(sizeof(*handle), GFP_KERNEL);
71 	if (!handle)
72 		return ERR_PTR(-ENOMEM);
73 
74 	mutex_lock(&iommu_sva_lock);
75 	/* Search for an existing domain. */
76 	domain = iommu_get_domain_for_dev_pasid(dev, mm->pasid,
77 						IOMMU_DOMAIN_SVA);
78 	if (IS_ERR(domain)) {
79 		ret = PTR_ERR(domain);
80 		goto out_unlock;
81 	}
82 
83 	if (domain) {
84 		domain->users++;
85 		goto out;
86 	}
87 
88 	/* Allocate a new domain and set it on device pasid. */
89 	domain = iommu_sva_domain_alloc(dev, mm);
90 	if (!domain) {
91 		ret = -ENOMEM;
92 		goto out_unlock;
93 	}
94 
95 	ret = iommu_attach_device_pasid(domain, dev, mm->pasid);
96 	if (ret)
97 		goto out_free_domain;
98 	domain->users = 1;
99 out:
100 	mutex_unlock(&iommu_sva_lock);
101 	handle->dev = dev;
102 	handle->domain = domain;
103 
104 	return handle;
105 
106 out_free_domain:
107 	iommu_domain_free(domain);
108 out_unlock:
109 	mutex_unlock(&iommu_sva_lock);
110 	kfree(handle);
111 
112 	return ERR_PTR(ret);
113 }
114 EXPORT_SYMBOL_GPL(iommu_sva_bind_device);
115 
116 /**
117  * iommu_sva_unbind_device() - Remove a bond created with iommu_sva_bind_device
118  * @handle: the handle returned by iommu_sva_bind_device()
119  *
120  * Put reference to a bond between device and address space. The device should
121  * not be issuing any more transaction for this PASID. All outstanding page
122  * requests for this PASID must have been flushed to the IOMMU.
123  */
iommu_sva_unbind_device(struct iommu_sva * handle)124 void iommu_sva_unbind_device(struct iommu_sva *handle)
125 {
126 	struct iommu_domain *domain = handle->domain;
127 	ioasid_t pasid = domain->mm->pasid;
128 	struct device *dev = handle->dev;
129 
130 	mutex_lock(&iommu_sva_lock);
131 	if (--domain->users == 0) {
132 		iommu_detach_device_pasid(domain, dev, pasid);
133 		iommu_domain_free(domain);
134 	}
135 	mutex_unlock(&iommu_sva_lock);
136 	kfree(handle);
137 }
138 EXPORT_SYMBOL_GPL(iommu_sva_unbind_device);
139 
iommu_sva_get_pasid(struct iommu_sva * handle)140 u32 iommu_sva_get_pasid(struct iommu_sva *handle)
141 {
142 	struct iommu_domain *domain = handle->domain;
143 
144 	return domain->mm->pasid;
145 }
146 EXPORT_SYMBOL_GPL(iommu_sva_get_pasid);
147 
148 /*
149  * I/O page fault handler for SVA
150  */
151 enum iommu_page_response_code
iommu_sva_handle_iopf(struct iommu_fault * fault,void * data)152 iommu_sva_handle_iopf(struct iommu_fault *fault, void *data)
153 {
154 	vm_fault_t ret;
155 	struct vm_area_struct *vma;
156 	struct mm_struct *mm = data;
157 	unsigned int access_flags = 0;
158 	unsigned int fault_flags = FAULT_FLAG_REMOTE;
159 	struct iommu_fault_page_request *prm = &fault->prm;
160 	enum iommu_page_response_code status = IOMMU_PAGE_RESP_INVALID;
161 
162 	if (!(prm->flags & IOMMU_FAULT_PAGE_REQUEST_PASID_VALID))
163 		return status;
164 
165 	if (!mmget_not_zero(mm))
166 		return status;
167 
168 	mmap_read_lock(mm);
169 
170 	vma = vma_lookup(mm, prm->addr);
171 	if (!vma)
172 		/* Unmapped area */
173 		goto out_put_mm;
174 
175 	if (prm->perm & IOMMU_FAULT_PERM_READ)
176 		access_flags |= VM_READ;
177 
178 	if (prm->perm & IOMMU_FAULT_PERM_WRITE) {
179 		access_flags |= VM_WRITE;
180 		fault_flags |= FAULT_FLAG_WRITE;
181 	}
182 
183 	if (prm->perm & IOMMU_FAULT_PERM_EXEC) {
184 		access_flags |= VM_EXEC;
185 		fault_flags |= FAULT_FLAG_INSTRUCTION;
186 	}
187 
188 	if (!(prm->perm & IOMMU_FAULT_PERM_PRIV))
189 		fault_flags |= FAULT_FLAG_USER;
190 
191 	if (access_flags & ~vma->vm_flags)
192 		/* Access fault */
193 		goto out_put_mm;
194 
195 	ret = handle_mm_fault(vma, prm->addr, fault_flags, NULL);
196 	status = ret & VM_FAULT_ERROR ? IOMMU_PAGE_RESP_INVALID :
197 		IOMMU_PAGE_RESP_SUCCESS;
198 
199 out_put_mm:
200 	mmap_read_unlock(mm);
201 	mmput(mm);
202 
203 	return status;
204 }
205 
mm_pasid_drop(struct mm_struct * mm)206 void mm_pasid_drop(struct mm_struct *mm)
207 {
208 	if (likely(!mm_valid_pasid(mm)))
209 		return;
210 
211 	iommu_free_global_pasid(mm->pasid);
212 }
213