1 /* Copyright (C) 2009 Red Hat, Inc. 2 * Copyright (C) 2006 Rusty Russell IBM Corporation 3 * 4 * Author: Michael S. Tsirkin <mst@redhat.com> 5 * 6 * Inspiration, some code, and most witty comments come from 7 * Documentation/virtual/lguest/lguest.c, by Rusty Russell 8 * 9 * This work is licensed under the terms of the GNU GPL, version 2. 10 * 11 * Generic code for virtio server in host kernel. 12 */ 13 14 #include <linux/eventfd.h> 15 #include <linux/vhost.h> 16 #include <linux/uio.h> 17 #include <linux/mm.h> 18 #include <linux/mmu_context.h> 19 #include <linux/miscdevice.h> 20 #include <linux/mutex.h> 21 #include <linux/poll.h> 22 #include <linux/file.h> 23 #include <linux/highmem.h> 24 #include <linux/slab.h> 25 #include <linux/vmalloc.h> 26 #include <linux/kthread.h> 27 #include <linux/cgroup.h> 28 #include <linux/module.h> 29 #include <linux/sort.h> 30 #include <linux/sched/mm.h> 31 #include <linux/sched/signal.h> 32 #include <linux/interval_tree_generic.h> 33 34 #include "vhost.h" 35 36 static ushort max_mem_regions = 64; 37 module_param(max_mem_regions, ushort, 0444); 38 MODULE_PARM_DESC(max_mem_regions, 39 "Maximum number of memory regions in memory map. (default: 64)"); 40 static int max_iotlb_entries = 2048; 41 module_param(max_iotlb_entries, int, 0444); 42 MODULE_PARM_DESC(max_iotlb_entries, 43 "Maximum number of iotlb entries. (default: 2048)"); 44 45 enum { 46 VHOST_MEMORY_F_LOG = 0x1, 47 }; 48 49 #define vhost_used_event(vq) ((__virtio16 __user *)&vq->avail->ring[vq->num]) 50 #define vhost_avail_event(vq) ((__virtio16 __user *)&vq->used->ring[vq->num]) 51 52 INTERVAL_TREE_DEFINE(struct vhost_umem_node, 53 rb, __u64, __subtree_last, 54 START, LAST, static inline, vhost_umem_interval_tree); 55 56 #ifdef CONFIG_VHOST_CROSS_ENDIAN_LEGACY 57 static void vhost_disable_cross_endian(struct vhost_virtqueue *vq) 58 { 59 vq->user_be = !virtio_legacy_is_little_endian(); 60 } 61 62 static void vhost_enable_cross_endian_big(struct vhost_virtqueue *vq) 63 { 64 vq->user_be = true; 65 } 66 67 static void vhost_enable_cross_endian_little(struct vhost_virtqueue *vq) 68 { 69 vq->user_be = false; 70 } 71 72 static long vhost_set_vring_endian(struct vhost_virtqueue *vq, int __user *argp) 73 { 74 struct vhost_vring_state s; 75 76 if (vq->private_data) 77 return -EBUSY; 78 79 if (copy_from_user(&s, argp, sizeof(s))) 80 return -EFAULT; 81 82 if (s.num != VHOST_VRING_LITTLE_ENDIAN && 83 s.num != VHOST_VRING_BIG_ENDIAN) 84 return -EINVAL; 85 86 if (s.num == VHOST_VRING_BIG_ENDIAN) 87 vhost_enable_cross_endian_big(vq); 88 else 89 vhost_enable_cross_endian_little(vq); 90 91 return 0; 92 } 93 94 static long vhost_get_vring_endian(struct vhost_virtqueue *vq, u32 idx, 95 int __user *argp) 96 { 97 struct vhost_vring_state s = { 98 .index = idx, 99 .num = vq->user_be 100 }; 101 102 if (copy_to_user(argp, &s, sizeof(s))) 103 return -EFAULT; 104 105 return 0; 106 } 107 108 static void vhost_init_is_le(struct vhost_virtqueue *vq) 109 { 110 /* Note for legacy virtio: user_be is initialized at reset time 111 * according to the host endianness. If userspace does not set an 112 * explicit endianness, the default behavior is native endian, as 113 * expected by legacy virtio. 114 */ 115 vq->is_le = vhost_has_feature(vq, VIRTIO_F_VERSION_1) || !vq->user_be; 116 } 117 #else 118 static void vhost_disable_cross_endian(struct vhost_virtqueue *vq) 119 { 120 } 121 122 static long vhost_set_vring_endian(struct vhost_virtqueue *vq, int __user *argp) 123 { 124 return -ENOIOCTLCMD; 125 } 126 127 static long vhost_get_vring_endian(struct vhost_virtqueue *vq, u32 idx, 128 int __user *argp) 129 { 130 return -ENOIOCTLCMD; 131 } 132 133 static void vhost_init_is_le(struct vhost_virtqueue *vq) 134 { 135 vq->is_le = vhost_has_feature(vq, VIRTIO_F_VERSION_1) 136 || virtio_legacy_is_little_endian(); 137 } 138 #endif /* CONFIG_VHOST_CROSS_ENDIAN_LEGACY */ 139 140 static void vhost_reset_is_le(struct vhost_virtqueue *vq) 141 { 142 vhost_init_is_le(vq); 143 } 144 145 struct vhost_flush_struct { 146 struct vhost_work work; 147 struct completion wait_event; 148 }; 149 150 static void vhost_flush_work(struct vhost_work *work) 151 { 152 struct vhost_flush_struct *s; 153 154 s = container_of(work, struct vhost_flush_struct, work); 155 complete(&s->wait_event); 156 } 157 158 static void vhost_poll_func(struct file *file, wait_queue_head_t *wqh, 159 poll_table *pt) 160 { 161 struct vhost_poll *poll; 162 163 poll = container_of(pt, struct vhost_poll, table); 164 poll->wqh = wqh; 165 add_wait_queue(wqh, &poll->wait); 166 } 167 168 static int vhost_poll_wakeup(wait_queue_entry_t *wait, unsigned mode, int sync, 169 void *key) 170 { 171 struct vhost_poll *poll = container_of(wait, struct vhost_poll, wait); 172 173 if (!(key_to_poll(key) & poll->mask)) 174 return 0; 175 176 vhost_poll_queue(poll); 177 return 0; 178 } 179 180 void vhost_work_init(struct vhost_work *work, vhost_work_fn_t fn) 181 { 182 clear_bit(VHOST_WORK_QUEUED, &work->flags); 183 work->fn = fn; 184 } 185 EXPORT_SYMBOL_GPL(vhost_work_init); 186 187 /* Init poll structure */ 188 void vhost_poll_init(struct vhost_poll *poll, vhost_work_fn_t fn, 189 __poll_t mask, struct vhost_dev *dev) 190 { 191 init_waitqueue_func_entry(&poll->wait, vhost_poll_wakeup); 192 init_poll_funcptr(&poll->table, vhost_poll_func); 193 poll->mask = mask; 194 poll->dev = dev; 195 poll->wqh = NULL; 196 197 vhost_work_init(&poll->work, fn); 198 } 199 EXPORT_SYMBOL_GPL(vhost_poll_init); 200 201 /* Start polling a file. We add ourselves to file's wait queue. The caller must 202 * keep a reference to a file until after vhost_poll_stop is called. */ 203 int vhost_poll_start(struct vhost_poll *poll, struct file *file) 204 { 205 __poll_t mask; 206 int ret = 0; 207 208 if (poll->wqh) 209 return 0; 210 211 mask = file->f_op->poll(file, &poll->table); 212 if (mask) 213 vhost_poll_wakeup(&poll->wait, 0, 0, poll_to_key(mask)); 214 if (mask & EPOLLERR) { 215 if (poll->wqh) 216 remove_wait_queue(poll->wqh, &poll->wait); 217 ret = -EINVAL; 218 } 219 220 return ret; 221 } 222 EXPORT_SYMBOL_GPL(vhost_poll_start); 223 224 /* Stop polling a file. After this function returns, it becomes safe to drop the 225 * file reference. You must also flush afterwards. */ 226 void vhost_poll_stop(struct vhost_poll *poll) 227 { 228 if (poll->wqh) { 229 remove_wait_queue(poll->wqh, &poll->wait); 230 poll->wqh = NULL; 231 } 232 } 233 EXPORT_SYMBOL_GPL(vhost_poll_stop); 234 235 void vhost_work_flush(struct vhost_dev *dev, struct vhost_work *work) 236 { 237 struct vhost_flush_struct flush; 238 239 if (dev->worker) { 240 init_completion(&flush.wait_event); 241 vhost_work_init(&flush.work, vhost_flush_work); 242 243 vhost_work_queue(dev, &flush.work); 244 wait_for_completion(&flush.wait_event); 245 } 246 } 247 EXPORT_SYMBOL_GPL(vhost_work_flush); 248 249 /* Flush any work that has been scheduled. When calling this, don't hold any 250 * locks that are also used by the callback. */ 251 void vhost_poll_flush(struct vhost_poll *poll) 252 { 253 vhost_work_flush(poll->dev, &poll->work); 254 } 255 EXPORT_SYMBOL_GPL(vhost_poll_flush); 256 257 void vhost_work_queue(struct vhost_dev *dev, struct vhost_work *work) 258 { 259 if (!dev->worker) 260 return; 261 262 if (!test_and_set_bit(VHOST_WORK_QUEUED, &work->flags)) { 263 /* We can only add the work to the list after we're 264 * sure it was not in the list. 265 * test_and_set_bit() implies a memory barrier. 266 */ 267 llist_add(&work->node, &dev->work_list); 268 wake_up_process(dev->worker); 269 } 270 } 271 EXPORT_SYMBOL_GPL(vhost_work_queue); 272 273 /* A lockless hint for busy polling code to exit the loop */ 274 bool vhost_has_work(struct vhost_dev *dev) 275 { 276 return !llist_empty(&dev->work_list); 277 } 278 EXPORT_SYMBOL_GPL(vhost_has_work); 279 280 void vhost_poll_queue(struct vhost_poll *poll) 281 { 282 vhost_work_queue(poll->dev, &poll->work); 283 } 284 EXPORT_SYMBOL_GPL(vhost_poll_queue); 285 286 static void __vhost_vq_meta_reset(struct vhost_virtqueue *vq) 287 { 288 int j; 289 290 for (j = 0; j < VHOST_NUM_ADDRS; j++) 291 vq->meta_iotlb[j] = NULL; 292 } 293 294 static void vhost_vq_meta_reset(struct vhost_dev *d) 295 { 296 int i; 297 298 for (i = 0; i < d->nvqs; ++i) 299 __vhost_vq_meta_reset(d->vqs[i]); 300 } 301 302 static void vhost_vq_reset(struct vhost_dev *dev, 303 struct vhost_virtqueue *vq) 304 { 305 vq->num = 1; 306 vq->desc = NULL; 307 vq->avail = NULL; 308 vq->used = NULL; 309 vq->last_avail_idx = 0; 310 vq->avail_idx = 0; 311 vq->last_used_idx = 0; 312 vq->signalled_used = 0; 313 vq->signalled_used_valid = false; 314 vq->used_flags = 0; 315 vq->log_used = false; 316 vq->log_addr = -1ull; 317 vq->private_data = NULL; 318 vq->acked_features = 0; 319 vq->log_base = NULL; 320 vq->error_ctx = NULL; 321 vq->kick = NULL; 322 vq->call_ctx = NULL; 323 vq->log_ctx = NULL; 324 vhost_reset_is_le(vq); 325 vhost_disable_cross_endian(vq); 326 vq->busyloop_timeout = 0; 327 vq->umem = NULL; 328 vq->iotlb = NULL; 329 __vhost_vq_meta_reset(vq); 330 } 331 332 static int vhost_worker(void *data) 333 { 334 struct vhost_dev *dev = data; 335 struct vhost_work *work, *work_next; 336 struct llist_node *node; 337 mm_segment_t oldfs = get_fs(); 338 339 set_fs(USER_DS); 340 use_mm(dev->mm); 341 342 for (;;) { 343 /* mb paired w/ kthread_stop */ 344 set_current_state(TASK_INTERRUPTIBLE); 345 346 if (kthread_should_stop()) { 347 __set_current_state(TASK_RUNNING); 348 break; 349 } 350 351 node = llist_del_all(&dev->work_list); 352 if (!node) 353 schedule(); 354 355 node = llist_reverse_order(node); 356 /* make sure flag is seen after deletion */ 357 smp_wmb(); 358 llist_for_each_entry_safe(work, work_next, node, node) { 359 clear_bit(VHOST_WORK_QUEUED, &work->flags); 360 __set_current_state(TASK_RUNNING); 361 work->fn(work); 362 if (need_resched()) 363 schedule(); 364 } 365 } 366 unuse_mm(dev->mm); 367 set_fs(oldfs); 368 return 0; 369 } 370 371 static void vhost_vq_free_iovecs(struct vhost_virtqueue *vq) 372 { 373 kfree(vq->indirect); 374 vq->indirect = NULL; 375 kfree(vq->log); 376 vq->log = NULL; 377 kfree(vq->heads); 378 vq->heads = NULL; 379 } 380 381 /* Helper to allocate iovec buffers for all vqs. */ 382 static long vhost_dev_alloc_iovecs(struct vhost_dev *dev) 383 { 384 struct vhost_virtqueue *vq; 385 int i; 386 387 for (i = 0; i < dev->nvqs; ++i) { 388 vq = dev->vqs[i]; 389 vq->indirect = kmalloc(sizeof *vq->indirect * UIO_MAXIOV, 390 GFP_KERNEL); 391 vq->log = kmalloc(sizeof *vq->log * UIO_MAXIOV, GFP_KERNEL); 392 vq->heads = kmalloc(sizeof *vq->heads * UIO_MAXIOV, GFP_KERNEL); 393 if (!vq->indirect || !vq->log || !vq->heads) 394 goto err_nomem; 395 } 396 return 0; 397 398 err_nomem: 399 for (; i >= 0; --i) 400 vhost_vq_free_iovecs(dev->vqs[i]); 401 return -ENOMEM; 402 } 403 404 static void vhost_dev_free_iovecs(struct vhost_dev *dev) 405 { 406 int i; 407 408 for (i = 0; i < dev->nvqs; ++i) 409 vhost_vq_free_iovecs(dev->vqs[i]); 410 } 411 412 void vhost_dev_init(struct vhost_dev *dev, 413 struct vhost_virtqueue **vqs, int nvqs) 414 { 415 struct vhost_virtqueue *vq; 416 int i; 417 418 dev->vqs = vqs; 419 dev->nvqs = nvqs; 420 mutex_init(&dev->mutex); 421 dev->log_ctx = NULL; 422 dev->umem = NULL; 423 dev->iotlb = NULL; 424 dev->mm = NULL; 425 dev->worker = NULL; 426 init_llist_head(&dev->work_list); 427 init_waitqueue_head(&dev->wait); 428 INIT_LIST_HEAD(&dev->read_list); 429 INIT_LIST_HEAD(&dev->pending_list); 430 spin_lock_init(&dev->iotlb_lock); 431 432 433 for (i = 0; i < dev->nvqs; ++i) { 434 vq = dev->vqs[i]; 435 vq->log = NULL; 436 vq->indirect = NULL; 437 vq->heads = NULL; 438 vq->dev = dev; 439 mutex_init(&vq->mutex); 440 vhost_vq_reset(dev, vq); 441 if (vq->handle_kick) 442 vhost_poll_init(&vq->poll, vq->handle_kick, 443 EPOLLIN, dev); 444 } 445 } 446 EXPORT_SYMBOL_GPL(vhost_dev_init); 447 448 /* Caller should have device mutex */ 449 long vhost_dev_check_owner(struct vhost_dev *dev) 450 { 451 /* Are you the owner? If not, I don't think you mean to do that */ 452 return dev->mm == current->mm ? 0 : -EPERM; 453 } 454 EXPORT_SYMBOL_GPL(vhost_dev_check_owner); 455 456 struct vhost_attach_cgroups_struct { 457 struct vhost_work work; 458 struct task_struct *owner; 459 int ret; 460 }; 461 462 static void vhost_attach_cgroups_work(struct vhost_work *work) 463 { 464 struct vhost_attach_cgroups_struct *s; 465 466 s = container_of(work, struct vhost_attach_cgroups_struct, work); 467 s->ret = cgroup_attach_task_all(s->owner, current); 468 } 469 470 static int vhost_attach_cgroups(struct vhost_dev *dev) 471 { 472 struct vhost_attach_cgroups_struct attach; 473 474 attach.owner = current; 475 vhost_work_init(&attach.work, vhost_attach_cgroups_work); 476 vhost_work_queue(dev, &attach.work); 477 vhost_work_flush(dev, &attach.work); 478 return attach.ret; 479 } 480 481 /* Caller should have device mutex */ 482 bool vhost_dev_has_owner(struct vhost_dev *dev) 483 { 484 return dev->mm; 485 } 486 EXPORT_SYMBOL_GPL(vhost_dev_has_owner); 487 488 /* Caller should have device mutex */ 489 long vhost_dev_set_owner(struct vhost_dev *dev) 490 { 491 struct task_struct *worker; 492 int err; 493 494 /* Is there an owner already? */ 495 if (vhost_dev_has_owner(dev)) { 496 err = -EBUSY; 497 goto err_mm; 498 } 499 500 /* No owner, become one */ 501 dev->mm = get_task_mm(current); 502 worker = kthread_create(vhost_worker, dev, "vhost-%d", current->pid); 503 if (IS_ERR(worker)) { 504 err = PTR_ERR(worker); 505 goto err_worker; 506 } 507 508 dev->worker = worker; 509 wake_up_process(worker); /* avoid contributing to loadavg */ 510 511 err = vhost_attach_cgroups(dev); 512 if (err) 513 goto err_cgroup; 514 515 err = vhost_dev_alloc_iovecs(dev); 516 if (err) 517 goto err_cgroup; 518 519 return 0; 520 err_cgroup: 521 kthread_stop(worker); 522 dev->worker = NULL; 523 err_worker: 524 if (dev->mm) 525 mmput(dev->mm); 526 dev->mm = NULL; 527 err_mm: 528 return err; 529 } 530 EXPORT_SYMBOL_GPL(vhost_dev_set_owner); 531 532 struct vhost_umem *vhost_dev_reset_owner_prepare(void) 533 { 534 return kvzalloc(sizeof(struct vhost_umem), GFP_KERNEL); 535 } 536 EXPORT_SYMBOL_GPL(vhost_dev_reset_owner_prepare); 537 538 /* Caller should have device mutex */ 539 void vhost_dev_reset_owner(struct vhost_dev *dev, struct vhost_umem *umem) 540 { 541 int i; 542 543 vhost_dev_cleanup(dev); 544 545 /* Restore memory to default empty mapping. */ 546 INIT_LIST_HEAD(&umem->umem_list); 547 dev->umem = umem; 548 /* We don't need VQ locks below since vhost_dev_cleanup makes sure 549 * VQs aren't running. 550 */ 551 for (i = 0; i < dev->nvqs; ++i) 552 dev->vqs[i]->umem = umem; 553 } 554 EXPORT_SYMBOL_GPL(vhost_dev_reset_owner); 555 556 void vhost_dev_stop(struct vhost_dev *dev) 557 { 558 int i; 559 560 for (i = 0; i < dev->nvqs; ++i) { 561 if (dev->vqs[i]->kick && dev->vqs[i]->handle_kick) { 562 vhost_poll_stop(&dev->vqs[i]->poll); 563 vhost_poll_flush(&dev->vqs[i]->poll); 564 } 565 } 566 } 567 EXPORT_SYMBOL_GPL(vhost_dev_stop); 568 569 static void vhost_umem_free(struct vhost_umem *umem, 570 struct vhost_umem_node *node) 571 { 572 vhost_umem_interval_tree_remove(node, &umem->umem_tree); 573 list_del(&node->link); 574 kfree(node); 575 umem->numem--; 576 } 577 578 static void vhost_umem_clean(struct vhost_umem *umem) 579 { 580 struct vhost_umem_node *node, *tmp; 581 582 if (!umem) 583 return; 584 585 list_for_each_entry_safe(node, tmp, &umem->umem_list, link) 586 vhost_umem_free(umem, node); 587 588 kvfree(umem); 589 } 590 591 static void vhost_clear_msg(struct vhost_dev *dev) 592 { 593 struct vhost_msg_node *node, *n; 594 595 spin_lock(&dev->iotlb_lock); 596 597 list_for_each_entry_safe(node, n, &dev->read_list, node) { 598 list_del(&node->node); 599 kfree(node); 600 } 601 602 list_for_each_entry_safe(node, n, &dev->pending_list, node) { 603 list_del(&node->node); 604 kfree(node); 605 } 606 607 spin_unlock(&dev->iotlb_lock); 608 } 609 610 void vhost_dev_cleanup(struct vhost_dev *dev) 611 { 612 int i; 613 614 for (i = 0; i < dev->nvqs; ++i) { 615 if (dev->vqs[i]->error_ctx) 616 eventfd_ctx_put(dev->vqs[i]->error_ctx); 617 if (dev->vqs[i]->kick) 618 fput(dev->vqs[i]->kick); 619 if (dev->vqs[i]->call_ctx) 620 eventfd_ctx_put(dev->vqs[i]->call_ctx); 621 vhost_vq_reset(dev, dev->vqs[i]); 622 } 623 vhost_dev_free_iovecs(dev); 624 if (dev->log_ctx) 625 eventfd_ctx_put(dev->log_ctx); 626 dev->log_ctx = NULL; 627 /* No one will access memory at this point */ 628 vhost_umem_clean(dev->umem); 629 dev->umem = NULL; 630 vhost_umem_clean(dev->iotlb); 631 dev->iotlb = NULL; 632 vhost_clear_msg(dev); 633 wake_up_interruptible_poll(&dev->wait, EPOLLIN | EPOLLRDNORM); 634 WARN_ON(!llist_empty(&dev->work_list)); 635 if (dev->worker) { 636 kthread_stop(dev->worker); 637 dev->worker = NULL; 638 } 639 if (dev->mm) 640 mmput(dev->mm); 641 dev->mm = NULL; 642 } 643 EXPORT_SYMBOL_GPL(vhost_dev_cleanup); 644 645 static int log_access_ok(void __user *log_base, u64 addr, unsigned long sz) 646 { 647 u64 a = addr / VHOST_PAGE_SIZE / 8; 648 649 /* Make sure 64 bit math will not overflow. */ 650 if (a > ULONG_MAX - (unsigned long)log_base || 651 a + (unsigned long)log_base > ULONG_MAX) 652 return 0; 653 654 return access_ok(VERIFY_WRITE, log_base + a, 655 (sz + VHOST_PAGE_SIZE * 8 - 1) / VHOST_PAGE_SIZE / 8); 656 } 657 658 static bool vhost_overflow(u64 uaddr, u64 size) 659 { 660 /* Make sure 64 bit math will not overflow. */ 661 return uaddr > ULONG_MAX || size > ULONG_MAX || uaddr > ULONG_MAX - size; 662 } 663 664 /* Caller should have vq mutex and device mutex. */ 665 static int vq_memory_access_ok(void __user *log_base, struct vhost_umem *umem, 666 int log_all) 667 { 668 struct vhost_umem_node *node; 669 670 if (!umem) 671 return 0; 672 673 list_for_each_entry(node, &umem->umem_list, link) { 674 unsigned long a = node->userspace_addr; 675 676 if (vhost_overflow(node->userspace_addr, node->size)) 677 return 0; 678 679 680 if (!access_ok(VERIFY_WRITE, (void __user *)a, 681 node->size)) 682 return 0; 683 else if (log_all && !log_access_ok(log_base, 684 node->start, 685 node->size)) 686 return 0; 687 } 688 return 1; 689 } 690 691 static inline void __user *vhost_vq_meta_fetch(struct vhost_virtqueue *vq, 692 u64 addr, unsigned int size, 693 int type) 694 { 695 const struct vhost_umem_node *node = vq->meta_iotlb[type]; 696 697 if (!node) 698 return NULL; 699 700 return (void *)(uintptr_t)(node->userspace_addr + addr - node->start); 701 } 702 703 /* Can we switch to this memory table? */ 704 /* Caller should have device mutex but not vq mutex */ 705 static int memory_access_ok(struct vhost_dev *d, struct vhost_umem *umem, 706 int log_all) 707 { 708 int i; 709 710 for (i = 0; i < d->nvqs; ++i) { 711 int ok; 712 bool log; 713 714 mutex_lock(&d->vqs[i]->mutex); 715 log = log_all || vhost_has_feature(d->vqs[i], VHOST_F_LOG_ALL); 716 /* If ring is inactive, will check when it's enabled. */ 717 if (d->vqs[i]->private_data) 718 ok = vq_memory_access_ok(d->vqs[i]->log_base, 719 umem, log); 720 else 721 ok = 1; 722 mutex_unlock(&d->vqs[i]->mutex); 723 if (!ok) 724 return 0; 725 } 726 return 1; 727 } 728 729 static int translate_desc(struct vhost_virtqueue *vq, u64 addr, u32 len, 730 struct iovec iov[], int iov_size, int access); 731 732 static int vhost_copy_to_user(struct vhost_virtqueue *vq, void __user *to, 733 const void *from, unsigned size) 734 { 735 int ret; 736 737 if (!vq->iotlb) 738 return __copy_to_user(to, from, size); 739 else { 740 /* This function should be called after iotlb 741 * prefetch, which means we're sure that all vq 742 * could be access through iotlb. So -EAGAIN should 743 * not happen in this case. 744 */ 745 struct iov_iter t; 746 void __user *uaddr = vhost_vq_meta_fetch(vq, 747 (u64)(uintptr_t)to, size, 748 VHOST_ADDR_DESC); 749 750 if (uaddr) 751 return __copy_to_user(uaddr, from, size); 752 753 ret = translate_desc(vq, (u64)(uintptr_t)to, size, vq->iotlb_iov, 754 ARRAY_SIZE(vq->iotlb_iov), 755 VHOST_ACCESS_WO); 756 if (ret < 0) 757 goto out; 758 iov_iter_init(&t, WRITE, vq->iotlb_iov, ret, size); 759 ret = copy_to_iter(from, size, &t); 760 if (ret == size) 761 ret = 0; 762 } 763 out: 764 return ret; 765 } 766 767 static int vhost_copy_from_user(struct vhost_virtqueue *vq, void *to, 768 void __user *from, unsigned size) 769 { 770 int ret; 771 772 if (!vq->iotlb) 773 return __copy_from_user(to, from, size); 774 else { 775 /* This function should be called after iotlb 776 * prefetch, which means we're sure that vq 777 * could be access through iotlb. So -EAGAIN should 778 * not happen in this case. 779 */ 780 void __user *uaddr = vhost_vq_meta_fetch(vq, 781 (u64)(uintptr_t)from, size, 782 VHOST_ADDR_DESC); 783 struct iov_iter f; 784 785 if (uaddr) 786 return __copy_from_user(to, uaddr, size); 787 788 ret = translate_desc(vq, (u64)(uintptr_t)from, size, vq->iotlb_iov, 789 ARRAY_SIZE(vq->iotlb_iov), 790 VHOST_ACCESS_RO); 791 if (ret < 0) { 792 vq_err(vq, "IOTLB translation failure: uaddr " 793 "%p size 0x%llx\n", from, 794 (unsigned long long) size); 795 goto out; 796 } 797 iov_iter_init(&f, READ, vq->iotlb_iov, ret, size); 798 ret = copy_from_iter(to, size, &f); 799 if (ret == size) 800 ret = 0; 801 } 802 803 out: 804 return ret; 805 } 806 807 static void __user *__vhost_get_user_slow(struct vhost_virtqueue *vq, 808 void __user *addr, unsigned int size, 809 int type) 810 { 811 int ret; 812 813 ret = translate_desc(vq, (u64)(uintptr_t)addr, size, vq->iotlb_iov, 814 ARRAY_SIZE(vq->iotlb_iov), 815 VHOST_ACCESS_RO); 816 if (ret < 0) { 817 vq_err(vq, "IOTLB translation failure: uaddr " 818 "%p size 0x%llx\n", addr, 819 (unsigned long long) size); 820 return NULL; 821 } 822 823 if (ret != 1 || vq->iotlb_iov[0].iov_len != size) { 824 vq_err(vq, "Non atomic userspace memory access: uaddr " 825 "%p size 0x%llx\n", addr, 826 (unsigned long long) size); 827 return NULL; 828 } 829 830 return vq->iotlb_iov[0].iov_base; 831 } 832 833 /* This function should be called after iotlb 834 * prefetch, which means we're sure that vq 835 * could be access through iotlb. So -EAGAIN should 836 * not happen in this case. 837 */ 838 static inline void __user *__vhost_get_user(struct vhost_virtqueue *vq, 839 void *addr, unsigned int size, 840 int type) 841 { 842 void __user *uaddr = vhost_vq_meta_fetch(vq, 843 (u64)(uintptr_t)addr, size, type); 844 if (uaddr) 845 return uaddr; 846 847 return __vhost_get_user_slow(vq, addr, size, type); 848 } 849 850 #define vhost_put_user(vq, x, ptr) \ 851 ({ \ 852 int ret = -EFAULT; \ 853 if (!vq->iotlb) { \ 854 ret = __put_user(x, ptr); \ 855 } else { \ 856 __typeof__(ptr) to = \ 857 (__typeof__(ptr)) __vhost_get_user(vq, ptr, \ 858 sizeof(*ptr), VHOST_ADDR_USED); \ 859 if (to != NULL) \ 860 ret = __put_user(x, to); \ 861 else \ 862 ret = -EFAULT; \ 863 } \ 864 ret; \ 865 }) 866 867 #define vhost_get_user(vq, x, ptr, type) \ 868 ({ \ 869 int ret; \ 870 if (!vq->iotlb) { \ 871 ret = __get_user(x, ptr); \ 872 } else { \ 873 __typeof__(ptr) from = \ 874 (__typeof__(ptr)) __vhost_get_user(vq, ptr, \ 875 sizeof(*ptr), \ 876 type); \ 877 if (from != NULL) \ 878 ret = __get_user(x, from); \ 879 else \ 880 ret = -EFAULT; \ 881 } \ 882 ret; \ 883 }) 884 885 #define vhost_get_avail(vq, x, ptr) \ 886 vhost_get_user(vq, x, ptr, VHOST_ADDR_AVAIL) 887 888 #define vhost_get_used(vq, x, ptr) \ 889 vhost_get_user(vq, x, ptr, VHOST_ADDR_USED) 890 891 static void vhost_dev_lock_vqs(struct vhost_dev *d) 892 { 893 int i = 0; 894 for (i = 0; i < d->nvqs; ++i) 895 mutex_lock_nested(&d->vqs[i]->mutex, i); 896 } 897 898 static void vhost_dev_unlock_vqs(struct vhost_dev *d) 899 { 900 int i = 0; 901 for (i = 0; i < d->nvqs; ++i) 902 mutex_unlock(&d->vqs[i]->mutex); 903 } 904 905 static int vhost_new_umem_range(struct vhost_umem *umem, 906 u64 start, u64 size, u64 end, 907 u64 userspace_addr, int perm) 908 { 909 struct vhost_umem_node *tmp, *node = kmalloc(sizeof(*node), GFP_ATOMIC); 910 911 if (!node) 912 return -ENOMEM; 913 914 if (umem->numem == max_iotlb_entries) { 915 tmp = list_first_entry(&umem->umem_list, typeof(*tmp), link); 916 vhost_umem_free(umem, tmp); 917 } 918 919 node->start = start; 920 node->size = size; 921 node->last = end; 922 node->userspace_addr = userspace_addr; 923 node->perm = perm; 924 INIT_LIST_HEAD(&node->link); 925 list_add_tail(&node->link, &umem->umem_list); 926 vhost_umem_interval_tree_insert(node, &umem->umem_tree); 927 umem->numem++; 928 929 return 0; 930 } 931 932 static void vhost_del_umem_range(struct vhost_umem *umem, 933 u64 start, u64 end) 934 { 935 struct vhost_umem_node *node; 936 937 while ((node = vhost_umem_interval_tree_iter_first(&umem->umem_tree, 938 start, end))) 939 vhost_umem_free(umem, node); 940 } 941 942 static void vhost_iotlb_notify_vq(struct vhost_dev *d, 943 struct vhost_iotlb_msg *msg) 944 { 945 struct vhost_msg_node *node, *n; 946 947 spin_lock(&d->iotlb_lock); 948 949 list_for_each_entry_safe(node, n, &d->pending_list, node) { 950 struct vhost_iotlb_msg *vq_msg = &node->msg.iotlb; 951 if (msg->iova <= vq_msg->iova && 952 msg->iova + msg->size - 1 > vq_msg->iova && 953 vq_msg->type == VHOST_IOTLB_MISS) { 954 vhost_poll_queue(&node->vq->poll); 955 list_del(&node->node); 956 kfree(node); 957 } 958 } 959 960 spin_unlock(&d->iotlb_lock); 961 } 962 963 static int umem_access_ok(u64 uaddr, u64 size, int access) 964 { 965 unsigned long a = uaddr; 966 967 /* Make sure 64 bit math will not overflow. */ 968 if (vhost_overflow(uaddr, size)) 969 return -EFAULT; 970 971 if ((access & VHOST_ACCESS_RO) && 972 !access_ok(VERIFY_READ, (void __user *)a, size)) 973 return -EFAULT; 974 if ((access & VHOST_ACCESS_WO) && 975 !access_ok(VERIFY_WRITE, (void __user *)a, size)) 976 return -EFAULT; 977 return 0; 978 } 979 980 static int vhost_process_iotlb_msg(struct vhost_dev *dev, 981 struct vhost_iotlb_msg *msg) 982 { 983 int ret = 0; 984 985 vhost_dev_lock_vqs(dev); 986 switch (msg->type) { 987 case VHOST_IOTLB_UPDATE: 988 if (!dev->iotlb) { 989 ret = -EFAULT; 990 break; 991 } 992 if (umem_access_ok(msg->uaddr, msg->size, msg->perm)) { 993 ret = -EFAULT; 994 break; 995 } 996 vhost_vq_meta_reset(dev); 997 if (vhost_new_umem_range(dev->iotlb, msg->iova, msg->size, 998 msg->iova + msg->size - 1, 999 msg->uaddr, msg->perm)) { 1000 ret = -ENOMEM; 1001 break; 1002 } 1003 vhost_iotlb_notify_vq(dev, msg); 1004 break; 1005 case VHOST_IOTLB_INVALIDATE: 1006 if (!dev->iotlb) { 1007 ret = -EFAULT; 1008 break; 1009 } 1010 vhost_vq_meta_reset(dev); 1011 vhost_del_umem_range(dev->iotlb, msg->iova, 1012 msg->iova + msg->size - 1); 1013 break; 1014 default: 1015 ret = -EINVAL; 1016 break; 1017 } 1018 1019 vhost_dev_unlock_vqs(dev); 1020 return ret; 1021 } 1022 ssize_t vhost_chr_write_iter(struct vhost_dev *dev, 1023 struct iov_iter *from) 1024 { 1025 struct vhost_msg_node node; 1026 unsigned size = sizeof(struct vhost_msg); 1027 size_t ret; 1028 int err; 1029 1030 if (iov_iter_count(from) < size) 1031 return 0; 1032 ret = copy_from_iter(&node.msg, size, from); 1033 if (ret != size) 1034 goto done; 1035 1036 switch (node.msg.type) { 1037 case VHOST_IOTLB_MSG: 1038 err = vhost_process_iotlb_msg(dev, &node.msg.iotlb); 1039 if (err) 1040 ret = err; 1041 break; 1042 default: 1043 ret = -EINVAL; 1044 break; 1045 } 1046 1047 done: 1048 return ret; 1049 } 1050 EXPORT_SYMBOL(vhost_chr_write_iter); 1051 1052 __poll_t vhost_chr_poll(struct file *file, struct vhost_dev *dev, 1053 poll_table *wait) 1054 { 1055 __poll_t mask = 0; 1056 1057 poll_wait(file, &dev->wait, wait); 1058 1059 if (!list_empty(&dev->read_list)) 1060 mask |= EPOLLIN | EPOLLRDNORM; 1061 1062 return mask; 1063 } 1064 EXPORT_SYMBOL(vhost_chr_poll); 1065 1066 ssize_t vhost_chr_read_iter(struct vhost_dev *dev, struct iov_iter *to, 1067 int noblock) 1068 { 1069 DEFINE_WAIT(wait); 1070 struct vhost_msg_node *node; 1071 ssize_t ret = 0; 1072 unsigned size = sizeof(struct vhost_msg); 1073 1074 if (iov_iter_count(to) < size) 1075 return 0; 1076 1077 while (1) { 1078 if (!noblock) 1079 prepare_to_wait(&dev->wait, &wait, 1080 TASK_INTERRUPTIBLE); 1081 1082 node = vhost_dequeue_msg(dev, &dev->read_list); 1083 if (node) 1084 break; 1085 if (noblock) { 1086 ret = -EAGAIN; 1087 break; 1088 } 1089 if (signal_pending(current)) { 1090 ret = -ERESTARTSYS; 1091 break; 1092 } 1093 if (!dev->iotlb) { 1094 ret = -EBADFD; 1095 break; 1096 } 1097 1098 schedule(); 1099 } 1100 1101 if (!noblock) 1102 finish_wait(&dev->wait, &wait); 1103 1104 if (node) { 1105 ret = copy_to_iter(&node->msg, size, to); 1106 1107 if (ret != size || node->msg.type != VHOST_IOTLB_MISS) { 1108 kfree(node); 1109 return ret; 1110 } 1111 1112 vhost_enqueue_msg(dev, &dev->pending_list, node); 1113 } 1114 1115 return ret; 1116 } 1117 EXPORT_SYMBOL_GPL(vhost_chr_read_iter); 1118 1119 static int vhost_iotlb_miss(struct vhost_virtqueue *vq, u64 iova, int access) 1120 { 1121 struct vhost_dev *dev = vq->dev; 1122 struct vhost_msg_node *node; 1123 struct vhost_iotlb_msg *msg; 1124 1125 node = vhost_new_msg(vq, VHOST_IOTLB_MISS); 1126 if (!node) 1127 return -ENOMEM; 1128 1129 msg = &node->msg.iotlb; 1130 msg->type = VHOST_IOTLB_MISS; 1131 msg->iova = iova; 1132 msg->perm = access; 1133 1134 vhost_enqueue_msg(dev, &dev->read_list, node); 1135 1136 return 0; 1137 } 1138 1139 static int vq_access_ok(struct vhost_virtqueue *vq, unsigned int num, 1140 struct vring_desc __user *desc, 1141 struct vring_avail __user *avail, 1142 struct vring_used __user *used) 1143 1144 { 1145 size_t s = vhost_has_feature(vq, VIRTIO_RING_F_EVENT_IDX) ? 2 : 0; 1146 1147 return access_ok(VERIFY_READ, desc, num * sizeof *desc) && 1148 access_ok(VERIFY_READ, avail, 1149 sizeof *avail + num * sizeof *avail->ring + s) && 1150 access_ok(VERIFY_WRITE, used, 1151 sizeof *used + num * sizeof *used->ring + s); 1152 } 1153 1154 static void vhost_vq_meta_update(struct vhost_virtqueue *vq, 1155 const struct vhost_umem_node *node, 1156 int type) 1157 { 1158 int access = (type == VHOST_ADDR_USED) ? 1159 VHOST_ACCESS_WO : VHOST_ACCESS_RO; 1160 1161 if (likely(node->perm & access)) 1162 vq->meta_iotlb[type] = node; 1163 } 1164 1165 static int iotlb_access_ok(struct vhost_virtqueue *vq, 1166 int access, u64 addr, u64 len, int type) 1167 { 1168 const struct vhost_umem_node *node; 1169 struct vhost_umem *umem = vq->iotlb; 1170 u64 s = 0, size, orig_addr = addr, last = addr + len - 1; 1171 1172 if (vhost_vq_meta_fetch(vq, addr, len, type)) 1173 return true; 1174 1175 while (len > s) { 1176 node = vhost_umem_interval_tree_iter_first(&umem->umem_tree, 1177 addr, 1178 last); 1179 if (node == NULL || node->start > addr) { 1180 vhost_iotlb_miss(vq, addr, access); 1181 return false; 1182 } else if (!(node->perm & access)) { 1183 /* Report the possible access violation by 1184 * request another translation from userspace. 1185 */ 1186 return false; 1187 } 1188 1189 size = node->size - addr + node->start; 1190 1191 if (orig_addr == addr && size >= len) 1192 vhost_vq_meta_update(vq, node, type); 1193 1194 s += size; 1195 addr += size; 1196 } 1197 1198 return true; 1199 } 1200 1201 int vq_iotlb_prefetch(struct vhost_virtqueue *vq) 1202 { 1203 size_t s = vhost_has_feature(vq, VIRTIO_RING_F_EVENT_IDX) ? 2 : 0; 1204 unsigned int num = vq->num; 1205 1206 if (!vq->iotlb) 1207 return 1; 1208 1209 return iotlb_access_ok(vq, VHOST_ACCESS_RO, (u64)(uintptr_t)vq->desc, 1210 num * sizeof(*vq->desc), VHOST_ADDR_DESC) && 1211 iotlb_access_ok(vq, VHOST_ACCESS_RO, (u64)(uintptr_t)vq->avail, 1212 sizeof *vq->avail + 1213 num * sizeof(*vq->avail->ring) + s, 1214 VHOST_ADDR_AVAIL) && 1215 iotlb_access_ok(vq, VHOST_ACCESS_WO, (u64)(uintptr_t)vq->used, 1216 sizeof *vq->used + 1217 num * sizeof(*vq->used->ring) + s, 1218 VHOST_ADDR_USED); 1219 } 1220 EXPORT_SYMBOL_GPL(vq_iotlb_prefetch); 1221 1222 /* Can we log writes? */ 1223 /* Caller should have device mutex but not vq mutex */ 1224 int vhost_log_access_ok(struct vhost_dev *dev) 1225 { 1226 return memory_access_ok(dev, dev->umem, 1); 1227 } 1228 EXPORT_SYMBOL_GPL(vhost_log_access_ok); 1229 1230 /* Verify access for write logging. */ 1231 /* Caller should have vq mutex and device mutex */ 1232 static int vq_log_access_ok(struct vhost_virtqueue *vq, 1233 void __user *log_base) 1234 { 1235 size_t s = vhost_has_feature(vq, VIRTIO_RING_F_EVENT_IDX) ? 2 : 0; 1236 1237 return vq_memory_access_ok(log_base, vq->umem, 1238 vhost_has_feature(vq, VHOST_F_LOG_ALL)) && 1239 (!vq->log_used || log_access_ok(log_base, vq->log_addr, 1240 sizeof *vq->used + 1241 vq->num * sizeof *vq->used->ring + s)); 1242 } 1243 1244 /* Can we start vq? */ 1245 /* Caller should have vq mutex and device mutex */ 1246 int vhost_vq_access_ok(struct vhost_virtqueue *vq) 1247 { 1248 if (vq->iotlb) { 1249 /* When device IOTLB was used, the access validation 1250 * will be validated during prefetching. 1251 */ 1252 return 1; 1253 } 1254 return vq_access_ok(vq, vq->num, vq->desc, vq->avail, vq->used) && 1255 vq_log_access_ok(vq, vq->log_base); 1256 } 1257 EXPORT_SYMBOL_GPL(vhost_vq_access_ok); 1258 1259 static struct vhost_umem *vhost_umem_alloc(void) 1260 { 1261 struct vhost_umem *umem = kvzalloc(sizeof(*umem), GFP_KERNEL); 1262 1263 if (!umem) 1264 return NULL; 1265 1266 umem->umem_tree = RB_ROOT_CACHED; 1267 umem->numem = 0; 1268 INIT_LIST_HEAD(&umem->umem_list); 1269 1270 return umem; 1271 } 1272 1273 static long vhost_set_memory(struct vhost_dev *d, struct vhost_memory __user *m) 1274 { 1275 struct vhost_memory mem, *newmem; 1276 struct vhost_memory_region *region; 1277 struct vhost_umem *newumem, *oldumem; 1278 unsigned long size = offsetof(struct vhost_memory, regions); 1279 int i; 1280 1281 if (copy_from_user(&mem, m, size)) 1282 return -EFAULT; 1283 if (mem.padding) 1284 return -EOPNOTSUPP; 1285 if (mem.nregions > max_mem_regions) 1286 return -E2BIG; 1287 newmem = kvzalloc(size + mem.nregions * sizeof(*m->regions), GFP_KERNEL); 1288 if (!newmem) 1289 return -ENOMEM; 1290 1291 memcpy(newmem, &mem, size); 1292 if (copy_from_user(newmem->regions, m->regions, 1293 mem.nregions * sizeof *m->regions)) { 1294 kvfree(newmem); 1295 return -EFAULT; 1296 } 1297 1298 newumem = vhost_umem_alloc(); 1299 if (!newumem) { 1300 kvfree(newmem); 1301 return -ENOMEM; 1302 } 1303 1304 for (region = newmem->regions; 1305 region < newmem->regions + mem.nregions; 1306 region++) { 1307 if (vhost_new_umem_range(newumem, 1308 region->guest_phys_addr, 1309 region->memory_size, 1310 region->guest_phys_addr + 1311 region->memory_size - 1, 1312 region->userspace_addr, 1313 VHOST_ACCESS_RW)) 1314 goto err; 1315 } 1316 1317 if (!memory_access_ok(d, newumem, 0)) 1318 goto err; 1319 1320 oldumem = d->umem; 1321 d->umem = newumem; 1322 1323 /* All memory accesses are done under some VQ mutex. */ 1324 for (i = 0; i < d->nvqs; ++i) { 1325 mutex_lock(&d->vqs[i]->mutex); 1326 d->vqs[i]->umem = newumem; 1327 mutex_unlock(&d->vqs[i]->mutex); 1328 } 1329 1330 kvfree(newmem); 1331 vhost_umem_clean(oldumem); 1332 return 0; 1333 1334 err: 1335 vhost_umem_clean(newumem); 1336 kvfree(newmem); 1337 return -EFAULT; 1338 } 1339 1340 long vhost_vring_ioctl(struct vhost_dev *d, int ioctl, void __user *argp) 1341 { 1342 struct file *eventfp, *filep = NULL; 1343 bool pollstart = false, pollstop = false; 1344 struct eventfd_ctx *ctx = NULL; 1345 u32 __user *idxp = argp; 1346 struct vhost_virtqueue *vq; 1347 struct vhost_vring_state s; 1348 struct vhost_vring_file f; 1349 struct vhost_vring_addr a; 1350 u32 idx; 1351 long r; 1352 1353 r = get_user(idx, idxp); 1354 if (r < 0) 1355 return r; 1356 if (idx >= d->nvqs) 1357 return -ENOBUFS; 1358 1359 vq = d->vqs[idx]; 1360 1361 mutex_lock(&vq->mutex); 1362 1363 switch (ioctl) { 1364 case VHOST_SET_VRING_NUM: 1365 /* Resizing ring with an active backend? 1366 * You don't want to do that. */ 1367 if (vq->private_data) { 1368 r = -EBUSY; 1369 break; 1370 } 1371 if (copy_from_user(&s, argp, sizeof s)) { 1372 r = -EFAULT; 1373 break; 1374 } 1375 if (!s.num || s.num > 0xffff || (s.num & (s.num - 1))) { 1376 r = -EINVAL; 1377 break; 1378 } 1379 vq->num = s.num; 1380 break; 1381 case VHOST_SET_VRING_BASE: 1382 /* Moving base with an active backend? 1383 * You don't want to do that. */ 1384 if (vq->private_data) { 1385 r = -EBUSY; 1386 break; 1387 } 1388 if (copy_from_user(&s, argp, sizeof s)) { 1389 r = -EFAULT; 1390 break; 1391 } 1392 if (s.num > 0xffff) { 1393 r = -EINVAL; 1394 break; 1395 } 1396 vq->last_avail_idx = s.num; 1397 /* Forget the cached index value. */ 1398 vq->avail_idx = vq->last_avail_idx; 1399 break; 1400 case VHOST_GET_VRING_BASE: 1401 s.index = idx; 1402 s.num = vq->last_avail_idx; 1403 if (copy_to_user(argp, &s, sizeof s)) 1404 r = -EFAULT; 1405 break; 1406 case VHOST_SET_VRING_ADDR: 1407 if (copy_from_user(&a, argp, sizeof a)) { 1408 r = -EFAULT; 1409 break; 1410 } 1411 if (a.flags & ~(0x1 << VHOST_VRING_F_LOG)) { 1412 r = -EOPNOTSUPP; 1413 break; 1414 } 1415 /* For 32bit, verify that the top 32bits of the user 1416 data are set to zero. */ 1417 if ((u64)(unsigned long)a.desc_user_addr != a.desc_user_addr || 1418 (u64)(unsigned long)a.used_user_addr != a.used_user_addr || 1419 (u64)(unsigned long)a.avail_user_addr != a.avail_user_addr) { 1420 r = -EFAULT; 1421 break; 1422 } 1423 1424 /* Make sure it's safe to cast pointers to vring types. */ 1425 BUILD_BUG_ON(__alignof__ *vq->avail > VRING_AVAIL_ALIGN_SIZE); 1426 BUILD_BUG_ON(__alignof__ *vq->used > VRING_USED_ALIGN_SIZE); 1427 if ((a.avail_user_addr & (VRING_AVAIL_ALIGN_SIZE - 1)) || 1428 (a.used_user_addr & (VRING_USED_ALIGN_SIZE - 1)) || 1429 (a.log_guest_addr & (VRING_USED_ALIGN_SIZE - 1))) { 1430 r = -EINVAL; 1431 break; 1432 } 1433 1434 /* We only verify access here if backend is configured. 1435 * If it is not, we don't as size might not have been setup. 1436 * We will verify when backend is configured. */ 1437 if (vq->private_data) { 1438 if (!vq_access_ok(vq, vq->num, 1439 (void __user *)(unsigned long)a.desc_user_addr, 1440 (void __user *)(unsigned long)a.avail_user_addr, 1441 (void __user *)(unsigned long)a.used_user_addr)) { 1442 r = -EINVAL; 1443 break; 1444 } 1445 1446 /* Also validate log access for used ring if enabled. */ 1447 if ((a.flags & (0x1 << VHOST_VRING_F_LOG)) && 1448 !log_access_ok(vq->log_base, a.log_guest_addr, 1449 sizeof *vq->used + 1450 vq->num * sizeof *vq->used->ring)) { 1451 r = -EINVAL; 1452 break; 1453 } 1454 } 1455 1456 vq->log_used = !!(a.flags & (0x1 << VHOST_VRING_F_LOG)); 1457 vq->desc = (void __user *)(unsigned long)a.desc_user_addr; 1458 vq->avail = (void __user *)(unsigned long)a.avail_user_addr; 1459 vq->log_addr = a.log_guest_addr; 1460 vq->used = (void __user *)(unsigned long)a.used_user_addr; 1461 break; 1462 case VHOST_SET_VRING_KICK: 1463 if (copy_from_user(&f, argp, sizeof f)) { 1464 r = -EFAULT; 1465 break; 1466 } 1467 eventfp = f.fd == -1 ? NULL : eventfd_fget(f.fd); 1468 if (IS_ERR(eventfp)) { 1469 r = PTR_ERR(eventfp); 1470 break; 1471 } 1472 if (eventfp != vq->kick) { 1473 pollstop = (filep = vq->kick) != NULL; 1474 pollstart = (vq->kick = eventfp) != NULL; 1475 } else 1476 filep = eventfp; 1477 break; 1478 case VHOST_SET_VRING_CALL: 1479 if (copy_from_user(&f, argp, sizeof f)) { 1480 r = -EFAULT; 1481 break; 1482 } 1483 ctx = f.fd == -1 ? NULL : eventfd_ctx_fdget(f.fd); 1484 if (IS_ERR(ctx)) { 1485 r = PTR_ERR(ctx); 1486 break; 1487 } 1488 swap(ctx, vq->call_ctx); 1489 break; 1490 case VHOST_SET_VRING_ERR: 1491 if (copy_from_user(&f, argp, sizeof f)) { 1492 r = -EFAULT; 1493 break; 1494 } 1495 ctx = f.fd == -1 ? NULL : eventfd_ctx_fdget(f.fd); 1496 if (IS_ERR(ctx)) { 1497 r = PTR_ERR(ctx); 1498 break; 1499 } 1500 swap(ctx, vq->error_ctx); 1501 break; 1502 case VHOST_SET_VRING_ENDIAN: 1503 r = vhost_set_vring_endian(vq, argp); 1504 break; 1505 case VHOST_GET_VRING_ENDIAN: 1506 r = vhost_get_vring_endian(vq, idx, argp); 1507 break; 1508 case VHOST_SET_VRING_BUSYLOOP_TIMEOUT: 1509 if (copy_from_user(&s, argp, sizeof(s))) { 1510 r = -EFAULT; 1511 break; 1512 } 1513 vq->busyloop_timeout = s.num; 1514 break; 1515 case VHOST_GET_VRING_BUSYLOOP_TIMEOUT: 1516 s.index = idx; 1517 s.num = vq->busyloop_timeout; 1518 if (copy_to_user(argp, &s, sizeof(s))) 1519 r = -EFAULT; 1520 break; 1521 default: 1522 r = -ENOIOCTLCMD; 1523 } 1524 1525 if (pollstop && vq->handle_kick) 1526 vhost_poll_stop(&vq->poll); 1527 1528 if (!IS_ERR_OR_NULL(ctx)) 1529 eventfd_ctx_put(ctx); 1530 if (filep) 1531 fput(filep); 1532 1533 if (pollstart && vq->handle_kick) 1534 r = vhost_poll_start(&vq->poll, vq->kick); 1535 1536 mutex_unlock(&vq->mutex); 1537 1538 if (pollstop && vq->handle_kick) 1539 vhost_poll_flush(&vq->poll); 1540 return r; 1541 } 1542 EXPORT_SYMBOL_GPL(vhost_vring_ioctl); 1543 1544 int vhost_init_device_iotlb(struct vhost_dev *d, bool enabled) 1545 { 1546 struct vhost_umem *niotlb, *oiotlb; 1547 int i; 1548 1549 niotlb = vhost_umem_alloc(); 1550 if (!niotlb) 1551 return -ENOMEM; 1552 1553 oiotlb = d->iotlb; 1554 d->iotlb = niotlb; 1555 1556 for (i = 0; i < d->nvqs; ++i) { 1557 mutex_lock(&d->vqs[i]->mutex); 1558 d->vqs[i]->iotlb = niotlb; 1559 mutex_unlock(&d->vqs[i]->mutex); 1560 } 1561 1562 vhost_umem_clean(oiotlb); 1563 1564 return 0; 1565 } 1566 EXPORT_SYMBOL_GPL(vhost_init_device_iotlb); 1567 1568 /* Caller must have device mutex */ 1569 long vhost_dev_ioctl(struct vhost_dev *d, unsigned int ioctl, void __user *argp) 1570 { 1571 struct eventfd_ctx *ctx; 1572 u64 p; 1573 long r; 1574 int i, fd; 1575 1576 /* If you are not the owner, you can become one */ 1577 if (ioctl == VHOST_SET_OWNER) { 1578 r = vhost_dev_set_owner(d); 1579 goto done; 1580 } 1581 1582 /* You must be the owner to do anything else */ 1583 r = vhost_dev_check_owner(d); 1584 if (r) 1585 goto done; 1586 1587 switch (ioctl) { 1588 case VHOST_SET_MEM_TABLE: 1589 r = vhost_set_memory(d, argp); 1590 break; 1591 case VHOST_SET_LOG_BASE: 1592 if (copy_from_user(&p, argp, sizeof p)) { 1593 r = -EFAULT; 1594 break; 1595 } 1596 if ((u64)(unsigned long)p != p) { 1597 r = -EFAULT; 1598 break; 1599 } 1600 for (i = 0; i < d->nvqs; ++i) { 1601 struct vhost_virtqueue *vq; 1602 void __user *base = (void __user *)(unsigned long)p; 1603 vq = d->vqs[i]; 1604 mutex_lock(&vq->mutex); 1605 /* If ring is inactive, will check when it's enabled. */ 1606 if (vq->private_data && !vq_log_access_ok(vq, base)) 1607 r = -EFAULT; 1608 else 1609 vq->log_base = base; 1610 mutex_unlock(&vq->mutex); 1611 } 1612 break; 1613 case VHOST_SET_LOG_FD: 1614 r = get_user(fd, (int __user *)argp); 1615 if (r < 0) 1616 break; 1617 ctx = fd == -1 ? NULL : eventfd_ctx_fdget(fd); 1618 if (IS_ERR(ctx)) { 1619 r = PTR_ERR(ctx); 1620 break; 1621 } 1622 swap(ctx, d->log_ctx); 1623 for (i = 0; i < d->nvqs; ++i) { 1624 mutex_lock(&d->vqs[i]->mutex); 1625 d->vqs[i]->log_ctx = d->log_ctx; 1626 mutex_unlock(&d->vqs[i]->mutex); 1627 } 1628 if (ctx) 1629 eventfd_ctx_put(ctx); 1630 break; 1631 default: 1632 r = -ENOIOCTLCMD; 1633 break; 1634 } 1635 done: 1636 return r; 1637 } 1638 EXPORT_SYMBOL_GPL(vhost_dev_ioctl); 1639 1640 /* TODO: This is really inefficient. We need something like get_user() 1641 * (instruction directly accesses the data, with an exception table entry 1642 * returning -EFAULT). See Documentation/x86/exception-tables.txt. 1643 */ 1644 static int set_bit_to_user(int nr, void __user *addr) 1645 { 1646 unsigned long log = (unsigned long)addr; 1647 struct page *page; 1648 void *base; 1649 int bit = nr + (log % PAGE_SIZE) * 8; 1650 int r; 1651 1652 r = get_user_pages_fast(log, 1, 1, &page); 1653 if (r < 0) 1654 return r; 1655 BUG_ON(r != 1); 1656 base = kmap_atomic(page); 1657 set_bit(bit, base); 1658 kunmap_atomic(base); 1659 set_page_dirty_lock(page); 1660 put_page(page); 1661 return 0; 1662 } 1663 1664 static int log_write(void __user *log_base, 1665 u64 write_address, u64 write_length) 1666 { 1667 u64 write_page = write_address / VHOST_PAGE_SIZE; 1668 int r; 1669 1670 if (!write_length) 1671 return 0; 1672 write_length += write_address % VHOST_PAGE_SIZE; 1673 for (;;) { 1674 u64 base = (u64)(unsigned long)log_base; 1675 u64 log = base + write_page / 8; 1676 int bit = write_page % 8; 1677 if ((u64)(unsigned long)log != log) 1678 return -EFAULT; 1679 r = set_bit_to_user(bit, (void __user *)(unsigned long)log); 1680 if (r < 0) 1681 return r; 1682 if (write_length <= VHOST_PAGE_SIZE) 1683 break; 1684 write_length -= VHOST_PAGE_SIZE; 1685 write_page += 1; 1686 } 1687 return r; 1688 } 1689 1690 int vhost_log_write(struct vhost_virtqueue *vq, struct vhost_log *log, 1691 unsigned int log_num, u64 len) 1692 { 1693 int i, r; 1694 1695 /* Make sure data written is seen before log. */ 1696 smp_wmb(); 1697 for (i = 0; i < log_num; ++i) { 1698 u64 l = min(log[i].len, len); 1699 r = log_write(vq->log_base, log[i].addr, l); 1700 if (r < 0) 1701 return r; 1702 len -= l; 1703 if (!len) { 1704 if (vq->log_ctx) 1705 eventfd_signal(vq->log_ctx, 1); 1706 return 0; 1707 } 1708 } 1709 /* Length written exceeds what we have stored. This is a bug. */ 1710 BUG(); 1711 return 0; 1712 } 1713 EXPORT_SYMBOL_GPL(vhost_log_write); 1714 1715 static int vhost_update_used_flags(struct vhost_virtqueue *vq) 1716 { 1717 void __user *used; 1718 if (vhost_put_user(vq, cpu_to_vhost16(vq, vq->used_flags), 1719 &vq->used->flags) < 0) 1720 return -EFAULT; 1721 if (unlikely(vq->log_used)) { 1722 /* Make sure the flag is seen before log. */ 1723 smp_wmb(); 1724 /* Log used flag write. */ 1725 used = &vq->used->flags; 1726 log_write(vq->log_base, vq->log_addr + 1727 (used - (void __user *)vq->used), 1728 sizeof vq->used->flags); 1729 if (vq->log_ctx) 1730 eventfd_signal(vq->log_ctx, 1); 1731 } 1732 return 0; 1733 } 1734 1735 static int vhost_update_avail_event(struct vhost_virtqueue *vq, u16 avail_event) 1736 { 1737 if (vhost_put_user(vq, cpu_to_vhost16(vq, vq->avail_idx), 1738 vhost_avail_event(vq))) 1739 return -EFAULT; 1740 if (unlikely(vq->log_used)) { 1741 void __user *used; 1742 /* Make sure the event is seen before log. */ 1743 smp_wmb(); 1744 /* Log avail event write */ 1745 used = vhost_avail_event(vq); 1746 log_write(vq->log_base, vq->log_addr + 1747 (used - (void __user *)vq->used), 1748 sizeof *vhost_avail_event(vq)); 1749 if (vq->log_ctx) 1750 eventfd_signal(vq->log_ctx, 1); 1751 } 1752 return 0; 1753 } 1754 1755 int vhost_vq_init_access(struct vhost_virtqueue *vq) 1756 { 1757 __virtio16 last_used_idx; 1758 int r; 1759 bool is_le = vq->is_le; 1760 1761 if (!vq->private_data) 1762 return 0; 1763 1764 vhost_init_is_le(vq); 1765 1766 r = vhost_update_used_flags(vq); 1767 if (r) 1768 goto err; 1769 vq->signalled_used_valid = false; 1770 if (!vq->iotlb && 1771 !access_ok(VERIFY_READ, &vq->used->idx, sizeof vq->used->idx)) { 1772 r = -EFAULT; 1773 goto err; 1774 } 1775 r = vhost_get_used(vq, last_used_idx, &vq->used->idx); 1776 if (r) { 1777 vq_err(vq, "Can't access used idx at %p\n", 1778 &vq->used->idx); 1779 goto err; 1780 } 1781 vq->last_used_idx = vhost16_to_cpu(vq, last_used_idx); 1782 return 0; 1783 1784 err: 1785 vq->is_le = is_le; 1786 return r; 1787 } 1788 EXPORT_SYMBOL_GPL(vhost_vq_init_access); 1789 1790 static int translate_desc(struct vhost_virtqueue *vq, u64 addr, u32 len, 1791 struct iovec iov[], int iov_size, int access) 1792 { 1793 const struct vhost_umem_node *node; 1794 struct vhost_dev *dev = vq->dev; 1795 struct vhost_umem *umem = dev->iotlb ? dev->iotlb : dev->umem; 1796 struct iovec *_iov; 1797 u64 s = 0; 1798 int ret = 0; 1799 1800 while ((u64)len > s) { 1801 u64 size; 1802 if (unlikely(ret >= iov_size)) { 1803 ret = -ENOBUFS; 1804 break; 1805 } 1806 1807 node = vhost_umem_interval_tree_iter_first(&umem->umem_tree, 1808 addr, addr + len - 1); 1809 if (node == NULL || node->start > addr) { 1810 if (umem != dev->iotlb) { 1811 ret = -EFAULT; 1812 break; 1813 } 1814 ret = -EAGAIN; 1815 break; 1816 } else if (!(node->perm & access)) { 1817 ret = -EPERM; 1818 break; 1819 } 1820 1821 _iov = iov + ret; 1822 size = node->size - addr + node->start; 1823 _iov->iov_len = min((u64)len - s, size); 1824 _iov->iov_base = (void __user *)(unsigned long) 1825 (node->userspace_addr + addr - node->start); 1826 s += size; 1827 addr += size; 1828 ++ret; 1829 } 1830 1831 if (ret == -EAGAIN) 1832 vhost_iotlb_miss(vq, addr, access); 1833 return ret; 1834 } 1835 1836 /* Each buffer in the virtqueues is actually a chain of descriptors. This 1837 * function returns the next descriptor in the chain, 1838 * or -1U if we're at the end. */ 1839 static unsigned next_desc(struct vhost_virtqueue *vq, struct vring_desc *desc) 1840 { 1841 unsigned int next; 1842 1843 /* If this descriptor says it doesn't chain, we're done. */ 1844 if (!(desc->flags & cpu_to_vhost16(vq, VRING_DESC_F_NEXT))) 1845 return -1U; 1846 1847 /* Check they're not leading us off end of descriptors. */ 1848 next = vhost16_to_cpu(vq, READ_ONCE(desc->next)); 1849 return next; 1850 } 1851 1852 static int get_indirect(struct vhost_virtqueue *vq, 1853 struct iovec iov[], unsigned int iov_size, 1854 unsigned int *out_num, unsigned int *in_num, 1855 struct vhost_log *log, unsigned int *log_num, 1856 struct vring_desc *indirect) 1857 { 1858 struct vring_desc desc; 1859 unsigned int i = 0, count, found = 0; 1860 u32 len = vhost32_to_cpu(vq, indirect->len); 1861 struct iov_iter from; 1862 int ret, access; 1863 1864 /* Sanity check */ 1865 if (unlikely(len % sizeof desc)) { 1866 vq_err(vq, "Invalid length in indirect descriptor: " 1867 "len 0x%llx not multiple of 0x%zx\n", 1868 (unsigned long long)len, 1869 sizeof desc); 1870 return -EINVAL; 1871 } 1872 1873 ret = translate_desc(vq, vhost64_to_cpu(vq, indirect->addr), len, vq->indirect, 1874 UIO_MAXIOV, VHOST_ACCESS_RO); 1875 if (unlikely(ret < 0)) { 1876 if (ret != -EAGAIN) 1877 vq_err(vq, "Translation failure %d in indirect.\n", ret); 1878 return ret; 1879 } 1880 iov_iter_init(&from, READ, vq->indirect, ret, len); 1881 1882 /* We will use the result as an address to read from, so most 1883 * architectures only need a compiler barrier here. */ 1884 read_barrier_depends(); 1885 1886 count = len / sizeof desc; 1887 /* Buffers are chained via a 16 bit next field, so 1888 * we can have at most 2^16 of these. */ 1889 if (unlikely(count > USHRT_MAX + 1)) { 1890 vq_err(vq, "Indirect buffer length too big: %d\n", 1891 indirect->len); 1892 return -E2BIG; 1893 } 1894 1895 do { 1896 unsigned iov_count = *in_num + *out_num; 1897 if (unlikely(++found > count)) { 1898 vq_err(vq, "Loop detected: last one at %u " 1899 "indirect size %u\n", 1900 i, count); 1901 return -EINVAL; 1902 } 1903 if (unlikely(!copy_from_iter_full(&desc, sizeof(desc), &from))) { 1904 vq_err(vq, "Failed indirect descriptor: idx %d, %zx\n", 1905 i, (size_t)vhost64_to_cpu(vq, indirect->addr) + i * sizeof desc); 1906 return -EINVAL; 1907 } 1908 if (unlikely(desc.flags & cpu_to_vhost16(vq, VRING_DESC_F_INDIRECT))) { 1909 vq_err(vq, "Nested indirect descriptor: idx %d, %zx\n", 1910 i, (size_t)vhost64_to_cpu(vq, indirect->addr) + i * sizeof desc); 1911 return -EINVAL; 1912 } 1913 1914 if (desc.flags & cpu_to_vhost16(vq, VRING_DESC_F_WRITE)) 1915 access = VHOST_ACCESS_WO; 1916 else 1917 access = VHOST_ACCESS_RO; 1918 1919 ret = translate_desc(vq, vhost64_to_cpu(vq, desc.addr), 1920 vhost32_to_cpu(vq, desc.len), iov + iov_count, 1921 iov_size - iov_count, access); 1922 if (unlikely(ret < 0)) { 1923 if (ret != -EAGAIN) 1924 vq_err(vq, "Translation failure %d indirect idx %d\n", 1925 ret, i); 1926 return ret; 1927 } 1928 /* If this is an input descriptor, increment that count. */ 1929 if (access == VHOST_ACCESS_WO) { 1930 *in_num += ret; 1931 if (unlikely(log)) { 1932 log[*log_num].addr = vhost64_to_cpu(vq, desc.addr); 1933 log[*log_num].len = vhost32_to_cpu(vq, desc.len); 1934 ++*log_num; 1935 } 1936 } else { 1937 /* If it's an output descriptor, they're all supposed 1938 * to come before any input descriptors. */ 1939 if (unlikely(*in_num)) { 1940 vq_err(vq, "Indirect descriptor " 1941 "has out after in: idx %d\n", i); 1942 return -EINVAL; 1943 } 1944 *out_num += ret; 1945 } 1946 } while ((i = next_desc(vq, &desc)) != -1); 1947 return 0; 1948 } 1949 1950 /* This looks in the virtqueue and for the first available buffer, and converts 1951 * it to an iovec for convenient access. Since descriptors consist of some 1952 * number of output then some number of input descriptors, it's actually two 1953 * iovecs, but we pack them into one and note how many of each there were. 1954 * 1955 * This function returns the descriptor number found, or vq->num (which is 1956 * never a valid descriptor number) if none was found. A negative code is 1957 * returned on error. */ 1958 int vhost_get_vq_desc(struct vhost_virtqueue *vq, 1959 struct iovec iov[], unsigned int iov_size, 1960 unsigned int *out_num, unsigned int *in_num, 1961 struct vhost_log *log, unsigned int *log_num) 1962 { 1963 struct vring_desc desc; 1964 unsigned int i, head, found = 0; 1965 u16 last_avail_idx; 1966 __virtio16 avail_idx; 1967 __virtio16 ring_head; 1968 int ret, access; 1969 1970 /* Check it isn't doing very strange things with descriptor numbers. */ 1971 last_avail_idx = vq->last_avail_idx; 1972 1973 if (vq->avail_idx == vq->last_avail_idx) { 1974 if (unlikely(vhost_get_avail(vq, avail_idx, &vq->avail->idx))) { 1975 vq_err(vq, "Failed to access avail idx at %p\n", 1976 &vq->avail->idx); 1977 return -EFAULT; 1978 } 1979 vq->avail_idx = vhost16_to_cpu(vq, avail_idx); 1980 1981 if (unlikely((u16)(vq->avail_idx - last_avail_idx) > vq->num)) { 1982 vq_err(vq, "Guest moved used index from %u to %u", 1983 last_avail_idx, vq->avail_idx); 1984 return -EFAULT; 1985 } 1986 1987 /* If there's nothing new since last we looked, return 1988 * invalid. 1989 */ 1990 if (vq->avail_idx == last_avail_idx) 1991 return vq->num; 1992 1993 /* Only get avail ring entries after they have been 1994 * exposed by guest. 1995 */ 1996 smp_rmb(); 1997 } 1998 1999 /* Grab the next descriptor number they're advertising, and increment 2000 * the index we've seen. */ 2001 if (unlikely(vhost_get_avail(vq, ring_head, 2002 &vq->avail->ring[last_avail_idx & (vq->num - 1)]))) { 2003 vq_err(vq, "Failed to read head: idx %d address %p\n", 2004 last_avail_idx, 2005 &vq->avail->ring[last_avail_idx % vq->num]); 2006 return -EFAULT; 2007 } 2008 2009 head = vhost16_to_cpu(vq, ring_head); 2010 2011 /* If their number is silly, that's an error. */ 2012 if (unlikely(head >= vq->num)) { 2013 vq_err(vq, "Guest says index %u > %u is available", 2014 head, vq->num); 2015 return -EINVAL; 2016 } 2017 2018 /* When we start there are none of either input nor output. */ 2019 *out_num = *in_num = 0; 2020 if (unlikely(log)) 2021 *log_num = 0; 2022 2023 i = head; 2024 do { 2025 unsigned iov_count = *in_num + *out_num; 2026 if (unlikely(i >= vq->num)) { 2027 vq_err(vq, "Desc index is %u > %u, head = %u", 2028 i, vq->num, head); 2029 return -EINVAL; 2030 } 2031 if (unlikely(++found > vq->num)) { 2032 vq_err(vq, "Loop detected: last one at %u " 2033 "vq size %u head %u\n", 2034 i, vq->num, head); 2035 return -EINVAL; 2036 } 2037 ret = vhost_copy_from_user(vq, &desc, vq->desc + i, 2038 sizeof desc); 2039 if (unlikely(ret)) { 2040 vq_err(vq, "Failed to get descriptor: idx %d addr %p\n", 2041 i, vq->desc + i); 2042 return -EFAULT; 2043 } 2044 if (desc.flags & cpu_to_vhost16(vq, VRING_DESC_F_INDIRECT)) { 2045 ret = get_indirect(vq, iov, iov_size, 2046 out_num, in_num, 2047 log, log_num, &desc); 2048 if (unlikely(ret < 0)) { 2049 if (ret != -EAGAIN) 2050 vq_err(vq, "Failure detected " 2051 "in indirect descriptor at idx %d\n", i); 2052 return ret; 2053 } 2054 continue; 2055 } 2056 2057 if (desc.flags & cpu_to_vhost16(vq, VRING_DESC_F_WRITE)) 2058 access = VHOST_ACCESS_WO; 2059 else 2060 access = VHOST_ACCESS_RO; 2061 ret = translate_desc(vq, vhost64_to_cpu(vq, desc.addr), 2062 vhost32_to_cpu(vq, desc.len), iov + iov_count, 2063 iov_size - iov_count, access); 2064 if (unlikely(ret < 0)) { 2065 if (ret != -EAGAIN) 2066 vq_err(vq, "Translation failure %d descriptor idx %d\n", 2067 ret, i); 2068 return ret; 2069 } 2070 if (access == VHOST_ACCESS_WO) { 2071 /* If this is an input descriptor, 2072 * increment that count. */ 2073 *in_num += ret; 2074 if (unlikely(log)) { 2075 log[*log_num].addr = vhost64_to_cpu(vq, desc.addr); 2076 log[*log_num].len = vhost32_to_cpu(vq, desc.len); 2077 ++*log_num; 2078 } 2079 } else { 2080 /* If it's an output descriptor, they're all supposed 2081 * to come before any input descriptors. */ 2082 if (unlikely(*in_num)) { 2083 vq_err(vq, "Descriptor has out after in: " 2084 "idx %d\n", i); 2085 return -EINVAL; 2086 } 2087 *out_num += ret; 2088 } 2089 } while ((i = next_desc(vq, &desc)) != -1); 2090 2091 /* On success, increment avail index. */ 2092 vq->last_avail_idx++; 2093 2094 /* Assume notifications from guest are disabled at this point, 2095 * if they aren't we would need to update avail_event index. */ 2096 BUG_ON(!(vq->used_flags & VRING_USED_F_NO_NOTIFY)); 2097 return head; 2098 } 2099 EXPORT_SYMBOL_GPL(vhost_get_vq_desc); 2100 2101 /* Reverse the effect of vhost_get_vq_desc. Useful for error handling. */ 2102 void vhost_discard_vq_desc(struct vhost_virtqueue *vq, int n) 2103 { 2104 vq->last_avail_idx -= n; 2105 } 2106 EXPORT_SYMBOL_GPL(vhost_discard_vq_desc); 2107 2108 /* After we've used one of their buffers, we tell them about it. We'll then 2109 * want to notify the guest, using eventfd. */ 2110 int vhost_add_used(struct vhost_virtqueue *vq, unsigned int head, int len) 2111 { 2112 struct vring_used_elem heads = { 2113 cpu_to_vhost32(vq, head), 2114 cpu_to_vhost32(vq, len) 2115 }; 2116 2117 return vhost_add_used_n(vq, &heads, 1); 2118 } 2119 EXPORT_SYMBOL_GPL(vhost_add_used); 2120 2121 static int __vhost_add_used_n(struct vhost_virtqueue *vq, 2122 struct vring_used_elem *heads, 2123 unsigned count) 2124 { 2125 struct vring_used_elem __user *used; 2126 u16 old, new; 2127 int start; 2128 2129 start = vq->last_used_idx & (vq->num - 1); 2130 used = vq->used->ring + start; 2131 if (count == 1) { 2132 if (vhost_put_user(vq, heads[0].id, &used->id)) { 2133 vq_err(vq, "Failed to write used id"); 2134 return -EFAULT; 2135 } 2136 if (vhost_put_user(vq, heads[0].len, &used->len)) { 2137 vq_err(vq, "Failed to write used len"); 2138 return -EFAULT; 2139 } 2140 } else if (vhost_copy_to_user(vq, used, heads, count * sizeof *used)) { 2141 vq_err(vq, "Failed to write used"); 2142 return -EFAULT; 2143 } 2144 if (unlikely(vq->log_used)) { 2145 /* Make sure data is seen before log. */ 2146 smp_wmb(); 2147 /* Log used ring entry write. */ 2148 log_write(vq->log_base, 2149 vq->log_addr + 2150 ((void __user *)used - (void __user *)vq->used), 2151 count * sizeof *used); 2152 } 2153 old = vq->last_used_idx; 2154 new = (vq->last_used_idx += count); 2155 /* If the driver never bothers to signal in a very long while, 2156 * used index might wrap around. If that happens, invalidate 2157 * signalled_used index we stored. TODO: make sure driver 2158 * signals at least once in 2^16 and remove this. */ 2159 if (unlikely((u16)(new - vq->signalled_used) < (u16)(new - old))) 2160 vq->signalled_used_valid = false; 2161 return 0; 2162 } 2163 2164 /* After we've used one of their buffers, we tell them about it. We'll then 2165 * want to notify the guest, using eventfd. */ 2166 int vhost_add_used_n(struct vhost_virtqueue *vq, struct vring_used_elem *heads, 2167 unsigned count) 2168 { 2169 int start, n, r; 2170 2171 start = vq->last_used_idx & (vq->num - 1); 2172 n = vq->num - start; 2173 if (n < count) { 2174 r = __vhost_add_used_n(vq, heads, n); 2175 if (r < 0) 2176 return r; 2177 heads += n; 2178 count -= n; 2179 } 2180 r = __vhost_add_used_n(vq, heads, count); 2181 2182 /* Make sure buffer is written before we update index. */ 2183 smp_wmb(); 2184 if (vhost_put_user(vq, cpu_to_vhost16(vq, vq->last_used_idx), 2185 &vq->used->idx)) { 2186 vq_err(vq, "Failed to increment used idx"); 2187 return -EFAULT; 2188 } 2189 if (unlikely(vq->log_used)) { 2190 /* Log used index update. */ 2191 log_write(vq->log_base, 2192 vq->log_addr + offsetof(struct vring_used, idx), 2193 sizeof vq->used->idx); 2194 if (vq->log_ctx) 2195 eventfd_signal(vq->log_ctx, 1); 2196 } 2197 return r; 2198 } 2199 EXPORT_SYMBOL_GPL(vhost_add_used_n); 2200 2201 static bool vhost_notify(struct vhost_dev *dev, struct vhost_virtqueue *vq) 2202 { 2203 __u16 old, new; 2204 __virtio16 event; 2205 bool v; 2206 /* Flush out used index updates. This is paired 2207 * with the barrier that the Guest executes when enabling 2208 * interrupts. */ 2209 smp_mb(); 2210 2211 if (vhost_has_feature(vq, VIRTIO_F_NOTIFY_ON_EMPTY) && 2212 unlikely(vq->avail_idx == vq->last_avail_idx)) 2213 return true; 2214 2215 if (!vhost_has_feature(vq, VIRTIO_RING_F_EVENT_IDX)) { 2216 __virtio16 flags; 2217 if (vhost_get_avail(vq, flags, &vq->avail->flags)) { 2218 vq_err(vq, "Failed to get flags"); 2219 return true; 2220 } 2221 return !(flags & cpu_to_vhost16(vq, VRING_AVAIL_F_NO_INTERRUPT)); 2222 } 2223 old = vq->signalled_used; 2224 v = vq->signalled_used_valid; 2225 new = vq->signalled_used = vq->last_used_idx; 2226 vq->signalled_used_valid = true; 2227 2228 if (unlikely(!v)) 2229 return true; 2230 2231 if (vhost_get_avail(vq, event, vhost_used_event(vq))) { 2232 vq_err(vq, "Failed to get used event idx"); 2233 return true; 2234 } 2235 return vring_need_event(vhost16_to_cpu(vq, event), new, old); 2236 } 2237 2238 /* This actually signals the guest, using eventfd. */ 2239 void vhost_signal(struct vhost_dev *dev, struct vhost_virtqueue *vq) 2240 { 2241 /* Signal the Guest tell them we used something up. */ 2242 if (vq->call_ctx && vhost_notify(dev, vq)) 2243 eventfd_signal(vq->call_ctx, 1); 2244 } 2245 EXPORT_SYMBOL_GPL(vhost_signal); 2246 2247 /* And here's the combo meal deal. Supersize me! */ 2248 void vhost_add_used_and_signal(struct vhost_dev *dev, 2249 struct vhost_virtqueue *vq, 2250 unsigned int head, int len) 2251 { 2252 vhost_add_used(vq, head, len); 2253 vhost_signal(dev, vq); 2254 } 2255 EXPORT_SYMBOL_GPL(vhost_add_used_and_signal); 2256 2257 /* multi-buffer version of vhost_add_used_and_signal */ 2258 void vhost_add_used_and_signal_n(struct vhost_dev *dev, 2259 struct vhost_virtqueue *vq, 2260 struct vring_used_elem *heads, unsigned count) 2261 { 2262 vhost_add_used_n(vq, heads, count); 2263 vhost_signal(dev, vq); 2264 } 2265 EXPORT_SYMBOL_GPL(vhost_add_used_and_signal_n); 2266 2267 /* return true if we're sure that avaiable ring is empty */ 2268 bool vhost_vq_avail_empty(struct vhost_dev *dev, struct vhost_virtqueue *vq) 2269 { 2270 __virtio16 avail_idx; 2271 int r; 2272 2273 if (vq->avail_idx != vq->last_avail_idx) 2274 return false; 2275 2276 r = vhost_get_avail(vq, avail_idx, &vq->avail->idx); 2277 if (unlikely(r)) 2278 return false; 2279 vq->avail_idx = vhost16_to_cpu(vq, avail_idx); 2280 2281 return vq->avail_idx == vq->last_avail_idx; 2282 } 2283 EXPORT_SYMBOL_GPL(vhost_vq_avail_empty); 2284 2285 /* OK, now we need to know about added descriptors. */ 2286 bool vhost_enable_notify(struct vhost_dev *dev, struct vhost_virtqueue *vq) 2287 { 2288 __virtio16 avail_idx; 2289 int r; 2290 2291 if (!(vq->used_flags & VRING_USED_F_NO_NOTIFY)) 2292 return false; 2293 vq->used_flags &= ~VRING_USED_F_NO_NOTIFY; 2294 if (!vhost_has_feature(vq, VIRTIO_RING_F_EVENT_IDX)) { 2295 r = vhost_update_used_flags(vq); 2296 if (r) { 2297 vq_err(vq, "Failed to enable notification at %p: %d\n", 2298 &vq->used->flags, r); 2299 return false; 2300 } 2301 } else { 2302 r = vhost_update_avail_event(vq, vq->avail_idx); 2303 if (r) { 2304 vq_err(vq, "Failed to update avail event index at %p: %d\n", 2305 vhost_avail_event(vq), r); 2306 return false; 2307 } 2308 } 2309 /* They could have slipped one in as we were doing that: make 2310 * sure it's written, then check again. */ 2311 smp_mb(); 2312 r = vhost_get_avail(vq, avail_idx, &vq->avail->idx); 2313 if (r) { 2314 vq_err(vq, "Failed to check avail idx at %p: %d\n", 2315 &vq->avail->idx, r); 2316 return false; 2317 } 2318 2319 return vhost16_to_cpu(vq, avail_idx) != vq->avail_idx; 2320 } 2321 EXPORT_SYMBOL_GPL(vhost_enable_notify); 2322 2323 /* We don't need to be notified again. */ 2324 void vhost_disable_notify(struct vhost_dev *dev, struct vhost_virtqueue *vq) 2325 { 2326 int r; 2327 2328 if (vq->used_flags & VRING_USED_F_NO_NOTIFY) 2329 return; 2330 vq->used_flags |= VRING_USED_F_NO_NOTIFY; 2331 if (!vhost_has_feature(vq, VIRTIO_RING_F_EVENT_IDX)) { 2332 r = vhost_update_used_flags(vq); 2333 if (r) 2334 vq_err(vq, "Failed to enable notification at %p: %d\n", 2335 &vq->used->flags, r); 2336 } 2337 } 2338 EXPORT_SYMBOL_GPL(vhost_disable_notify); 2339 2340 /* Create a new message. */ 2341 struct vhost_msg_node *vhost_new_msg(struct vhost_virtqueue *vq, int type) 2342 { 2343 struct vhost_msg_node *node = kmalloc(sizeof *node, GFP_KERNEL); 2344 if (!node) 2345 return NULL; 2346 node->vq = vq; 2347 node->msg.type = type; 2348 return node; 2349 } 2350 EXPORT_SYMBOL_GPL(vhost_new_msg); 2351 2352 void vhost_enqueue_msg(struct vhost_dev *dev, struct list_head *head, 2353 struct vhost_msg_node *node) 2354 { 2355 spin_lock(&dev->iotlb_lock); 2356 list_add_tail(&node->node, head); 2357 spin_unlock(&dev->iotlb_lock); 2358 2359 wake_up_interruptible_poll(&dev->wait, EPOLLIN | EPOLLRDNORM); 2360 } 2361 EXPORT_SYMBOL_GPL(vhost_enqueue_msg); 2362 2363 struct vhost_msg_node *vhost_dequeue_msg(struct vhost_dev *dev, 2364 struct list_head *head) 2365 { 2366 struct vhost_msg_node *node = NULL; 2367 2368 spin_lock(&dev->iotlb_lock); 2369 if (!list_empty(head)) { 2370 node = list_first_entry(head, struct vhost_msg_node, 2371 node); 2372 list_del(&node->node); 2373 } 2374 spin_unlock(&dev->iotlb_lock); 2375 2376 return node; 2377 } 2378 EXPORT_SYMBOL_GPL(vhost_dequeue_msg); 2379 2380 2381 static int __init vhost_init(void) 2382 { 2383 return 0; 2384 } 2385 2386 static void __exit vhost_exit(void) 2387 { 2388 } 2389 2390 module_init(vhost_init); 2391 module_exit(vhost_exit); 2392 2393 MODULE_VERSION("0.0.1"); 2394 MODULE_LICENSE("GPL v2"); 2395 MODULE_AUTHOR("Michael S. Tsirkin"); 2396 MODULE_DESCRIPTION("Host kernel accelerator for virtio"); 2397