xref: /openbmc/linux/drivers/vhost/vdpa.c (revision e5242c5f)
1 // SPDX-License-Identifier: GPL-2.0
2 /*
3  * Copyright (C) 2018-2020 Intel Corporation.
4  * Copyright (C) 2020 Red Hat, Inc.
5  *
6  * Author: Tiwei Bie <tiwei.bie@intel.com>
7  *         Jason Wang <jasowang@redhat.com>
8  *
9  * Thanks Michael S. Tsirkin for the valuable comments and
10  * suggestions.  And thanks to Cunming Liang and Zhihong Wang for all
11  * their supports.
12  */
13 
14 #include <linux/kernel.h>
15 #include <linux/module.h>
16 #include <linux/cdev.h>
17 #include <linux/device.h>
18 #include <linux/mm.h>
19 #include <linux/slab.h>
20 #include <linux/iommu.h>
21 #include <linux/uuid.h>
22 #include <linux/vdpa.h>
23 #include <linux/nospec.h>
24 #include <linux/vhost.h>
25 
26 #include "vhost.h"
27 
28 enum {
29 	VHOST_VDPA_BACKEND_FEATURES =
30 	(1ULL << VHOST_BACKEND_F_IOTLB_MSG_V2) |
31 	(1ULL << VHOST_BACKEND_F_IOTLB_BATCH) |
32 	(1ULL << VHOST_BACKEND_F_IOTLB_ASID),
33 };
34 
35 #define VHOST_VDPA_DEV_MAX (1U << MINORBITS)
36 
37 #define VHOST_VDPA_IOTLB_BUCKETS 16
38 
39 struct vhost_vdpa_as {
40 	struct hlist_node hash_link;
41 	struct vhost_iotlb iotlb;
42 	u32 id;
43 };
44 
45 struct vhost_vdpa {
46 	struct vhost_dev vdev;
47 	struct iommu_domain *domain;
48 	struct vhost_virtqueue *vqs;
49 	struct completion completion;
50 	struct vdpa_device *vdpa;
51 	struct hlist_head as[VHOST_VDPA_IOTLB_BUCKETS];
52 	struct device dev;
53 	struct cdev cdev;
54 	atomic_t opened;
55 	u32 nvqs;
56 	int virtio_id;
57 	int minor;
58 	struct eventfd_ctx *config_ctx;
59 	int in_batch;
60 	struct vdpa_iova_range range;
61 	u32 batch_asid;
62 };
63 
64 static DEFINE_IDA(vhost_vdpa_ida);
65 
66 static dev_t vhost_vdpa_major;
67 
68 static void vhost_vdpa_iotlb_unmap(struct vhost_vdpa *v,
69 				   struct vhost_iotlb *iotlb, u64 start,
70 				   u64 last, u32 asid);
71 
72 static inline u32 iotlb_to_asid(struct vhost_iotlb *iotlb)
73 {
74 	struct vhost_vdpa_as *as = container_of(iotlb, struct
75 						vhost_vdpa_as, iotlb);
76 	return as->id;
77 }
78 
79 static struct vhost_vdpa_as *asid_to_as(struct vhost_vdpa *v, u32 asid)
80 {
81 	struct hlist_head *head = &v->as[asid % VHOST_VDPA_IOTLB_BUCKETS];
82 	struct vhost_vdpa_as *as;
83 
84 	hlist_for_each_entry(as, head, hash_link)
85 		if (as->id == asid)
86 			return as;
87 
88 	return NULL;
89 }
90 
91 static struct vhost_iotlb *asid_to_iotlb(struct vhost_vdpa *v, u32 asid)
92 {
93 	struct vhost_vdpa_as *as = asid_to_as(v, asid);
94 
95 	if (!as)
96 		return NULL;
97 
98 	return &as->iotlb;
99 }
100 
101 static struct vhost_vdpa_as *vhost_vdpa_alloc_as(struct vhost_vdpa *v, u32 asid)
102 {
103 	struct hlist_head *head = &v->as[asid % VHOST_VDPA_IOTLB_BUCKETS];
104 	struct vhost_vdpa_as *as;
105 
106 	if (asid_to_as(v, asid))
107 		return NULL;
108 
109 	if (asid >= v->vdpa->nas)
110 		return NULL;
111 
112 	as = kmalloc(sizeof(*as), GFP_KERNEL);
113 	if (!as)
114 		return NULL;
115 
116 	vhost_iotlb_init(&as->iotlb, 0, 0);
117 	as->id = asid;
118 	hlist_add_head(&as->hash_link, head);
119 
120 	return as;
121 }
122 
123 static struct vhost_vdpa_as *vhost_vdpa_find_alloc_as(struct vhost_vdpa *v,
124 						      u32 asid)
125 {
126 	struct vhost_vdpa_as *as = asid_to_as(v, asid);
127 
128 	if (as)
129 		return as;
130 
131 	return vhost_vdpa_alloc_as(v, asid);
132 }
133 
134 static int vhost_vdpa_remove_as(struct vhost_vdpa *v, u32 asid)
135 {
136 	struct vhost_vdpa_as *as = asid_to_as(v, asid);
137 
138 	if (!as)
139 		return -EINVAL;
140 
141 	hlist_del(&as->hash_link);
142 	vhost_vdpa_iotlb_unmap(v, &as->iotlb, 0ULL, 0ULL - 1, asid);
143 	kfree(as);
144 
145 	return 0;
146 }
147 
148 static void handle_vq_kick(struct vhost_work *work)
149 {
150 	struct vhost_virtqueue *vq = container_of(work, struct vhost_virtqueue,
151 						  poll.work);
152 	struct vhost_vdpa *v = container_of(vq->dev, struct vhost_vdpa, vdev);
153 	const struct vdpa_config_ops *ops = v->vdpa->config;
154 
155 	ops->kick_vq(v->vdpa, vq - v->vqs);
156 }
157 
158 static irqreturn_t vhost_vdpa_virtqueue_cb(void *private)
159 {
160 	struct vhost_virtqueue *vq = private;
161 	struct eventfd_ctx *call_ctx = vq->call_ctx.ctx;
162 
163 	if (call_ctx)
164 		eventfd_signal(call_ctx, 1);
165 
166 	return IRQ_HANDLED;
167 }
168 
169 static irqreturn_t vhost_vdpa_config_cb(void *private)
170 {
171 	struct vhost_vdpa *v = private;
172 	struct eventfd_ctx *config_ctx = v->config_ctx;
173 
174 	if (config_ctx)
175 		eventfd_signal(config_ctx, 1);
176 
177 	return IRQ_HANDLED;
178 }
179 
180 static void vhost_vdpa_setup_vq_irq(struct vhost_vdpa *v, u16 qid)
181 {
182 	struct vhost_virtqueue *vq = &v->vqs[qid];
183 	const struct vdpa_config_ops *ops = v->vdpa->config;
184 	struct vdpa_device *vdpa = v->vdpa;
185 	int ret, irq;
186 
187 	if (!ops->get_vq_irq)
188 		return;
189 
190 	irq = ops->get_vq_irq(vdpa, qid);
191 	if (irq < 0)
192 		return;
193 
194 	if (!vq->call_ctx.ctx)
195 		return;
196 
197 	vq->call_ctx.producer.irq = irq;
198 	ret = irq_bypass_register_producer(&vq->call_ctx.producer);
199 	if (unlikely(ret))
200 		dev_info(&v->dev, "vq %u, irq bypass producer (token %p) registration fails, ret =  %d\n",
201 			 qid, vq->call_ctx.producer.token, ret);
202 }
203 
204 static void vhost_vdpa_unsetup_vq_irq(struct vhost_vdpa *v, u16 qid)
205 {
206 	struct vhost_virtqueue *vq = &v->vqs[qid];
207 
208 	irq_bypass_unregister_producer(&vq->call_ctx.producer);
209 }
210 
211 static int vhost_vdpa_reset(struct vhost_vdpa *v)
212 {
213 	struct vdpa_device *vdpa = v->vdpa;
214 
215 	v->in_batch = 0;
216 
217 	return vdpa_reset(vdpa);
218 }
219 
220 static long vhost_vdpa_bind_mm(struct vhost_vdpa *v)
221 {
222 	struct vdpa_device *vdpa = v->vdpa;
223 	const struct vdpa_config_ops *ops = vdpa->config;
224 
225 	if (!vdpa->use_va || !ops->bind_mm)
226 		return 0;
227 
228 	return ops->bind_mm(vdpa, v->vdev.mm);
229 }
230 
231 static void vhost_vdpa_unbind_mm(struct vhost_vdpa *v)
232 {
233 	struct vdpa_device *vdpa = v->vdpa;
234 	const struct vdpa_config_ops *ops = vdpa->config;
235 
236 	if (!vdpa->use_va || !ops->unbind_mm)
237 		return;
238 
239 	ops->unbind_mm(vdpa);
240 }
241 
242 static long vhost_vdpa_get_device_id(struct vhost_vdpa *v, u8 __user *argp)
243 {
244 	struct vdpa_device *vdpa = v->vdpa;
245 	const struct vdpa_config_ops *ops = vdpa->config;
246 	u32 device_id;
247 
248 	device_id = ops->get_device_id(vdpa);
249 
250 	if (copy_to_user(argp, &device_id, sizeof(device_id)))
251 		return -EFAULT;
252 
253 	return 0;
254 }
255 
256 static long vhost_vdpa_get_status(struct vhost_vdpa *v, u8 __user *statusp)
257 {
258 	struct vdpa_device *vdpa = v->vdpa;
259 	const struct vdpa_config_ops *ops = vdpa->config;
260 	u8 status;
261 
262 	status = ops->get_status(vdpa);
263 
264 	if (copy_to_user(statusp, &status, sizeof(status)))
265 		return -EFAULT;
266 
267 	return 0;
268 }
269 
270 static long vhost_vdpa_set_status(struct vhost_vdpa *v, u8 __user *statusp)
271 {
272 	struct vdpa_device *vdpa = v->vdpa;
273 	const struct vdpa_config_ops *ops = vdpa->config;
274 	u8 status, status_old;
275 	u32 nvqs = v->nvqs;
276 	int ret;
277 	u16 i;
278 
279 	if (copy_from_user(&status, statusp, sizeof(status)))
280 		return -EFAULT;
281 
282 	status_old = ops->get_status(vdpa);
283 
284 	/*
285 	 * Userspace shouldn't remove status bits unless reset the
286 	 * status to 0.
287 	 */
288 	if (status != 0 && (status_old & ~status) != 0)
289 		return -EINVAL;
290 
291 	if ((status_old & VIRTIO_CONFIG_S_DRIVER_OK) && !(status & VIRTIO_CONFIG_S_DRIVER_OK))
292 		for (i = 0; i < nvqs; i++)
293 			vhost_vdpa_unsetup_vq_irq(v, i);
294 
295 	if (status == 0) {
296 		ret = vdpa_reset(vdpa);
297 		if (ret)
298 			return ret;
299 	} else
300 		vdpa_set_status(vdpa, status);
301 
302 	if ((status & VIRTIO_CONFIG_S_DRIVER_OK) && !(status_old & VIRTIO_CONFIG_S_DRIVER_OK))
303 		for (i = 0; i < nvqs; i++)
304 			vhost_vdpa_setup_vq_irq(v, i);
305 
306 	return 0;
307 }
308 
309 static int vhost_vdpa_config_validate(struct vhost_vdpa *v,
310 				      struct vhost_vdpa_config *c)
311 {
312 	struct vdpa_device *vdpa = v->vdpa;
313 	size_t size = vdpa->config->get_config_size(vdpa);
314 
315 	if (c->len == 0 || c->off > size)
316 		return -EINVAL;
317 
318 	if (c->len > size - c->off)
319 		return -E2BIG;
320 
321 	return 0;
322 }
323 
324 static long vhost_vdpa_get_config(struct vhost_vdpa *v,
325 				  struct vhost_vdpa_config __user *c)
326 {
327 	struct vdpa_device *vdpa = v->vdpa;
328 	struct vhost_vdpa_config config;
329 	unsigned long size = offsetof(struct vhost_vdpa_config, buf);
330 	u8 *buf;
331 
332 	if (copy_from_user(&config, c, size))
333 		return -EFAULT;
334 	if (vhost_vdpa_config_validate(v, &config))
335 		return -EINVAL;
336 	buf = kvzalloc(config.len, GFP_KERNEL);
337 	if (!buf)
338 		return -ENOMEM;
339 
340 	vdpa_get_config(vdpa, config.off, buf, config.len);
341 
342 	if (copy_to_user(c->buf, buf, config.len)) {
343 		kvfree(buf);
344 		return -EFAULT;
345 	}
346 
347 	kvfree(buf);
348 	return 0;
349 }
350 
351 static long vhost_vdpa_set_config(struct vhost_vdpa *v,
352 				  struct vhost_vdpa_config __user *c)
353 {
354 	struct vdpa_device *vdpa = v->vdpa;
355 	struct vhost_vdpa_config config;
356 	unsigned long size = offsetof(struct vhost_vdpa_config, buf);
357 	u8 *buf;
358 
359 	if (copy_from_user(&config, c, size))
360 		return -EFAULT;
361 	if (vhost_vdpa_config_validate(v, &config))
362 		return -EINVAL;
363 
364 	buf = vmemdup_user(c->buf, config.len);
365 	if (IS_ERR(buf))
366 		return PTR_ERR(buf);
367 
368 	vdpa_set_config(vdpa, config.off, buf, config.len);
369 
370 	kvfree(buf);
371 	return 0;
372 }
373 
374 static bool vhost_vdpa_can_suspend(const struct vhost_vdpa *v)
375 {
376 	struct vdpa_device *vdpa = v->vdpa;
377 	const struct vdpa_config_ops *ops = vdpa->config;
378 
379 	return ops->suspend;
380 }
381 
382 static bool vhost_vdpa_can_resume(const struct vhost_vdpa *v)
383 {
384 	struct vdpa_device *vdpa = v->vdpa;
385 	const struct vdpa_config_ops *ops = vdpa->config;
386 
387 	return ops->resume;
388 }
389 
390 static long vhost_vdpa_get_features(struct vhost_vdpa *v, u64 __user *featurep)
391 {
392 	struct vdpa_device *vdpa = v->vdpa;
393 	const struct vdpa_config_ops *ops = vdpa->config;
394 	u64 features;
395 
396 	features = ops->get_device_features(vdpa);
397 
398 	if (copy_to_user(featurep, &features, sizeof(features)))
399 		return -EFAULT;
400 
401 	return 0;
402 }
403 
404 static u64 vhost_vdpa_get_backend_features(const struct vhost_vdpa *v)
405 {
406 	struct vdpa_device *vdpa = v->vdpa;
407 	const struct vdpa_config_ops *ops = vdpa->config;
408 
409 	if (!ops->get_backend_features)
410 		return 0;
411 	else
412 		return ops->get_backend_features(vdpa);
413 }
414 
415 static long vhost_vdpa_set_features(struct vhost_vdpa *v, u64 __user *featurep)
416 {
417 	struct vdpa_device *vdpa = v->vdpa;
418 	const struct vdpa_config_ops *ops = vdpa->config;
419 	struct vhost_dev *d = &v->vdev;
420 	u64 actual_features;
421 	u64 features;
422 	int i;
423 
424 	/*
425 	 * It's not allowed to change the features after they have
426 	 * been negotiated.
427 	 */
428 	if (ops->get_status(vdpa) & VIRTIO_CONFIG_S_FEATURES_OK)
429 		return -EBUSY;
430 
431 	if (copy_from_user(&features, featurep, sizeof(features)))
432 		return -EFAULT;
433 
434 	if (vdpa_set_features(vdpa, features))
435 		return -EINVAL;
436 
437 	/* let the vqs know what has been configured */
438 	actual_features = ops->get_driver_features(vdpa);
439 	for (i = 0; i < d->nvqs; ++i) {
440 		struct vhost_virtqueue *vq = d->vqs[i];
441 
442 		mutex_lock(&vq->mutex);
443 		vq->acked_features = actual_features;
444 		mutex_unlock(&vq->mutex);
445 	}
446 
447 	return 0;
448 }
449 
450 static long vhost_vdpa_get_vring_num(struct vhost_vdpa *v, u16 __user *argp)
451 {
452 	struct vdpa_device *vdpa = v->vdpa;
453 	const struct vdpa_config_ops *ops = vdpa->config;
454 	u16 num;
455 
456 	num = ops->get_vq_num_max(vdpa);
457 
458 	if (copy_to_user(argp, &num, sizeof(num)))
459 		return -EFAULT;
460 
461 	return 0;
462 }
463 
464 static void vhost_vdpa_config_put(struct vhost_vdpa *v)
465 {
466 	if (v->config_ctx) {
467 		eventfd_ctx_put(v->config_ctx);
468 		v->config_ctx = NULL;
469 	}
470 }
471 
472 static long vhost_vdpa_set_config_call(struct vhost_vdpa *v, u32 __user *argp)
473 {
474 	struct vdpa_callback cb;
475 	int fd;
476 	struct eventfd_ctx *ctx;
477 
478 	cb.callback = vhost_vdpa_config_cb;
479 	cb.private = v;
480 	if (copy_from_user(&fd, argp, sizeof(fd)))
481 		return  -EFAULT;
482 
483 	ctx = fd == VHOST_FILE_UNBIND ? NULL : eventfd_ctx_fdget(fd);
484 	swap(ctx, v->config_ctx);
485 
486 	if (!IS_ERR_OR_NULL(ctx))
487 		eventfd_ctx_put(ctx);
488 
489 	if (IS_ERR(v->config_ctx)) {
490 		long ret = PTR_ERR(v->config_ctx);
491 
492 		v->config_ctx = NULL;
493 		return ret;
494 	}
495 
496 	v->vdpa->config->set_config_cb(v->vdpa, &cb);
497 
498 	return 0;
499 }
500 
501 static long vhost_vdpa_get_iova_range(struct vhost_vdpa *v, u32 __user *argp)
502 {
503 	struct vhost_vdpa_iova_range range = {
504 		.first = v->range.first,
505 		.last = v->range.last,
506 	};
507 
508 	if (copy_to_user(argp, &range, sizeof(range)))
509 		return -EFAULT;
510 	return 0;
511 }
512 
513 static long vhost_vdpa_get_config_size(struct vhost_vdpa *v, u32 __user *argp)
514 {
515 	struct vdpa_device *vdpa = v->vdpa;
516 	const struct vdpa_config_ops *ops = vdpa->config;
517 	u32 size;
518 
519 	size = ops->get_config_size(vdpa);
520 
521 	if (copy_to_user(argp, &size, sizeof(size)))
522 		return -EFAULT;
523 
524 	return 0;
525 }
526 
527 static long vhost_vdpa_get_vqs_count(struct vhost_vdpa *v, u32 __user *argp)
528 {
529 	struct vdpa_device *vdpa = v->vdpa;
530 
531 	if (copy_to_user(argp, &vdpa->nvqs, sizeof(vdpa->nvqs)))
532 		return -EFAULT;
533 
534 	return 0;
535 }
536 
537 /* After a successful return of ioctl the device must not process more
538  * virtqueue descriptors. The device can answer to read or writes of config
539  * fields as if it were not suspended. In particular, writing to "queue_enable"
540  * with a value of 1 will not make the device start processing buffers.
541  */
542 static long vhost_vdpa_suspend(struct vhost_vdpa *v)
543 {
544 	struct vdpa_device *vdpa = v->vdpa;
545 	const struct vdpa_config_ops *ops = vdpa->config;
546 
547 	if (!ops->suspend)
548 		return -EOPNOTSUPP;
549 
550 	return ops->suspend(vdpa);
551 }
552 
553 /* After a successful return of this ioctl the device resumes processing
554  * virtqueue descriptors. The device becomes fully operational the same way it
555  * was before it was suspended.
556  */
557 static long vhost_vdpa_resume(struct vhost_vdpa *v)
558 {
559 	struct vdpa_device *vdpa = v->vdpa;
560 	const struct vdpa_config_ops *ops = vdpa->config;
561 
562 	if (!ops->resume)
563 		return -EOPNOTSUPP;
564 
565 	return ops->resume(vdpa);
566 }
567 
568 static long vhost_vdpa_vring_ioctl(struct vhost_vdpa *v, unsigned int cmd,
569 				   void __user *argp)
570 {
571 	struct vdpa_device *vdpa = v->vdpa;
572 	const struct vdpa_config_ops *ops = vdpa->config;
573 	struct vdpa_vq_state vq_state;
574 	struct vdpa_callback cb;
575 	struct vhost_virtqueue *vq;
576 	struct vhost_vring_state s;
577 	u32 idx;
578 	long r;
579 
580 	r = get_user(idx, (u32 __user *)argp);
581 	if (r < 0)
582 		return r;
583 
584 	if (idx >= v->nvqs)
585 		return -ENOBUFS;
586 
587 	idx = array_index_nospec(idx, v->nvqs);
588 	vq = &v->vqs[idx];
589 
590 	switch (cmd) {
591 	case VHOST_VDPA_SET_VRING_ENABLE:
592 		if (copy_from_user(&s, argp, sizeof(s)))
593 			return -EFAULT;
594 		ops->set_vq_ready(vdpa, idx, s.num);
595 		return 0;
596 	case VHOST_VDPA_GET_VRING_GROUP:
597 		if (!ops->get_vq_group)
598 			return -EOPNOTSUPP;
599 		s.index = idx;
600 		s.num = ops->get_vq_group(vdpa, idx);
601 		if (s.num >= vdpa->ngroups)
602 			return -EIO;
603 		else if (copy_to_user(argp, &s, sizeof(s)))
604 			return -EFAULT;
605 		return 0;
606 	case VHOST_VDPA_SET_GROUP_ASID:
607 		if (copy_from_user(&s, argp, sizeof(s)))
608 			return -EFAULT;
609 		if (s.num >= vdpa->nas)
610 			return -EINVAL;
611 		if (!ops->set_group_asid)
612 			return -EOPNOTSUPP;
613 		return ops->set_group_asid(vdpa, idx, s.num);
614 	case VHOST_GET_VRING_BASE:
615 		r = ops->get_vq_state(v->vdpa, idx, &vq_state);
616 		if (r)
617 			return r;
618 
619 		if (vhost_has_feature(vq, VIRTIO_F_RING_PACKED)) {
620 			vq->last_avail_idx = vq_state.packed.last_avail_idx |
621 					     (vq_state.packed.last_avail_counter << 15);
622 			vq->last_used_idx = vq_state.packed.last_used_idx |
623 					    (vq_state.packed.last_used_counter << 15);
624 		} else {
625 			vq->last_avail_idx = vq_state.split.avail_index;
626 		}
627 		break;
628 	case VHOST_SET_VRING_CALL:
629 		if (vq->call_ctx.ctx) {
630 			if (ops->get_status(vdpa) &
631 			    VIRTIO_CONFIG_S_DRIVER_OK)
632 				vhost_vdpa_unsetup_vq_irq(v, idx);
633 			vq->call_ctx.producer.token = NULL;
634 		}
635 		break;
636 	}
637 
638 	r = vhost_vring_ioctl(&v->vdev, cmd, argp);
639 	if (r)
640 		return r;
641 
642 	switch (cmd) {
643 	case VHOST_SET_VRING_ADDR:
644 		if (ops->set_vq_address(vdpa, idx,
645 					(u64)(uintptr_t)vq->desc,
646 					(u64)(uintptr_t)vq->avail,
647 					(u64)(uintptr_t)vq->used))
648 			r = -EINVAL;
649 		break;
650 
651 	case VHOST_SET_VRING_BASE:
652 		if (vhost_has_feature(vq, VIRTIO_F_RING_PACKED)) {
653 			vq_state.packed.last_avail_idx = vq->last_avail_idx & 0x7fff;
654 			vq_state.packed.last_avail_counter = !!(vq->last_avail_idx & 0x8000);
655 			vq_state.packed.last_used_idx = vq->last_used_idx & 0x7fff;
656 			vq_state.packed.last_used_counter = !!(vq->last_used_idx & 0x8000);
657 		} else {
658 			vq_state.split.avail_index = vq->last_avail_idx;
659 		}
660 		r = ops->set_vq_state(vdpa, idx, &vq_state);
661 		break;
662 
663 	case VHOST_SET_VRING_CALL:
664 		if (vq->call_ctx.ctx) {
665 			cb.callback = vhost_vdpa_virtqueue_cb;
666 			cb.private = vq;
667 			cb.trigger = vq->call_ctx.ctx;
668 			vq->call_ctx.producer.token = vq->call_ctx.ctx;
669 			if (ops->get_status(vdpa) &
670 			    VIRTIO_CONFIG_S_DRIVER_OK)
671 				vhost_vdpa_setup_vq_irq(v, idx);
672 		} else {
673 			cb.callback = NULL;
674 			cb.private = NULL;
675 			cb.trigger = NULL;
676 		}
677 		ops->set_vq_cb(vdpa, idx, &cb);
678 		break;
679 
680 	case VHOST_SET_VRING_NUM:
681 		ops->set_vq_num(vdpa, idx, vq->num);
682 		break;
683 	}
684 
685 	return r;
686 }
687 
688 static long vhost_vdpa_unlocked_ioctl(struct file *filep,
689 				      unsigned int cmd, unsigned long arg)
690 {
691 	struct vhost_vdpa *v = filep->private_data;
692 	struct vhost_dev *d = &v->vdev;
693 	void __user *argp = (void __user *)arg;
694 	u64 __user *featurep = argp;
695 	u64 features;
696 	long r = 0;
697 
698 	if (cmd == VHOST_SET_BACKEND_FEATURES) {
699 		if (copy_from_user(&features, featurep, sizeof(features)))
700 			return -EFAULT;
701 		if (features & ~(VHOST_VDPA_BACKEND_FEATURES |
702 				 BIT_ULL(VHOST_BACKEND_F_SUSPEND) |
703 				 BIT_ULL(VHOST_BACKEND_F_RESUME) |
704 				 BIT_ULL(VHOST_BACKEND_F_ENABLE_AFTER_DRIVER_OK)))
705 			return -EOPNOTSUPP;
706 		if ((features & BIT_ULL(VHOST_BACKEND_F_SUSPEND)) &&
707 		     !vhost_vdpa_can_suspend(v))
708 			return -EOPNOTSUPP;
709 		if ((features & BIT_ULL(VHOST_BACKEND_F_RESUME)) &&
710 		     !vhost_vdpa_can_resume(v))
711 			return -EOPNOTSUPP;
712 		vhost_set_backend_features(&v->vdev, features);
713 		return 0;
714 	}
715 
716 	mutex_lock(&d->mutex);
717 
718 	switch (cmd) {
719 	case VHOST_VDPA_GET_DEVICE_ID:
720 		r = vhost_vdpa_get_device_id(v, argp);
721 		break;
722 	case VHOST_VDPA_GET_STATUS:
723 		r = vhost_vdpa_get_status(v, argp);
724 		break;
725 	case VHOST_VDPA_SET_STATUS:
726 		r = vhost_vdpa_set_status(v, argp);
727 		break;
728 	case VHOST_VDPA_GET_CONFIG:
729 		r = vhost_vdpa_get_config(v, argp);
730 		break;
731 	case VHOST_VDPA_SET_CONFIG:
732 		r = vhost_vdpa_set_config(v, argp);
733 		break;
734 	case VHOST_GET_FEATURES:
735 		r = vhost_vdpa_get_features(v, argp);
736 		break;
737 	case VHOST_SET_FEATURES:
738 		r = vhost_vdpa_set_features(v, argp);
739 		break;
740 	case VHOST_VDPA_GET_VRING_NUM:
741 		r = vhost_vdpa_get_vring_num(v, argp);
742 		break;
743 	case VHOST_VDPA_GET_GROUP_NUM:
744 		if (copy_to_user(argp, &v->vdpa->ngroups,
745 				 sizeof(v->vdpa->ngroups)))
746 			r = -EFAULT;
747 		break;
748 	case VHOST_VDPA_GET_AS_NUM:
749 		if (copy_to_user(argp, &v->vdpa->nas, sizeof(v->vdpa->nas)))
750 			r = -EFAULT;
751 		break;
752 	case VHOST_SET_LOG_BASE:
753 	case VHOST_SET_LOG_FD:
754 		r = -ENOIOCTLCMD;
755 		break;
756 	case VHOST_VDPA_SET_CONFIG_CALL:
757 		r = vhost_vdpa_set_config_call(v, argp);
758 		break;
759 	case VHOST_GET_BACKEND_FEATURES:
760 		features = VHOST_VDPA_BACKEND_FEATURES;
761 		if (vhost_vdpa_can_suspend(v))
762 			features |= BIT_ULL(VHOST_BACKEND_F_SUSPEND);
763 		if (vhost_vdpa_can_resume(v))
764 			features |= BIT_ULL(VHOST_BACKEND_F_RESUME);
765 		features |= vhost_vdpa_get_backend_features(v);
766 		if (copy_to_user(featurep, &features, sizeof(features)))
767 			r = -EFAULT;
768 		break;
769 	case VHOST_VDPA_GET_IOVA_RANGE:
770 		r = vhost_vdpa_get_iova_range(v, argp);
771 		break;
772 	case VHOST_VDPA_GET_CONFIG_SIZE:
773 		r = vhost_vdpa_get_config_size(v, argp);
774 		break;
775 	case VHOST_VDPA_GET_VQS_COUNT:
776 		r = vhost_vdpa_get_vqs_count(v, argp);
777 		break;
778 	case VHOST_VDPA_SUSPEND:
779 		r = vhost_vdpa_suspend(v);
780 		break;
781 	case VHOST_VDPA_RESUME:
782 		r = vhost_vdpa_resume(v);
783 		break;
784 	default:
785 		r = vhost_dev_ioctl(&v->vdev, cmd, argp);
786 		if (r == -ENOIOCTLCMD)
787 			r = vhost_vdpa_vring_ioctl(v, cmd, argp);
788 		break;
789 	}
790 
791 	if (r)
792 		goto out;
793 
794 	switch (cmd) {
795 	case VHOST_SET_OWNER:
796 		r = vhost_vdpa_bind_mm(v);
797 		if (r)
798 			vhost_dev_reset_owner(d, NULL);
799 		break;
800 	}
801 out:
802 	mutex_unlock(&d->mutex);
803 	return r;
804 }
805 static void vhost_vdpa_general_unmap(struct vhost_vdpa *v,
806 				     struct vhost_iotlb_map *map, u32 asid)
807 {
808 	struct vdpa_device *vdpa = v->vdpa;
809 	const struct vdpa_config_ops *ops = vdpa->config;
810 	if (ops->dma_map) {
811 		ops->dma_unmap(vdpa, asid, map->start, map->size);
812 	} else if (ops->set_map == NULL) {
813 		iommu_unmap(v->domain, map->start, map->size);
814 	}
815 }
816 
817 static void vhost_vdpa_pa_unmap(struct vhost_vdpa *v, struct vhost_iotlb *iotlb,
818 				u64 start, u64 last, u32 asid)
819 {
820 	struct vhost_dev *dev = &v->vdev;
821 	struct vhost_iotlb_map *map;
822 	struct page *page;
823 	unsigned long pfn, pinned;
824 
825 	while ((map = vhost_iotlb_itree_first(iotlb, start, last)) != NULL) {
826 		pinned = PFN_DOWN(map->size);
827 		for (pfn = PFN_DOWN(map->addr);
828 		     pinned > 0; pfn++, pinned--) {
829 			page = pfn_to_page(pfn);
830 			if (map->perm & VHOST_ACCESS_WO)
831 				set_page_dirty_lock(page);
832 			unpin_user_page(page);
833 		}
834 		atomic64_sub(PFN_DOWN(map->size), &dev->mm->pinned_vm);
835 		vhost_vdpa_general_unmap(v, map, asid);
836 		vhost_iotlb_map_free(iotlb, map);
837 	}
838 }
839 
840 static void vhost_vdpa_va_unmap(struct vhost_vdpa *v, struct vhost_iotlb *iotlb,
841 				u64 start, u64 last, u32 asid)
842 {
843 	struct vhost_iotlb_map *map;
844 	struct vdpa_map_file *map_file;
845 
846 	while ((map = vhost_iotlb_itree_first(iotlb, start, last)) != NULL) {
847 		map_file = (struct vdpa_map_file *)map->opaque;
848 		fput(map_file->file);
849 		kfree(map_file);
850 		vhost_vdpa_general_unmap(v, map, asid);
851 		vhost_iotlb_map_free(iotlb, map);
852 	}
853 }
854 
855 static void vhost_vdpa_iotlb_unmap(struct vhost_vdpa *v,
856 				   struct vhost_iotlb *iotlb, u64 start,
857 				   u64 last, u32 asid)
858 {
859 	struct vdpa_device *vdpa = v->vdpa;
860 
861 	if (vdpa->use_va)
862 		return vhost_vdpa_va_unmap(v, iotlb, start, last, asid);
863 
864 	return vhost_vdpa_pa_unmap(v, iotlb, start, last, asid);
865 }
866 
867 static int perm_to_iommu_flags(u32 perm)
868 {
869 	int flags = 0;
870 
871 	switch (perm) {
872 	case VHOST_ACCESS_WO:
873 		flags |= IOMMU_WRITE;
874 		break;
875 	case VHOST_ACCESS_RO:
876 		flags |= IOMMU_READ;
877 		break;
878 	case VHOST_ACCESS_RW:
879 		flags |= (IOMMU_WRITE | IOMMU_READ);
880 		break;
881 	default:
882 		WARN(1, "invalidate vhost IOTLB permission\n");
883 		break;
884 	}
885 
886 	return flags | IOMMU_CACHE;
887 }
888 
889 static int vhost_vdpa_map(struct vhost_vdpa *v, struct vhost_iotlb *iotlb,
890 			  u64 iova, u64 size, u64 pa, u32 perm, void *opaque)
891 {
892 	struct vhost_dev *dev = &v->vdev;
893 	struct vdpa_device *vdpa = v->vdpa;
894 	const struct vdpa_config_ops *ops = vdpa->config;
895 	u32 asid = iotlb_to_asid(iotlb);
896 	int r = 0;
897 
898 	r = vhost_iotlb_add_range_ctx(iotlb, iova, iova + size - 1,
899 				      pa, perm, opaque);
900 	if (r)
901 		return r;
902 
903 	if (ops->dma_map) {
904 		r = ops->dma_map(vdpa, asid, iova, size, pa, perm, opaque);
905 	} else if (ops->set_map) {
906 		if (!v->in_batch)
907 			r = ops->set_map(vdpa, asid, iotlb);
908 	} else {
909 		r = iommu_map(v->domain, iova, pa, size,
910 			      perm_to_iommu_flags(perm), GFP_KERNEL);
911 	}
912 	if (r) {
913 		vhost_iotlb_del_range(iotlb, iova, iova + size - 1);
914 		return r;
915 	}
916 
917 	if (!vdpa->use_va)
918 		atomic64_add(PFN_DOWN(size), &dev->mm->pinned_vm);
919 
920 	return 0;
921 }
922 
923 static void vhost_vdpa_unmap(struct vhost_vdpa *v,
924 			     struct vhost_iotlb *iotlb,
925 			     u64 iova, u64 size)
926 {
927 	struct vdpa_device *vdpa = v->vdpa;
928 	const struct vdpa_config_ops *ops = vdpa->config;
929 	u32 asid = iotlb_to_asid(iotlb);
930 
931 	vhost_vdpa_iotlb_unmap(v, iotlb, iova, iova + size - 1, asid);
932 
933 	if (ops->set_map) {
934 		if (!v->in_batch)
935 			ops->set_map(vdpa, asid, iotlb);
936 	}
937 
938 }
939 
940 static int vhost_vdpa_va_map(struct vhost_vdpa *v,
941 			     struct vhost_iotlb *iotlb,
942 			     u64 iova, u64 size, u64 uaddr, u32 perm)
943 {
944 	struct vhost_dev *dev = &v->vdev;
945 	u64 offset, map_size, map_iova = iova;
946 	struct vdpa_map_file *map_file;
947 	struct vm_area_struct *vma;
948 	int ret = 0;
949 
950 	mmap_read_lock(dev->mm);
951 
952 	while (size) {
953 		vma = find_vma(dev->mm, uaddr);
954 		if (!vma) {
955 			ret = -EINVAL;
956 			break;
957 		}
958 		map_size = min(size, vma->vm_end - uaddr);
959 		if (!(vma->vm_file && (vma->vm_flags & VM_SHARED) &&
960 			!(vma->vm_flags & (VM_IO | VM_PFNMAP))))
961 			goto next;
962 
963 		map_file = kzalloc(sizeof(*map_file), GFP_KERNEL);
964 		if (!map_file) {
965 			ret = -ENOMEM;
966 			break;
967 		}
968 		offset = (vma->vm_pgoff << PAGE_SHIFT) + uaddr - vma->vm_start;
969 		map_file->offset = offset;
970 		map_file->file = get_file(vma->vm_file);
971 		ret = vhost_vdpa_map(v, iotlb, map_iova, map_size, uaddr,
972 				     perm, map_file);
973 		if (ret) {
974 			fput(map_file->file);
975 			kfree(map_file);
976 			break;
977 		}
978 next:
979 		size -= map_size;
980 		uaddr += map_size;
981 		map_iova += map_size;
982 	}
983 	if (ret)
984 		vhost_vdpa_unmap(v, iotlb, iova, map_iova - iova);
985 
986 	mmap_read_unlock(dev->mm);
987 
988 	return ret;
989 }
990 
991 static int vhost_vdpa_pa_map(struct vhost_vdpa *v,
992 			     struct vhost_iotlb *iotlb,
993 			     u64 iova, u64 size, u64 uaddr, u32 perm)
994 {
995 	struct vhost_dev *dev = &v->vdev;
996 	struct page **page_list;
997 	unsigned long list_size = PAGE_SIZE / sizeof(struct page *);
998 	unsigned int gup_flags = FOLL_LONGTERM;
999 	unsigned long npages, cur_base, map_pfn, last_pfn = 0;
1000 	unsigned long lock_limit, sz2pin, nchunks, i;
1001 	u64 start = iova;
1002 	long pinned;
1003 	int ret = 0;
1004 
1005 	/* Limit the use of memory for bookkeeping */
1006 	page_list = (struct page **) __get_free_page(GFP_KERNEL);
1007 	if (!page_list)
1008 		return -ENOMEM;
1009 
1010 	if (perm & VHOST_ACCESS_WO)
1011 		gup_flags |= FOLL_WRITE;
1012 
1013 	npages = PFN_UP(size + (iova & ~PAGE_MASK));
1014 	if (!npages) {
1015 		ret = -EINVAL;
1016 		goto free;
1017 	}
1018 
1019 	mmap_read_lock(dev->mm);
1020 
1021 	lock_limit = PFN_DOWN(rlimit(RLIMIT_MEMLOCK));
1022 	if (npages + atomic64_read(&dev->mm->pinned_vm) > lock_limit) {
1023 		ret = -ENOMEM;
1024 		goto unlock;
1025 	}
1026 
1027 	cur_base = uaddr & PAGE_MASK;
1028 	iova &= PAGE_MASK;
1029 	nchunks = 0;
1030 
1031 	while (npages) {
1032 		sz2pin = min_t(unsigned long, npages, list_size);
1033 		pinned = pin_user_pages(cur_base, sz2pin,
1034 					gup_flags, page_list);
1035 		if (sz2pin != pinned) {
1036 			if (pinned < 0) {
1037 				ret = pinned;
1038 			} else {
1039 				unpin_user_pages(page_list, pinned);
1040 				ret = -ENOMEM;
1041 			}
1042 			goto out;
1043 		}
1044 		nchunks++;
1045 
1046 		if (!last_pfn)
1047 			map_pfn = page_to_pfn(page_list[0]);
1048 
1049 		for (i = 0; i < pinned; i++) {
1050 			unsigned long this_pfn = page_to_pfn(page_list[i]);
1051 			u64 csize;
1052 
1053 			if (last_pfn && (this_pfn != last_pfn + 1)) {
1054 				/* Pin a contiguous chunk of memory */
1055 				csize = PFN_PHYS(last_pfn - map_pfn + 1);
1056 				ret = vhost_vdpa_map(v, iotlb, iova, csize,
1057 						     PFN_PHYS(map_pfn),
1058 						     perm, NULL);
1059 				if (ret) {
1060 					/*
1061 					 * Unpin the pages that are left unmapped
1062 					 * from this point on in the current
1063 					 * page_list. The remaining outstanding
1064 					 * ones which may stride across several
1065 					 * chunks will be covered in the common
1066 					 * error path subsequently.
1067 					 */
1068 					unpin_user_pages(&page_list[i],
1069 							 pinned - i);
1070 					goto out;
1071 				}
1072 
1073 				map_pfn = this_pfn;
1074 				iova += csize;
1075 				nchunks = 0;
1076 			}
1077 
1078 			last_pfn = this_pfn;
1079 		}
1080 
1081 		cur_base += PFN_PHYS(pinned);
1082 		npages -= pinned;
1083 	}
1084 
1085 	/* Pin the rest chunk */
1086 	ret = vhost_vdpa_map(v, iotlb, iova, PFN_PHYS(last_pfn - map_pfn + 1),
1087 			     PFN_PHYS(map_pfn), perm, NULL);
1088 out:
1089 	if (ret) {
1090 		if (nchunks) {
1091 			unsigned long pfn;
1092 
1093 			/*
1094 			 * Unpin the outstanding pages which are yet to be
1095 			 * mapped but haven't due to vdpa_map() or
1096 			 * pin_user_pages() failure.
1097 			 *
1098 			 * Mapped pages are accounted in vdpa_map(), hence
1099 			 * the corresponding unpinning will be handled by
1100 			 * vdpa_unmap().
1101 			 */
1102 			WARN_ON(!last_pfn);
1103 			for (pfn = map_pfn; pfn <= last_pfn; pfn++)
1104 				unpin_user_page(pfn_to_page(pfn));
1105 		}
1106 		vhost_vdpa_unmap(v, iotlb, start, size);
1107 	}
1108 unlock:
1109 	mmap_read_unlock(dev->mm);
1110 free:
1111 	free_page((unsigned long)page_list);
1112 	return ret;
1113 
1114 }
1115 
1116 static int vhost_vdpa_process_iotlb_update(struct vhost_vdpa *v,
1117 					   struct vhost_iotlb *iotlb,
1118 					   struct vhost_iotlb_msg *msg)
1119 {
1120 	struct vdpa_device *vdpa = v->vdpa;
1121 
1122 	if (msg->iova < v->range.first || !msg->size ||
1123 	    msg->iova > U64_MAX - msg->size + 1 ||
1124 	    msg->iova + msg->size - 1 > v->range.last)
1125 		return -EINVAL;
1126 
1127 	if (vhost_iotlb_itree_first(iotlb, msg->iova,
1128 				    msg->iova + msg->size - 1))
1129 		return -EEXIST;
1130 
1131 	if (vdpa->use_va)
1132 		return vhost_vdpa_va_map(v, iotlb, msg->iova, msg->size,
1133 					 msg->uaddr, msg->perm);
1134 
1135 	return vhost_vdpa_pa_map(v, iotlb, msg->iova, msg->size, msg->uaddr,
1136 				 msg->perm);
1137 }
1138 
1139 static int vhost_vdpa_process_iotlb_msg(struct vhost_dev *dev, u32 asid,
1140 					struct vhost_iotlb_msg *msg)
1141 {
1142 	struct vhost_vdpa *v = container_of(dev, struct vhost_vdpa, vdev);
1143 	struct vdpa_device *vdpa = v->vdpa;
1144 	const struct vdpa_config_ops *ops = vdpa->config;
1145 	struct vhost_iotlb *iotlb = NULL;
1146 	struct vhost_vdpa_as *as = NULL;
1147 	int r = 0;
1148 
1149 	mutex_lock(&dev->mutex);
1150 
1151 	r = vhost_dev_check_owner(dev);
1152 	if (r)
1153 		goto unlock;
1154 
1155 	if (msg->type == VHOST_IOTLB_UPDATE ||
1156 	    msg->type == VHOST_IOTLB_BATCH_BEGIN) {
1157 		as = vhost_vdpa_find_alloc_as(v, asid);
1158 		if (!as) {
1159 			dev_err(&v->dev, "can't find and alloc asid %d\n",
1160 				asid);
1161 			r = -EINVAL;
1162 			goto unlock;
1163 		}
1164 		iotlb = &as->iotlb;
1165 	} else
1166 		iotlb = asid_to_iotlb(v, asid);
1167 
1168 	if ((v->in_batch && v->batch_asid != asid) || !iotlb) {
1169 		if (v->in_batch && v->batch_asid != asid) {
1170 			dev_info(&v->dev, "batch id %d asid %d\n",
1171 				 v->batch_asid, asid);
1172 		}
1173 		if (!iotlb)
1174 			dev_err(&v->dev, "no iotlb for asid %d\n", asid);
1175 		r = -EINVAL;
1176 		goto unlock;
1177 	}
1178 
1179 	switch (msg->type) {
1180 	case VHOST_IOTLB_UPDATE:
1181 		r = vhost_vdpa_process_iotlb_update(v, iotlb, msg);
1182 		break;
1183 	case VHOST_IOTLB_INVALIDATE:
1184 		vhost_vdpa_unmap(v, iotlb, msg->iova, msg->size);
1185 		break;
1186 	case VHOST_IOTLB_BATCH_BEGIN:
1187 		v->batch_asid = asid;
1188 		v->in_batch = true;
1189 		break;
1190 	case VHOST_IOTLB_BATCH_END:
1191 		if (v->in_batch && ops->set_map)
1192 			ops->set_map(vdpa, asid, iotlb);
1193 		v->in_batch = false;
1194 		break;
1195 	default:
1196 		r = -EINVAL;
1197 		break;
1198 	}
1199 unlock:
1200 	mutex_unlock(&dev->mutex);
1201 
1202 	return r;
1203 }
1204 
1205 static ssize_t vhost_vdpa_chr_write_iter(struct kiocb *iocb,
1206 					 struct iov_iter *from)
1207 {
1208 	struct file *file = iocb->ki_filp;
1209 	struct vhost_vdpa *v = file->private_data;
1210 	struct vhost_dev *dev = &v->vdev;
1211 
1212 	return vhost_chr_write_iter(dev, from);
1213 }
1214 
1215 static int vhost_vdpa_alloc_domain(struct vhost_vdpa *v)
1216 {
1217 	struct vdpa_device *vdpa = v->vdpa;
1218 	const struct vdpa_config_ops *ops = vdpa->config;
1219 	struct device *dma_dev = vdpa_get_dma_dev(vdpa);
1220 	const struct bus_type *bus;
1221 	int ret;
1222 
1223 	/* Device want to do DMA by itself */
1224 	if (ops->set_map || ops->dma_map)
1225 		return 0;
1226 
1227 	bus = dma_dev->bus;
1228 	if (!bus)
1229 		return -EFAULT;
1230 
1231 	if (!device_iommu_capable(dma_dev, IOMMU_CAP_CACHE_COHERENCY)) {
1232 		dev_warn_once(&v->dev,
1233 			      "Failed to allocate domain, device is not IOMMU cache coherent capable\n");
1234 		return -ENOTSUPP;
1235 	}
1236 
1237 	v->domain = iommu_domain_alloc(bus);
1238 	if (!v->domain)
1239 		return -EIO;
1240 
1241 	ret = iommu_attach_device(v->domain, dma_dev);
1242 	if (ret)
1243 		goto err_attach;
1244 
1245 	return 0;
1246 
1247 err_attach:
1248 	iommu_domain_free(v->domain);
1249 	v->domain = NULL;
1250 	return ret;
1251 }
1252 
1253 static void vhost_vdpa_free_domain(struct vhost_vdpa *v)
1254 {
1255 	struct vdpa_device *vdpa = v->vdpa;
1256 	struct device *dma_dev = vdpa_get_dma_dev(vdpa);
1257 
1258 	if (v->domain) {
1259 		iommu_detach_device(v->domain, dma_dev);
1260 		iommu_domain_free(v->domain);
1261 	}
1262 
1263 	v->domain = NULL;
1264 }
1265 
1266 static void vhost_vdpa_set_iova_range(struct vhost_vdpa *v)
1267 {
1268 	struct vdpa_iova_range *range = &v->range;
1269 	struct vdpa_device *vdpa = v->vdpa;
1270 	const struct vdpa_config_ops *ops = vdpa->config;
1271 
1272 	if (ops->get_iova_range) {
1273 		*range = ops->get_iova_range(vdpa);
1274 	} else if (v->domain && v->domain->geometry.force_aperture) {
1275 		range->first = v->domain->geometry.aperture_start;
1276 		range->last = v->domain->geometry.aperture_end;
1277 	} else {
1278 		range->first = 0;
1279 		range->last = ULLONG_MAX;
1280 	}
1281 }
1282 
1283 static void vhost_vdpa_cleanup(struct vhost_vdpa *v)
1284 {
1285 	struct vhost_vdpa_as *as;
1286 	u32 asid;
1287 
1288 	for (asid = 0; asid < v->vdpa->nas; asid++) {
1289 		as = asid_to_as(v, asid);
1290 		if (as)
1291 			vhost_vdpa_remove_as(v, asid);
1292 	}
1293 
1294 	vhost_vdpa_free_domain(v);
1295 	vhost_dev_cleanup(&v->vdev);
1296 	kfree(v->vdev.vqs);
1297 }
1298 
1299 static int vhost_vdpa_open(struct inode *inode, struct file *filep)
1300 {
1301 	struct vhost_vdpa *v;
1302 	struct vhost_dev *dev;
1303 	struct vhost_virtqueue **vqs;
1304 	int r, opened;
1305 	u32 i, nvqs;
1306 
1307 	v = container_of(inode->i_cdev, struct vhost_vdpa, cdev);
1308 
1309 	opened = atomic_cmpxchg(&v->opened, 0, 1);
1310 	if (opened)
1311 		return -EBUSY;
1312 
1313 	nvqs = v->nvqs;
1314 	r = vhost_vdpa_reset(v);
1315 	if (r)
1316 		goto err;
1317 
1318 	vqs = kmalloc_array(nvqs, sizeof(*vqs), GFP_KERNEL);
1319 	if (!vqs) {
1320 		r = -ENOMEM;
1321 		goto err;
1322 	}
1323 
1324 	dev = &v->vdev;
1325 	for (i = 0; i < nvqs; i++) {
1326 		vqs[i] = &v->vqs[i];
1327 		vqs[i]->handle_kick = handle_vq_kick;
1328 		vqs[i]->call_ctx.ctx = NULL;
1329 	}
1330 	vhost_dev_init(dev, vqs, nvqs, 0, 0, 0, false,
1331 		       vhost_vdpa_process_iotlb_msg);
1332 
1333 	r = vhost_vdpa_alloc_domain(v);
1334 	if (r)
1335 		goto err_alloc_domain;
1336 
1337 	vhost_vdpa_set_iova_range(v);
1338 
1339 	filep->private_data = v;
1340 
1341 	return 0;
1342 
1343 err_alloc_domain:
1344 	vhost_vdpa_cleanup(v);
1345 err:
1346 	atomic_dec(&v->opened);
1347 	return r;
1348 }
1349 
1350 static void vhost_vdpa_clean_irq(struct vhost_vdpa *v)
1351 {
1352 	u32 i;
1353 
1354 	for (i = 0; i < v->nvqs; i++)
1355 		vhost_vdpa_unsetup_vq_irq(v, i);
1356 }
1357 
1358 static int vhost_vdpa_release(struct inode *inode, struct file *filep)
1359 {
1360 	struct vhost_vdpa *v = filep->private_data;
1361 	struct vhost_dev *d = &v->vdev;
1362 
1363 	mutex_lock(&d->mutex);
1364 	filep->private_data = NULL;
1365 	vhost_vdpa_clean_irq(v);
1366 	vhost_vdpa_reset(v);
1367 	vhost_dev_stop(&v->vdev);
1368 	vhost_vdpa_unbind_mm(v);
1369 	vhost_vdpa_config_put(v);
1370 	vhost_vdpa_cleanup(v);
1371 	mutex_unlock(&d->mutex);
1372 
1373 	atomic_dec(&v->opened);
1374 	complete(&v->completion);
1375 
1376 	return 0;
1377 }
1378 
1379 #ifdef CONFIG_MMU
1380 static vm_fault_t vhost_vdpa_fault(struct vm_fault *vmf)
1381 {
1382 	struct vhost_vdpa *v = vmf->vma->vm_file->private_data;
1383 	struct vdpa_device *vdpa = v->vdpa;
1384 	const struct vdpa_config_ops *ops = vdpa->config;
1385 	struct vdpa_notification_area notify;
1386 	struct vm_area_struct *vma = vmf->vma;
1387 	u16 index = vma->vm_pgoff;
1388 
1389 	notify = ops->get_vq_notification(vdpa, index);
1390 
1391 	return vmf_insert_pfn(vma, vmf->address & PAGE_MASK, PFN_DOWN(notify.addr));
1392 }
1393 
1394 static const struct vm_operations_struct vhost_vdpa_vm_ops = {
1395 	.fault = vhost_vdpa_fault,
1396 };
1397 
1398 static int vhost_vdpa_mmap(struct file *file, struct vm_area_struct *vma)
1399 {
1400 	struct vhost_vdpa *v = vma->vm_file->private_data;
1401 	struct vdpa_device *vdpa = v->vdpa;
1402 	const struct vdpa_config_ops *ops = vdpa->config;
1403 	struct vdpa_notification_area notify;
1404 	unsigned long index = vma->vm_pgoff;
1405 
1406 	if (vma->vm_end - vma->vm_start != PAGE_SIZE)
1407 		return -EINVAL;
1408 	if ((vma->vm_flags & VM_SHARED) == 0)
1409 		return -EINVAL;
1410 	if (vma->vm_flags & VM_READ)
1411 		return -EINVAL;
1412 	if (index > 65535)
1413 		return -EINVAL;
1414 	if (!ops->get_vq_notification)
1415 		return -ENOTSUPP;
1416 
1417 	/* To be safe and easily modelled by userspace, We only
1418 	 * support the doorbell which sits on the page boundary and
1419 	 * does not share the page with other registers.
1420 	 */
1421 	notify = ops->get_vq_notification(vdpa, index);
1422 	if (notify.addr & (PAGE_SIZE - 1))
1423 		return -EINVAL;
1424 	if (vma->vm_end - vma->vm_start != notify.size)
1425 		return -ENOTSUPP;
1426 
1427 	vm_flags_set(vma, VM_IO | VM_PFNMAP | VM_DONTEXPAND | VM_DONTDUMP);
1428 	vma->vm_ops = &vhost_vdpa_vm_ops;
1429 	return 0;
1430 }
1431 #endif /* CONFIG_MMU */
1432 
1433 static const struct file_operations vhost_vdpa_fops = {
1434 	.owner		= THIS_MODULE,
1435 	.open		= vhost_vdpa_open,
1436 	.release	= vhost_vdpa_release,
1437 	.write_iter	= vhost_vdpa_chr_write_iter,
1438 	.unlocked_ioctl	= vhost_vdpa_unlocked_ioctl,
1439 #ifdef CONFIG_MMU
1440 	.mmap		= vhost_vdpa_mmap,
1441 #endif /* CONFIG_MMU */
1442 	.compat_ioctl	= compat_ptr_ioctl,
1443 };
1444 
1445 static void vhost_vdpa_release_dev(struct device *device)
1446 {
1447 	struct vhost_vdpa *v =
1448 	       container_of(device, struct vhost_vdpa, dev);
1449 
1450 	ida_simple_remove(&vhost_vdpa_ida, v->minor);
1451 	kfree(v->vqs);
1452 	kfree(v);
1453 }
1454 
1455 static int vhost_vdpa_probe(struct vdpa_device *vdpa)
1456 {
1457 	const struct vdpa_config_ops *ops = vdpa->config;
1458 	struct vhost_vdpa *v;
1459 	int minor;
1460 	int i, r;
1461 
1462 	/* We can't support platform IOMMU device with more than 1
1463 	 * group or as
1464 	 */
1465 	if (!ops->set_map && !ops->dma_map &&
1466 	    (vdpa->ngroups > 1 || vdpa->nas > 1))
1467 		return -EOPNOTSUPP;
1468 
1469 	v = kzalloc(sizeof(*v), GFP_KERNEL | __GFP_RETRY_MAYFAIL);
1470 	if (!v)
1471 		return -ENOMEM;
1472 
1473 	minor = ida_simple_get(&vhost_vdpa_ida, 0,
1474 			       VHOST_VDPA_DEV_MAX, GFP_KERNEL);
1475 	if (minor < 0) {
1476 		kfree(v);
1477 		return minor;
1478 	}
1479 
1480 	atomic_set(&v->opened, 0);
1481 	v->minor = minor;
1482 	v->vdpa = vdpa;
1483 	v->nvqs = vdpa->nvqs;
1484 	v->virtio_id = ops->get_device_id(vdpa);
1485 
1486 	device_initialize(&v->dev);
1487 	v->dev.release = vhost_vdpa_release_dev;
1488 	v->dev.parent = &vdpa->dev;
1489 	v->dev.devt = MKDEV(MAJOR(vhost_vdpa_major), minor);
1490 	v->vqs = kmalloc_array(v->nvqs, sizeof(struct vhost_virtqueue),
1491 			       GFP_KERNEL);
1492 	if (!v->vqs) {
1493 		r = -ENOMEM;
1494 		goto err;
1495 	}
1496 
1497 	r = dev_set_name(&v->dev, "vhost-vdpa-%u", minor);
1498 	if (r)
1499 		goto err;
1500 
1501 	cdev_init(&v->cdev, &vhost_vdpa_fops);
1502 	v->cdev.owner = THIS_MODULE;
1503 
1504 	r = cdev_device_add(&v->cdev, &v->dev);
1505 	if (r)
1506 		goto err;
1507 
1508 	init_completion(&v->completion);
1509 	vdpa_set_drvdata(vdpa, v);
1510 
1511 	for (i = 0; i < VHOST_VDPA_IOTLB_BUCKETS; i++)
1512 		INIT_HLIST_HEAD(&v->as[i]);
1513 
1514 	return 0;
1515 
1516 err:
1517 	put_device(&v->dev);
1518 	return r;
1519 }
1520 
1521 static void vhost_vdpa_remove(struct vdpa_device *vdpa)
1522 {
1523 	struct vhost_vdpa *v = vdpa_get_drvdata(vdpa);
1524 	int opened;
1525 
1526 	cdev_device_del(&v->cdev, &v->dev);
1527 
1528 	do {
1529 		opened = atomic_cmpxchg(&v->opened, 0, 1);
1530 		if (!opened)
1531 			break;
1532 		wait_for_completion(&v->completion);
1533 	} while (1);
1534 
1535 	put_device(&v->dev);
1536 }
1537 
1538 static struct vdpa_driver vhost_vdpa_driver = {
1539 	.driver = {
1540 		.name	= "vhost_vdpa",
1541 	},
1542 	.probe	= vhost_vdpa_probe,
1543 	.remove	= vhost_vdpa_remove,
1544 };
1545 
1546 static int __init vhost_vdpa_init(void)
1547 {
1548 	int r;
1549 
1550 	r = alloc_chrdev_region(&vhost_vdpa_major, 0, VHOST_VDPA_DEV_MAX,
1551 				"vhost-vdpa");
1552 	if (r)
1553 		goto err_alloc_chrdev;
1554 
1555 	r = vdpa_register_driver(&vhost_vdpa_driver);
1556 	if (r)
1557 		goto err_vdpa_register_driver;
1558 
1559 	return 0;
1560 
1561 err_vdpa_register_driver:
1562 	unregister_chrdev_region(vhost_vdpa_major, VHOST_VDPA_DEV_MAX);
1563 err_alloc_chrdev:
1564 	return r;
1565 }
1566 module_init(vhost_vdpa_init);
1567 
1568 static void __exit vhost_vdpa_exit(void)
1569 {
1570 	vdpa_unregister_driver(&vhost_vdpa_driver);
1571 	unregister_chrdev_region(vhost_vdpa_major, VHOST_VDPA_DEV_MAX);
1572 }
1573 module_exit(vhost_vdpa_exit);
1574 
1575 MODULE_VERSION("0.0.1");
1576 MODULE_LICENSE("GPL v2");
1577 MODULE_AUTHOR("Intel Corporation");
1578 MODULE_DESCRIPTION("vDPA-based vhost backend for virtio");
1579