xref: /openbmc/linux/drivers/vhost/vdpa.c (revision 657c45b3)
1 // SPDX-License-Identifier: GPL-2.0
2 /*
3  * Copyright (C) 2018-2020 Intel Corporation.
4  * Copyright (C) 2020 Red Hat, Inc.
5  *
6  * Author: Tiwei Bie <tiwei.bie@intel.com>
7  *         Jason Wang <jasowang@redhat.com>
8  *
9  * Thanks Michael S. Tsirkin for the valuable comments and
10  * suggestions.  And thanks to Cunming Liang and Zhihong Wang for all
11  * their supports.
12  */
13 
14 #include <linux/kernel.h>
15 #include <linux/module.h>
16 #include <linux/cdev.h>
17 #include <linux/device.h>
18 #include <linux/mm.h>
19 #include <linux/slab.h>
20 #include <linux/iommu.h>
21 #include <linux/uuid.h>
22 #include <linux/vdpa.h>
23 #include <linux/nospec.h>
24 #include <linux/vhost.h>
25 
26 #include "vhost.h"
27 
28 enum {
29 	VHOST_VDPA_BACKEND_FEATURES =
30 	(1ULL << VHOST_BACKEND_F_IOTLB_MSG_V2) |
31 	(1ULL << VHOST_BACKEND_F_IOTLB_BATCH) |
32 	(1ULL << VHOST_BACKEND_F_IOTLB_ASID),
33 };
34 
35 #define VHOST_VDPA_DEV_MAX (1U << MINORBITS)
36 
37 #define VHOST_VDPA_IOTLB_BUCKETS 16
38 
39 struct vhost_vdpa_as {
40 	struct hlist_node hash_link;
41 	struct vhost_iotlb iotlb;
42 	u32 id;
43 };
44 
45 struct vhost_vdpa {
46 	struct vhost_dev vdev;
47 	struct iommu_domain *domain;
48 	struct vhost_virtqueue *vqs;
49 	struct completion completion;
50 	struct vdpa_device *vdpa;
51 	struct hlist_head as[VHOST_VDPA_IOTLB_BUCKETS];
52 	struct device dev;
53 	struct cdev cdev;
54 	atomic_t opened;
55 	u32 nvqs;
56 	int virtio_id;
57 	int minor;
58 	struct eventfd_ctx *config_ctx;
59 	int in_batch;
60 	struct vdpa_iova_range range;
61 	u32 batch_asid;
62 };
63 
64 static DEFINE_IDA(vhost_vdpa_ida);
65 
66 static dev_t vhost_vdpa_major;
67 
68 static void vhost_vdpa_iotlb_unmap(struct vhost_vdpa *v,
69 				   struct vhost_iotlb *iotlb, u64 start,
70 				   u64 last, u32 asid);
71 
72 static inline u32 iotlb_to_asid(struct vhost_iotlb *iotlb)
73 {
74 	struct vhost_vdpa_as *as = container_of(iotlb, struct
75 						vhost_vdpa_as, iotlb);
76 	return as->id;
77 }
78 
79 static struct vhost_vdpa_as *asid_to_as(struct vhost_vdpa *v, u32 asid)
80 {
81 	struct hlist_head *head = &v->as[asid % VHOST_VDPA_IOTLB_BUCKETS];
82 	struct vhost_vdpa_as *as;
83 
84 	hlist_for_each_entry(as, head, hash_link)
85 		if (as->id == asid)
86 			return as;
87 
88 	return NULL;
89 }
90 
91 static struct vhost_iotlb *asid_to_iotlb(struct vhost_vdpa *v, u32 asid)
92 {
93 	struct vhost_vdpa_as *as = asid_to_as(v, asid);
94 
95 	if (!as)
96 		return NULL;
97 
98 	return &as->iotlb;
99 }
100 
101 static struct vhost_vdpa_as *vhost_vdpa_alloc_as(struct vhost_vdpa *v, u32 asid)
102 {
103 	struct hlist_head *head = &v->as[asid % VHOST_VDPA_IOTLB_BUCKETS];
104 	struct vhost_vdpa_as *as;
105 
106 	if (asid_to_as(v, asid))
107 		return NULL;
108 
109 	if (asid >= v->vdpa->nas)
110 		return NULL;
111 
112 	as = kmalloc(sizeof(*as), GFP_KERNEL);
113 	if (!as)
114 		return NULL;
115 
116 	vhost_iotlb_init(&as->iotlb, 0, 0);
117 	as->id = asid;
118 	hlist_add_head(&as->hash_link, head);
119 
120 	return as;
121 }
122 
123 static struct vhost_vdpa_as *vhost_vdpa_find_alloc_as(struct vhost_vdpa *v,
124 						      u32 asid)
125 {
126 	struct vhost_vdpa_as *as = asid_to_as(v, asid);
127 
128 	if (as)
129 		return as;
130 
131 	return vhost_vdpa_alloc_as(v, asid);
132 }
133 
134 static int vhost_vdpa_remove_as(struct vhost_vdpa *v, u32 asid)
135 {
136 	struct vhost_vdpa_as *as = asid_to_as(v, asid);
137 
138 	if (!as)
139 		return -EINVAL;
140 
141 	hlist_del(&as->hash_link);
142 	vhost_vdpa_iotlb_unmap(v, &as->iotlb, 0ULL, 0ULL - 1, asid);
143 	kfree(as);
144 
145 	return 0;
146 }
147 
148 static void handle_vq_kick(struct vhost_work *work)
149 {
150 	struct vhost_virtqueue *vq = container_of(work, struct vhost_virtqueue,
151 						  poll.work);
152 	struct vhost_vdpa *v = container_of(vq->dev, struct vhost_vdpa, vdev);
153 	const struct vdpa_config_ops *ops = v->vdpa->config;
154 
155 	ops->kick_vq(v->vdpa, vq - v->vqs);
156 }
157 
158 static irqreturn_t vhost_vdpa_virtqueue_cb(void *private)
159 {
160 	struct vhost_virtqueue *vq = private;
161 	struct eventfd_ctx *call_ctx = vq->call_ctx.ctx;
162 
163 	if (call_ctx)
164 		eventfd_signal(call_ctx, 1);
165 
166 	return IRQ_HANDLED;
167 }
168 
169 static irqreturn_t vhost_vdpa_config_cb(void *private)
170 {
171 	struct vhost_vdpa *v = private;
172 	struct eventfd_ctx *config_ctx = v->config_ctx;
173 
174 	if (config_ctx)
175 		eventfd_signal(config_ctx, 1);
176 
177 	return IRQ_HANDLED;
178 }
179 
180 static void vhost_vdpa_setup_vq_irq(struct vhost_vdpa *v, u16 qid)
181 {
182 	struct vhost_virtqueue *vq = &v->vqs[qid];
183 	const struct vdpa_config_ops *ops = v->vdpa->config;
184 	struct vdpa_device *vdpa = v->vdpa;
185 	int ret, irq;
186 
187 	if (!ops->get_vq_irq)
188 		return;
189 
190 	irq = ops->get_vq_irq(vdpa, qid);
191 	if (irq < 0)
192 		return;
193 
194 	irq_bypass_unregister_producer(&vq->call_ctx.producer);
195 	if (!vq->call_ctx.ctx)
196 		return;
197 
198 	vq->call_ctx.producer.token = vq->call_ctx.ctx;
199 	vq->call_ctx.producer.irq = irq;
200 	ret = irq_bypass_register_producer(&vq->call_ctx.producer);
201 	if (unlikely(ret))
202 		dev_info(&v->dev, "vq %u, irq bypass producer (token %p) registration fails, ret =  %d\n",
203 			 qid, vq->call_ctx.producer.token, ret);
204 }
205 
206 static void vhost_vdpa_unsetup_vq_irq(struct vhost_vdpa *v, u16 qid)
207 {
208 	struct vhost_virtqueue *vq = &v->vqs[qid];
209 
210 	irq_bypass_unregister_producer(&vq->call_ctx.producer);
211 }
212 
213 static int vhost_vdpa_reset(struct vhost_vdpa *v)
214 {
215 	struct vdpa_device *vdpa = v->vdpa;
216 
217 	v->in_batch = 0;
218 
219 	return vdpa_reset(vdpa);
220 }
221 
222 static long vhost_vdpa_bind_mm(struct vhost_vdpa *v)
223 {
224 	struct vdpa_device *vdpa = v->vdpa;
225 	const struct vdpa_config_ops *ops = vdpa->config;
226 
227 	if (!vdpa->use_va || !ops->bind_mm)
228 		return 0;
229 
230 	return ops->bind_mm(vdpa, v->vdev.mm);
231 }
232 
233 static void vhost_vdpa_unbind_mm(struct vhost_vdpa *v)
234 {
235 	struct vdpa_device *vdpa = v->vdpa;
236 	const struct vdpa_config_ops *ops = vdpa->config;
237 
238 	if (!vdpa->use_va || !ops->unbind_mm)
239 		return;
240 
241 	ops->unbind_mm(vdpa);
242 }
243 
244 static long vhost_vdpa_get_device_id(struct vhost_vdpa *v, u8 __user *argp)
245 {
246 	struct vdpa_device *vdpa = v->vdpa;
247 	const struct vdpa_config_ops *ops = vdpa->config;
248 	u32 device_id;
249 
250 	device_id = ops->get_device_id(vdpa);
251 
252 	if (copy_to_user(argp, &device_id, sizeof(device_id)))
253 		return -EFAULT;
254 
255 	return 0;
256 }
257 
258 static long vhost_vdpa_get_status(struct vhost_vdpa *v, u8 __user *statusp)
259 {
260 	struct vdpa_device *vdpa = v->vdpa;
261 	const struct vdpa_config_ops *ops = vdpa->config;
262 	u8 status;
263 
264 	status = ops->get_status(vdpa);
265 
266 	if (copy_to_user(statusp, &status, sizeof(status)))
267 		return -EFAULT;
268 
269 	return 0;
270 }
271 
272 static long vhost_vdpa_set_status(struct vhost_vdpa *v, u8 __user *statusp)
273 {
274 	struct vdpa_device *vdpa = v->vdpa;
275 	const struct vdpa_config_ops *ops = vdpa->config;
276 	u8 status, status_old;
277 	u32 nvqs = v->nvqs;
278 	int ret;
279 	u16 i;
280 
281 	if (copy_from_user(&status, statusp, sizeof(status)))
282 		return -EFAULT;
283 
284 	status_old = ops->get_status(vdpa);
285 
286 	/*
287 	 * Userspace shouldn't remove status bits unless reset the
288 	 * status to 0.
289 	 */
290 	if (status != 0 && (status_old & ~status) != 0)
291 		return -EINVAL;
292 
293 	if ((status_old & VIRTIO_CONFIG_S_DRIVER_OK) && !(status & VIRTIO_CONFIG_S_DRIVER_OK))
294 		for (i = 0; i < nvqs; i++)
295 			vhost_vdpa_unsetup_vq_irq(v, i);
296 
297 	if (status == 0) {
298 		ret = vdpa_reset(vdpa);
299 		if (ret)
300 			return ret;
301 	} else
302 		vdpa_set_status(vdpa, status);
303 
304 	if ((status & VIRTIO_CONFIG_S_DRIVER_OK) && !(status_old & VIRTIO_CONFIG_S_DRIVER_OK))
305 		for (i = 0; i < nvqs; i++)
306 			vhost_vdpa_setup_vq_irq(v, i);
307 
308 	return 0;
309 }
310 
311 static int vhost_vdpa_config_validate(struct vhost_vdpa *v,
312 				      struct vhost_vdpa_config *c)
313 {
314 	struct vdpa_device *vdpa = v->vdpa;
315 	size_t size = vdpa->config->get_config_size(vdpa);
316 
317 	if (c->len == 0 || c->off > size)
318 		return -EINVAL;
319 
320 	if (c->len > size - c->off)
321 		return -E2BIG;
322 
323 	return 0;
324 }
325 
326 static long vhost_vdpa_get_config(struct vhost_vdpa *v,
327 				  struct vhost_vdpa_config __user *c)
328 {
329 	struct vdpa_device *vdpa = v->vdpa;
330 	struct vhost_vdpa_config config;
331 	unsigned long size = offsetof(struct vhost_vdpa_config, buf);
332 	u8 *buf;
333 
334 	if (copy_from_user(&config, c, size))
335 		return -EFAULT;
336 	if (vhost_vdpa_config_validate(v, &config))
337 		return -EINVAL;
338 	buf = kvzalloc(config.len, GFP_KERNEL);
339 	if (!buf)
340 		return -ENOMEM;
341 
342 	vdpa_get_config(vdpa, config.off, buf, config.len);
343 
344 	if (copy_to_user(c->buf, buf, config.len)) {
345 		kvfree(buf);
346 		return -EFAULT;
347 	}
348 
349 	kvfree(buf);
350 	return 0;
351 }
352 
353 static long vhost_vdpa_set_config(struct vhost_vdpa *v,
354 				  struct vhost_vdpa_config __user *c)
355 {
356 	struct vdpa_device *vdpa = v->vdpa;
357 	struct vhost_vdpa_config config;
358 	unsigned long size = offsetof(struct vhost_vdpa_config, buf);
359 	u8 *buf;
360 
361 	if (copy_from_user(&config, c, size))
362 		return -EFAULT;
363 	if (vhost_vdpa_config_validate(v, &config))
364 		return -EINVAL;
365 
366 	buf = vmemdup_user(c->buf, config.len);
367 	if (IS_ERR(buf))
368 		return PTR_ERR(buf);
369 
370 	vdpa_set_config(vdpa, config.off, buf, config.len);
371 
372 	kvfree(buf);
373 	return 0;
374 }
375 
376 static bool vhost_vdpa_can_suspend(const struct vhost_vdpa *v)
377 {
378 	struct vdpa_device *vdpa = v->vdpa;
379 	const struct vdpa_config_ops *ops = vdpa->config;
380 
381 	return ops->suspend;
382 }
383 
384 static bool vhost_vdpa_can_resume(const struct vhost_vdpa *v)
385 {
386 	struct vdpa_device *vdpa = v->vdpa;
387 	const struct vdpa_config_ops *ops = vdpa->config;
388 
389 	return ops->resume;
390 }
391 
392 static long vhost_vdpa_get_features(struct vhost_vdpa *v, u64 __user *featurep)
393 {
394 	struct vdpa_device *vdpa = v->vdpa;
395 	const struct vdpa_config_ops *ops = vdpa->config;
396 	u64 features;
397 
398 	features = ops->get_device_features(vdpa);
399 
400 	if (copy_to_user(featurep, &features, sizeof(features)))
401 		return -EFAULT;
402 
403 	return 0;
404 }
405 
406 static long vhost_vdpa_set_features(struct vhost_vdpa *v, u64 __user *featurep)
407 {
408 	struct vdpa_device *vdpa = v->vdpa;
409 	const struct vdpa_config_ops *ops = vdpa->config;
410 	struct vhost_dev *d = &v->vdev;
411 	u64 actual_features;
412 	u64 features;
413 	int i;
414 
415 	/*
416 	 * It's not allowed to change the features after they have
417 	 * been negotiated.
418 	 */
419 	if (ops->get_status(vdpa) & VIRTIO_CONFIG_S_FEATURES_OK)
420 		return -EBUSY;
421 
422 	if (copy_from_user(&features, featurep, sizeof(features)))
423 		return -EFAULT;
424 
425 	if (vdpa_set_features(vdpa, features))
426 		return -EINVAL;
427 
428 	/* let the vqs know what has been configured */
429 	actual_features = ops->get_driver_features(vdpa);
430 	for (i = 0; i < d->nvqs; ++i) {
431 		struct vhost_virtqueue *vq = d->vqs[i];
432 
433 		mutex_lock(&vq->mutex);
434 		vq->acked_features = actual_features;
435 		mutex_unlock(&vq->mutex);
436 	}
437 
438 	return 0;
439 }
440 
441 static long vhost_vdpa_get_vring_num(struct vhost_vdpa *v, u16 __user *argp)
442 {
443 	struct vdpa_device *vdpa = v->vdpa;
444 	const struct vdpa_config_ops *ops = vdpa->config;
445 	u16 num;
446 
447 	num = ops->get_vq_num_max(vdpa);
448 
449 	if (copy_to_user(argp, &num, sizeof(num)))
450 		return -EFAULT;
451 
452 	return 0;
453 }
454 
455 static void vhost_vdpa_config_put(struct vhost_vdpa *v)
456 {
457 	if (v->config_ctx) {
458 		eventfd_ctx_put(v->config_ctx);
459 		v->config_ctx = NULL;
460 	}
461 }
462 
463 static long vhost_vdpa_set_config_call(struct vhost_vdpa *v, u32 __user *argp)
464 {
465 	struct vdpa_callback cb;
466 	int fd;
467 	struct eventfd_ctx *ctx;
468 
469 	cb.callback = vhost_vdpa_config_cb;
470 	cb.private = v;
471 	if (copy_from_user(&fd, argp, sizeof(fd)))
472 		return  -EFAULT;
473 
474 	ctx = fd == VHOST_FILE_UNBIND ? NULL : eventfd_ctx_fdget(fd);
475 	swap(ctx, v->config_ctx);
476 
477 	if (!IS_ERR_OR_NULL(ctx))
478 		eventfd_ctx_put(ctx);
479 
480 	if (IS_ERR(v->config_ctx)) {
481 		long ret = PTR_ERR(v->config_ctx);
482 
483 		v->config_ctx = NULL;
484 		return ret;
485 	}
486 
487 	v->vdpa->config->set_config_cb(v->vdpa, &cb);
488 
489 	return 0;
490 }
491 
492 static long vhost_vdpa_get_iova_range(struct vhost_vdpa *v, u32 __user *argp)
493 {
494 	struct vhost_vdpa_iova_range range = {
495 		.first = v->range.first,
496 		.last = v->range.last,
497 	};
498 
499 	if (copy_to_user(argp, &range, sizeof(range)))
500 		return -EFAULT;
501 	return 0;
502 }
503 
504 static long vhost_vdpa_get_config_size(struct vhost_vdpa *v, u32 __user *argp)
505 {
506 	struct vdpa_device *vdpa = v->vdpa;
507 	const struct vdpa_config_ops *ops = vdpa->config;
508 	u32 size;
509 
510 	size = ops->get_config_size(vdpa);
511 
512 	if (copy_to_user(argp, &size, sizeof(size)))
513 		return -EFAULT;
514 
515 	return 0;
516 }
517 
518 static long vhost_vdpa_get_vqs_count(struct vhost_vdpa *v, u32 __user *argp)
519 {
520 	struct vdpa_device *vdpa = v->vdpa;
521 
522 	if (copy_to_user(argp, &vdpa->nvqs, sizeof(vdpa->nvqs)))
523 		return -EFAULT;
524 
525 	return 0;
526 }
527 
528 /* After a successful return of ioctl the device must not process more
529  * virtqueue descriptors. The device can answer to read or writes of config
530  * fields as if it were not suspended. In particular, writing to "queue_enable"
531  * with a value of 1 will not make the device start processing buffers.
532  */
533 static long vhost_vdpa_suspend(struct vhost_vdpa *v)
534 {
535 	struct vdpa_device *vdpa = v->vdpa;
536 	const struct vdpa_config_ops *ops = vdpa->config;
537 
538 	if (!ops->suspend)
539 		return -EOPNOTSUPP;
540 
541 	return ops->suspend(vdpa);
542 }
543 
544 /* After a successful return of this ioctl the device resumes processing
545  * virtqueue descriptors. The device becomes fully operational the same way it
546  * was before it was suspended.
547  */
548 static long vhost_vdpa_resume(struct vhost_vdpa *v)
549 {
550 	struct vdpa_device *vdpa = v->vdpa;
551 	const struct vdpa_config_ops *ops = vdpa->config;
552 
553 	if (!ops->resume)
554 		return -EOPNOTSUPP;
555 
556 	return ops->resume(vdpa);
557 }
558 
559 static long vhost_vdpa_vring_ioctl(struct vhost_vdpa *v, unsigned int cmd,
560 				   void __user *argp)
561 {
562 	struct vdpa_device *vdpa = v->vdpa;
563 	const struct vdpa_config_ops *ops = vdpa->config;
564 	struct vdpa_vq_state vq_state;
565 	struct vdpa_callback cb;
566 	struct vhost_virtqueue *vq;
567 	struct vhost_vring_state s;
568 	u32 idx;
569 	long r;
570 
571 	r = get_user(idx, (u32 __user *)argp);
572 	if (r < 0)
573 		return r;
574 
575 	if (idx >= v->nvqs)
576 		return -ENOBUFS;
577 
578 	idx = array_index_nospec(idx, v->nvqs);
579 	vq = &v->vqs[idx];
580 
581 	switch (cmd) {
582 	case VHOST_VDPA_SET_VRING_ENABLE:
583 		if (copy_from_user(&s, argp, sizeof(s)))
584 			return -EFAULT;
585 		ops->set_vq_ready(vdpa, idx, s.num);
586 		return 0;
587 	case VHOST_VDPA_GET_VRING_GROUP:
588 		if (!ops->get_vq_group)
589 			return -EOPNOTSUPP;
590 		s.index = idx;
591 		s.num = ops->get_vq_group(vdpa, idx);
592 		if (s.num >= vdpa->ngroups)
593 			return -EIO;
594 		else if (copy_to_user(argp, &s, sizeof(s)))
595 			return -EFAULT;
596 		return 0;
597 	case VHOST_VDPA_SET_GROUP_ASID:
598 		if (copy_from_user(&s, argp, sizeof(s)))
599 			return -EFAULT;
600 		if (s.num >= vdpa->nas)
601 			return -EINVAL;
602 		if (!ops->set_group_asid)
603 			return -EOPNOTSUPP;
604 		return ops->set_group_asid(vdpa, idx, s.num);
605 	case VHOST_GET_VRING_BASE:
606 		r = ops->get_vq_state(v->vdpa, idx, &vq_state);
607 		if (r)
608 			return r;
609 
610 		if (vhost_has_feature(vq, VIRTIO_F_RING_PACKED)) {
611 			vq->last_avail_idx = vq_state.packed.last_avail_idx |
612 					     (vq_state.packed.last_avail_counter << 15);
613 			vq->last_used_idx = vq_state.packed.last_used_idx |
614 					    (vq_state.packed.last_used_counter << 15);
615 		} else {
616 			vq->last_avail_idx = vq_state.split.avail_index;
617 		}
618 		break;
619 	}
620 
621 	r = vhost_vring_ioctl(&v->vdev, cmd, argp);
622 	if (r)
623 		return r;
624 
625 	switch (cmd) {
626 	case VHOST_SET_VRING_ADDR:
627 		if (ops->set_vq_address(vdpa, idx,
628 					(u64)(uintptr_t)vq->desc,
629 					(u64)(uintptr_t)vq->avail,
630 					(u64)(uintptr_t)vq->used))
631 			r = -EINVAL;
632 		break;
633 
634 	case VHOST_SET_VRING_BASE:
635 		if (vhost_has_feature(vq, VIRTIO_F_RING_PACKED)) {
636 			vq_state.packed.last_avail_idx = vq->last_avail_idx & 0x7fff;
637 			vq_state.packed.last_avail_counter = !!(vq->last_avail_idx & 0x8000);
638 			vq_state.packed.last_used_idx = vq->last_used_idx & 0x7fff;
639 			vq_state.packed.last_used_counter = !!(vq->last_used_idx & 0x8000);
640 		} else {
641 			vq_state.split.avail_index = vq->last_avail_idx;
642 		}
643 		r = ops->set_vq_state(vdpa, idx, &vq_state);
644 		break;
645 
646 	case VHOST_SET_VRING_CALL:
647 		if (vq->call_ctx.ctx) {
648 			cb.callback = vhost_vdpa_virtqueue_cb;
649 			cb.private = vq;
650 			cb.trigger = vq->call_ctx.ctx;
651 		} else {
652 			cb.callback = NULL;
653 			cb.private = NULL;
654 			cb.trigger = NULL;
655 		}
656 		ops->set_vq_cb(vdpa, idx, &cb);
657 		vhost_vdpa_setup_vq_irq(v, idx);
658 		break;
659 
660 	case VHOST_SET_VRING_NUM:
661 		ops->set_vq_num(vdpa, idx, vq->num);
662 		break;
663 	}
664 
665 	return r;
666 }
667 
668 static long vhost_vdpa_unlocked_ioctl(struct file *filep,
669 				      unsigned int cmd, unsigned long arg)
670 {
671 	struct vhost_vdpa *v = filep->private_data;
672 	struct vhost_dev *d = &v->vdev;
673 	void __user *argp = (void __user *)arg;
674 	u64 __user *featurep = argp;
675 	u64 features;
676 	long r = 0;
677 
678 	if (cmd == VHOST_SET_BACKEND_FEATURES) {
679 		if (copy_from_user(&features, featurep, sizeof(features)))
680 			return -EFAULT;
681 		if (features & ~(VHOST_VDPA_BACKEND_FEATURES |
682 				 BIT_ULL(VHOST_BACKEND_F_SUSPEND) |
683 				 BIT_ULL(VHOST_BACKEND_F_RESUME)))
684 			return -EOPNOTSUPP;
685 		if ((features & BIT_ULL(VHOST_BACKEND_F_SUSPEND)) &&
686 		     !vhost_vdpa_can_suspend(v))
687 			return -EOPNOTSUPP;
688 		if ((features & BIT_ULL(VHOST_BACKEND_F_RESUME)) &&
689 		     !vhost_vdpa_can_resume(v))
690 			return -EOPNOTSUPP;
691 		vhost_set_backend_features(&v->vdev, features);
692 		return 0;
693 	}
694 
695 	mutex_lock(&d->mutex);
696 
697 	switch (cmd) {
698 	case VHOST_VDPA_GET_DEVICE_ID:
699 		r = vhost_vdpa_get_device_id(v, argp);
700 		break;
701 	case VHOST_VDPA_GET_STATUS:
702 		r = vhost_vdpa_get_status(v, argp);
703 		break;
704 	case VHOST_VDPA_SET_STATUS:
705 		r = vhost_vdpa_set_status(v, argp);
706 		break;
707 	case VHOST_VDPA_GET_CONFIG:
708 		r = vhost_vdpa_get_config(v, argp);
709 		break;
710 	case VHOST_VDPA_SET_CONFIG:
711 		r = vhost_vdpa_set_config(v, argp);
712 		break;
713 	case VHOST_GET_FEATURES:
714 		r = vhost_vdpa_get_features(v, argp);
715 		break;
716 	case VHOST_SET_FEATURES:
717 		r = vhost_vdpa_set_features(v, argp);
718 		break;
719 	case VHOST_VDPA_GET_VRING_NUM:
720 		r = vhost_vdpa_get_vring_num(v, argp);
721 		break;
722 	case VHOST_VDPA_GET_GROUP_NUM:
723 		if (copy_to_user(argp, &v->vdpa->ngroups,
724 				 sizeof(v->vdpa->ngroups)))
725 			r = -EFAULT;
726 		break;
727 	case VHOST_VDPA_GET_AS_NUM:
728 		if (copy_to_user(argp, &v->vdpa->nas, sizeof(v->vdpa->nas)))
729 			r = -EFAULT;
730 		break;
731 	case VHOST_SET_LOG_BASE:
732 	case VHOST_SET_LOG_FD:
733 		r = -ENOIOCTLCMD;
734 		break;
735 	case VHOST_VDPA_SET_CONFIG_CALL:
736 		r = vhost_vdpa_set_config_call(v, argp);
737 		break;
738 	case VHOST_GET_BACKEND_FEATURES:
739 		features = VHOST_VDPA_BACKEND_FEATURES;
740 		if (vhost_vdpa_can_suspend(v))
741 			features |= BIT_ULL(VHOST_BACKEND_F_SUSPEND);
742 		if (vhost_vdpa_can_resume(v))
743 			features |= BIT_ULL(VHOST_BACKEND_F_RESUME);
744 		if (copy_to_user(featurep, &features, sizeof(features)))
745 			r = -EFAULT;
746 		break;
747 	case VHOST_VDPA_GET_IOVA_RANGE:
748 		r = vhost_vdpa_get_iova_range(v, argp);
749 		break;
750 	case VHOST_VDPA_GET_CONFIG_SIZE:
751 		r = vhost_vdpa_get_config_size(v, argp);
752 		break;
753 	case VHOST_VDPA_GET_VQS_COUNT:
754 		r = vhost_vdpa_get_vqs_count(v, argp);
755 		break;
756 	case VHOST_VDPA_SUSPEND:
757 		r = vhost_vdpa_suspend(v);
758 		break;
759 	case VHOST_VDPA_RESUME:
760 		r = vhost_vdpa_resume(v);
761 		break;
762 	default:
763 		r = vhost_dev_ioctl(&v->vdev, cmd, argp);
764 		if (r == -ENOIOCTLCMD)
765 			r = vhost_vdpa_vring_ioctl(v, cmd, argp);
766 		break;
767 	}
768 
769 	if (r)
770 		goto out;
771 
772 	switch (cmd) {
773 	case VHOST_SET_OWNER:
774 		r = vhost_vdpa_bind_mm(v);
775 		if (r)
776 			vhost_dev_reset_owner(d, NULL);
777 		break;
778 	}
779 out:
780 	mutex_unlock(&d->mutex);
781 	return r;
782 }
783 static void vhost_vdpa_general_unmap(struct vhost_vdpa *v,
784 				     struct vhost_iotlb_map *map, u32 asid)
785 {
786 	struct vdpa_device *vdpa = v->vdpa;
787 	const struct vdpa_config_ops *ops = vdpa->config;
788 	if (ops->dma_map) {
789 		ops->dma_unmap(vdpa, asid, map->start, map->size);
790 	} else if (ops->set_map == NULL) {
791 		iommu_unmap(v->domain, map->start, map->size);
792 	}
793 }
794 
795 static void vhost_vdpa_pa_unmap(struct vhost_vdpa *v, struct vhost_iotlb *iotlb,
796 				u64 start, u64 last, u32 asid)
797 {
798 	struct vhost_dev *dev = &v->vdev;
799 	struct vhost_iotlb_map *map;
800 	struct page *page;
801 	unsigned long pfn, pinned;
802 
803 	while ((map = vhost_iotlb_itree_first(iotlb, start, last)) != NULL) {
804 		pinned = PFN_DOWN(map->size);
805 		for (pfn = PFN_DOWN(map->addr);
806 		     pinned > 0; pfn++, pinned--) {
807 			page = pfn_to_page(pfn);
808 			if (map->perm & VHOST_ACCESS_WO)
809 				set_page_dirty_lock(page);
810 			unpin_user_page(page);
811 		}
812 		atomic64_sub(PFN_DOWN(map->size), &dev->mm->pinned_vm);
813 		vhost_vdpa_general_unmap(v, map, asid);
814 		vhost_iotlb_map_free(iotlb, map);
815 	}
816 }
817 
818 static void vhost_vdpa_va_unmap(struct vhost_vdpa *v, struct vhost_iotlb *iotlb,
819 				u64 start, u64 last, u32 asid)
820 {
821 	struct vhost_iotlb_map *map;
822 	struct vdpa_map_file *map_file;
823 
824 	while ((map = vhost_iotlb_itree_first(iotlb, start, last)) != NULL) {
825 		map_file = (struct vdpa_map_file *)map->opaque;
826 		fput(map_file->file);
827 		kfree(map_file);
828 		vhost_vdpa_general_unmap(v, map, asid);
829 		vhost_iotlb_map_free(iotlb, map);
830 	}
831 }
832 
833 static void vhost_vdpa_iotlb_unmap(struct vhost_vdpa *v,
834 				   struct vhost_iotlb *iotlb, u64 start,
835 				   u64 last, u32 asid)
836 {
837 	struct vdpa_device *vdpa = v->vdpa;
838 
839 	if (vdpa->use_va)
840 		return vhost_vdpa_va_unmap(v, iotlb, start, last, asid);
841 
842 	return vhost_vdpa_pa_unmap(v, iotlb, start, last, asid);
843 }
844 
845 static int perm_to_iommu_flags(u32 perm)
846 {
847 	int flags = 0;
848 
849 	switch (perm) {
850 	case VHOST_ACCESS_WO:
851 		flags |= IOMMU_WRITE;
852 		break;
853 	case VHOST_ACCESS_RO:
854 		flags |= IOMMU_READ;
855 		break;
856 	case VHOST_ACCESS_RW:
857 		flags |= (IOMMU_WRITE | IOMMU_READ);
858 		break;
859 	default:
860 		WARN(1, "invalidate vhost IOTLB permission\n");
861 		break;
862 	}
863 
864 	return flags | IOMMU_CACHE;
865 }
866 
867 static int vhost_vdpa_map(struct vhost_vdpa *v, struct vhost_iotlb *iotlb,
868 			  u64 iova, u64 size, u64 pa, u32 perm, void *opaque)
869 {
870 	struct vhost_dev *dev = &v->vdev;
871 	struct vdpa_device *vdpa = v->vdpa;
872 	const struct vdpa_config_ops *ops = vdpa->config;
873 	u32 asid = iotlb_to_asid(iotlb);
874 	int r = 0;
875 
876 	r = vhost_iotlb_add_range_ctx(iotlb, iova, iova + size - 1,
877 				      pa, perm, opaque);
878 	if (r)
879 		return r;
880 
881 	if (ops->dma_map) {
882 		r = ops->dma_map(vdpa, asid, iova, size, pa, perm, opaque);
883 	} else if (ops->set_map) {
884 		if (!v->in_batch)
885 			r = ops->set_map(vdpa, asid, iotlb);
886 	} else {
887 		r = iommu_map(v->domain, iova, pa, size,
888 			      perm_to_iommu_flags(perm), GFP_KERNEL);
889 	}
890 	if (r) {
891 		vhost_iotlb_del_range(iotlb, iova, iova + size - 1);
892 		return r;
893 	}
894 
895 	if (!vdpa->use_va)
896 		atomic64_add(PFN_DOWN(size), &dev->mm->pinned_vm);
897 
898 	return 0;
899 }
900 
901 static void vhost_vdpa_unmap(struct vhost_vdpa *v,
902 			     struct vhost_iotlb *iotlb,
903 			     u64 iova, u64 size)
904 {
905 	struct vdpa_device *vdpa = v->vdpa;
906 	const struct vdpa_config_ops *ops = vdpa->config;
907 	u32 asid = iotlb_to_asid(iotlb);
908 
909 	vhost_vdpa_iotlb_unmap(v, iotlb, iova, iova + size - 1, asid);
910 
911 	if (ops->set_map) {
912 		if (!v->in_batch)
913 			ops->set_map(vdpa, asid, iotlb);
914 	}
915 
916 }
917 
918 static int vhost_vdpa_va_map(struct vhost_vdpa *v,
919 			     struct vhost_iotlb *iotlb,
920 			     u64 iova, u64 size, u64 uaddr, u32 perm)
921 {
922 	struct vhost_dev *dev = &v->vdev;
923 	u64 offset, map_size, map_iova = iova;
924 	struct vdpa_map_file *map_file;
925 	struct vm_area_struct *vma;
926 	int ret = 0;
927 
928 	mmap_read_lock(dev->mm);
929 
930 	while (size) {
931 		vma = find_vma(dev->mm, uaddr);
932 		if (!vma) {
933 			ret = -EINVAL;
934 			break;
935 		}
936 		map_size = min(size, vma->vm_end - uaddr);
937 		if (!(vma->vm_file && (vma->vm_flags & VM_SHARED) &&
938 			!(vma->vm_flags & (VM_IO | VM_PFNMAP))))
939 			goto next;
940 
941 		map_file = kzalloc(sizeof(*map_file), GFP_KERNEL);
942 		if (!map_file) {
943 			ret = -ENOMEM;
944 			break;
945 		}
946 		offset = (vma->vm_pgoff << PAGE_SHIFT) + uaddr - vma->vm_start;
947 		map_file->offset = offset;
948 		map_file->file = get_file(vma->vm_file);
949 		ret = vhost_vdpa_map(v, iotlb, map_iova, map_size, uaddr,
950 				     perm, map_file);
951 		if (ret) {
952 			fput(map_file->file);
953 			kfree(map_file);
954 			break;
955 		}
956 next:
957 		size -= map_size;
958 		uaddr += map_size;
959 		map_iova += map_size;
960 	}
961 	if (ret)
962 		vhost_vdpa_unmap(v, iotlb, iova, map_iova - iova);
963 
964 	mmap_read_unlock(dev->mm);
965 
966 	return ret;
967 }
968 
969 static int vhost_vdpa_pa_map(struct vhost_vdpa *v,
970 			     struct vhost_iotlb *iotlb,
971 			     u64 iova, u64 size, u64 uaddr, u32 perm)
972 {
973 	struct vhost_dev *dev = &v->vdev;
974 	struct page **page_list;
975 	unsigned long list_size = PAGE_SIZE / sizeof(struct page *);
976 	unsigned int gup_flags = FOLL_LONGTERM;
977 	unsigned long npages, cur_base, map_pfn, last_pfn = 0;
978 	unsigned long lock_limit, sz2pin, nchunks, i;
979 	u64 start = iova;
980 	long pinned;
981 	int ret = 0;
982 
983 	/* Limit the use of memory for bookkeeping */
984 	page_list = (struct page **) __get_free_page(GFP_KERNEL);
985 	if (!page_list)
986 		return -ENOMEM;
987 
988 	if (perm & VHOST_ACCESS_WO)
989 		gup_flags |= FOLL_WRITE;
990 
991 	npages = PFN_UP(size + (iova & ~PAGE_MASK));
992 	if (!npages) {
993 		ret = -EINVAL;
994 		goto free;
995 	}
996 
997 	mmap_read_lock(dev->mm);
998 
999 	lock_limit = PFN_DOWN(rlimit(RLIMIT_MEMLOCK));
1000 	if (npages + atomic64_read(&dev->mm->pinned_vm) > lock_limit) {
1001 		ret = -ENOMEM;
1002 		goto unlock;
1003 	}
1004 
1005 	cur_base = uaddr & PAGE_MASK;
1006 	iova &= PAGE_MASK;
1007 	nchunks = 0;
1008 
1009 	while (npages) {
1010 		sz2pin = min_t(unsigned long, npages, list_size);
1011 		pinned = pin_user_pages(cur_base, sz2pin,
1012 					gup_flags, page_list);
1013 		if (sz2pin != pinned) {
1014 			if (pinned < 0) {
1015 				ret = pinned;
1016 			} else {
1017 				unpin_user_pages(page_list, pinned);
1018 				ret = -ENOMEM;
1019 			}
1020 			goto out;
1021 		}
1022 		nchunks++;
1023 
1024 		if (!last_pfn)
1025 			map_pfn = page_to_pfn(page_list[0]);
1026 
1027 		for (i = 0; i < pinned; i++) {
1028 			unsigned long this_pfn = page_to_pfn(page_list[i]);
1029 			u64 csize;
1030 
1031 			if (last_pfn && (this_pfn != last_pfn + 1)) {
1032 				/* Pin a contiguous chunk of memory */
1033 				csize = PFN_PHYS(last_pfn - map_pfn + 1);
1034 				ret = vhost_vdpa_map(v, iotlb, iova, csize,
1035 						     PFN_PHYS(map_pfn),
1036 						     perm, NULL);
1037 				if (ret) {
1038 					/*
1039 					 * Unpin the pages that are left unmapped
1040 					 * from this point on in the current
1041 					 * page_list. The remaining outstanding
1042 					 * ones which may stride across several
1043 					 * chunks will be covered in the common
1044 					 * error path subsequently.
1045 					 */
1046 					unpin_user_pages(&page_list[i],
1047 							 pinned - i);
1048 					goto out;
1049 				}
1050 
1051 				map_pfn = this_pfn;
1052 				iova += csize;
1053 				nchunks = 0;
1054 			}
1055 
1056 			last_pfn = this_pfn;
1057 		}
1058 
1059 		cur_base += PFN_PHYS(pinned);
1060 		npages -= pinned;
1061 	}
1062 
1063 	/* Pin the rest chunk */
1064 	ret = vhost_vdpa_map(v, iotlb, iova, PFN_PHYS(last_pfn - map_pfn + 1),
1065 			     PFN_PHYS(map_pfn), perm, NULL);
1066 out:
1067 	if (ret) {
1068 		if (nchunks) {
1069 			unsigned long pfn;
1070 
1071 			/*
1072 			 * Unpin the outstanding pages which are yet to be
1073 			 * mapped but haven't due to vdpa_map() or
1074 			 * pin_user_pages() failure.
1075 			 *
1076 			 * Mapped pages are accounted in vdpa_map(), hence
1077 			 * the corresponding unpinning will be handled by
1078 			 * vdpa_unmap().
1079 			 */
1080 			WARN_ON(!last_pfn);
1081 			for (pfn = map_pfn; pfn <= last_pfn; pfn++)
1082 				unpin_user_page(pfn_to_page(pfn));
1083 		}
1084 		vhost_vdpa_unmap(v, iotlb, start, size);
1085 	}
1086 unlock:
1087 	mmap_read_unlock(dev->mm);
1088 free:
1089 	free_page((unsigned long)page_list);
1090 	return ret;
1091 
1092 }
1093 
1094 static int vhost_vdpa_process_iotlb_update(struct vhost_vdpa *v,
1095 					   struct vhost_iotlb *iotlb,
1096 					   struct vhost_iotlb_msg *msg)
1097 {
1098 	struct vdpa_device *vdpa = v->vdpa;
1099 
1100 	if (msg->iova < v->range.first || !msg->size ||
1101 	    msg->iova > U64_MAX - msg->size + 1 ||
1102 	    msg->iova + msg->size - 1 > v->range.last)
1103 		return -EINVAL;
1104 
1105 	if (vhost_iotlb_itree_first(iotlb, msg->iova,
1106 				    msg->iova + msg->size - 1))
1107 		return -EEXIST;
1108 
1109 	if (vdpa->use_va)
1110 		return vhost_vdpa_va_map(v, iotlb, msg->iova, msg->size,
1111 					 msg->uaddr, msg->perm);
1112 
1113 	return vhost_vdpa_pa_map(v, iotlb, msg->iova, msg->size, msg->uaddr,
1114 				 msg->perm);
1115 }
1116 
1117 static int vhost_vdpa_process_iotlb_msg(struct vhost_dev *dev, u32 asid,
1118 					struct vhost_iotlb_msg *msg)
1119 {
1120 	struct vhost_vdpa *v = container_of(dev, struct vhost_vdpa, vdev);
1121 	struct vdpa_device *vdpa = v->vdpa;
1122 	const struct vdpa_config_ops *ops = vdpa->config;
1123 	struct vhost_iotlb *iotlb = NULL;
1124 	struct vhost_vdpa_as *as = NULL;
1125 	int r = 0;
1126 
1127 	mutex_lock(&dev->mutex);
1128 
1129 	r = vhost_dev_check_owner(dev);
1130 	if (r)
1131 		goto unlock;
1132 
1133 	if (msg->type == VHOST_IOTLB_UPDATE ||
1134 	    msg->type == VHOST_IOTLB_BATCH_BEGIN) {
1135 		as = vhost_vdpa_find_alloc_as(v, asid);
1136 		if (!as) {
1137 			dev_err(&v->dev, "can't find and alloc asid %d\n",
1138 				asid);
1139 			r = -EINVAL;
1140 			goto unlock;
1141 		}
1142 		iotlb = &as->iotlb;
1143 	} else
1144 		iotlb = asid_to_iotlb(v, asid);
1145 
1146 	if ((v->in_batch && v->batch_asid != asid) || !iotlb) {
1147 		if (v->in_batch && v->batch_asid != asid) {
1148 			dev_info(&v->dev, "batch id %d asid %d\n",
1149 				 v->batch_asid, asid);
1150 		}
1151 		if (!iotlb)
1152 			dev_err(&v->dev, "no iotlb for asid %d\n", asid);
1153 		r = -EINVAL;
1154 		goto unlock;
1155 	}
1156 
1157 	switch (msg->type) {
1158 	case VHOST_IOTLB_UPDATE:
1159 		r = vhost_vdpa_process_iotlb_update(v, iotlb, msg);
1160 		break;
1161 	case VHOST_IOTLB_INVALIDATE:
1162 		vhost_vdpa_unmap(v, iotlb, msg->iova, msg->size);
1163 		break;
1164 	case VHOST_IOTLB_BATCH_BEGIN:
1165 		v->batch_asid = asid;
1166 		v->in_batch = true;
1167 		break;
1168 	case VHOST_IOTLB_BATCH_END:
1169 		if (v->in_batch && ops->set_map)
1170 			ops->set_map(vdpa, asid, iotlb);
1171 		v->in_batch = false;
1172 		break;
1173 	default:
1174 		r = -EINVAL;
1175 		break;
1176 	}
1177 unlock:
1178 	mutex_unlock(&dev->mutex);
1179 
1180 	return r;
1181 }
1182 
1183 static ssize_t vhost_vdpa_chr_write_iter(struct kiocb *iocb,
1184 					 struct iov_iter *from)
1185 {
1186 	struct file *file = iocb->ki_filp;
1187 	struct vhost_vdpa *v = file->private_data;
1188 	struct vhost_dev *dev = &v->vdev;
1189 
1190 	return vhost_chr_write_iter(dev, from);
1191 }
1192 
1193 static int vhost_vdpa_alloc_domain(struct vhost_vdpa *v)
1194 {
1195 	struct vdpa_device *vdpa = v->vdpa;
1196 	const struct vdpa_config_ops *ops = vdpa->config;
1197 	struct device *dma_dev = vdpa_get_dma_dev(vdpa);
1198 	const struct bus_type *bus;
1199 	int ret;
1200 
1201 	/* Device want to do DMA by itself */
1202 	if (ops->set_map || ops->dma_map)
1203 		return 0;
1204 
1205 	bus = dma_dev->bus;
1206 	if (!bus)
1207 		return -EFAULT;
1208 
1209 	if (!device_iommu_capable(dma_dev, IOMMU_CAP_CACHE_COHERENCY)) {
1210 		dev_warn_once(&v->dev,
1211 			      "Failed to allocate domain, device is not IOMMU cache coherent capable\n");
1212 		return -ENOTSUPP;
1213 	}
1214 
1215 	v->domain = iommu_domain_alloc(bus);
1216 	if (!v->domain)
1217 		return -EIO;
1218 
1219 	ret = iommu_attach_device(v->domain, dma_dev);
1220 	if (ret)
1221 		goto err_attach;
1222 
1223 	return 0;
1224 
1225 err_attach:
1226 	iommu_domain_free(v->domain);
1227 	v->domain = NULL;
1228 	return ret;
1229 }
1230 
1231 static void vhost_vdpa_free_domain(struct vhost_vdpa *v)
1232 {
1233 	struct vdpa_device *vdpa = v->vdpa;
1234 	struct device *dma_dev = vdpa_get_dma_dev(vdpa);
1235 
1236 	if (v->domain) {
1237 		iommu_detach_device(v->domain, dma_dev);
1238 		iommu_domain_free(v->domain);
1239 	}
1240 
1241 	v->domain = NULL;
1242 }
1243 
1244 static void vhost_vdpa_set_iova_range(struct vhost_vdpa *v)
1245 {
1246 	struct vdpa_iova_range *range = &v->range;
1247 	struct vdpa_device *vdpa = v->vdpa;
1248 	const struct vdpa_config_ops *ops = vdpa->config;
1249 
1250 	if (ops->get_iova_range) {
1251 		*range = ops->get_iova_range(vdpa);
1252 	} else if (v->domain && v->domain->geometry.force_aperture) {
1253 		range->first = v->domain->geometry.aperture_start;
1254 		range->last = v->domain->geometry.aperture_end;
1255 	} else {
1256 		range->first = 0;
1257 		range->last = ULLONG_MAX;
1258 	}
1259 }
1260 
1261 static void vhost_vdpa_cleanup(struct vhost_vdpa *v)
1262 {
1263 	struct vhost_vdpa_as *as;
1264 	u32 asid;
1265 
1266 	for (asid = 0; asid < v->vdpa->nas; asid++) {
1267 		as = asid_to_as(v, asid);
1268 		if (as)
1269 			vhost_vdpa_remove_as(v, asid);
1270 	}
1271 
1272 	vhost_vdpa_free_domain(v);
1273 	vhost_dev_cleanup(&v->vdev);
1274 	kfree(v->vdev.vqs);
1275 }
1276 
1277 static int vhost_vdpa_open(struct inode *inode, struct file *filep)
1278 {
1279 	struct vhost_vdpa *v;
1280 	struct vhost_dev *dev;
1281 	struct vhost_virtqueue **vqs;
1282 	int r, opened;
1283 	u32 i, nvqs;
1284 
1285 	v = container_of(inode->i_cdev, struct vhost_vdpa, cdev);
1286 
1287 	opened = atomic_cmpxchg(&v->opened, 0, 1);
1288 	if (opened)
1289 		return -EBUSY;
1290 
1291 	nvqs = v->nvqs;
1292 	r = vhost_vdpa_reset(v);
1293 	if (r)
1294 		goto err;
1295 
1296 	vqs = kmalloc_array(nvqs, sizeof(*vqs), GFP_KERNEL);
1297 	if (!vqs) {
1298 		r = -ENOMEM;
1299 		goto err;
1300 	}
1301 
1302 	dev = &v->vdev;
1303 	for (i = 0; i < nvqs; i++) {
1304 		vqs[i] = &v->vqs[i];
1305 		vqs[i]->handle_kick = handle_vq_kick;
1306 	}
1307 	vhost_dev_init(dev, vqs, nvqs, 0, 0, 0, false,
1308 		       vhost_vdpa_process_iotlb_msg);
1309 
1310 	r = vhost_vdpa_alloc_domain(v);
1311 	if (r)
1312 		goto err_alloc_domain;
1313 
1314 	vhost_vdpa_set_iova_range(v);
1315 
1316 	filep->private_data = v;
1317 
1318 	return 0;
1319 
1320 err_alloc_domain:
1321 	vhost_vdpa_cleanup(v);
1322 err:
1323 	atomic_dec(&v->opened);
1324 	return r;
1325 }
1326 
1327 static void vhost_vdpa_clean_irq(struct vhost_vdpa *v)
1328 {
1329 	u32 i;
1330 
1331 	for (i = 0; i < v->nvqs; i++)
1332 		vhost_vdpa_unsetup_vq_irq(v, i);
1333 }
1334 
1335 static int vhost_vdpa_release(struct inode *inode, struct file *filep)
1336 {
1337 	struct vhost_vdpa *v = filep->private_data;
1338 	struct vhost_dev *d = &v->vdev;
1339 
1340 	mutex_lock(&d->mutex);
1341 	filep->private_data = NULL;
1342 	vhost_vdpa_clean_irq(v);
1343 	vhost_vdpa_reset(v);
1344 	vhost_dev_stop(&v->vdev);
1345 	vhost_vdpa_unbind_mm(v);
1346 	vhost_vdpa_config_put(v);
1347 	vhost_vdpa_cleanup(v);
1348 	mutex_unlock(&d->mutex);
1349 
1350 	atomic_dec(&v->opened);
1351 	complete(&v->completion);
1352 
1353 	return 0;
1354 }
1355 
1356 #ifdef CONFIG_MMU
1357 static vm_fault_t vhost_vdpa_fault(struct vm_fault *vmf)
1358 {
1359 	struct vhost_vdpa *v = vmf->vma->vm_file->private_data;
1360 	struct vdpa_device *vdpa = v->vdpa;
1361 	const struct vdpa_config_ops *ops = vdpa->config;
1362 	struct vdpa_notification_area notify;
1363 	struct vm_area_struct *vma = vmf->vma;
1364 	u16 index = vma->vm_pgoff;
1365 
1366 	notify = ops->get_vq_notification(vdpa, index);
1367 
1368 	vma->vm_page_prot = pgprot_noncached(vma->vm_page_prot);
1369 	if (remap_pfn_range(vma, vmf->address & PAGE_MASK,
1370 			    PFN_DOWN(notify.addr), PAGE_SIZE,
1371 			    vma->vm_page_prot))
1372 		return VM_FAULT_SIGBUS;
1373 
1374 	return VM_FAULT_NOPAGE;
1375 }
1376 
1377 static const struct vm_operations_struct vhost_vdpa_vm_ops = {
1378 	.fault = vhost_vdpa_fault,
1379 };
1380 
1381 static int vhost_vdpa_mmap(struct file *file, struct vm_area_struct *vma)
1382 {
1383 	struct vhost_vdpa *v = vma->vm_file->private_data;
1384 	struct vdpa_device *vdpa = v->vdpa;
1385 	const struct vdpa_config_ops *ops = vdpa->config;
1386 	struct vdpa_notification_area notify;
1387 	unsigned long index = vma->vm_pgoff;
1388 
1389 	if (vma->vm_end - vma->vm_start != PAGE_SIZE)
1390 		return -EINVAL;
1391 	if ((vma->vm_flags & VM_SHARED) == 0)
1392 		return -EINVAL;
1393 	if (vma->vm_flags & VM_READ)
1394 		return -EINVAL;
1395 	if (index > 65535)
1396 		return -EINVAL;
1397 	if (!ops->get_vq_notification)
1398 		return -ENOTSUPP;
1399 
1400 	/* To be safe and easily modelled by userspace, We only
1401 	 * support the doorbell which sits on the page boundary and
1402 	 * does not share the page with other registers.
1403 	 */
1404 	notify = ops->get_vq_notification(vdpa, index);
1405 	if (notify.addr & (PAGE_SIZE - 1))
1406 		return -EINVAL;
1407 	if (vma->vm_end - vma->vm_start != notify.size)
1408 		return -ENOTSUPP;
1409 
1410 	vm_flags_set(vma, VM_IO | VM_PFNMAP | VM_DONTEXPAND | VM_DONTDUMP);
1411 	vma->vm_ops = &vhost_vdpa_vm_ops;
1412 	return 0;
1413 }
1414 #endif /* CONFIG_MMU */
1415 
1416 static const struct file_operations vhost_vdpa_fops = {
1417 	.owner		= THIS_MODULE,
1418 	.open		= vhost_vdpa_open,
1419 	.release	= vhost_vdpa_release,
1420 	.write_iter	= vhost_vdpa_chr_write_iter,
1421 	.unlocked_ioctl	= vhost_vdpa_unlocked_ioctl,
1422 #ifdef CONFIG_MMU
1423 	.mmap		= vhost_vdpa_mmap,
1424 #endif /* CONFIG_MMU */
1425 	.compat_ioctl	= compat_ptr_ioctl,
1426 };
1427 
1428 static void vhost_vdpa_release_dev(struct device *device)
1429 {
1430 	struct vhost_vdpa *v =
1431 	       container_of(device, struct vhost_vdpa, dev);
1432 
1433 	ida_simple_remove(&vhost_vdpa_ida, v->minor);
1434 	kfree(v->vqs);
1435 	kfree(v);
1436 }
1437 
1438 static int vhost_vdpa_probe(struct vdpa_device *vdpa)
1439 {
1440 	const struct vdpa_config_ops *ops = vdpa->config;
1441 	struct vhost_vdpa *v;
1442 	int minor;
1443 	int i, r;
1444 
1445 	/* We can't support platform IOMMU device with more than 1
1446 	 * group or as
1447 	 */
1448 	if (!ops->set_map && !ops->dma_map &&
1449 	    (vdpa->ngroups > 1 || vdpa->nas > 1))
1450 		return -EOPNOTSUPP;
1451 
1452 	v = kzalloc(sizeof(*v), GFP_KERNEL | __GFP_RETRY_MAYFAIL);
1453 	if (!v)
1454 		return -ENOMEM;
1455 
1456 	minor = ida_simple_get(&vhost_vdpa_ida, 0,
1457 			       VHOST_VDPA_DEV_MAX, GFP_KERNEL);
1458 	if (minor < 0) {
1459 		kfree(v);
1460 		return minor;
1461 	}
1462 
1463 	atomic_set(&v->opened, 0);
1464 	v->minor = minor;
1465 	v->vdpa = vdpa;
1466 	v->nvqs = vdpa->nvqs;
1467 	v->virtio_id = ops->get_device_id(vdpa);
1468 
1469 	device_initialize(&v->dev);
1470 	v->dev.release = vhost_vdpa_release_dev;
1471 	v->dev.parent = &vdpa->dev;
1472 	v->dev.devt = MKDEV(MAJOR(vhost_vdpa_major), minor);
1473 	v->vqs = kmalloc_array(v->nvqs, sizeof(struct vhost_virtqueue),
1474 			       GFP_KERNEL);
1475 	if (!v->vqs) {
1476 		r = -ENOMEM;
1477 		goto err;
1478 	}
1479 
1480 	r = dev_set_name(&v->dev, "vhost-vdpa-%u", minor);
1481 	if (r)
1482 		goto err;
1483 
1484 	cdev_init(&v->cdev, &vhost_vdpa_fops);
1485 	v->cdev.owner = THIS_MODULE;
1486 
1487 	r = cdev_device_add(&v->cdev, &v->dev);
1488 	if (r)
1489 		goto err;
1490 
1491 	init_completion(&v->completion);
1492 	vdpa_set_drvdata(vdpa, v);
1493 
1494 	for (i = 0; i < VHOST_VDPA_IOTLB_BUCKETS; i++)
1495 		INIT_HLIST_HEAD(&v->as[i]);
1496 
1497 	return 0;
1498 
1499 err:
1500 	put_device(&v->dev);
1501 	ida_simple_remove(&vhost_vdpa_ida, v->minor);
1502 	return r;
1503 }
1504 
1505 static void vhost_vdpa_remove(struct vdpa_device *vdpa)
1506 {
1507 	struct vhost_vdpa *v = vdpa_get_drvdata(vdpa);
1508 	int opened;
1509 
1510 	cdev_device_del(&v->cdev, &v->dev);
1511 
1512 	do {
1513 		opened = atomic_cmpxchg(&v->opened, 0, 1);
1514 		if (!opened)
1515 			break;
1516 		wait_for_completion(&v->completion);
1517 	} while (1);
1518 
1519 	put_device(&v->dev);
1520 }
1521 
1522 static struct vdpa_driver vhost_vdpa_driver = {
1523 	.driver = {
1524 		.name	= "vhost_vdpa",
1525 	},
1526 	.probe	= vhost_vdpa_probe,
1527 	.remove	= vhost_vdpa_remove,
1528 };
1529 
1530 static int __init vhost_vdpa_init(void)
1531 {
1532 	int r;
1533 
1534 	r = alloc_chrdev_region(&vhost_vdpa_major, 0, VHOST_VDPA_DEV_MAX,
1535 				"vhost-vdpa");
1536 	if (r)
1537 		goto err_alloc_chrdev;
1538 
1539 	r = vdpa_register_driver(&vhost_vdpa_driver);
1540 	if (r)
1541 		goto err_vdpa_register_driver;
1542 
1543 	return 0;
1544 
1545 err_vdpa_register_driver:
1546 	unregister_chrdev_region(vhost_vdpa_major, VHOST_VDPA_DEV_MAX);
1547 err_alloc_chrdev:
1548 	return r;
1549 }
1550 module_init(vhost_vdpa_init);
1551 
1552 static void __exit vhost_vdpa_exit(void)
1553 {
1554 	vdpa_unregister_driver(&vhost_vdpa_driver);
1555 	unregister_chrdev_region(vhost_vdpa_major, VHOST_VDPA_DEV_MAX);
1556 }
1557 module_exit(vhost_vdpa_exit);
1558 
1559 MODULE_VERSION("0.0.1");
1560 MODULE_LICENSE("GPL v2");
1561 MODULE_AUTHOR("Intel Corporation");
1562 MODULE_DESCRIPTION("vDPA-based vhost backend for virtio");
1563