1 // SPDX-License-Identifier: GPL-2.0-only
2 /* Copyright (C) 2009 Red Hat, Inc.
3 * Copyright (C) 2006 Rusty Russell IBM Corporation
4 *
5 * Author: Michael S. Tsirkin <mst@redhat.com>
6 *
7 * Inspiration, some code, and most witty comments come from
8 * Documentation/virtual/lguest/lguest.c, by Rusty Russell
9 *
10 * Generic code for virtio server in host kernel.
11 */
12
13 #include <linux/eventfd.h>
14 #include <linux/vhost.h>
15 #include <linux/uio.h>
16 #include <linux/mm.h>
17 #include <linux/miscdevice.h>
18 #include <linux/mutex.h>
19 #include <linux/poll.h>
20 #include <linux/file.h>
21 #include <linux/highmem.h>
22 #include <linux/slab.h>
23 #include <linux/vmalloc.h>
24 #include <linux/kthread.h>
25 #include <linux/module.h>
26 #include <linux/sort.h>
27 #include <linux/sched/mm.h>
28 #include <linux/sched/signal.h>
29 #include <linux/sched/vhost_task.h>
30 #include <linux/interval_tree_generic.h>
31 #include <linux/nospec.h>
32 #include <linux/kcov.h>
33
34 #include "vhost.h"
35
36 static ushort max_mem_regions = 64;
37 module_param(max_mem_regions, ushort, 0444);
38 MODULE_PARM_DESC(max_mem_regions,
39 "Maximum number of memory regions in memory map. (default: 64)");
40 static int max_iotlb_entries = 2048;
41 module_param(max_iotlb_entries, int, 0444);
42 MODULE_PARM_DESC(max_iotlb_entries,
43 "Maximum number of iotlb entries. (default: 2048)");
44
45 enum {
46 VHOST_MEMORY_F_LOG = 0x1,
47 };
48
49 #define vhost_used_event(vq) ((__virtio16 __user *)&vq->avail->ring[vq->num])
50 #define vhost_avail_event(vq) ((__virtio16 __user *)&vq->used->ring[vq->num])
51
52 #ifdef CONFIG_VHOST_CROSS_ENDIAN_LEGACY
vhost_disable_cross_endian(struct vhost_virtqueue * vq)53 static void vhost_disable_cross_endian(struct vhost_virtqueue *vq)
54 {
55 vq->user_be = !virtio_legacy_is_little_endian();
56 }
57
vhost_enable_cross_endian_big(struct vhost_virtqueue * vq)58 static void vhost_enable_cross_endian_big(struct vhost_virtqueue *vq)
59 {
60 vq->user_be = true;
61 }
62
vhost_enable_cross_endian_little(struct vhost_virtqueue * vq)63 static void vhost_enable_cross_endian_little(struct vhost_virtqueue *vq)
64 {
65 vq->user_be = false;
66 }
67
vhost_set_vring_endian(struct vhost_virtqueue * vq,int __user * argp)68 static long vhost_set_vring_endian(struct vhost_virtqueue *vq, int __user *argp)
69 {
70 struct vhost_vring_state s;
71
72 if (vq->private_data)
73 return -EBUSY;
74
75 if (copy_from_user(&s, argp, sizeof(s)))
76 return -EFAULT;
77
78 if (s.num != VHOST_VRING_LITTLE_ENDIAN &&
79 s.num != VHOST_VRING_BIG_ENDIAN)
80 return -EINVAL;
81
82 if (s.num == VHOST_VRING_BIG_ENDIAN)
83 vhost_enable_cross_endian_big(vq);
84 else
85 vhost_enable_cross_endian_little(vq);
86
87 return 0;
88 }
89
vhost_get_vring_endian(struct vhost_virtqueue * vq,u32 idx,int __user * argp)90 static long vhost_get_vring_endian(struct vhost_virtqueue *vq, u32 idx,
91 int __user *argp)
92 {
93 struct vhost_vring_state s = {
94 .index = idx,
95 .num = vq->user_be
96 };
97
98 if (copy_to_user(argp, &s, sizeof(s)))
99 return -EFAULT;
100
101 return 0;
102 }
103
vhost_init_is_le(struct vhost_virtqueue * vq)104 static void vhost_init_is_le(struct vhost_virtqueue *vq)
105 {
106 /* Note for legacy virtio: user_be is initialized at reset time
107 * according to the host endianness. If userspace does not set an
108 * explicit endianness, the default behavior is native endian, as
109 * expected by legacy virtio.
110 */
111 vq->is_le = vhost_has_feature(vq, VIRTIO_F_VERSION_1) || !vq->user_be;
112 }
113 #else
vhost_disable_cross_endian(struct vhost_virtqueue * vq)114 static void vhost_disable_cross_endian(struct vhost_virtqueue *vq)
115 {
116 }
117
vhost_set_vring_endian(struct vhost_virtqueue * vq,int __user * argp)118 static long vhost_set_vring_endian(struct vhost_virtqueue *vq, int __user *argp)
119 {
120 return -ENOIOCTLCMD;
121 }
122
vhost_get_vring_endian(struct vhost_virtqueue * vq,u32 idx,int __user * argp)123 static long vhost_get_vring_endian(struct vhost_virtqueue *vq, u32 idx,
124 int __user *argp)
125 {
126 return -ENOIOCTLCMD;
127 }
128
vhost_init_is_le(struct vhost_virtqueue * vq)129 static void vhost_init_is_le(struct vhost_virtqueue *vq)
130 {
131 vq->is_le = vhost_has_feature(vq, VIRTIO_F_VERSION_1)
132 || virtio_legacy_is_little_endian();
133 }
134 #endif /* CONFIG_VHOST_CROSS_ENDIAN_LEGACY */
135
vhost_reset_is_le(struct vhost_virtqueue * vq)136 static void vhost_reset_is_le(struct vhost_virtqueue *vq)
137 {
138 vhost_init_is_le(vq);
139 }
140
141 struct vhost_flush_struct {
142 struct vhost_work work;
143 struct completion wait_event;
144 };
145
vhost_flush_work(struct vhost_work * work)146 static void vhost_flush_work(struct vhost_work *work)
147 {
148 struct vhost_flush_struct *s;
149
150 s = container_of(work, struct vhost_flush_struct, work);
151 complete(&s->wait_event);
152 }
153
vhost_poll_func(struct file * file,wait_queue_head_t * wqh,poll_table * pt)154 static void vhost_poll_func(struct file *file, wait_queue_head_t *wqh,
155 poll_table *pt)
156 {
157 struct vhost_poll *poll;
158
159 poll = container_of(pt, struct vhost_poll, table);
160 poll->wqh = wqh;
161 add_wait_queue(wqh, &poll->wait);
162 }
163
vhost_poll_wakeup(wait_queue_entry_t * wait,unsigned mode,int sync,void * key)164 static int vhost_poll_wakeup(wait_queue_entry_t *wait, unsigned mode, int sync,
165 void *key)
166 {
167 struct vhost_poll *poll = container_of(wait, struct vhost_poll, wait);
168 struct vhost_work *work = &poll->work;
169
170 if (!(key_to_poll(key) & poll->mask))
171 return 0;
172
173 if (!poll->dev->use_worker)
174 work->fn(work);
175 else
176 vhost_poll_queue(poll);
177
178 return 0;
179 }
180
vhost_work_init(struct vhost_work * work,vhost_work_fn_t fn)181 void vhost_work_init(struct vhost_work *work, vhost_work_fn_t fn)
182 {
183 clear_bit(VHOST_WORK_QUEUED, &work->flags);
184 work->fn = fn;
185 }
186 EXPORT_SYMBOL_GPL(vhost_work_init);
187
188 /* Init poll structure */
vhost_poll_init(struct vhost_poll * poll,vhost_work_fn_t fn,__poll_t mask,struct vhost_dev * dev,struct vhost_virtqueue * vq)189 void vhost_poll_init(struct vhost_poll *poll, vhost_work_fn_t fn,
190 __poll_t mask, struct vhost_dev *dev,
191 struct vhost_virtqueue *vq)
192 {
193 init_waitqueue_func_entry(&poll->wait, vhost_poll_wakeup);
194 init_poll_funcptr(&poll->table, vhost_poll_func);
195 poll->mask = mask;
196 poll->dev = dev;
197 poll->wqh = NULL;
198 poll->vq = vq;
199
200 vhost_work_init(&poll->work, fn);
201 }
202 EXPORT_SYMBOL_GPL(vhost_poll_init);
203
204 /* Start polling a file. We add ourselves to file's wait queue. The caller must
205 * keep a reference to a file until after vhost_poll_stop is called. */
vhost_poll_start(struct vhost_poll * poll,struct file * file)206 int vhost_poll_start(struct vhost_poll *poll, struct file *file)
207 {
208 __poll_t mask;
209
210 if (poll->wqh)
211 return 0;
212
213 mask = vfs_poll(file, &poll->table);
214 if (mask)
215 vhost_poll_wakeup(&poll->wait, 0, 0, poll_to_key(mask));
216 if (mask & EPOLLERR) {
217 vhost_poll_stop(poll);
218 return -EINVAL;
219 }
220
221 return 0;
222 }
223 EXPORT_SYMBOL_GPL(vhost_poll_start);
224
225 /* Stop polling a file. After this function returns, it becomes safe to drop the
226 * file reference. You must also flush afterwards. */
vhost_poll_stop(struct vhost_poll * poll)227 void vhost_poll_stop(struct vhost_poll *poll)
228 {
229 if (poll->wqh) {
230 remove_wait_queue(poll->wqh, &poll->wait);
231 poll->wqh = NULL;
232 }
233 }
234 EXPORT_SYMBOL_GPL(vhost_poll_stop);
235
vhost_worker_queue(struct vhost_worker * worker,struct vhost_work * work)236 static void vhost_worker_queue(struct vhost_worker *worker,
237 struct vhost_work *work)
238 {
239 if (!test_and_set_bit(VHOST_WORK_QUEUED, &work->flags)) {
240 /* We can only add the work to the list after we're
241 * sure it was not in the list.
242 * test_and_set_bit() implies a memory barrier.
243 */
244 llist_add(&work->node, &worker->work_list);
245 vhost_task_wake(worker->vtsk);
246 }
247 }
248
vhost_vq_work_queue(struct vhost_virtqueue * vq,struct vhost_work * work)249 bool vhost_vq_work_queue(struct vhost_virtqueue *vq, struct vhost_work *work)
250 {
251 struct vhost_worker *worker;
252 bool queued = false;
253
254 rcu_read_lock();
255 worker = rcu_dereference(vq->worker);
256 if (worker) {
257 queued = true;
258 vhost_worker_queue(worker, work);
259 }
260 rcu_read_unlock();
261
262 return queued;
263 }
264 EXPORT_SYMBOL_GPL(vhost_vq_work_queue);
265
vhost_vq_flush(struct vhost_virtqueue * vq)266 void vhost_vq_flush(struct vhost_virtqueue *vq)
267 {
268 struct vhost_flush_struct flush;
269
270 init_completion(&flush.wait_event);
271 vhost_work_init(&flush.work, vhost_flush_work);
272
273 if (vhost_vq_work_queue(vq, &flush.work))
274 wait_for_completion(&flush.wait_event);
275 }
276 EXPORT_SYMBOL_GPL(vhost_vq_flush);
277
278 /**
279 * __vhost_worker_flush - flush a worker
280 * @worker: worker to flush
281 *
282 * The worker's flush_mutex must be held.
283 */
__vhost_worker_flush(struct vhost_worker * worker)284 static void __vhost_worker_flush(struct vhost_worker *worker)
285 {
286 struct vhost_flush_struct flush;
287
288 if (!worker->attachment_cnt || worker->killed)
289 return;
290
291 init_completion(&flush.wait_event);
292 vhost_work_init(&flush.work, vhost_flush_work);
293
294 vhost_worker_queue(worker, &flush.work);
295 /*
296 * Drop mutex in case our worker is killed and it needs to take the
297 * mutex to force cleanup.
298 */
299 mutex_unlock(&worker->mutex);
300 wait_for_completion(&flush.wait_event);
301 mutex_lock(&worker->mutex);
302 }
303
vhost_worker_flush(struct vhost_worker * worker)304 static void vhost_worker_flush(struct vhost_worker *worker)
305 {
306 mutex_lock(&worker->mutex);
307 __vhost_worker_flush(worker);
308 mutex_unlock(&worker->mutex);
309 }
310
vhost_dev_flush(struct vhost_dev * dev)311 void vhost_dev_flush(struct vhost_dev *dev)
312 {
313 struct vhost_worker *worker;
314 unsigned long i;
315
316 xa_for_each(&dev->worker_xa, i, worker)
317 vhost_worker_flush(worker);
318 }
319 EXPORT_SYMBOL_GPL(vhost_dev_flush);
320
321 /* A lockless hint for busy polling code to exit the loop */
vhost_vq_has_work(struct vhost_virtqueue * vq)322 bool vhost_vq_has_work(struct vhost_virtqueue *vq)
323 {
324 struct vhost_worker *worker;
325 bool has_work = false;
326
327 rcu_read_lock();
328 worker = rcu_dereference(vq->worker);
329 if (worker && !llist_empty(&worker->work_list))
330 has_work = true;
331 rcu_read_unlock();
332
333 return has_work;
334 }
335 EXPORT_SYMBOL_GPL(vhost_vq_has_work);
336
vhost_poll_queue(struct vhost_poll * poll)337 void vhost_poll_queue(struct vhost_poll *poll)
338 {
339 vhost_vq_work_queue(poll->vq, &poll->work);
340 }
341 EXPORT_SYMBOL_GPL(vhost_poll_queue);
342
__vhost_vq_meta_reset(struct vhost_virtqueue * vq)343 static void __vhost_vq_meta_reset(struct vhost_virtqueue *vq)
344 {
345 int j;
346
347 for (j = 0; j < VHOST_NUM_ADDRS; j++)
348 vq->meta_iotlb[j] = NULL;
349 }
350
vhost_vq_meta_reset(struct vhost_dev * d)351 static void vhost_vq_meta_reset(struct vhost_dev *d)
352 {
353 int i;
354
355 for (i = 0; i < d->nvqs; ++i)
356 __vhost_vq_meta_reset(d->vqs[i]);
357 }
358
vhost_vring_call_reset(struct vhost_vring_call * call_ctx)359 static void vhost_vring_call_reset(struct vhost_vring_call *call_ctx)
360 {
361 call_ctx->ctx = NULL;
362 memset(&call_ctx->producer, 0x0, sizeof(struct irq_bypass_producer));
363 }
364
vhost_vq_is_setup(struct vhost_virtqueue * vq)365 bool vhost_vq_is_setup(struct vhost_virtqueue *vq)
366 {
367 return vq->avail && vq->desc && vq->used && vhost_vq_access_ok(vq);
368 }
369 EXPORT_SYMBOL_GPL(vhost_vq_is_setup);
370
vhost_vq_reset(struct vhost_dev * dev,struct vhost_virtqueue * vq)371 static void vhost_vq_reset(struct vhost_dev *dev,
372 struct vhost_virtqueue *vq)
373 {
374 vq->num = 1;
375 vq->desc = NULL;
376 vq->avail = NULL;
377 vq->used = NULL;
378 vq->last_avail_idx = 0;
379 vq->avail_idx = 0;
380 vq->last_used_idx = 0;
381 vq->signalled_used = 0;
382 vq->signalled_used_valid = false;
383 vq->used_flags = 0;
384 vq->log_used = false;
385 vq->log_addr = -1ull;
386 vq->private_data = NULL;
387 vq->acked_features = 0;
388 vq->acked_backend_features = 0;
389 vq->log_base = NULL;
390 vq->error_ctx = NULL;
391 vq->kick = NULL;
392 vq->log_ctx = NULL;
393 vhost_disable_cross_endian(vq);
394 vhost_reset_is_le(vq);
395 vq->busyloop_timeout = 0;
396 vq->umem = NULL;
397 vq->iotlb = NULL;
398 rcu_assign_pointer(vq->worker, NULL);
399 vhost_vring_call_reset(&vq->call_ctx);
400 __vhost_vq_meta_reset(vq);
401 }
402
vhost_run_work_list(void * data)403 static bool vhost_run_work_list(void *data)
404 {
405 struct vhost_worker *worker = data;
406 struct vhost_work *work, *work_next;
407 struct llist_node *node;
408
409 node = llist_del_all(&worker->work_list);
410 if (node) {
411 __set_current_state(TASK_RUNNING);
412
413 node = llist_reverse_order(node);
414 /* make sure flag is seen after deletion */
415 smp_wmb();
416 llist_for_each_entry_safe(work, work_next, node, node) {
417 clear_bit(VHOST_WORK_QUEUED, &work->flags);
418 kcov_remote_start_common(worker->kcov_handle);
419 work->fn(work);
420 kcov_remote_stop();
421 cond_resched();
422 }
423 }
424
425 return !!node;
426 }
427
vhost_worker_killed(void * data)428 static void vhost_worker_killed(void *data)
429 {
430 struct vhost_worker *worker = data;
431 struct vhost_dev *dev = worker->dev;
432 struct vhost_virtqueue *vq;
433 int i, attach_cnt = 0;
434
435 mutex_lock(&worker->mutex);
436 worker->killed = true;
437
438 for (i = 0; i < dev->nvqs; i++) {
439 vq = dev->vqs[i];
440
441 mutex_lock(&vq->mutex);
442 if (worker ==
443 rcu_dereference_check(vq->worker,
444 lockdep_is_held(&vq->mutex))) {
445 rcu_assign_pointer(vq->worker, NULL);
446 attach_cnt++;
447 }
448 mutex_unlock(&vq->mutex);
449 }
450
451 worker->attachment_cnt -= attach_cnt;
452 if (attach_cnt)
453 synchronize_rcu();
454 /*
455 * Finish vhost_worker_flush calls and any other works that snuck in
456 * before the synchronize_rcu.
457 */
458 vhost_run_work_list(worker);
459 mutex_unlock(&worker->mutex);
460 }
461
vhost_vq_free_iovecs(struct vhost_virtqueue * vq)462 static void vhost_vq_free_iovecs(struct vhost_virtqueue *vq)
463 {
464 kfree(vq->indirect);
465 vq->indirect = NULL;
466 kfree(vq->log);
467 vq->log = NULL;
468 kfree(vq->heads);
469 vq->heads = NULL;
470 }
471
472 /* Helper to allocate iovec buffers for all vqs. */
vhost_dev_alloc_iovecs(struct vhost_dev * dev)473 static long vhost_dev_alloc_iovecs(struct vhost_dev *dev)
474 {
475 struct vhost_virtqueue *vq;
476 int i;
477
478 for (i = 0; i < dev->nvqs; ++i) {
479 vq = dev->vqs[i];
480 vq->indirect = kmalloc_array(UIO_MAXIOV,
481 sizeof(*vq->indirect),
482 GFP_KERNEL);
483 vq->log = kmalloc_array(dev->iov_limit, sizeof(*vq->log),
484 GFP_KERNEL);
485 vq->heads = kmalloc_array(dev->iov_limit, sizeof(*vq->heads),
486 GFP_KERNEL);
487 if (!vq->indirect || !vq->log || !vq->heads)
488 goto err_nomem;
489 }
490 return 0;
491
492 err_nomem:
493 for (; i >= 0; --i)
494 vhost_vq_free_iovecs(dev->vqs[i]);
495 return -ENOMEM;
496 }
497
vhost_dev_free_iovecs(struct vhost_dev * dev)498 static void vhost_dev_free_iovecs(struct vhost_dev *dev)
499 {
500 int i;
501
502 for (i = 0; i < dev->nvqs; ++i)
503 vhost_vq_free_iovecs(dev->vqs[i]);
504 }
505
vhost_exceeds_weight(struct vhost_virtqueue * vq,int pkts,int total_len)506 bool vhost_exceeds_weight(struct vhost_virtqueue *vq,
507 int pkts, int total_len)
508 {
509 struct vhost_dev *dev = vq->dev;
510
511 if ((dev->byte_weight && total_len >= dev->byte_weight) ||
512 pkts >= dev->weight) {
513 vhost_poll_queue(&vq->poll);
514 return true;
515 }
516
517 return false;
518 }
519 EXPORT_SYMBOL_GPL(vhost_exceeds_weight);
520
vhost_get_avail_size(struct vhost_virtqueue * vq,unsigned int num)521 static size_t vhost_get_avail_size(struct vhost_virtqueue *vq,
522 unsigned int num)
523 {
524 size_t event __maybe_unused =
525 vhost_has_feature(vq, VIRTIO_RING_F_EVENT_IDX) ? 2 : 0;
526
527 return size_add(struct_size(vq->avail, ring, num), event);
528 }
529
vhost_get_used_size(struct vhost_virtqueue * vq,unsigned int num)530 static size_t vhost_get_used_size(struct vhost_virtqueue *vq,
531 unsigned int num)
532 {
533 size_t event __maybe_unused =
534 vhost_has_feature(vq, VIRTIO_RING_F_EVENT_IDX) ? 2 : 0;
535
536 return size_add(struct_size(vq->used, ring, num), event);
537 }
538
vhost_get_desc_size(struct vhost_virtqueue * vq,unsigned int num)539 static size_t vhost_get_desc_size(struct vhost_virtqueue *vq,
540 unsigned int num)
541 {
542 return sizeof(*vq->desc) * num;
543 }
544
vhost_dev_init(struct vhost_dev * dev,struct vhost_virtqueue ** vqs,int nvqs,int iov_limit,int weight,int byte_weight,bool use_worker,int (* msg_handler)(struct vhost_dev * dev,u32 asid,struct vhost_iotlb_msg * msg))545 void vhost_dev_init(struct vhost_dev *dev,
546 struct vhost_virtqueue **vqs, int nvqs,
547 int iov_limit, int weight, int byte_weight,
548 bool use_worker,
549 int (*msg_handler)(struct vhost_dev *dev, u32 asid,
550 struct vhost_iotlb_msg *msg))
551 {
552 struct vhost_virtqueue *vq;
553 int i;
554
555 dev->vqs = vqs;
556 dev->nvqs = nvqs;
557 mutex_init(&dev->mutex);
558 dev->log_ctx = NULL;
559 dev->umem = NULL;
560 dev->iotlb = NULL;
561 dev->mm = NULL;
562 dev->iov_limit = iov_limit;
563 dev->weight = weight;
564 dev->byte_weight = byte_weight;
565 dev->use_worker = use_worker;
566 dev->msg_handler = msg_handler;
567 init_waitqueue_head(&dev->wait);
568 INIT_LIST_HEAD(&dev->read_list);
569 INIT_LIST_HEAD(&dev->pending_list);
570 spin_lock_init(&dev->iotlb_lock);
571 xa_init_flags(&dev->worker_xa, XA_FLAGS_ALLOC);
572
573 for (i = 0; i < dev->nvqs; ++i) {
574 vq = dev->vqs[i];
575 vq->log = NULL;
576 vq->indirect = NULL;
577 vq->heads = NULL;
578 vq->dev = dev;
579 mutex_init(&vq->mutex);
580 vhost_vq_reset(dev, vq);
581 if (vq->handle_kick)
582 vhost_poll_init(&vq->poll, vq->handle_kick,
583 EPOLLIN, dev, vq);
584 }
585 }
586 EXPORT_SYMBOL_GPL(vhost_dev_init);
587
588 /* Caller should have device mutex */
vhost_dev_check_owner(struct vhost_dev * dev)589 long vhost_dev_check_owner(struct vhost_dev *dev)
590 {
591 /* Are you the owner? If not, I don't think you mean to do that */
592 return dev->mm == current->mm ? 0 : -EPERM;
593 }
594 EXPORT_SYMBOL_GPL(vhost_dev_check_owner);
595
596 /* Caller should have device mutex */
vhost_dev_has_owner(struct vhost_dev * dev)597 bool vhost_dev_has_owner(struct vhost_dev *dev)
598 {
599 return dev->mm;
600 }
601 EXPORT_SYMBOL_GPL(vhost_dev_has_owner);
602
vhost_attach_mm(struct vhost_dev * dev)603 static void vhost_attach_mm(struct vhost_dev *dev)
604 {
605 /* No owner, become one */
606 if (dev->use_worker) {
607 dev->mm = get_task_mm(current);
608 } else {
609 /* vDPA device does not use worker thead, so there's
610 * no need to hold the address space for mm. This help
611 * to avoid deadlock in the case of mmap() which may
612 * held the refcnt of the file and depends on release
613 * method to remove vma.
614 */
615 dev->mm = current->mm;
616 mmgrab(dev->mm);
617 }
618 }
619
vhost_detach_mm(struct vhost_dev * dev)620 static void vhost_detach_mm(struct vhost_dev *dev)
621 {
622 if (!dev->mm)
623 return;
624
625 if (dev->use_worker)
626 mmput(dev->mm);
627 else
628 mmdrop(dev->mm);
629
630 dev->mm = NULL;
631 }
632
vhost_worker_destroy(struct vhost_dev * dev,struct vhost_worker * worker)633 static void vhost_worker_destroy(struct vhost_dev *dev,
634 struct vhost_worker *worker)
635 {
636 if (!worker)
637 return;
638
639 WARN_ON(!llist_empty(&worker->work_list));
640 xa_erase(&dev->worker_xa, worker->id);
641 vhost_task_stop(worker->vtsk);
642 kfree(worker);
643 }
644
vhost_workers_free(struct vhost_dev * dev)645 static void vhost_workers_free(struct vhost_dev *dev)
646 {
647 struct vhost_worker *worker;
648 unsigned long i;
649
650 if (!dev->use_worker)
651 return;
652
653 for (i = 0; i < dev->nvqs; i++)
654 rcu_assign_pointer(dev->vqs[i]->worker, NULL);
655 /*
656 * Free the default worker we created and cleanup workers userspace
657 * created but couldn't clean up (it forgot or crashed).
658 */
659 xa_for_each(&dev->worker_xa, i, worker)
660 vhost_worker_destroy(dev, worker);
661 xa_destroy(&dev->worker_xa);
662 }
663
vhost_worker_create(struct vhost_dev * dev)664 static struct vhost_worker *vhost_worker_create(struct vhost_dev *dev)
665 {
666 struct vhost_worker *worker;
667 struct vhost_task *vtsk;
668 char name[TASK_COMM_LEN];
669 int ret;
670 u32 id;
671
672 worker = kzalloc(sizeof(*worker), GFP_KERNEL_ACCOUNT);
673 if (!worker)
674 return NULL;
675
676 worker->dev = dev;
677 snprintf(name, sizeof(name), "vhost-%d", current->pid);
678
679 vtsk = vhost_task_create(vhost_run_work_list, vhost_worker_killed,
680 worker, name);
681 if (!vtsk)
682 goto free_worker;
683
684 mutex_init(&worker->mutex);
685 init_llist_head(&worker->work_list);
686 worker->kcov_handle = kcov_common_handle();
687 worker->vtsk = vtsk;
688
689 vhost_task_start(vtsk);
690
691 ret = xa_alloc(&dev->worker_xa, &id, worker, xa_limit_32b, GFP_KERNEL);
692 if (ret < 0)
693 goto stop_worker;
694 worker->id = id;
695
696 return worker;
697
698 stop_worker:
699 vhost_task_stop(vtsk);
700 free_worker:
701 kfree(worker);
702 return NULL;
703 }
704
705 /* Caller must have device mutex */
__vhost_vq_attach_worker(struct vhost_virtqueue * vq,struct vhost_worker * worker)706 static void __vhost_vq_attach_worker(struct vhost_virtqueue *vq,
707 struct vhost_worker *worker)
708 {
709 struct vhost_worker *old_worker;
710
711 mutex_lock(&worker->mutex);
712 if (worker->killed) {
713 mutex_unlock(&worker->mutex);
714 return;
715 }
716
717 mutex_lock(&vq->mutex);
718
719 old_worker = rcu_dereference_check(vq->worker,
720 lockdep_is_held(&vq->mutex));
721 rcu_assign_pointer(vq->worker, worker);
722 worker->attachment_cnt++;
723
724 if (!old_worker) {
725 mutex_unlock(&vq->mutex);
726 mutex_unlock(&worker->mutex);
727 return;
728 }
729 mutex_unlock(&vq->mutex);
730 mutex_unlock(&worker->mutex);
731
732 /*
733 * Take the worker mutex to make sure we see the work queued from
734 * device wide flushes which doesn't use RCU for execution.
735 */
736 mutex_lock(&old_worker->mutex);
737 if (old_worker->killed) {
738 mutex_unlock(&old_worker->mutex);
739 return;
740 }
741
742 /*
743 * We don't want to call synchronize_rcu for every vq during setup
744 * because it will slow down VM startup. If we haven't done
745 * VHOST_SET_VRING_KICK and not done the driver specific
746 * SET_ENDPOINT/RUNNUNG then we can skip the sync since there will
747 * not be any works queued for scsi and net.
748 */
749 mutex_lock(&vq->mutex);
750 if (!vhost_vq_get_backend(vq) && !vq->kick) {
751 mutex_unlock(&vq->mutex);
752
753 old_worker->attachment_cnt--;
754 mutex_unlock(&old_worker->mutex);
755 /*
756 * vsock can queue anytime after VHOST_VSOCK_SET_GUEST_CID.
757 * Warn if it adds support for multiple workers but forgets to
758 * handle the early queueing case.
759 */
760 WARN_ON(!old_worker->attachment_cnt &&
761 !llist_empty(&old_worker->work_list));
762 return;
763 }
764 mutex_unlock(&vq->mutex);
765
766 /* Make sure new vq queue/flush/poll calls see the new worker */
767 synchronize_rcu();
768 /* Make sure whatever was queued gets run */
769 __vhost_worker_flush(old_worker);
770 old_worker->attachment_cnt--;
771 mutex_unlock(&old_worker->mutex);
772 }
773
774 /* Caller must have device mutex */
vhost_vq_attach_worker(struct vhost_virtqueue * vq,struct vhost_vring_worker * info)775 static int vhost_vq_attach_worker(struct vhost_virtqueue *vq,
776 struct vhost_vring_worker *info)
777 {
778 unsigned long index = info->worker_id;
779 struct vhost_dev *dev = vq->dev;
780 struct vhost_worker *worker;
781
782 if (!dev->use_worker)
783 return -EINVAL;
784
785 worker = xa_find(&dev->worker_xa, &index, UINT_MAX, XA_PRESENT);
786 if (!worker || worker->id != info->worker_id)
787 return -ENODEV;
788
789 __vhost_vq_attach_worker(vq, worker);
790 return 0;
791 }
792
793 /* Caller must have device mutex */
vhost_new_worker(struct vhost_dev * dev,struct vhost_worker_state * info)794 static int vhost_new_worker(struct vhost_dev *dev,
795 struct vhost_worker_state *info)
796 {
797 struct vhost_worker *worker;
798
799 worker = vhost_worker_create(dev);
800 if (!worker)
801 return -ENOMEM;
802
803 info->worker_id = worker->id;
804 return 0;
805 }
806
807 /* Caller must have device mutex */
vhost_free_worker(struct vhost_dev * dev,struct vhost_worker_state * info)808 static int vhost_free_worker(struct vhost_dev *dev,
809 struct vhost_worker_state *info)
810 {
811 unsigned long index = info->worker_id;
812 struct vhost_worker *worker;
813
814 worker = xa_find(&dev->worker_xa, &index, UINT_MAX, XA_PRESENT);
815 if (!worker || worker->id != info->worker_id)
816 return -ENODEV;
817
818 mutex_lock(&worker->mutex);
819 if (worker->attachment_cnt || worker->killed) {
820 mutex_unlock(&worker->mutex);
821 return -EBUSY;
822 }
823 /*
824 * A flush might have raced and snuck in before attachment_cnt was set
825 * to zero. Make sure flushes are flushed from the queue before
826 * freeing.
827 */
828 __vhost_worker_flush(worker);
829 mutex_unlock(&worker->mutex);
830
831 vhost_worker_destroy(dev, worker);
832 return 0;
833 }
834
vhost_get_vq_from_user(struct vhost_dev * dev,void __user * argp,struct vhost_virtqueue ** vq,u32 * id)835 static int vhost_get_vq_from_user(struct vhost_dev *dev, void __user *argp,
836 struct vhost_virtqueue **vq, u32 *id)
837 {
838 u32 __user *idxp = argp;
839 u32 idx;
840 long r;
841
842 r = get_user(idx, idxp);
843 if (r < 0)
844 return r;
845
846 if (idx >= dev->nvqs)
847 return -ENOBUFS;
848
849 idx = array_index_nospec(idx, dev->nvqs);
850
851 *vq = dev->vqs[idx];
852 *id = idx;
853 return 0;
854 }
855
856 /* Caller must have device mutex */
vhost_worker_ioctl(struct vhost_dev * dev,unsigned int ioctl,void __user * argp)857 long vhost_worker_ioctl(struct vhost_dev *dev, unsigned int ioctl,
858 void __user *argp)
859 {
860 struct vhost_vring_worker ring_worker;
861 struct vhost_worker_state state;
862 struct vhost_worker *worker;
863 struct vhost_virtqueue *vq;
864 long ret;
865 u32 idx;
866
867 if (!dev->use_worker)
868 return -EINVAL;
869
870 if (!vhost_dev_has_owner(dev))
871 return -EINVAL;
872
873 ret = vhost_dev_check_owner(dev);
874 if (ret)
875 return ret;
876
877 switch (ioctl) {
878 /* dev worker ioctls */
879 case VHOST_NEW_WORKER:
880 ret = vhost_new_worker(dev, &state);
881 if (!ret && copy_to_user(argp, &state, sizeof(state)))
882 ret = -EFAULT;
883 return ret;
884 case VHOST_FREE_WORKER:
885 if (copy_from_user(&state, argp, sizeof(state)))
886 return -EFAULT;
887 return vhost_free_worker(dev, &state);
888 /* vring worker ioctls */
889 case VHOST_ATTACH_VRING_WORKER:
890 case VHOST_GET_VRING_WORKER:
891 break;
892 default:
893 return -ENOIOCTLCMD;
894 }
895
896 ret = vhost_get_vq_from_user(dev, argp, &vq, &idx);
897 if (ret)
898 return ret;
899
900 switch (ioctl) {
901 case VHOST_ATTACH_VRING_WORKER:
902 if (copy_from_user(&ring_worker, argp, sizeof(ring_worker))) {
903 ret = -EFAULT;
904 break;
905 }
906
907 ret = vhost_vq_attach_worker(vq, &ring_worker);
908 break;
909 case VHOST_GET_VRING_WORKER:
910 worker = rcu_dereference_check(vq->worker,
911 lockdep_is_held(&dev->mutex));
912 if (!worker) {
913 ret = -EINVAL;
914 break;
915 }
916
917 ring_worker.index = idx;
918 ring_worker.worker_id = worker->id;
919
920 if (copy_to_user(argp, &ring_worker, sizeof(ring_worker)))
921 ret = -EFAULT;
922 break;
923 default:
924 ret = -ENOIOCTLCMD;
925 break;
926 }
927
928 return ret;
929 }
930 EXPORT_SYMBOL_GPL(vhost_worker_ioctl);
931
932 /* Caller should have device mutex */
vhost_dev_set_owner(struct vhost_dev * dev)933 long vhost_dev_set_owner(struct vhost_dev *dev)
934 {
935 struct vhost_worker *worker;
936 int err, i;
937
938 /* Is there an owner already? */
939 if (vhost_dev_has_owner(dev)) {
940 err = -EBUSY;
941 goto err_mm;
942 }
943
944 vhost_attach_mm(dev);
945
946 err = vhost_dev_alloc_iovecs(dev);
947 if (err)
948 goto err_iovecs;
949
950 if (dev->use_worker) {
951 /*
952 * This should be done last, because vsock can queue work
953 * before VHOST_SET_OWNER so it simplifies the failure path
954 * below since we don't have to worry about vsock queueing
955 * while we free the worker.
956 */
957 worker = vhost_worker_create(dev);
958 if (!worker) {
959 err = -ENOMEM;
960 goto err_worker;
961 }
962
963 for (i = 0; i < dev->nvqs; i++)
964 __vhost_vq_attach_worker(dev->vqs[i], worker);
965 }
966
967 return 0;
968
969 err_worker:
970 vhost_dev_free_iovecs(dev);
971 err_iovecs:
972 vhost_detach_mm(dev);
973 err_mm:
974 return err;
975 }
976 EXPORT_SYMBOL_GPL(vhost_dev_set_owner);
977
iotlb_alloc(void)978 static struct vhost_iotlb *iotlb_alloc(void)
979 {
980 return vhost_iotlb_alloc(max_iotlb_entries,
981 VHOST_IOTLB_FLAG_RETIRE);
982 }
983
vhost_dev_reset_owner_prepare(void)984 struct vhost_iotlb *vhost_dev_reset_owner_prepare(void)
985 {
986 return iotlb_alloc();
987 }
988 EXPORT_SYMBOL_GPL(vhost_dev_reset_owner_prepare);
989
990 /* Caller should have device mutex */
vhost_dev_reset_owner(struct vhost_dev * dev,struct vhost_iotlb * umem)991 void vhost_dev_reset_owner(struct vhost_dev *dev, struct vhost_iotlb *umem)
992 {
993 int i;
994
995 vhost_dev_cleanup(dev);
996
997 dev->umem = umem;
998 /* We don't need VQ locks below since vhost_dev_cleanup makes sure
999 * VQs aren't running.
1000 */
1001 for (i = 0; i < dev->nvqs; ++i)
1002 dev->vqs[i]->umem = umem;
1003 }
1004 EXPORT_SYMBOL_GPL(vhost_dev_reset_owner);
1005
vhost_dev_stop(struct vhost_dev * dev)1006 void vhost_dev_stop(struct vhost_dev *dev)
1007 {
1008 int i;
1009
1010 for (i = 0; i < dev->nvqs; ++i) {
1011 if (dev->vqs[i]->kick && dev->vqs[i]->handle_kick)
1012 vhost_poll_stop(&dev->vqs[i]->poll);
1013 }
1014
1015 vhost_dev_flush(dev);
1016 }
1017 EXPORT_SYMBOL_GPL(vhost_dev_stop);
1018
vhost_clear_msg(struct vhost_dev * dev)1019 void vhost_clear_msg(struct vhost_dev *dev)
1020 {
1021 struct vhost_msg_node *node, *n;
1022
1023 spin_lock(&dev->iotlb_lock);
1024
1025 list_for_each_entry_safe(node, n, &dev->read_list, node) {
1026 list_del(&node->node);
1027 kfree(node);
1028 }
1029
1030 list_for_each_entry_safe(node, n, &dev->pending_list, node) {
1031 list_del(&node->node);
1032 kfree(node);
1033 }
1034
1035 spin_unlock(&dev->iotlb_lock);
1036 }
1037 EXPORT_SYMBOL_GPL(vhost_clear_msg);
1038
vhost_dev_cleanup(struct vhost_dev * dev)1039 void vhost_dev_cleanup(struct vhost_dev *dev)
1040 {
1041 int i;
1042
1043 for (i = 0; i < dev->nvqs; ++i) {
1044 if (dev->vqs[i]->error_ctx)
1045 eventfd_ctx_put(dev->vqs[i]->error_ctx);
1046 if (dev->vqs[i]->kick)
1047 fput(dev->vqs[i]->kick);
1048 if (dev->vqs[i]->call_ctx.ctx)
1049 eventfd_ctx_put(dev->vqs[i]->call_ctx.ctx);
1050 vhost_vq_reset(dev, dev->vqs[i]);
1051 }
1052 vhost_dev_free_iovecs(dev);
1053 if (dev->log_ctx)
1054 eventfd_ctx_put(dev->log_ctx);
1055 dev->log_ctx = NULL;
1056 /* No one will access memory at this point */
1057 vhost_iotlb_free(dev->umem);
1058 dev->umem = NULL;
1059 vhost_iotlb_free(dev->iotlb);
1060 dev->iotlb = NULL;
1061 vhost_clear_msg(dev);
1062 wake_up_interruptible_poll(&dev->wait, EPOLLIN | EPOLLRDNORM);
1063 vhost_workers_free(dev);
1064 vhost_detach_mm(dev);
1065 }
1066 EXPORT_SYMBOL_GPL(vhost_dev_cleanup);
1067
log_access_ok(void __user * log_base,u64 addr,unsigned long sz)1068 static bool log_access_ok(void __user *log_base, u64 addr, unsigned long sz)
1069 {
1070 u64 a = addr / VHOST_PAGE_SIZE / 8;
1071
1072 /* Make sure 64 bit math will not overflow. */
1073 if (a > ULONG_MAX - (unsigned long)log_base ||
1074 a + (unsigned long)log_base > ULONG_MAX)
1075 return false;
1076
1077 return access_ok(log_base + a,
1078 (sz + VHOST_PAGE_SIZE * 8 - 1) / VHOST_PAGE_SIZE / 8);
1079 }
1080
1081 /* Make sure 64 bit math will not overflow. */
vhost_overflow(u64 uaddr,u64 size)1082 static bool vhost_overflow(u64 uaddr, u64 size)
1083 {
1084 if (uaddr > ULONG_MAX || size > ULONG_MAX)
1085 return true;
1086
1087 if (!size)
1088 return false;
1089
1090 return uaddr > ULONG_MAX - size + 1;
1091 }
1092
1093 /* Caller should have vq mutex and device mutex. */
vq_memory_access_ok(void __user * log_base,struct vhost_iotlb * umem,int log_all)1094 static bool vq_memory_access_ok(void __user *log_base, struct vhost_iotlb *umem,
1095 int log_all)
1096 {
1097 struct vhost_iotlb_map *map;
1098
1099 if (!umem)
1100 return false;
1101
1102 list_for_each_entry(map, &umem->list, link) {
1103 unsigned long a = map->addr;
1104
1105 if (vhost_overflow(map->addr, map->size))
1106 return false;
1107
1108
1109 if (!access_ok((void __user *)a, map->size))
1110 return false;
1111 else if (log_all && !log_access_ok(log_base,
1112 map->start,
1113 map->size))
1114 return false;
1115 }
1116 return true;
1117 }
1118
vhost_vq_meta_fetch(struct vhost_virtqueue * vq,u64 addr,unsigned int size,int type)1119 static inline void __user *vhost_vq_meta_fetch(struct vhost_virtqueue *vq,
1120 u64 addr, unsigned int size,
1121 int type)
1122 {
1123 const struct vhost_iotlb_map *map = vq->meta_iotlb[type];
1124
1125 if (!map)
1126 return NULL;
1127
1128 return (void __user *)(uintptr_t)(map->addr + addr - map->start);
1129 }
1130
1131 /* Can we switch to this memory table? */
1132 /* Caller should have device mutex but not vq mutex */
memory_access_ok(struct vhost_dev * d,struct vhost_iotlb * umem,int log_all)1133 static bool memory_access_ok(struct vhost_dev *d, struct vhost_iotlb *umem,
1134 int log_all)
1135 {
1136 int i;
1137
1138 for (i = 0; i < d->nvqs; ++i) {
1139 bool ok;
1140 bool log;
1141
1142 mutex_lock(&d->vqs[i]->mutex);
1143 log = log_all || vhost_has_feature(d->vqs[i], VHOST_F_LOG_ALL);
1144 /* If ring is inactive, will check when it's enabled. */
1145 if (d->vqs[i]->private_data)
1146 ok = vq_memory_access_ok(d->vqs[i]->log_base,
1147 umem, log);
1148 else
1149 ok = true;
1150 mutex_unlock(&d->vqs[i]->mutex);
1151 if (!ok)
1152 return false;
1153 }
1154 return true;
1155 }
1156
1157 static int translate_desc(struct vhost_virtqueue *vq, u64 addr, u32 len,
1158 struct iovec iov[], int iov_size, int access);
1159
vhost_copy_to_user(struct vhost_virtqueue * vq,void __user * to,const void * from,unsigned size)1160 static int vhost_copy_to_user(struct vhost_virtqueue *vq, void __user *to,
1161 const void *from, unsigned size)
1162 {
1163 int ret;
1164
1165 if (!vq->iotlb)
1166 return __copy_to_user(to, from, size);
1167 else {
1168 /* This function should be called after iotlb
1169 * prefetch, which means we're sure that all vq
1170 * could be access through iotlb. So -EAGAIN should
1171 * not happen in this case.
1172 */
1173 struct iov_iter t;
1174 void __user *uaddr = vhost_vq_meta_fetch(vq,
1175 (u64)(uintptr_t)to, size,
1176 VHOST_ADDR_USED);
1177
1178 if (uaddr)
1179 return __copy_to_user(uaddr, from, size);
1180
1181 ret = translate_desc(vq, (u64)(uintptr_t)to, size, vq->iotlb_iov,
1182 ARRAY_SIZE(vq->iotlb_iov),
1183 VHOST_ACCESS_WO);
1184 if (ret < 0)
1185 goto out;
1186 iov_iter_init(&t, ITER_DEST, vq->iotlb_iov, ret, size);
1187 ret = copy_to_iter(from, size, &t);
1188 if (ret == size)
1189 ret = 0;
1190 }
1191 out:
1192 return ret;
1193 }
1194
vhost_copy_from_user(struct vhost_virtqueue * vq,void * to,void __user * from,unsigned size)1195 static int vhost_copy_from_user(struct vhost_virtqueue *vq, void *to,
1196 void __user *from, unsigned size)
1197 {
1198 int ret;
1199
1200 if (!vq->iotlb)
1201 return __copy_from_user(to, from, size);
1202 else {
1203 /* This function should be called after iotlb
1204 * prefetch, which means we're sure that vq
1205 * could be access through iotlb. So -EAGAIN should
1206 * not happen in this case.
1207 */
1208 void __user *uaddr = vhost_vq_meta_fetch(vq,
1209 (u64)(uintptr_t)from, size,
1210 VHOST_ADDR_DESC);
1211 struct iov_iter f;
1212
1213 if (uaddr)
1214 return __copy_from_user(to, uaddr, size);
1215
1216 ret = translate_desc(vq, (u64)(uintptr_t)from, size, vq->iotlb_iov,
1217 ARRAY_SIZE(vq->iotlb_iov),
1218 VHOST_ACCESS_RO);
1219 if (ret < 0) {
1220 vq_err(vq, "IOTLB translation failure: uaddr "
1221 "%p size 0x%llx\n", from,
1222 (unsigned long long) size);
1223 goto out;
1224 }
1225 iov_iter_init(&f, ITER_SOURCE, vq->iotlb_iov, ret, size);
1226 ret = copy_from_iter(to, size, &f);
1227 if (ret == size)
1228 ret = 0;
1229 }
1230
1231 out:
1232 return ret;
1233 }
1234
__vhost_get_user_slow(struct vhost_virtqueue * vq,void __user * addr,unsigned int size,int type)1235 static void __user *__vhost_get_user_slow(struct vhost_virtqueue *vq,
1236 void __user *addr, unsigned int size,
1237 int type)
1238 {
1239 int ret;
1240
1241 ret = translate_desc(vq, (u64)(uintptr_t)addr, size, vq->iotlb_iov,
1242 ARRAY_SIZE(vq->iotlb_iov),
1243 VHOST_ACCESS_RO);
1244 if (ret < 0) {
1245 vq_err(vq, "IOTLB translation failure: uaddr "
1246 "%p size 0x%llx\n", addr,
1247 (unsigned long long) size);
1248 return NULL;
1249 }
1250
1251 if (ret != 1 || vq->iotlb_iov[0].iov_len != size) {
1252 vq_err(vq, "Non atomic userspace memory access: uaddr "
1253 "%p size 0x%llx\n", addr,
1254 (unsigned long long) size);
1255 return NULL;
1256 }
1257
1258 return vq->iotlb_iov[0].iov_base;
1259 }
1260
1261 /* This function should be called after iotlb
1262 * prefetch, which means we're sure that vq
1263 * could be access through iotlb. So -EAGAIN should
1264 * not happen in this case.
1265 */
__vhost_get_user(struct vhost_virtqueue * vq,void __user * addr,unsigned int size,int type)1266 static inline void __user *__vhost_get_user(struct vhost_virtqueue *vq,
1267 void __user *addr, unsigned int size,
1268 int type)
1269 {
1270 void __user *uaddr = vhost_vq_meta_fetch(vq,
1271 (u64)(uintptr_t)addr, size, type);
1272 if (uaddr)
1273 return uaddr;
1274
1275 return __vhost_get_user_slow(vq, addr, size, type);
1276 }
1277
1278 #define vhost_put_user(vq, x, ptr) \
1279 ({ \
1280 int ret; \
1281 if (!vq->iotlb) { \
1282 ret = __put_user(x, ptr); \
1283 } else { \
1284 __typeof__(ptr) to = \
1285 (__typeof__(ptr)) __vhost_get_user(vq, ptr, \
1286 sizeof(*ptr), VHOST_ADDR_USED); \
1287 if (to != NULL) \
1288 ret = __put_user(x, to); \
1289 else \
1290 ret = -EFAULT; \
1291 } \
1292 ret; \
1293 })
1294
vhost_put_avail_event(struct vhost_virtqueue * vq)1295 static inline int vhost_put_avail_event(struct vhost_virtqueue *vq)
1296 {
1297 return vhost_put_user(vq, cpu_to_vhost16(vq, vq->avail_idx),
1298 vhost_avail_event(vq));
1299 }
1300
vhost_put_used(struct vhost_virtqueue * vq,struct vring_used_elem * head,int idx,int count)1301 static inline int vhost_put_used(struct vhost_virtqueue *vq,
1302 struct vring_used_elem *head, int idx,
1303 int count)
1304 {
1305 return vhost_copy_to_user(vq, vq->used->ring + idx, head,
1306 count * sizeof(*head));
1307 }
1308
vhost_put_used_flags(struct vhost_virtqueue * vq)1309 static inline int vhost_put_used_flags(struct vhost_virtqueue *vq)
1310
1311 {
1312 return vhost_put_user(vq, cpu_to_vhost16(vq, vq->used_flags),
1313 &vq->used->flags);
1314 }
1315
vhost_put_used_idx(struct vhost_virtqueue * vq)1316 static inline int vhost_put_used_idx(struct vhost_virtqueue *vq)
1317
1318 {
1319 return vhost_put_user(vq, cpu_to_vhost16(vq, vq->last_used_idx),
1320 &vq->used->idx);
1321 }
1322
1323 #define vhost_get_user(vq, x, ptr, type) \
1324 ({ \
1325 int ret; \
1326 if (!vq->iotlb) { \
1327 ret = __get_user(x, ptr); \
1328 } else { \
1329 __typeof__(ptr) from = \
1330 (__typeof__(ptr)) __vhost_get_user(vq, ptr, \
1331 sizeof(*ptr), \
1332 type); \
1333 if (from != NULL) \
1334 ret = __get_user(x, from); \
1335 else \
1336 ret = -EFAULT; \
1337 } \
1338 ret; \
1339 })
1340
1341 #define vhost_get_avail(vq, x, ptr) \
1342 vhost_get_user(vq, x, ptr, VHOST_ADDR_AVAIL)
1343
1344 #define vhost_get_used(vq, x, ptr) \
1345 vhost_get_user(vq, x, ptr, VHOST_ADDR_USED)
1346
vhost_dev_lock_vqs(struct vhost_dev * d)1347 static void vhost_dev_lock_vqs(struct vhost_dev *d)
1348 {
1349 int i = 0;
1350 for (i = 0; i < d->nvqs; ++i)
1351 mutex_lock_nested(&d->vqs[i]->mutex, i);
1352 }
1353
vhost_dev_unlock_vqs(struct vhost_dev * d)1354 static void vhost_dev_unlock_vqs(struct vhost_dev *d)
1355 {
1356 int i = 0;
1357 for (i = 0; i < d->nvqs; ++i)
1358 mutex_unlock(&d->vqs[i]->mutex);
1359 }
1360
vhost_get_avail_idx(struct vhost_virtqueue * vq,__virtio16 * idx)1361 static inline int vhost_get_avail_idx(struct vhost_virtqueue *vq,
1362 __virtio16 *idx)
1363 {
1364 return vhost_get_avail(vq, *idx, &vq->avail->idx);
1365 }
1366
vhost_get_avail_head(struct vhost_virtqueue * vq,__virtio16 * head,int idx)1367 static inline int vhost_get_avail_head(struct vhost_virtqueue *vq,
1368 __virtio16 *head, int idx)
1369 {
1370 return vhost_get_avail(vq, *head,
1371 &vq->avail->ring[idx & (vq->num - 1)]);
1372 }
1373
vhost_get_avail_flags(struct vhost_virtqueue * vq,__virtio16 * flags)1374 static inline int vhost_get_avail_flags(struct vhost_virtqueue *vq,
1375 __virtio16 *flags)
1376 {
1377 return vhost_get_avail(vq, *flags, &vq->avail->flags);
1378 }
1379
vhost_get_used_event(struct vhost_virtqueue * vq,__virtio16 * event)1380 static inline int vhost_get_used_event(struct vhost_virtqueue *vq,
1381 __virtio16 *event)
1382 {
1383 return vhost_get_avail(vq, *event, vhost_used_event(vq));
1384 }
1385
vhost_get_used_idx(struct vhost_virtqueue * vq,__virtio16 * idx)1386 static inline int vhost_get_used_idx(struct vhost_virtqueue *vq,
1387 __virtio16 *idx)
1388 {
1389 return vhost_get_used(vq, *idx, &vq->used->idx);
1390 }
1391
vhost_get_desc(struct vhost_virtqueue * vq,struct vring_desc * desc,int idx)1392 static inline int vhost_get_desc(struct vhost_virtqueue *vq,
1393 struct vring_desc *desc, int idx)
1394 {
1395 return vhost_copy_from_user(vq, desc, vq->desc + idx, sizeof(*desc));
1396 }
1397
vhost_iotlb_notify_vq(struct vhost_dev * d,struct vhost_iotlb_msg * msg)1398 static void vhost_iotlb_notify_vq(struct vhost_dev *d,
1399 struct vhost_iotlb_msg *msg)
1400 {
1401 struct vhost_msg_node *node, *n;
1402
1403 spin_lock(&d->iotlb_lock);
1404
1405 list_for_each_entry_safe(node, n, &d->pending_list, node) {
1406 struct vhost_iotlb_msg *vq_msg = &node->msg.iotlb;
1407 if (msg->iova <= vq_msg->iova &&
1408 msg->iova + msg->size - 1 >= vq_msg->iova &&
1409 vq_msg->type == VHOST_IOTLB_MISS) {
1410 vhost_poll_queue(&node->vq->poll);
1411 list_del(&node->node);
1412 kfree(node);
1413 }
1414 }
1415
1416 spin_unlock(&d->iotlb_lock);
1417 }
1418
umem_access_ok(u64 uaddr,u64 size,int access)1419 static bool umem_access_ok(u64 uaddr, u64 size, int access)
1420 {
1421 unsigned long a = uaddr;
1422
1423 /* Make sure 64 bit math will not overflow. */
1424 if (vhost_overflow(uaddr, size))
1425 return false;
1426
1427 if ((access & VHOST_ACCESS_RO) &&
1428 !access_ok((void __user *)a, size))
1429 return false;
1430 if ((access & VHOST_ACCESS_WO) &&
1431 !access_ok((void __user *)a, size))
1432 return false;
1433 return true;
1434 }
1435
vhost_process_iotlb_msg(struct vhost_dev * dev,u32 asid,struct vhost_iotlb_msg * msg)1436 static int vhost_process_iotlb_msg(struct vhost_dev *dev, u32 asid,
1437 struct vhost_iotlb_msg *msg)
1438 {
1439 int ret = 0;
1440
1441 if (asid != 0)
1442 return -EINVAL;
1443
1444 mutex_lock(&dev->mutex);
1445 vhost_dev_lock_vqs(dev);
1446 switch (msg->type) {
1447 case VHOST_IOTLB_UPDATE:
1448 if (!dev->iotlb) {
1449 ret = -EFAULT;
1450 break;
1451 }
1452 if (!umem_access_ok(msg->uaddr, msg->size, msg->perm)) {
1453 ret = -EFAULT;
1454 break;
1455 }
1456 vhost_vq_meta_reset(dev);
1457 if (vhost_iotlb_add_range(dev->iotlb, msg->iova,
1458 msg->iova + msg->size - 1,
1459 msg->uaddr, msg->perm)) {
1460 ret = -ENOMEM;
1461 break;
1462 }
1463 vhost_iotlb_notify_vq(dev, msg);
1464 break;
1465 case VHOST_IOTLB_INVALIDATE:
1466 if (!dev->iotlb) {
1467 ret = -EFAULT;
1468 break;
1469 }
1470 vhost_vq_meta_reset(dev);
1471 vhost_iotlb_del_range(dev->iotlb, msg->iova,
1472 msg->iova + msg->size - 1);
1473 break;
1474 default:
1475 ret = -EINVAL;
1476 break;
1477 }
1478
1479 vhost_dev_unlock_vqs(dev);
1480 mutex_unlock(&dev->mutex);
1481
1482 return ret;
1483 }
vhost_chr_write_iter(struct vhost_dev * dev,struct iov_iter * from)1484 ssize_t vhost_chr_write_iter(struct vhost_dev *dev,
1485 struct iov_iter *from)
1486 {
1487 struct vhost_iotlb_msg msg;
1488 size_t offset;
1489 int type, ret;
1490 u32 asid = 0;
1491
1492 ret = copy_from_iter(&type, sizeof(type), from);
1493 if (ret != sizeof(type)) {
1494 ret = -EINVAL;
1495 goto done;
1496 }
1497
1498 switch (type) {
1499 case VHOST_IOTLB_MSG:
1500 /* There maybe a hole after type for V1 message type,
1501 * so skip it here.
1502 */
1503 offset = offsetof(struct vhost_msg, iotlb) - sizeof(int);
1504 break;
1505 case VHOST_IOTLB_MSG_V2:
1506 if (vhost_backend_has_feature(dev->vqs[0],
1507 VHOST_BACKEND_F_IOTLB_ASID)) {
1508 ret = copy_from_iter(&asid, sizeof(asid), from);
1509 if (ret != sizeof(asid)) {
1510 ret = -EINVAL;
1511 goto done;
1512 }
1513 offset = 0;
1514 } else
1515 offset = sizeof(__u32);
1516 break;
1517 default:
1518 ret = -EINVAL;
1519 goto done;
1520 }
1521
1522 iov_iter_advance(from, offset);
1523 ret = copy_from_iter(&msg, sizeof(msg), from);
1524 if (ret != sizeof(msg)) {
1525 ret = -EINVAL;
1526 goto done;
1527 }
1528
1529 if (msg.type == VHOST_IOTLB_UPDATE && msg.size == 0) {
1530 ret = -EINVAL;
1531 goto done;
1532 }
1533
1534 if (dev->msg_handler)
1535 ret = dev->msg_handler(dev, asid, &msg);
1536 else
1537 ret = vhost_process_iotlb_msg(dev, asid, &msg);
1538 if (ret) {
1539 ret = -EFAULT;
1540 goto done;
1541 }
1542
1543 ret = (type == VHOST_IOTLB_MSG) ? sizeof(struct vhost_msg) :
1544 sizeof(struct vhost_msg_v2);
1545 done:
1546 return ret;
1547 }
1548 EXPORT_SYMBOL(vhost_chr_write_iter);
1549
vhost_chr_poll(struct file * file,struct vhost_dev * dev,poll_table * wait)1550 __poll_t vhost_chr_poll(struct file *file, struct vhost_dev *dev,
1551 poll_table *wait)
1552 {
1553 __poll_t mask = 0;
1554
1555 poll_wait(file, &dev->wait, wait);
1556
1557 if (!list_empty(&dev->read_list))
1558 mask |= EPOLLIN | EPOLLRDNORM;
1559
1560 return mask;
1561 }
1562 EXPORT_SYMBOL(vhost_chr_poll);
1563
vhost_chr_read_iter(struct vhost_dev * dev,struct iov_iter * to,int noblock)1564 ssize_t vhost_chr_read_iter(struct vhost_dev *dev, struct iov_iter *to,
1565 int noblock)
1566 {
1567 DEFINE_WAIT(wait);
1568 struct vhost_msg_node *node;
1569 ssize_t ret = 0;
1570 unsigned size = sizeof(struct vhost_msg);
1571
1572 if (iov_iter_count(to) < size)
1573 return 0;
1574
1575 while (1) {
1576 if (!noblock)
1577 prepare_to_wait(&dev->wait, &wait,
1578 TASK_INTERRUPTIBLE);
1579
1580 node = vhost_dequeue_msg(dev, &dev->read_list);
1581 if (node)
1582 break;
1583 if (noblock) {
1584 ret = -EAGAIN;
1585 break;
1586 }
1587 if (signal_pending(current)) {
1588 ret = -ERESTARTSYS;
1589 break;
1590 }
1591 if (!dev->iotlb) {
1592 ret = -EBADFD;
1593 break;
1594 }
1595
1596 schedule();
1597 }
1598
1599 if (!noblock)
1600 finish_wait(&dev->wait, &wait);
1601
1602 if (node) {
1603 struct vhost_iotlb_msg *msg;
1604 void *start = &node->msg;
1605
1606 switch (node->msg.type) {
1607 case VHOST_IOTLB_MSG:
1608 size = sizeof(node->msg);
1609 msg = &node->msg.iotlb;
1610 break;
1611 case VHOST_IOTLB_MSG_V2:
1612 size = sizeof(node->msg_v2);
1613 msg = &node->msg_v2.iotlb;
1614 break;
1615 default:
1616 BUG();
1617 break;
1618 }
1619
1620 ret = copy_to_iter(start, size, to);
1621 if (ret != size || msg->type != VHOST_IOTLB_MISS) {
1622 kfree(node);
1623 return ret;
1624 }
1625 vhost_enqueue_msg(dev, &dev->pending_list, node);
1626 }
1627
1628 return ret;
1629 }
1630 EXPORT_SYMBOL_GPL(vhost_chr_read_iter);
1631
vhost_iotlb_miss(struct vhost_virtqueue * vq,u64 iova,int access)1632 static int vhost_iotlb_miss(struct vhost_virtqueue *vq, u64 iova, int access)
1633 {
1634 struct vhost_dev *dev = vq->dev;
1635 struct vhost_msg_node *node;
1636 struct vhost_iotlb_msg *msg;
1637 bool v2 = vhost_backend_has_feature(vq, VHOST_BACKEND_F_IOTLB_MSG_V2);
1638
1639 node = vhost_new_msg(vq, v2 ? VHOST_IOTLB_MSG_V2 : VHOST_IOTLB_MSG);
1640 if (!node)
1641 return -ENOMEM;
1642
1643 if (v2) {
1644 node->msg_v2.type = VHOST_IOTLB_MSG_V2;
1645 msg = &node->msg_v2.iotlb;
1646 } else {
1647 msg = &node->msg.iotlb;
1648 }
1649
1650 msg->type = VHOST_IOTLB_MISS;
1651 msg->iova = iova;
1652 msg->perm = access;
1653
1654 vhost_enqueue_msg(dev, &dev->read_list, node);
1655
1656 return 0;
1657 }
1658
vq_access_ok(struct vhost_virtqueue * vq,unsigned int num,vring_desc_t __user * desc,vring_avail_t __user * avail,vring_used_t __user * used)1659 static bool vq_access_ok(struct vhost_virtqueue *vq, unsigned int num,
1660 vring_desc_t __user *desc,
1661 vring_avail_t __user *avail,
1662 vring_used_t __user *used)
1663
1664 {
1665 /* If an IOTLB device is present, the vring addresses are
1666 * GIOVAs. Access validation occurs at prefetch time. */
1667 if (vq->iotlb)
1668 return true;
1669
1670 return access_ok(desc, vhost_get_desc_size(vq, num)) &&
1671 access_ok(avail, vhost_get_avail_size(vq, num)) &&
1672 access_ok(used, vhost_get_used_size(vq, num));
1673 }
1674
vhost_vq_meta_update(struct vhost_virtqueue * vq,const struct vhost_iotlb_map * map,int type)1675 static void vhost_vq_meta_update(struct vhost_virtqueue *vq,
1676 const struct vhost_iotlb_map *map,
1677 int type)
1678 {
1679 int access = (type == VHOST_ADDR_USED) ?
1680 VHOST_ACCESS_WO : VHOST_ACCESS_RO;
1681
1682 if (likely(map->perm & access))
1683 vq->meta_iotlb[type] = map;
1684 }
1685
iotlb_access_ok(struct vhost_virtqueue * vq,int access,u64 addr,u64 len,int type)1686 static bool iotlb_access_ok(struct vhost_virtqueue *vq,
1687 int access, u64 addr, u64 len, int type)
1688 {
1689 const struct vhost_iotlb_map *map;
1690 struct vhost_iotlb *umem = vq->iotlb;
1691 u64 s = 0, size, orig_addr = addr, last = addr + len - 1;
1692
1693 if (vhost_vq_meta_fetch(vq, addr, len, type))
1694 return true;
1695
1696 while (len > s) {
1697 map = vhost_iotlb_itree_first(umem, addr, last);
1698 if (map == NULL || map->start > addr) {
1699 vhost_iotlb_miss(vq, addr, access);
1700 return false;
1701 } else if (!(map->perm & access)) {
1702 /* Report the possible access violation by
1703 * request another translation from userspace.
1704 */
1705 return false;
1706 }
1707
1708 size = map->size - addr + map->start;
1709
1710 if (orig_addr == addr && size >= len)
1711 vhost_vq_meta_update(vq, map, type);
1712
1713 s += size;
1714 addr += size;
1715 }
1716
1717 return true;
1718 }
1719
vq_meta_prefetch(struct vhost_virtqueue * vq)1720 int vq_meta_prefetch(struct vhost_virtqueue *vq)
1721 {
1722 unsigned int num = vq->num;
1723
1724 if (!vq->iotlb)
1725 return 1;
1726
1727 return iotlb_access_ok(vq, VHOST_MAP_RO, (u64)(uintptr_t)vq->desc,
1728 vhost_get_desc_size(vq, num), VHOST_ADDR_DESC) &&
1729 iotlb_access_ok(vq, VHOST_MAP_RO, (u64)(uintptr_t)vq->avail,
1730 vhost_get_avail_size(vq, num),
1731 VHOST_ADDR_AVAIL) &&
1732 iotlb_access_ok(vq, VHOST_MAP_WO, (u64)(uintptr_t)vq->used,
1733 vhost_get_used_size(vq, num), VHOST_ADDR_USED);
1734 }
1735 EXPORT_SYMBOL_GPL(vq_meta_prefetch);
1736
1737 /* Can we log writes? */
1738 /* Caller should have device mutex but not vq mutex */
vhost_log_access_ok(struct vhost_dev * dev)1739 bool vhost_log_access_ok(struct vhost_dev *dev)
1740 {
1741 return memory_access_ok(dev, dev->umem, 1);
1742 }
1743 EXPORT_SYMBOL_GPL(vhost_log_access_ok);
1744
vq_log_used_access_ok(struct vhost_virtqueue * vq,void __user * log_base,bool log_used,u64 log_addr)1745 static bool vq_log_used_access_ok(struct vhost_virtqueue *vq,
1746 void __user *log_base,
1747 bool log_used,
1748 u64 log_addr)
1749 {
1750 /* If an IOTLB device is present, log_addr is a GIOVA that
1751 * will never be logged by log_used(). */
1752 if (vq->iotlb)
1753 return true;
1754
1755 return !log_used || log_access_ok(log_base, log_addr,
1756 vhost_get_used_size(vq, vq->num));
1757 }
1758
1759 /* Verify access for write logging. */
1760 /* Caller should have vq mutex and device mutex */
vq_log_access_ok(struct vhost_virtqueue * vq,void __user * log_base)1761 static bool vq_log_access_ok(struct vhost_virtqueue *vq,
1762 void __user *log_base)
1763 {
1764 return vq_memory_access_ok(log_base, vq->umem,
1765 vhost_has_feature(vq, VHOST_F_LOG_ALL)) &&
1766 vq_log_used_access_ok(vq, log_base, vq->log_used, vq->log_addr);
1767 }
1768
1769 /* Can we start vq? */
1770 /* Caller should have vq mutex and device mutex */
vhost_vq_access_ok(struct vhost_virtqueue * vq)1771 bool vhost_vq_access_ok(struct vhost_virtqueue *vq)
1772 {
1773 if (!vq_log_access_ok(vq, vq->log_base))
1774 return false;
1775
1776 return vq_access_ok(vq, vq->num, vq->desc, vq->avail, vq->used);
1777 }
1778 EXPORT_SYMBOL_GPL(vhost_vq_access_ok);
1779
vhost_set_memory(struct vhost_dev * d,struct vhost_memory __user * m)1780 static long vhost_set_memory(struct vhost_dev *d, struct vhost_memory __user *m)
1781 {
1782 struct vhost_memory mem, *newmem;
1783 struct vhost_memory_region *region;
1784 struct vhost_iotlb *newumem, *oldumem;
1785 unsigned long size = offsetof(struct vhost_memory, regions);
1786 int i;
1787
1788 if (copy_from_user(&mem, m, size))
1789 return -EFAULT;
1790 if (mem.padding)
1791 return -EOPNOTSUPP;
1792 if (mem.nregions > max_mem_regions)
1793 return -E2BIG;
1794 newmem = kvzalloc(struct_size(newmem, regions, mem.nregions),
1795 GFP_KERNEL);
1796 if (!newmem)
1797 return -ENOMEM;
1798
1799 memcpy(newmem, &mem, size);
1800 if (copy_from_user(newmem->regions, m->regions,
1801 flex_array_size(newmem, regions, mem.nregions))) {
1802 kvfree(newmem);
1803 return -EFAULT;
1804 }
1805
1806 newumem = iotlb_alloc();
1807 if (!newumem) {
1808 kvfree(newmem);
1809 return -ENOMEM;
1810 }
1811
1812 for (region = newmem->regions;
1813 region < newmem->regions + mem.nregions;
1814 region++) {
1815 if (vhost_iotlb_add_range(newumem,
1816 region->guest_phys_addr,
1817 region->guest_phys_addr +
1818 region->memory_size - 1,
1819 region->userspace_addr,
1820 VHOST_MAP_RW))
1821 goto err;
1822 }
1823
1824 if (!memory_access_ok(d, newumem, 0))
1825 goto err;
1826
1827 oldumem = d->umem;
1828 d->umem = newumem;
1829
1830 /* All memory accesses are done under some VQ mutex. */
1831 for (i = 0; i < d->nvqs; ++i) {
1832 mutex_lock(&d->vqs[i]->mutex);
1833 d->vqs[i]->umem = newumem;
1834 mutex_unlock(&d->vqs[i]->mutex);
1835 }
1836
1837 kvfree(newmem);
1838 vhost_iotlb_free(oldumem);
1839 return 0;
1840
1841 err:
1842 vhost_iotlb_free(newumem);
1843 kvfree(newmem);
1844 return -EFAULT;
1845 }
1846
vhost_vring_set_num(struct vhost_dev * d,struct vhost_virtqueue * vq,void __user * argp)1847 static long vhost_vring_set_num(struct vhost_dev *d,
1848 struct vhost_virtqueue *vq,
1849 void __user *argp)
1850 {
1851 struct vhost_vring_state s;
1852
1853 /* Resizing ring with an active backend?
1854 * You don't want to do that. */
1855 if (vq->private_data)
1856 return -EBUSY;
1857
1858 if (copy_from_user(&s, argp, sizeof s))
1859 return -EFAULT;
1860
1861 if (!s.num || s.num > 0xffff || (s.num & (s.num - 1)))
1862 return -EINVAL;
1863 vq->num = s.num;
1864
1865 return 0;
1866 }
1867
vhost_vring_set_addr(struct vhost_dev * d,struct vhost_virtqueue * vq,void __user * argp)1868 static long vhost_vring_set_addr(struct vhost_dev *d,
1869 struct vhost_virtqueue *vq,
1870 void __user *argp)
1871 {
1872 struct vhost_vring_addr a;
1873
1874 if (copy_from_user(&a, argp, sizeof a))
1875 return -EFAULT;
1876 if (a.flags & ~(0x1 << VHOST_VRING_F_LOG))
1877 return -EOPNOTSUPP;
1878
1879 /* For 32bit, verify that the top 32bits of the user
1880 data are set to zero. */
1881 if ((u64)(unsigned long)a.desc_user_addr != a.desc_user_addr ||
1882 (u64)(unsigned long)a.used_user_addr != a.used_user_addr ||
1883 (u64)(unsigned long)a.avail_user_addr != a.avail_user_addr)
1884 return -EFAULT;
1885
1886 /* Make sure it's safe to cast pointers to vring types. */
1887 BUILD_BUG_ON(__alignof__ *vq->avail > VRING_AVAIL_ALIGN_SIZE);
1888 BUILD_BUG_ON(__alignof__ *vq->used > VRING_USED_ALIGN_SIZE);
1889 if ((a.avail_user_addr & (VRING_AVAIL_ALIGN_SIZE - 1)) ||
1890 (a.used_user_addr & (VRING_USED_ALIGN_SIZE - 1)) ||
1891 (a.log_guest_addr & (VRING_USED_ALIGN_SIZE - 1)))
1892 return -EINVAL;
1893
1894 /* We only verify access here if backend is configured.
1895 * If it is not, we don't as size might not have been setup.
1896 * We will verify when backend is configured. */
1897 if (vq->private_data) {
1898 if (!vq_access_ok(vq, vq->num,
1899 (void __user *)(unsigned long)a.desc_user_addr,
1900 (void __user *)(unsigned long)a.avail_user_addr,
1901 (void __user *)(unsigned long)a.used_user_addr))
1902 return -EINVAL;
1903
1904 /* Also validate log access for used ring if enabled. */
1905 if (!vq_log_used_access_ok(vq, vq->log_base,
1906 a.flags & (0x1 << VHOST_VRING_F_LOG),
1907 a.log_guest_addr))
1908 return -EINVAL;
1909 }
1910
1911 vq->log_used = !!(a.flags & (0x1 << VHOST_VRING_F_LOG));
1912 vq->desc = (void __user *)(unsigned long)a.desc_user_addr;
1913 vq->avail = (void __user *)(unsigned long)a.avail_user_addr;
1914 vq->log_addr = a.log_guest_addr;
1915 vq->used = (void __user *)(unsigned long)a.used_user_addr;
1916
1917 return 0;
1918 }
1919
vhost_vring_set_num_addr(struct vhost_dev * d,struct vhost_virtqueue * vq,unsigned int ioctl,void __user * argp)1920 static long vhost_vring_set_num_addr(struct vhost_dev *d,
1921 struct vhost_virtqueue *vq,
1922 unsigned int ioctl,
1923 void __user *argp)
1924 {
1925 long r;
1926
1927 mutex_lock(&vq->mutex);
1928
1929 switch (ioctl) {
1930 case VHOST_SET_VRING_NUM:
1931 r = vhost_vring_set_num(d, vq, argp);
1932 break;
1933 case VHOST_SET_VRING_ADDR:
1934 r = vhost_vring_set_addr(d, vq, argp);
1935 break;
1936 default:
1937 BUG();
1938 }
1939
1940 mutex_unlock(&vq->mutex);
1941
1942 return r;
1943 }
vhost_vring_ioctl(struct vhost_dev * d,unsigned int ioctl,void __user * argp)1944 long vhost_vring_ioctl(struct vhost_dev *d, unsigned int ioctl, void __user *argp)
1945 {
1946 struct file *eventfp, *filep = NULL;
1947 bool pollstart = false, pollstop = false;
1948 struct eventfd_ctx *ctx = NULL;
1949 struct vhost_virtqueue *vq;
1950 struct vhost_vring_state s;
1951 struct vhost_vring_file f;
1952 u32 idx;
1953 long r;
1954
1955 r = vhost_get_vq_from_user(d, argp, &vq, &idx);
1956 if (r < 0)
1957 return r;
1958
1959 if (ioctl == VHOST_SET_VRING_NUM ||
1960 ioctl == VHOST_SET_VRING_ADDR) {
1961 return vhost_vring_set_num_addr(d, vq, ioctl, argp);
1962 }
1963
1964 mutex_lock(&vq->mutex);
1965
1966 switch (ioctl) {
1967 case VHOST_SET_VRING_BASE:
1968 /* Moving base with an active backend?
1969 * You don't want to do that. */
1970 if (vq->private_data) {
1971 r = -EBUSY;
1972 break;
1973 }
1974 if (copy_from_user(&s, argp, sizeof s)) {
1975 r = -EFAULT;
1976 break;
1977 }
1978 if (vhost_has_feature(vq, VIRTIO_F_RING_PACKED)) {
1979 vq->last_avail_idx = s.num & 0xffff;
1980 vq->last_used_idx = (s.num >> 16) & 0xffff;
1981 } else {
1982 if (s.num > 0xffff) {
1983 r = -EINVAL;
1984 break;
1985 }
1986 vq->last_avail_idx = s.num;
1987 }
1988 /* Forget the cached index value. */
1989 vq->avail_idx = vq->last_avail_idx;
1990 break;
1991 case VHOST_GET_VRING_BASE:
1992 s.index = idx;
1993 if (vhost_has_feature(vq, VIRTIO_F_RING_PACKED))
1994 s.num = (u32)vq->last_avail_idx | ((u32)vq->last_used_idx << 16);
1995 else
1996 s.num = vq->last_avail_idx;
1997 if (copy_to_user(argp, &s, sizeof s))
1998 r = -EFAULT;
1999 break;
2000 case VHOST_SET_VRING_KICK:
2001 if (copy_from_user(&f, argp, sizeof f)) {
2002 r = -EFAULT;
2003 break;
2004 }
2005 eventfp = f.fd == VHOST_FILE_UNBIND ? NULL : eventfd_fget(f.fd);
2006 if (IS_ERR(eventfp)) {
2007 r = PTR_ERR(eventfp);
2008 break;
2009 }
2010 if (eventfp != vq->kick) {
2011 pollstop = (filep = vq->kick) != NULL;
2012 pollstart = (vq->kick = eventfp) != NULL;
2013 } else
2014 filep = eventfp;
2015 break;
2016 case VHOST_SET_VRING_CALL:
2017 if (copy_from_user(&f, argp, sizeof f)) {
2018 r = -EFAULT;
2019 break;
2020 }
2021 ctx = f.fd == VHOST_FILE_UNBIND ? NULL : eventfd_ctx_fdget(f.fd);
2022 if (IS_ERR(ctx)) {
2023 r = PTR_ERR(ctx);
2024 break;
2025 }
2026
2027 swap(ctx, vq->call_ctx.ctx);
2028 break;
2029 case VHOST_SET_VRING_ERR:
2030 if (copy_from_user(&f, argp, sizeof f)) {
2031 r = -EFAULT;
2032 break;
2033 }
2034 ctx = f.fd == VHOST_FILE_UNBIND ? NULL : eventfd_ctx_fdget(f.fd);
2035 if (IS_ERR(ctx)) {
2036 r = PTR_ERR(ctx);
2037 break;
2038 }
2039 swap(ctx, vq->error_ctx);
2040 break;
2041 case VHOST_SET_VRING_ENDIAN:
2042 r = vhost_set_vring_endian(vq, argp);
2043 break;
2044 case VHOST_GET_VRING_ENDIAN:
2045 r = vhost_get_vring_endian(vq, idx, argp);
2046 break;
2047 case VHOST_SET_VRING_BUSYLOOP_TIMEOUT:
2048 if (copy_from_user(&s, argp, sizeof(s))) {
2049 r = -EFAULT;
2050 break;
2051 }
2052 vq->busyloop_timeout = s.num;
2053 break;
2054 case VHOST_GET_VRING_BUSYLOOP_TIMEOUT:
2055 s.index = idx;
2056 s.num = vq->busyloop_timeout;
2057 if (copy_to_user(argp, &s, sizeof(s)))
2058 r = -EFAULT;
2059 break;
2060 default:
2061 r = -ENOIOCTLCMD;
2062 }
2063
2064 if (pollstop && vq->handle_kick)
2065 vhost_poll_stop(&vq->poll);
2066
2067 if (!IS_ERR_OR_NULL(ctx))
2068 eventfd_ctx_put(ctx);
2069 if (filep)
2070 fput(filep);
2071
2072 if (pollstart && vq->handle_kick)
2073 r = vhost_poll_start(&vq->poll, vq->kick);
2074
2075 mutex_unlock(&vq->mutex);
2076
2077 if (pollstop && vq->handle_kick)
2078 vhost_dev_flush(vq->poll.dev);
2079 return r;
2080 }
2081 EXPORT_SYMBOL_GPL(vhost_vring_ioctl);
2082
vhost_init_device_iotlb(struct vhost_dev * d)2083 int vhost_init_device_iotlb(struct vhost_dev *d)
2084 {
2085 struct vhost_iotlb *niotlb, *oiotlb;
2086 int i;
2087
2088 niotlb = iotlb_alloc();
2089 if (!niotlb)
2090 return -ENOMEM;
2091
2092 oiotlb = d->iotlb;
2093 d->iotlb = niotlb;
2094
2095 for (i = 0; i < d->nvqs; ++i) {
2096 struct vhost_virtqueue *vq = d->vqs[i];
2097
2098 mutex_lock(&vq->mutex);
2099 vq->iotlb = niotlb;
2100 __vhost_vq_meta_reset(vq);
2101 mutex_unlock(&vq->mutex);
2102 }
2103
2104 vhost_iotlb_free(oiotlb);
2105
2106 return 0;
2107 }
2108 EXPORT_SYMBOL_GPL(vhost_init_device_iotlb);
2109
2110 /* Caller must have device mutex */
vhost_dev_ioctl(struct vhost_dev * d,unsigned int ioctl,void __user * argp)2111 long vhost_dev_ioctl(struct vhost_dev *d, unsigned int ioctl, void __user *argp)
2112 {
2113 struct eventfd_ctx *ctx;
2114 u64 p;
2115 long r;
2116 int i, fd;
2117
2118 /* If you are not the owner, you can become one */
2119 if (ioctl == VHOST_SET_OWNER) {
2120 r = vhost_dev_set_owner(d);
2121 goto done;
2122 }
2123
2124 /* You must be the owner to do anything else */
2125 r = vhost_dev_check_owner(d);
2126 if (r)
2127 goto done;
2128
2129 switch (ioctl) {
2130 case VHOST_SET_MEM_TABLE:
2131 r = vhost_set_memory(d, argp);
2132 break;
2133 case VHOST_SET_LOG_BASE:
2134 if (copy_from_user(&p, argp, sizeof p)) {
2135 r = -EFAULT;
2136 break;
2137 }
2138 if ((u64)(unsigned long)p != p) {
2139 r = -EFAULT;
2140 break;
2141 }
2142 for (i = 0; i < d->nvqs; ++i) {
2143 struct vhost_virtqueue *vq;
2144 void __user *base = (void __user *)(unsigned long)p;
2145 vq = d->vqs[i];
2146 mutex_lock(&vq->mutex);
2147 /* If ring is inactive, will check when it's enabled. */
2148 if (vq->private_data && !vq_log_access_ok(vq, base))
2149 r = -EFAULT;
2150 else
2151 vq->log_base = base;
2152 mutex_unlock(&vq->mutex);
2153 }
2154 break;
2155 case VHOST_SET_LOG_FD:
2156 r = get_user(fd, (int __user *)argp);
2157 if (r < 0)
2158 break;
2159 ctx = fd == VHOST_FILE_UNBIND ? NULL : eventfd_ctx_fdget(fd);
2160 if (IS_ERR(ctx)) {
2161 r = PTR_ERR(ctx);
2162 break;
2163 }
2164 swap(ctx, d->log_ctx);
2165 for (i = 0; i < d->nvqs; ++i) {
2166 mutex_lock(&d->vqs[i]->mutex);
2167 d->vqs[i]->log_ctx = d->log_ctx;
2168 mutex_unlock(&d->vqs[i]->mutex);
2169 }
2170 if (ctx)
2171 eventfd_ctx_put(ctx);
2172 break;
2173 default:
2174 r = -ENOIOCTLCMD;
2175 break;
2176 }
2177 done:
2178 return r;
2179 }
2180 EXPORT_SYMBOL_GPL(vhost_dev_ioctl);
2181
2182 /* TODO: This is really inefficient. We need something like get_user()
2183 * (instruction directly accesses the data, with an exception table entry
2184 * returning -EFAULT). See Documentation/arch/x86/exception-tables.rst.
2185 */
set_bit_to_user(int nr,void __user * addr)2186 static int set_bit_to_user(int nr, void __user *addr)
2187 {
2188 unsigned long log = (unsigned long)addr;
2189 struct page *page;
2190 void *base;
2191 int bit = nr + (log % PAGE_SIZE) * 8;
2192 int r;
2193
2194 r = pin_user_pages_fast(log, 1, FOLL_WRITE, &page);
2195 if (r < 0)
2196 return r;
2197 BUG_ON(r != 1);
2198 base = kmap_atomic(page);
2199 set_bit(bit, base);
2200 kunmap_atomic(base);
2201 unpin_user_pages_dirty_lock(&page, 1, true);
2202 return 0;
2203 }
2204
log_write(void __user * log_base,u64 write_address,u64 write_length)2205 static int log_write(void __user *log_base,
2206 u64 write_address, u64 write_length)
2207 {
2208 u64 write_page = write_address / VHOST_PAGE_SIZE;
2209 int r;
2210
2211 if (!write_length)
2212 return 0;
2213 write_length += write_address % VHOST_PAGE_SIZE;
2214 for (;;) {
2215 u64 base = (u64)(unsigned long)log_base;
2216 u64 log = base + write_page / 8;
2217 int bit = write_page % 8;
2218 if ((u64)(unsigned long)log != log)
2219 return -EFAULT;
2220 r = set_bit_to_user(bit, (void __user *)(unsigned long)log);
2221 if (r < 0)
2222 return r;
2223 if (write_length <= VHOST_PAGE_SIZE)
2224 break;
2225 write_length -= VHOST_PAGE_SIZE;
2226 write_page += 1;
2227 }
2228 return r;
2229 }
2230
log_write_hva(struct vhost_virtqueue * vq,u64 hva,u64 len)2231 static int log_write_hva(struct vhost_virtqueue *vq, u64 hva, u64 len)
2232 {
2233 struct vhost_iotlb *umem = vq->umem;
2234 struct vhost_iotlb_map *u;
2235 u64 start, end, l, min;
2236 int r;
2237 bool hit = false;
2238
2239 while (len) {
2240 min = len;
2241 /* More than one GPAs can be mapped into a single HVA. So
2242 * iterate all possible umems here to be safe.
2243 */
2244 list_for_each_entry(u, &umem->list, link) {
2245 if (u->addr > hva - 1 + len ||
2246 u->addr - 1 + u->size < hva)
2247 continue;
2248 start = max(u->addr, hva);
2249 end = min(u->addr - 1 + u->size, hva - 1 + len);
2250 l = end - start + 1;
2251 r = log_write(vq->log_base,
2252 u->start + start - u->addr,
2253 l);
2254 if (r < 0)
2255 return r;
2256 hit = true;
2257 min = min(l, min);
2258 }
2259
2260 if (!hit)
2261 return -EFAULT;
2262
2263 len -= min;
2264 hva += min;
2265 }
2266
2267 return 0;
2268 }
2269
log_used(struct vhost_virtqueue * vq,u64 used_offset,u64 len)2270 static int log_used(struct vhost_virtqueue *vq, u64 used_offset, u64 len)
2271 {
2272 struct iovec *iov = vq->log_iov;
2273 int i, ret;
2274
2275 if (!vq->iotlb)
2276 return log_write(vq->log_base, vq->log_addr + used_offset, len);
2277
2278 ret = translate_desc(vq, (uintptr_t)vq->used + used_offset,
2279 len, iov, 64, VHOST_ACCESS_WO);
2280 if (ret < 0)
2281 return ret;
2282
2283 for (i = 0; i < ret; i++) {
2284 ret = log_write_hva(vq, (uintptr_t)iov[i].iov_base,
2285 iov[i].iov_len);
2286 if (ret)
2287 return ret;
2288 }
2289
2290 return 0;
2291 }
2292
vhost_log_write(struct vhost_virtqueue * vq,struct vhost_log * log,unsigned int log_num,u64 len,struct iovec * iov,int count)2293 int vhost_log_write(struct vhost_virtqueue *vq, struct vhost_log *log,
2294 unsigned int log_num, u64 len, struct iovec *iov, int count)
2295 {
2296 int i, r;
2297
2298 /* Make sure data written is seen before log. */
2299 smp_wmb();
2300
2301 if (vq->iotlb) {
2302 for (i = 0; i < count; i++) {
2303 r = log_write_hva(vq, (uintptr_t)iov[i].iov_base,
2304 iov[i].iov_len);
2305 if (r < 0)
2306 return r;
2307 }
2308 return 0;
2309 }
2310
2311 for (i = 0; i < log_num; ++i) {
2312 u64 l = min(log[i].len, len);
2313 r = log_write(vq->log_base, log[i].addr, l);
2314 if (r < 0)
2315 return r;
2316 len -= l;
2317 if (!len) {
2318 if (vq->log_ctx)
2319 eventfd_signal(vq->log_ctx, 1);
2320 return 0;
2321 }
2322 }
2323 /* Length written exceeds what we have stored. This is a bug. */
2324 BUG();
2325 return 0;
2326 }
2327 EXPORT_SYMBOL_GPL(vhost_log_write);
2328
vhost_update_used_flags(struct vhost_virtqueue * vq)2329 static int vhost_update_used_flags(struct vhost_virtqueue *vq)
2330 {
2331 void __user *used;
2332 if (vhost_put_used_flags(vq))
2333 return -EFAULT;
2334 if (unlikely(vq->log_used)) {
2335 /* Make sure the flag is seen before log. */
2336 smp_wmb();
2337 /* Log used flag write. */
2338 used = &vq->used->flags;
2339 log_used(vq, (used - (void __user *)vq->used),
2340 sizeof vq->used->flags);
2341 if (vq->log_ctx)
2342 eventfd_signal(vq->log_ctx, 1);
2343 }
2344 return 0;
2345 }
2346
vhost_update_avail_event(struct vhost_virtqueue * vq)2347 static int vhost_update_avail_event(struct vhost_virtqueue *vq)
2348 {
2349 if (vhost_put_avail_event(vq))
2350 return -EFAULT;
2351 if (unlikely(vq->log_used)) {
2352 void __user *used;
2353 /* Make sure the event is seen before log. */
2354 smp_wmb();
2355 /* Log avail event write */
2356 used = vhost_avail_event(vq);
2357 log_used(vq, (used - (void __user *)vq->used),
2358 sizeof *vhost_avail_event(vq));
2359 if (vq->log_ctx)
2360 eventfd_signal(vq->log_ctx, 1);
2361 }
2362 return 0;
2363 }
2364
vhost_vq_init_access(struct vhost_virtqueue * vq)2365 int vhost_vq_init_access(struct vhost_virtqueue *vq)
2366 {
2367 __virtio16 last_used_idx;
2368 int r;
2369 bool is_le = vq->is_le;
2370
2371 if (!vq->private_data)
2372 return 0;
2373
2374 vhost_init_is_le(vq);
2375
2376 r = vhost_update_used_flags(vq);
2377 if (r)
2378 goto err;
2379 vq->signalled_used_valid = false;
2380 if (!vq->iotlb &&
2381 !access_ok(&vq->used->idx, sizeof vq->used->idx)) {
2382 r = -EFAULT;
2383 goto err;
2384 }
2385 r = vhost_get_used_idx(vq, &last_used_idx);
2386 if (r) {
2387 vq_err(vq, "Can't access used idx at %p\n",
2388 &vq->used->idx);
2389 goto err;
2390 }
2391 vq->last_used_idx = vhost16_to_cpu(vq, last_used_idx);
2392 return 0;
2393
2394 err:
2395 vq->is_le = is_le;
2396 return r;
2397 }
2398 EXPORT_SYMBOL_GPL(vhost_vq_init_access);
2399
translate_desc(struct vhost_virtqueue * vq,u64 addr,u32 len,struct iovec iov[],int iov_size,int access)2400 static int translate_desc(struct vhost_virtqueue *vq, u64 addr, u32 len,
2401 struct iovec iov[], int iov_size, int access)
2402 {
2403 const struct vhost_iotlb_map *map;
2404 struct vhost_dev *dev = vq->dev;
2405 struct vhost_iotlb *umem = dev->iotlb ? dev->iotlb : dev->umem;
2406 struct iovec *_iov;
2407 u64 s = 0, last = addr + len - 1;
2408 int ret = 0;
2409
2410 while ((u64)len > s) {
2411 u64 size;
2412 if (unlikely(ret >= iov_size)) {
2413 ret = -ENOBUFS;
2414 break;
2415 }
2416
2417 map = vhost_iotlb_itree_first(umem, addr, last);
2418 if (map == NULL || map->start > addr) {
2419 if (umem != dev->iotlb) {
2420 ret = -EFAULT;
2421 break;
2422 }
2423 ret = -EAGAIN;
2424 break;
2425 } else if (!(map->perm & access)) {
2426 ret = -EPERM;
2427 break;
2428 }
2429
2430 _iov = iov + ret;
2431 size = map->size - addr + map->start;
2432 _iov->iov_len = min((u64)len - s, size);
2433 _iov->iov_base = (void __user *)(unsigned long)
2434 (map->addr + addr - map->start);
2435 s += size;
2436 addr += size;
2437 ++ret;
2438 }
2439
2440 if (ret == -EAGAIN)
2441 vhost_iotlb_miss(vq, addr, access);
2442 return ret;
2443 }
2444
2445 /* Each buffer in the virtqueues is actually a chain of descriptors. This
2446 * function returns the next descriptor in the chain,
2447 * or -1U if we're at the end. */
next_desc(struct vhost_virtqueue * vq,struct vring_desc * desc)2448 static unsigned next_desc(struct vhost_virtqueue *vq, struct vring_desc *desc)
2449 {
2450 unsigned int next;
2451
2452 /* If this descriptor says it doesn't chain, we're done. */
2453 if (!(desc->flags & cpu_to_vhost16(vq, VRING_DESC_F_NEXT)))
2454 return -1U;
2455
2456 /* Check they're not leading us off end of descriptors. */
2457 next = vhost16_to_cpu(vq, READ_ONCE(desc->next));
2458 return next;
2459 }
2460
get_indirect(struct vhost_virtqueue * vq,struct iovec iov[],unsigned int iov_size,unsigned int * out_num,unsigned int * in_num,struct vhost_log * log,unsigned int * log_num,struct vring_desc * indirect)2461 static int get_indirect(struct vhost_virtqueue *vq,
2462 struct iovec iov[], unsigned int iov_size,
2463 unsigned int *out_num, unsigned int *in_num,
2464 struct vhost_log *log, unsigned int *log_num,
2465 struct vring_desc *indirect)
2466 {
2467 struct vring_desc desc;
2468 unsigned int i = 0, count, found = 0;
2469 u32 len = vhost32_to_cpu(vq, indirect->len);
2470 struct iov_iter from;
2471 int ret, access;
2472
2473 /* Sanity check */
2474 if (unlikely(len % sizeof desc)) {
2475 vq_err(vq, "Invalid length in indirect descriptor: "
2476 "len 0x%llx not multiple of 0x%zx\n",
2477 (unsigned long long)len,
2478 sizeof desc);
2479 return -EINVAL;
2480 }
2481
2482 ret = translate_desc(vq, vhost64_to_cpu(vq, indirect->addr), len, vq->indirect,
2483 UIO_MAXIOV, VHOST_ACCESS_RO);
2484 if (unlikely(ret < 0)) {
2485 if (ret != -EAGAIN)
2486 vq_err(vq, "Translation failure %d in indirect.\n", ret);
2487 return ret;
2488 }
2489 iov_iter_init(&from, ITER_SOURCE, vq->indirect, ret, len);
2490 count = len / sizeof desc;
2491 /* Buffers are chained via a 16 bit next field, so
2492 * we can have at most 2^16 of these. */
2493 if (unlikely(count > USHRT_MAX + 1)) {
2494 vq_err(vq, "Indirect buffer length too big: %d\n",
2495 indirect->len);
2496 return -E2BIG;
2497 }
2498
2499 do {
2500 unsigned iov_count = *in_num + *out_num;
2501 if (unlikely(++found > count)) {
2502 vq_err(vq, "Loop detected: last one at %u "
2503 "indirect size %u\n",
2504 i, count);
2505 return -EINVAL;
2506 }
2507 if (unlikely(!copy_from_iter_full(&desc, sizeof(desc), &from))) {
2508 vq_err(vq, "Failed indirect descriptor: idx %d, %zx\n",
2509 i, (size_t)vhost64_to_cpu(vq, indirect->addr) + i * sizeof desc);
2510 return -EINVAL;
2511 }
2512 if (unlikely(desc.flags & cpu_to_vhost16(vq, VRING_DESC_F_INDIRECT))) {
2513 vq_err(vq, "Nested indirect descriptor: idx %d, %zx\n",
2514 i, (size_t)vhost64_to_cpu(vq, indirect->addr) + i * sizeof desc);
2515 return -EINVAL;
2516 }
2517
2518 if (desc.flags & cpu_to_vhost16(vq, VRING_DESC_F_WRITE))
2519 access = VHOST_ACCESS_WO;
2520 else
2521 access = VHOST_ACCESS_RO;
2522
2523 ret = translate_desc(vq, vhost64_to_cpu(vq, desc.addr),
2524 vhost32_to_cpu(vq, desc.len), iov + iov_count,
2525 iov_size - iov_count, access);
2526 if (unlikely(ret < 0)) {
2527 if (ret != -EAGAIN)
2528 vq_err(vq, "Translation failure %d indirect idx %d\n",
2529 ret, i);
2530 return ret;
2531 }
2532 /* If this is an input descriptor, increment that count. */
2533 if (access == VHOST_ACCESS_WO) {
2534 *in_num += ret;
2535 if (unlikely(log && ret)) {
2536 log[*log_num].addr = vhost64_to_cpu(vq, desc.addr);
2537 log[*log_num].len = vhost32_to_cpu(vq, desc.len);
2538 ++*log_num;
2539 }
2540 } else {
2541 /* If it's an output descriptor, they're all supposed
2542 * to come before any input descriptors. */
2543 if (unlikely(*in_num)) {
2544 vq_err(vq, "Indirect descriptor "
2545 "has out after in: idx %d\n", i);
2546 return -EINVAL;
2547 }
2548 *out_num += ret;
2549 }
2550 } while ((i = next_desc(vq, &desc)) != -1);
2551 return 0;
2552 }
2553
2554 /* This looks in the virtqueue and for the first available buffer, and converts
2555 * it to an iovec for convenient access. Since descriptors consist of some
2556 * number of output then some number of input descriptors, it's actually two
2557 * iovecs, but we pack them into one and note how many of each there were.
2558 *
2559 * This function returns the descriptor number found, or vq->num (which is
2560 * never a valid descriptor number) if none was found. A negative code is
2561 * returned on error. */
vhost_get_vq_desc(struct vhost_virtqueue * vq,struct iovec iov[],unsigned int iov_size,unsigned int * out_num,unsigned int * in_num,struct vhost_log * log,unsigned int * log_num)2562 int vhost_get_vq_desc(struct vhost_virtqueue *vq,
2563 struct iovec iov[], unsigned int iov_size,
2564 unsigned int *out_num, unsigned int *in_num,
2565 struct vhost_log *log, unsigned int *log_num)
2566 {
2567 struct vring_desc desc;
2568 unsigned int i, head, found = 0;
2569 u16 last_avail_idx;
2570 __virtio16 avail_idx;
2571 __virtio16 ring_head;
2572 int ret, access;
2573
2574 /* Check it isn't doing very strange things with descriptor numbers. */
2575 last_avail_idx = vq->last_avail_idx;
2576
2577 if (vq->avail_idx == vq->last_avail_idx) {
2578 if (unlikely(vhost_get_avail_idx(vq, &avail_idx))) {
2579 vq_err(vq, "Failed to access avail idx at %p\n",
2580 &vq->avail->idx);
2581 return -EFAULT;
2582 }
2583 vq->avail_idx = vhost16_to_cpu(vq, avail_idx);
2584
2585 if (unlikely((u16)(vq->avail_idx - last_avail_idx) > vq->num)) {
2586 vq_err(vq, "Guest moved used index from %u to %u",
2587 last_avail_idx, vq->avail_idx);
2588 return -EFAULT;
2589 }
2590
2591 /* If there's nothing new since last we looked, return
2592 * invalid.
2593 */
2594 if (vq->avail_idx == last_avail_idx)
2595 return vq->num;
2596
2597 /* Only get avail ring entries after they have been
2598 * exposed by guest.
2599 */
2600 smp_rmb();
2601 }
2602
2603 /* Grab the next descriptor number they're advertising, and increment
2604 * the index we've seen. */
2605 if (unlikely(vhost_get_avail_head(vq, &ring_head, last_avail_idx))) {
2606 vq_err(vq, "Failed to read head: idx %d address %p\n",
2607 last_avail_idx,
2608 &vq->avail->ring[last_avail_idx % vq->num]);
2609 return -EFAULT;
2610 }
2611
2612 head = vhost16_to_cpu(vq, ring_head);
2613
2614 /* If their number is silly, that's an error. */
2615 if (unlikely(head >= vq->num)) {
2616 vq_err(vq, "Guest says index %u > %u is available",
2617 head, vq->num);
2618 return -EINVAL;
2619 }
2620
2621 /* When we start there are none of either input nor output. */
2622 *out_num = *in_num = 0;
2623 if (unlikely(log))
2624 *log_num = 0;
2625
2626 i = head;
2627 do {
2628 unsigned iov_count = *in_num + *out_num;
2629 if (unlikely(i >= vq->num)) {
2630 vq_err(vq, "Desc index is %u > %u, head = %u",
2631 i, vq->num, head);
2632 return -EINVAL;
2633 }
2634 if (unlikely(++found > vq->num)) {
2635 vq_err(vq, "Loop detected: last one at %u "
2636 "vq size %u head %u\n",
2637 i, vq->num, head);
2638 return -EINVAL;
2639 }
2640 ret = vhost_get_desc(vq, &desc, i);
2641 if (unlikely(ret)) {
2642 vq_err(vq, "Failed to get descriptor: idx %d addr %p\n",
2643 i, vq->desc + i);
2644 return -EFAULT;
2645 }
2646 if (desc.flags & cpu_to_vhost16(vq, VRING_DESC_F_INDIRECT)) {
2647 ret = get_indirect(vq, iov, iov_size,
2648 out_num, in_num,
2649 log, log_num, &desc);
2650 if (unlikely(ret < 0)) {
2651 if (ret != -EAGAIN)
2652 vq_err(vq, "Failure detected "
2653 "in indirect descriptor at idx %d\n", i);
2654 return ret;
2655 }
2656 continue;
2657 }
2658
2659 if (desc.flags & cpu_to_vhost16(vq, VRING_DESC_F_WRITE))
2660 access = VHOST_ACCESS_WO;
2661 else
2662 access = VHOST_ACCESS_RO;
2663 ret = translate_desc(vq, vhost64_to_cpu(vq, desc.addr),
2664 vhost32_to_cpu(vq, desc.len), iov + iov_count,
2665 iov_size - iov_count, access);
2666 if (unlikely(ret < 0)) {
2667 if (ret != -EAGAIN)
2668 vq_err(vq, "Translation failure %d descriptor idx %d\n",
2669 ret, i);
2670 return ret;
2671 }
2672 if (access == VHOST_ACCESS_WO) {
2673 /* If this is an input descriptor,
2674 * increment that count. */
2675 *in_num += ret;
2676 if (unlikely(log && ret)) {
2677 log[*log_num].addr = vhost64_to_cpu(vq, desc.addr);
2678 log[*log_num].len = vhost32_to_cpu(vq, desc.len);
2679 ++*log_num;
2680 }
2681 } else {
2682 /* If it's an output descriptor, they're all supposed
2683 * to come before any input descriptors. */
2684 if (unlikely(*in_num)) {
2685 vq_err(vq, "Descriptor has out after in: "
2686 "idx %d\n", i);
2687 return -EINVAL;
2688 }
2689 *out_num += ret;
2690 }
2691 } while ((i = next_desc(vq, &desc)) != -1);
2692
2693 /* On success, increment avail index. */
2694 vq->last_avail_idx++;
2695
2696 /* Assume notifications from guest are disabled at this point,
2697 * if they aren't we would need to update avail_event index. */
2698 BUG_ON(!(vq->used_flags & VRING_USED_F_NO_NOTIFY));
2699 return head;
2700 }
2701 EXPORT_SYMBOL_GPL(vhost_get_vq_desc);
2702
2703 /* Reverse the effect of vhost_get_vq_desc. Useful for error handling. */
vhost_discard_vq_desc(struct vhost_virtqueue * vq,int n)2704 void vhost_discard_vq_desc(struct vhost_virtqueue *vq, int n)
2705 {
2706 vq->last_avail_idx -= n;
2707 }
2708 EXPORT_SYMBOL_GPL(vhost_discard_vq_desc);
2709
2710 /* After we've used one of their buffers, we tell them about it. We'll then
2711 * want to notify the guest, using eventfd. */
vhost_add_used(struct vhost_virtqueue * vq,unsigned int head,int len)2712 int vhost_add_used(struct vhost_virtqueue *vq, unsigned int head, int len)
2713 {
2714 struct vring_used_elem heads = {
2715 cpu_to_vhost32(vq, head),
2716 cpu_to_vhost32(vq, len)
2717 };
2718
2719 return vhost_add_used_n(vq, &heads, 1);
2720 }
2721 EXPORT_SYMBOL_GPL(vhost_add_used);
2722
__vhost_add_used_n(struct vhost_virtqueue * vq,struct vring_used_elem * heads,unsigned count)2723 static int __vhost_add_used_n(struct vhost_virtqueue *vq,
2724 struct vring_used_elem *heads,
2725 unsigned count)
2726 {
2727 vring_used_elem_t __user *used;
2728 u16 old, new;
2729 int start;
2730
2731 start = vq->last_used_idx & (vq->num - 1);
2732 used = vq->used->ring + start;
2733 if (vhost_put_used(vq, heads, start, count)) {
2734 vq_err(vq, "Failed to write used");
2735 return -EFAULT;
2736 }
2737 if (unlikely(vq->log_used)) {
2738 /* Make sure data is seen before log. */
2739 smp_wmb();
2740 /* Log used ring entry write. */
2741 log_used(vq, ((void __user *)used - (void __user *)vq->used),
2742 count * sizeof *used);
2743 }
2744 old = vq->last_used_idx;
2745 new = (vq->last_used_idx += count);
2746 /* If the driver never bothers to signal in a very long while,
2747 * used index might wrap around. If that happens, invalidate
2748 * signalled_used index we stored. TODO: make sure driver
2749 * signals at least once in 2^16 and remove this. */
2750 if (unlikely((u16)(new - vq->signalled_used) < (u16)(new - old)))
2751 vq->signalled_used_valid = false;
2752 return 0;
2753 }
2754
2755 /* After we've used one of their buffers, we tell them about it. We'll then
2756 * want to notify the guest, using eventfd. */
vhost_add_used_n(struct vhost_virtqueue * vq,struct vring_used_elem * heads,unsigned count)2757 int vhost_add_used_n(struct vhost_virtqueue *vq, struct vring_used_elem *heads,
2758 unsigned count)
2759 {
2760 int start, n, r;
2761
2762 start = vq->last_used_idx & (vq->num - 1);
2763 n = vq->num - start;
2764 if (n < count) {
2765 r = __vhost_add_used_n(vq, heads, n);
2766 if (r < 0)
2767 return r;
2768 heads += n;
2769 count -= n;
2770 }
2771 r = __vhost_add_used_n(vq, heads, count);
2772
2773 /* Make sure buffer is written before we update index. */
2774 smp_wmb();
2775 if (vhost_put_used_idx(vq)) {
2776 vq_err(vq, "Failed to increment used idx");
2777 return -EFAULT;
2778 }
2779 if (unlikely(vq->log_used)) {
2780 /* Make sure used idx is seen before log. */
2781 smp_wmb();
2782 /* Log used index update. */
2783 log_used(vq, offsetof(struct vring_used, idx),
2784 sizeof vq->used->idx);
2785 if (vq->log_ctx)
2786 eventfd_signal(vq->log_ctx, 1);
2787 }
2788 return r;
2789 }
2790 EXPORT_SYMBOL_GPL(vhost_add_used_n);
2791
vhost_notify(struct vhost_dev * dev,struct vhost_virtqueue * vq)2792 static bool vhost_notify(struct vhost_dev *dev, struct vhost_virtqueue *vq)
2793 {
2794 __u16 old, new;
2795 __virtio16 event;
2796 bool v;
2797 /* Flush out used index updates. This is paired
2798 * with the barrier that the Guest executes when enabling
2799 * interrupts. */
2800 smp_mb();
2801
2802 if (vhost_has_feature(vq, VIRTIO_F_NOTIFY_ON_EMPTY) &&
2803 unlikely(vq->avail_idx == vq->last_avail_idx))
2804 return true;
2805
2806 if (!vhost_has_feature(vq, VIRTIO_RING_F_EVENT_IDX)) {
2807 __virtio16 flags;
2808 if (vhost_get_avail_flags(vq, &flags)) {
2809 vq_err(vq, "Failed to get flags");
2810 return true;
2811 }
2812 return !(flags & cpu_to_vhost16(vq, VRING_AVAIL_F_NO_INTERRUPT));
2813 }
2814 old = vq->signalled_used;
2815 v = vq->signalled_used_valid;
2816 new = vq->signalled_used = vq->last_used_idx;
2817 vq->signalled_used_valid = true;
2818
2819 if (unlikely(!v))
2820 return true;
2821
2822 if (vhost_get_used_event(vq, &event)) {
2823 vq_err(vq, "Failed to get used event idx");
2824 return true;
2825 }
2826 return vring_need_event(vhost16_to_cpu(vq, event), new, old);
2827 }
2828
2829 /* This actually signals the guest, using eventfd. */
vhost_signal(struct vhost_dev * dev,struct vhost_virtqueue * vq)2830 void vhost_signal(struct vhost_dev *dev, struct vhost_virtqueue *vq)
2831 {
2832 /* Signal the Guest tell them we used something up. */
2833 if (vq->call_ctx.ctx && vhost_notify(dev, vq))
2834 eventfd_signal(vq->call_ctx.ctx, 1);
2835 }
2836 EXPORT_SYMBOL_GPL(vhost_signal);
2837
2838 /* And here's the combo meal deal. Supersize me! */
vhost_add_used_and_signal(struct vhost_dev * dev,struct vhost_virtqueue * vq,unsigned int head,int len)2839 void vhost_add_used_and_signal(struct vhost_dev *dev,
2840 struct vhost_virtqueue *vq,
2841 unsigned int head, int len)
2842 {
2843 vhost_add_used(vq, head, len);
2844 vhost_signal(dev, vq);
2845 }
2846 EXPORT_SYMBOL_GPL(vhost_add_used_and_signal);
2847
2848 /* multi-buffer version of vhost_add_used_and_signal */
vhost_add_used_and_signal_n(struct vhost_dev * dev,struct vhost_virtqueue * vq,struct vring_used_elem * heads,unsigned count)2849 void vhost_add_used_and_signal_n(struct vhost_dev *dev,
2850 struct vhost_virtqueue *vq,
2851 struct vring_used_elem *heads, unsigned count)
2852 {
2853 vhost_add_used_n(vq, heads, count);
2854 vhost_signal(dev, vq);
2855 }
2856 EXPORT_SYMBOL_GPL(vhost_add_used_and_signal_n);
2857
2858 /* return true if we're sure that avaiable ring is empty */
vhost_vq_avail_empty(struct vhost_dev * dev,struct vhost_virtqueue * vq)2859 bool vhost_vq_avail_empty(struct vhost_dev *dev, struct vhost_virtqueue *vq)
2860 {
2861 __virtio16 avail_idx;
2862 int r;
2863
2864 if (vq->avail_idx != vq->last_avail_idx)
2865 return false;
2866
2867 r = vhost_get_avail_idx(vq, &avail_idx);
2868 if (unlikely(r))
2869 return false;
2870
2871 vq->avail_idx = vhost16_to_cpu(vq, avail_idx);
2872 if (vq->avail_idx != vq->last_avail_idx) {
2873 /* Since we have updated avail_idx, the following
2874 * call to vhost_get_vq_desc() will read available
2875 * ring entries. Make sure that read happens after
2876 * the avail_idx read.
2877 */
2878 smp_rmb();
2879 return false;
2880 }
2881
2882 return true;
2883 }
2884 EXPORT_SYMBOL_GPL(vhost_vq_avail_empty);
2885
2886 /* OK, now we need to know about added descriptors. */
vhost_enable_notify(struct vhost_dev * dev,struct vhost_virtqueue * vq)2887 bool vhost_enable_notify(struct vhost_dev *dev, struct vhost_virtqueue *vq)
2888 {
2889 __virtio16 avail_idx;
2890 int r;
2891
2892 if (!(vq->used_flags & VRING_USED_F_NO_NOTIFY))
2893 return false;
2894 vq->used_flags &= ~VRING_USED_F_NO_NOTIFY;
2895 if (!vhost_has_feature(vq, VIRTIO_RING_F_EVENT_IDX)) {
2896 r = vhost_update_used_flags(vq);
2897 if (r) {
2898 vq_err(vq, "Failed to enable notification at %p: %d\n",
2899 &vq->used->flags, r);
2900 return false;
2901 }
2902 } else {
2903 r = vhost_update_avail_event(vq);
2904 if (r) {
2905 vq_err(vq, "Failed to update avail event index at %p: %d\n",
2906 vhost_avail_event(vq), r);
2907 return false;
2908 }
2909 }
2910 /* They could have slipped one in as we were doing that: make
2911 * sure it's written, then check again. */
2912 smp_mb();
2913 r = vhost_get_avail_idx(vq, &avail_idx);
2914 if (r) {
2915 vq_err(vq, "Failed to check avail idx at %p: %d\n",
2916 &vq->avail->idx, r);
2917 return false;
2918 }
2919
2920 vq->avail_idx = vhost16_to_cpu(vq, avail_idx);
2921 if (vq->avail_idx != vq->last_avail_idx) {
2922 /* Since we have updated avail_idx, the following
2923 * call to vhost_get_vq_desc() will read available
2924 * ring entries. Make sure that read happens after
2925 * the avail_idx read.
2926 */
2927 smp_rmb();
2928 return true;
2929 }
2930
2931 return false;
2932 }
2933 EXPORT_SYMBOL_GPL(vhost_enable_notify);
2934
2935 /* We don't need to be notified again. */
vhost_disable_notify(struct vhost_dev * dev,struct vhost_virtqueue * vq)2936 void vhost_disable_notify(struct vhost_dev *dev, struct vhost_virtqueue *vq)
2937 {
2938 int r;
2939
2940 if (vq->used_flags & VRING_USED_F_NO_NOTIFY)
2941 return;
2942 vq->used_flags |= VRING_USED_F_NO_NOTIFY;
2943 if (!vhost_has_feature(vq, VIRTIO_RING_F_EVENT_IDX)) {
2944 r = vhost_update_used_flags(vq);
2945 if (r)
2946 vq_err(vq, "Failed to disable notification at %p: %d\n",
2947 &vq->used->flags, r);
2948 }
2949 }
2950 EXPORT_SYMBOL_GPL(vhost_disable_notify);
2951
2952 /* Create a new message. */
vhost_new_msg(struct vhost_virtqueue * vq,int type)2953 struct vhost_msg_node *vhost_new_msg(struct vhost_virtqueue *vq, int type)
2954 {
2955 /* Make sure all padding within the structure is initialized. */
2956 struct vhost_msg_node *node = kzalloc(sizeof(*node), GFP_KERNEL);
2957 if (!node)
2958 return NULL;
2959
2960 node->vq = vq;
2961 node->msg.type = type;
2962 return node;
2963 }
2964 EXPORT_SYMBOL_GPL(vhost_new_msg);
2965
vhost_enqueue_msg(struct vhost_dev * dev,struct list_head * head,struct vhost_msg_node * node)2966 void vhost_enqueue_msg(struct vhost_dev *dev, struct list_head *head,
2967 struct vhost_msg_node *node)
2968 {
2969 spin_lock(&dev->iotlb_lock);
2970 list_add_tail(&node->node, head);
2971 spin_unlock(&dev->iotlb_lock);
2972
2973 wake_up_interruptible_poll(&dev->wait, EPOLLIN | EPOLLRDNORM);
2974 }
2975 EXPORT_SYMBOL_GPL(vhost_enqueue_msg);
2976
vhost_dequeue_msg(struct vhost_dev * dev,struct list_head * head)2977 struct vhost_msg_node *vhost_dequeue_msg(struct vhost_dev *dev,
2978 struct list_head *head)
2979 {
2980 struct vhost_msg_node *node = NULL;
2981
2982 spin_lock(&dev->iotlb_lock);
2983 if (!list_empty(head)) {
2984 node = list_first_entry(head, struct vhost_msg_node,
2985 node);
2986 list_del(&node->node);
2987 }
2988 spin_unlock(&dev->iotlb_lock);
2989
2990 return node;
2991 }
2992 EXPORT_SYMBOL_GPL(vhost_dequeue_msg);
2993
vhost_set_backend_features(struct vhost_dev * dev,u64 features)2994 void vhost_set_backend_features(struct vhost_dev *dev, u64 features)
2995 {
2996 struct vhost_virtqueue *vq;
2997 int i;
2998
2999 mutex_lock(&dev->mutex);
3000 for (i = 0; i < dev->nvqs; ++i) {
3001 vq = dev->vqs[i];
3002 mutex_lock(&vq->mutex);
3003 vq->acked_backend_features = features;
3004 mutex_unlock(&vq->mutex);
3005 }
3006 mutex_unlock(&dev->mutex);
3007 }
3008 EXPORT_SYMBOL_GPL(vhost_set_backend_features);
3009
vhost_init(void)3010 static int __init vhost_init(void)
3011 {
3012 return 0;
3013 }
3014
vhost_exit(void)3015 static void __exit vhost_exit(void)
3016 {
3017 }
3018
3019 module_init(vhost_init);
3020 module_exit(vhost_exit);
3021
3022 MODULE_VERSION("0.0.1");
3023 MODULE_LICENSE("GPL v2");
3024 MODULE_AUTHOR("Michael S. Tsirkin");
3025 MODULE_DESCRIPTION("Host kernel accelerator for virtio");
3026