1 // SPDX-License-Identifier: GPL-2.0-only 2 /* Copyright (C) 2009 Red Hat, Inc. 3 * Copyright (C) 2006 Rusty Russell IBM Corporation 4 * 5 * Author: Michael S. Tsirkin <mst@redhat.com> 6 * 7 * Inspiration, some code, and most witty comments come from 8 * Documentation/virtual/lguest/lguest.c, by Rusty Russell 9 * 10 * Generic code for virtio server in host kernel. 11 */ 12 13 #include <linux/eventfd.h> 14 #include <linux/vhost.h> 15 #include <linux/uio.h> 16 #include <linux/mm.h> 17 #include <linux/mmu_context.h> 18 #include <linux/miscdevice.h> 19 #include <linux/mutex.h> 20 #include <linux/poll.h> 21 #include <linux/file.h> 22 #include <linux/highmem.h> 23 #include <linux/slab.h> 24 #include <linux/vmalloc.h> 25 #include <linux/kthread.h> 26 #include <linux/cgroup.h> 27 #include <linux/module.h> 28 #include <linux/sort.h> 29 #include <linux/sched/mm.h> 30 #include <linux/sched/signal.h> 31 #include <linux/interval_tree_generic.h> 32 #include <linux/nospec.h> 33 34 #include "vhost.h" 35 36 static ushort max_mem_regions = 64; 37 module_param(max_mem_regions, ushort, 0444); 38 MODULE_PARM_DESC(max_mem_regions, 39 "Maximum number of memory regions in memory map. (default: 64)"); 40 static int max_iotlb_entries = 2048; 41 module_param(max_iotlb_entries, int, 0444); 42 MODULE_PARM_DESC(max_iotlb_entries, 43 "Maximum number of iotlb entries. (default: 2048)"); 44 45 enum { 46 VHOST_MEMORY_F_LOG = 0x1, 47 }; 48 49 #define vhost_used_event(vq) ((__virtio16 __user *)&vq->avail->ring[vq->num]) 50 #define vhost_avail_event(vq) ((__virtio16 __user *)&vq->used->ring[vq->num]) 51 52 INTERVAL_TREE_DEFINE(struct vhost_umem_node, 53 rb, __u64, __subtree_last, 54 START, LAST, static inline, vhost_umem_interval_tree); 55 56 #ifdef CONFIG_VHOST_CROSS_ENDIAN_LEGACY 57 static void vhost_disable_cross_endian(struct vhost_virtqueue *vq) 58 { 59 vq->user_be = !virtio_legacy_is_little_endian(); 60 } 61 62 static void vhost_enable_cross_endian_big(struct vhost_virtqueue *vq) 63 { 64 vq->user_be = true; 65 } 66 67 static void vhost_enable_cross_endian_little(struct vhost_virtqueue *vq) 68 { 69 vq->user_be = false; 70 } 71 72 static long vhost_set_vring_endian(struct vhost_virtqueue *vq, int __user *argp) 73 { 74 struct vhost_vring_state s; 75 76 if (vq->private_data) 77 return -EBUSY; 78 79 if (copy_from_user(&s, argp, sizeof(s))) 80 return -EFAULT; 81 82 if (s.num != VHOST_VRING_LITTLE_ENDIAN && 83 s.num != VHOST_VRING_BIG_ENDIAN) 84 return -EINVAL; 85 86 if (s.num == VHOST_VRING_BIG_ENDIAN) 87 vhost_enable_cross_endian_big(vq); 88 else 89 vhost_enable_cross_endian_little(vq); 90 91 return 0; 92 } 93 94 static long vhost_get_vring_endian(struct vhost_virtqueue *vq, u32 idx, 95 int __user *argp) 96 { 97 struct vhost_vring_state s = { 98 .index = idx, 99 .num = vq->user_be 100 }; 101 102 if (copy_to_user(argp, &s, sizeof(s))) 103 return -EFAULT; 104 105 return 0; 106 } 107 108 static void vhost_init_is_le(struct vhost_virtqueue *vq) 109 { 110 /* Note for legacy virtio: user_be is initialized at reset time 111 * according to the host endianness. If userspace does not set an 112 * explicit endianness, the default behavior is native endian, as 113 * expected by legacy virtio. 114 */ 115 vq->is_le = vhost_has_feature(vq, VIRTIO_F_VERSION_1) || !vq->user_be; 116 } 117 #else 118 static void vhost_disable_cross_endian(struct vhost_virtqueue *vq) 119 { 120 } 121 122 static long vhost_set_vring_endian(struct vhost_virtqueue *vq, int __user *argp) 123 { 124 return -ENOIOCTLCMD; 125 } 126 127 static long vhost_get_vring_endian(struct vhost_virtqueue *vq, u32 idx, 128 int __user *argp) 129 { 130 return -ENOIOCTLCMD; 131 } 132 133 static void vhost_init_is_le(struct vhost_virtqueue *vq) 134 { 135 vq->is_le = vhost_has_feature(vq, VIRTIO_F_VERSION_1) 136 || virtio_legacy_is_little_endian(); 137 } 138 #endif /* CONFIG_VHOST_CROSS_ENDIAN_LEGACY */ 139 140 static void vhost_reset_is_le(struct vhost_virtqueue *vq) 141 { 142 vhost_init_is_le(vq); 143 } 144 145 struct vhost_flush_struct { 146 struct vhost_work work; 147 struct completion wait_event; 148 }; 149 150 static void vhost_flush_work(struct vhost_work *work) 151 { 152 struct vhost_flush_struct *s; 153 154 s = container_of(work, struct vhost_flush_struct, work); 155 complete(&s->wait_event); 156 } 157 158 static void vhost_poll_func(struct file *file, wait_queue_head_t *wqh, 159 poll_table *pt) 160 { 161 struct vhost_poll *poll; 162 163 poll = container_of(pt, struct vhost_poll, table); 164 poll->wqh = wqh; 165 add_wait_queue(wqh, &poll->wait); 166 } 167 168 static int vhost_poll_wakeup(wait_queue_entry_t *wait, unsigned mode, int sync, 169 void *key) 170 { 171 struct vhost_poll *poll = container_of(wait, struct vhost_poll, wait); 172 173 if (!(key_to_poll(key) & poll->mask)) 174 return 0; 175 176 vhost_poll_queue(poll); 177 return 0; 178 } 179 180 void vhost_work_init(struct vhost_work *work, vhost_work_fn_t fn) 181 { 182 clear_bit(VHOST_WORK_QUEUED, &work->flags); 183 work->fn = fn; 184 } 185 EXPORT_SYMBOL_GPL(vhost_work_init); 186 187 /* Init poll structure */ 188 void vhost_poll_init(struct vhost_poll *poll, vhost_work_fn_t fn, 189 __poll_t mask, struct vhost_dev *dev) 190 { 191 init_waitqueue_func_entry(&poll->wait, vhost_poll_wakeup); 192 init_poll_funcptr(&poll->table, vhost_poll_func); 193 poll->mask = mask; 194 poll->dev = dev; 195 poll->wqh = NULL; 196 197 vhost_work_init(&poll->work, fn); 198 } 199 EXPORT_SYMBOL_GPL(vhost_poll_init); 200 201 /* Start polling a file. We add ourselves to file's wait queue. The caller must 202 * keep a reference to a file until after vhost_poll_stop is called. */ 203 int vhost_poll_start(struct vhost_poll *poll, struct file *file) 204 { 205 __poll_t mask; 206 int ret = 0; 207 208 if (poll->wqh) 209 return 0; 210 211 mask = vfs_poll(file, &poll->table); 212 if (mask) 213 vhost_poll_wakeup(&poll->wait, 0, 0, poll_to_key(mask)); 214 if (mask & EPOLLERR) { 215 vhost_poll_stop(poll); 216 ret = -EINVAL; 217 } 218 219 return ret; 220 } 221 EXPORT_SYMBOL_GPL(vhost_poll_start); 222 223 /* Stop polling a file. After this function returns, it becomes safe to drop the 224 * file reference. You must also flush afterwards. */ 225 void vhost_poll_stop(struct vhost_poll *poll) 226 { 227 if (poll->wqh) { 228 remove_wait_queue(poll->wqh, &poll->wait); 229 poll->wqh = NULL; 230 } 231 } 232 EXPORT_SYMBOL_GPL(vhost_poll_stop); 233 234 void vhost_work_flush(struct vhost_dev *dev, struct vhost_work *work) 235 { 236 struct vhost_flush_struct flush; 237 238 if (dev->worker) { 239 init_completion(&flush.wait_event); 240 vhost_work_init(&flush.work, vhost_flush_work); 241 242 vhost_work_queue(dev, &flush.work); 243 wait_for_completion(&flush.wait_event); 244 } 245 } 246 EXPORT_SYMBOL_GPL(vhost_work_flush); 247 248 /* Flush any work that has been scheduled. When calling this, don't hold any 249 * locks that are also used by the callback. */ 250 void vhost_poll_flush(struct vhost_poll *poll) 251 { 252 vhost_work_flush(poll->dev, &poll->work); 253 } 254 EXPORT_SYMBOL_GPL(vhost_poll_flush); 255 256 void vhost_work_queue(struct vhost_dev *dev, struct vhost_work *work) 257 { 258 if (!dev->worker) 259 return; 260 261 if (!test_and_set_bit(VHOST_WORK_QUEUED, &work->flags)) { 262 /* We can only add the work to the list after we're 263 * sure it was not in the list. 264 * test_and_set_bit() implies a memory barrier. 265 */ 266 llist_add(&work->node, &dev->work_list); 267 wake_up_process(dev->worker); 268 } 269 } 270 EXPORT_SYMBOL_GPL(vhost_work_queue); 271 272 /* A lockless hint for busy polling code to exit the loop */ 273 bool vhost_has_work(struct vhost_dev *dev) 274 { 275 return !llist_empty(&dev->work_list); 276 } 277 EXPORT_SYMBOL_GPL(vhost_has_work); 278 279 void vhost_poll_queue(struct vhost_poll *poll) 280 { 281 vhost_work_queue(poll->dev, &poll->work); 282 } 283 EXPORT_SYMBOL_GPL(vhost_poll_queue); 284 285 static void __vhost_vq_meta_reset(struct vhost_virtqueue *vq) 286 { 287 int j; 288 289 for (j = 0; j < VHOST_NUM_ADDRS; j++) 290 vq->meta_iotlb[j] = NULL; 291 } 292 293 static void vhost_vq_meta_reset(struct vhost_dev *d) 294 { 295 int i; 296 297 for (i = 0; i < d->nvqs; ++i) 298 __vhost_vq_meta_reset(d->vqs[i]); 299 } 300 301 #if VHOST_ARCH_CAN_ACCEL_UACCESS 302 static void vhost_map_unprefetch(struct vhost_map *map) 303 { 304 kfree(map->pages); 305 map->pages = NULL; 306 map->npages = 0; 307 map->addr = NULL; 308 } 309 310 static void vhost_uninit_vq_maps(struct vhost_virtqueue *vq) 311 { 312 struct vhost_map *map[VHOST_NUM_ADDRS]; 313 int i; 314 315 spin_lock(&vq->mmu_lock); 316 for (i = 0; i < VHOST_NUM_ADDRS; i++) { 317 map[i] = rcu_dereference_protected(vq->maps[i], 318 lockdep_is_held(&vq->mmu_lock)); 319 if (map[i]) 320 rcu_assign_pointer(vq->maps[i], NULL); 321 } 322 spin_unlock(&vq->mmu_lock); 323 324 synchronize_rcu(); 325 326 for (i = 0; i < VHOST_NUM_ADDRS; i++) 327 if (map[i]) 328 vhost_map_unprefetch(map[i]); 329 330 } 331 332 static void vhost_reset_vq_maps(struct vhost_virtqueue *vq) 333 { 334 int i; 335 336 vhost_uninit_vq_maps(vq); 337 for (i = 0; i < VHOST_NUM_ADDRS; i++) 338 vq->uaddrs[i].size = 0; 339 } 340 341 static bool vhost_map_range_overlap(struct vhost_uaddr *uaddr, 342 unsigned long start, 343 unsigned long end) 344 { 345 if (unlikely(!uaddr->size)) 346 return false; 347 348 return !(end < uaddr->uaddr || start > uaddr->uaddr - 1 + uaddr->size); 349 } 350 351 static void vhost_invalidate_vq_start(struct vhost_virtqueue *vq, 352 int index, 353 unsigned long start, 354 unsigned long end) 355 { 356 struct vhost_uaddr *uaddr = &vq->uaddrs[index]; 357 struct vhost_map *map; 358 int i; 359 360 if (!vhost_map_range_overlap(uaddr, start, end)) 361 return; 362 363 spin_lock(&vq->mmu_lock); 364 ++vq->invalidate_count; 365 366 map = rcu_dereference_protected(vq->maps[index], 367 lockdep_is_held(&vq->mmu_lock)); 368 if (map) { 369 if (uaddr->write) { 370 for (i = 0; i < map->npages; i++) 371 set_page_dirty(map->pages[i]); 372 } 373 rcu_assign_pointer(vq->maps[index], NULL); 374 } 375 spin_unlock(&vq->mmu_lock); 376 377 if (map) { 378 synchronize_rcu(); 379 vhost_map_unprefetch(map); 380 } 381 } 382 383 static void vhost_invalidate_vq_end(struct vhost_virtqueue *vq, 384 int index, 385 unsigned long start, 386 unsigned long end) 387 { 388 if (!vhost_map_range_overlap(&vq->uaddrs[index], start, end)) 389 return; 390 391 spin_lock(&vq->mmu_lock); 392 --vq->invalidate_count; 393 spin_unlock(&vq->mmu_lock); 394 } 395 396 static int vhost_invalidate_range_start(struct mmu_notifier *mn, 397 const struct mmu_notifier_range *range) 398 { 399 struct vhost_dev *dev = container_of(mn, struct vhost_dev, 400 mmu_notifier); 401 int i, j; 402 403 if (!mmu_notifier_range_blockable(range)) 404 return -EAGAIN; 405 406 for (i = 0; i < dev->nvqs; i++) { 407 struct vhost_virtqueue *vq = dev->vqs[i]; 408 409 for (j = 0; j < VHOST_NUM_ADDRS; j++) 410 vhost_invalidate_vq_start(vq, j, 411 range->start, 412 range->end); 413 } 414 415 return 0; 416 } 417 418 static void vhost_invalidate_range_end(struct mmu_notifier *mn, 419 const struct mmu_notifier_range *range) 420 { 421 struct vhost_dev *dev = container_of(mn, struct vhost_dev, 422 mmu_notifier); 423 int i, j; 424 425 for (i = 0; i < dev->nvqs; i++) { 426 struct vhost_virtqueue *vq = dev->vqs[i]; 427 428 for (j = 0; j < VHOST_NUM_ADDRS; j++) 429 vhost_invalidate_vq_end(vq, j, 430 range->start, 431 range->end); 432 } 433 } 434 435 static const struct mmu_notifier_ops vhost_mmu_notifier_ops = { 436 .invalidate_range_start = vhost_invalidate_range_start, 437 .invalidate_range_end = vhost_invalidate_range_end, 438 }; 439 440 static void vhost_init_maps(struct vhost_dev *dev) 441 { 442 struct vhost_virtqueue *vq; 443 int i, j; 444 445 dev->mmu_notifier.ops = &vhost_mmu_notifier_ops; 446 447 for (i = 0; i < dev->nvqs; ++i) { 448 vq = dev->vqs[i]; 449 for (j = 0; j < VHOST_NUM_ADDRS; j++) 450 RCU_INIT_POINTER(vq->maps[j], NULL); 451 } 452 } 453 #endif 454 455 static void vhost_vq_reset(struct vhost_dev *dev, 456 struct vhost_virtqueue *vq) 457 { 458 vq->num = 1; 459 vq->desc = NULL; 460 vq->avail = NULL; 461 vq->used = NULL; 462 vq->last_avail_idx = 0; 463 vq->avail_idx = 0; 464 vq->last_used_idx = 0; 465 vq->signalled_used = 0; 466 vq->signalled_used_valid = false; 467 vq->used_flags = 0; 468 vq->log_used = false; 469 vq->log_addr = -1ull; 470 vq->private_data = NULL; 471 vq->acked_features = 0; 472 vq->acked_backend_features = 0; 473 vq->log_base = NULL; 474 vq->error_ctx = NULL; 475 vq->kick = NULL; 476 vq->call_ctx = NULL; 477 vq->log_ctx = NULL; 478 vhost_reset_is_le(vq); 479 vhost_disable_cross_endian(vq); 480 vq->busyloop_timeout = 0; 481 vq->umem = NULL; 482 vq->iotlb = NULL; 483 vq->invalidate_count = 0; 484 __vhost_vq_meta_reset(vq); 485 #if VHOST_ARCH_CAN_ACCEL_UACCESS 486 vhost_reset_vq_maps(vq); 487 #endif 488 } 489 490 static int vhost_worker(void *data) 491 { 492 struct vhost_dev *dev = data; 493 struct vhost_work *work, *work_next; 494 struct llist_node *node; 495 mm_segment_t oldfs = get_fs(); 496 497 set_fs(USER_DS); 498 use_mm(dev->mm); 499 500 for (;;) { 501 /* mb paired w/ kthread_stop */ 502 set_current_state(TASK_INTERRUPTIBLE); 503 504 if (kthread_should_stop()) { 505 __set_current_state(TASK_RUNNING); 506 break; 507 } 508 509 node = llist_del_all(&dev->work_list); 510 if (!node) 511 schedule(); 512 513 node = llist_reverse_order(node); 514 /* make sure flag is seen after deletion */ 515 smp_wmb(); 516 llist_for_each_entry_safe(work, work_next, node, node) { 517 clear_bit(VHOST_WORK_QUEUED, &work->flags); 518 __set_current_state(TASK_RUNNING); 519 work->fn(work); 520 if (need_resched()) 521 schedule(); 522 } 523 } 524 unuse_mm(dev->mm); 525 set_fs(oldfs); 526 return 0; 527 } 528 529 static void vhost_vq_free_iovecs(struct vhost_virtqueue *vq) 530 { 531 kfree(vq->indirect); 532 vq->indirect = NULL; 533 kfree(vq->log); 534 vq->log = NULL; 535 kfree(vq->heads); 536 vq->heads = NULL; 537 } 538 539 /* Helper to allocate iovec buffers for all vqs. */ 540 static long vhost_dev_alloc_iovecs(struct vhost_dev *dev) 541 { 542 struct vhost_virtqueue *vq; 543 int i; 544 545 for (i = 0; i < dev->nvqs; ++i) { 546 vq = dev->vqs[i]; 547 vq->indirect = kmalloc_array(UIO_MAXIOV, 548 sizeof(*vq->indirect), 549 GFP_KERNEL); 550 vq->log = kmalloc_array(dev->iov_limit, sizeof(*vq->log), 551 GFP_KERNEL); 552 vq->heads = kmalloc_array(dev->iov_limit, sizeof(*vq->heads), 553 GFP_KERNEL); 554 if (!vq->indirect || !vq->log || !vq->heads) 555 goto err_nomem; 556 } 557 return 0; 558 559 err_nomem: 560 for (; i >= 0; --i) 561 vhost_vq_free_iovecs(dev->vqs[i]); 562 return -ENOMEM; 563 } 564 565 static void vhost_dev_free_iovecs(struct vhost_dev *dev) 566 { 567 int i; 568 569 for (i = 0; i < dev->nvqs; ++i) 570 vhost_vq_free_iovecs(dev->vqs[i]); 571 } 572 573 bool vhost_exceeds_weight(struct vhost_virtqueue *vq, 574 int pkts, int total_len) 575 { 576 struct vhost_dev *dev = vq->dev; 577 578 if ((dev->byte_weight && total_len >= dev->byte_weight) || 579 pkts >= dev->weight) { 580 vhost_poll_queue(&vq->poll); 581 return true; 582 } 583 584 return false; 585 } 586 EXPORT_SYMBOL_GPL(vhost_exceeds_weight); 587 588 static size_t vhost_get_avail_size(struct vhost_virtqueue *vq, 589 unsigned int num) 590 { 591 size_t event __maybe_unused = 592 vhost_has_feature(vq, VIRTIO_RING_F_EVENT_IDX) ? 2 : 0; 593 594 return sizeof(*vq->avail) + 595 sizeof(*vq->avail->ring) * num + event; 596 } 597 598 static size_t vhost_get_used_size(struct vhost_virtqueue *vq, 599 unsigned int num) 600 { 601 size_t event __maybe_unused = 602 vhost_has_feature(vq, VIRTIO_RING_F_EVENT_IDX) ? 2 : 0; 603 604 return sizeof(*vq->used) + 605 sizeof(*vq->used->ring) * num + event; 606 } 607 608 static size_t vhost_get_desc_size(struct vhost_virtqueue *vq, 609 unsigned int num) 610 { 611 return sizeof(*vq->desc) * num; 612 } 613 614 void vhost_dev_init(struct vhost_dev *dev, 615 struct vhost_virtqueue **vqs, int nvqs, 616 int iov_limit, int weight, int byte_weight) 617 { 618 struct vhost_virtqueue *vq; 619 int i; 620 621 dev->vqs = vqs; 622 dev->nvqs = nvqs; 623 mutex_init(&dev->mutex); 624 dev->log_ctx = NULL; 625 dev->umem = NULL; 626 dev->iotlb = NULL; 627 dev->mm = NULL; 628 dev->worker = NULL; 629 dev->iov_limit = iov_limit; 630 dev->weight = weight; 631 dev->byte_weight = byte_weight; 632 init_llist_head(&dev->work_list); 633 init_waitqueue_head(&dev->wait); 634 INIT_LIST_HEAD(&dev->read_list); 635 INIT_LIST_HEAD(&dev->pending_list); 636 spin_lock_init(&dev->iotlb_lock); 637 #if VHOST_ARCH_CAN_ACCEL_UACCESS 638 vhost_init_maps(dev); 639 #endif 640 641 for (i = 0; i < dev->nvqs; ++i) { 642 vq = dev->vqs[i]; 643 vq->log = NULL; 644 vq->indirect = NULL; 645 vq->heads = NULL; 646 vq->dev = dev; 647 mutex_init(&vq->mutex); 648 spin_lock_init(&vq->mmu_lock); 649 vhost_vq_reset(dev, vq); 650 if (vq->handle_kick) 651 vhost_poll_init(&vq->poll, vq->handle_kick, 652 EPOLLIN, dev); 653 } 654 } 655 EXPORT_SYMBOL_GPL(vhost_dev_init); 656 657 /* Caller should have device mutex */ 658 long vhost_dev_check_owner(struct vhost_dev *dev) 659 { 660 /* Are you the owner? If not, I don't think you mean to do that */ 661 return dev->mm == current->mm ? 0 : -EPERM; 662 } 663 EXPORT_SYMBOL_GPL(vhost_dev_check_owner); 664 665 struct vhost_attach_cgroups_struct { 666 struct vhost_work work; 667 struct task_struct *owner; 668 int ret; 669 }; 670 671 static void vhost_attach_cgroups_work(struct vhost_work *work) 672 { 673 struct vhost_attach_cgroups_struct *s; 674 675 s = container_of(work, struct vhost_attach_cgroups_struct, work); 676 s->ret = cgroup_attach_task_all(s->owner, current); 677 } 678 679 static int vhost_attach_cgroups(struct vhost_dev *dev) 680 { 681 struct vhost_attach_cgroups_struct attach; 682 683 attach.owner = current; 684 vhost_work_init(&attach.work, vhost_attach_cgroups_work); 685 vhost_work_queue(dev, &attach.work); 686 vhost_work_flush(dev, &attach.work); 687 return attach.ret; 688 } 689 690 /* Caller should have device mutex */ 691 bool vhost_dev_has_owner(struct vhost_dev *dev) 692 { 693 return dev->mm; 694 } 695 EXPORT_SYMBOL_GPL(vhost_dev_has_owner); 696 697 /* Caller should have device mutex */ 698 long vhost_dev_set_owner(struct vhost_dev *dev) 699 { 700 struct task_struct *worker; 701 int err; 702 703 /* Is there an owner already? */ 704 if (vhost_dev_has_owner(dev)) { 705 err = -EBUSY; 706 goto err_mm; 707 } 708 709 /* No owner, become one */ 710 dev->mm = get_task_mm(current); 711 worker = kthread_create(vhost_worker, dev, "vhost-%d", current->pid); 712 if (IS_ERR(worker)) { 713 err = PTR_ERR(worker); 714 goto err_worker; 715 } 716 717 dev->worker = worker; 718 wake_up_process(worker); /* avoid contributing to loadavg */ 719 720 err = vhost_attach_cgroups(dev); 721 if (err) 722 goto err_cgroup; 723 724 err = vhost_dev_alloc_iovecs(dev); 725 if (err) 726 goto err_cgroup; 727 728 #if VHOST_ARCH_CAN_ACCEL_UACCESS 729 err = mmu_notifier_register(&dev->mmu_notifier, dev->mm); 730 if (err) 731 goto err_mmu_notifier; 732 #endif 733 734 return 0; 735 736 #if VHOST_ARCH_CAN_ACCEL_UACCESS 737 err_mmu_notifier: 738 vhost_dev_free_iovecs(dev); 739 #endif 740 err_cgroup: 741 kthread_stop(worker); 742 dev->worker = NULL; 743 err_worker: 744 if (dev->mm) 745 mmput(dev->mm); 746 dev->mm = NULL; 747 err_mm: 748 return err; 749 } 750 EXPORT_SYMBOL_GPL(vhost_dev_set_owner); 751 752 struct vhost_umem *vhost_dev_reset_owner_prepare(void) 753 { 754 return kvzalloc(sizeof(struct vhost_umem), GFP_KERNEL); 755 } 756 EXPORT_SYMBOL_GPL(vhost_dev_reset_owner_prepare); 757 758 /* Caller should have device mutex */ 759 void vhost_dev_reset_owner(struct vhost_dev *dev, struct vhost_umem *umem) 760 { 761 int i; 762 763 vhost_dev_cleanup(dev); 764 765 /* Restore memory to default empty mapping. */ 766 INIT_LIST_HEAD(&umem->umem_list); 767 dev->umem = umem; 768 /* We don't need VQ locks below since vhost_dev_cleanup makes sure 769 * VQs aren't running. 770 */ 771 for (i = 0; i < dev->nvqs; ++i) 772 dev->vqs[i]->umem = umem; 773 } 774 EXPORT_SYMBOL_GPL(vhost_dev_reset_owner); 775 776 void vhost_dev_stop(struct vhost_dev *dev) 777 { 778 int i; 779 780 for (i = 0; i < dev->nvqs; ++i) { 781 if (dev->vqs[i]->kick && dev->vqs[i]->handle_kick) { 782 vhost_poll_stop(&dev->vqs[i]->poll); 783 vhost_poll_flush(&dev->vqs[i]->poll); 784 } 785 } 786 } 787 EXPORT_SYMBOL_GPL(vhost_dev_stop); 788 789 static void vhost_umem_free(struct vhost_umem *umem, 790 struct vhost_umem_node *node) 791 { 792 vhost_umem_interval_tree_remove(node, &umem->umem_tree); 793 list_del(&node->link); 794 kfree(node); 795 umem->numem--; 796 } 797 798 static void vhost_umem_clean(struct vhost_umem *umem) 799 { 800 struct vhost_umem_node *node, *tmp; 801 802 if (!umem) 803 return; 804 805 list_for_each_entry_safe(node, tmp, &umem->umem_list, link) 806 vhost_umem_free(umem, node); 807 808 kvfree(umem); 809 } 810 811 static void vhost_clear_msg(struct vhost_dev *dev) 812 { 813 struct vhost_msg_node *node, *n; 814 815 spin_lock(&dev->iotlb_lock); 816 817 list_for_each_entry_safe(node, n, &dev->read_list, node) { 818 list_del(&node->node); 819 kfree(node); 820 } 821 822 list_for_each_entry_safe(node, n, &dev->pending_list, node) { 823 list_del(&node->node); 824 kfree(node); 825 } 826 827 spin_unlock(&dev->iotlb_lock); 828 } 829 830 #if VHOST_ARCH_CAN_ACCEL_UACCESS 831 static void vhost_setup_uaddr(struct vhost_virtqueue *vq, 832 int index, unsigned long uaddr, 833 size_t size, bool write) 834 { 835 struct vhost_uaddr *addr = &vq->uaddrs[index]; 836 837 addr->uaddr = uaddr; 838 addr->size = size; 839 addr->write = write; 840 } 841 842 static void vhost_setup_vq_uaddr(struct vhost_virtqueue *vq) 843 { 844 vhost_setup_uaddr(vq, VHOST_ADDR_DESC, 845 (unsigned long)vq->desc, 846 vhost_get_desc_size(vq, vq->num), 847 false); 848 vhost_setup_uaddr(vq, VHOST_ADDR_AVAIL, 849 (unsigned long)vq->avail, 850 vhost_get_avail_size(vq, vq->num), 851 false); 852 vhost_setup_uaddr(vq, VHOST_ADDR_USED, 853 (unsigned long)vq->used, 854 vhost_get_used_size(vq, vq->num), 855 true); 856 } 857 858 static int vhost_map_prefetch(struct vhost_virtqueue *vq, 859 int index) 860 { 861 struct vhost_map *map; 862 struct vhost_uaddr *uaddr = &vq->uaddrs[index]; 863 struct page **pages; 864 int npages = DIV_ROUND_UP(uaddr->size, PAGE_SIZE); 865 int npinned; 866 void *vaddr, *v; 867 int err; 868 int i; 869 870 spin_lock(&vq->mmu_lock); 871 872 err = -EFAULT; 873 if (vq->invalidate_count) 874 goto err; 875 876 err = -ENOMEM; 877 map = kmalloc(sizeof(*map), GFP_ATOMIC); 878 if (!map) 879 goto err; 880 881 pages = kmalloc_array(npages, sizeof(struct page *), GFP_ATOMIC); 882 if (!pages) 883 goto err_pages; 884 885 err = EFAULT; 886 npinned = __get_user_pages_fast(uaddr->uaddr, npages, 887 uaddr->write, pages); 888 if (npinned > 0) 889 release_pages(pages, npinned); 890 if (npinned != npages) 891 goto err_gup; 892 893 for (i = 0; i < npinned; i++) 894 if (PageHighMem(pages[i])) 895 goto err_gup; 896 897 vaddr = v = page_address(pages[0]); 898 899 /* For simplicity, fallback to userspace address if VA is not 900 * contigious. 901 */ 902 for (i = 1; i < npinned; i++) { 903 v += PAGE_SIZE; 904 if (v != page_address(pages[i])) 905 goto err_gup; 906 } 907 908 map->addr = vaddr + (uaddr->uaddr & (PAGE_SIZE - 1)); 909 map->npages = npages; 910 map->pages = pages; 911 912 rcu_assign_pointer(vq->maps[index], map); 913 /* No need for a synchronize_rcu(). This function should be 914 * called by dev->worker so we are serialized with all 915 * readers. 916 */ 917 spin_unlock(&vq->mmu_lock); 918 919 return 0; 920 921 err_gup: 922 kfree(pages); 923 err_pages: 924 kfree(map); 925 err: 926 spin_unlock(&vq->mmu_lock); 927 return err; 928 } 929 #endif 930 931 void vhost_dev_cleanup(struct vhost_dev *dev) 932 { 933 int i; 934 935 for (i = 0; i < dev->nvqs; ++i) { 936 if (dev->vqs[i]->error_ctx) 937 eventfd_ctx_put(dev->vqs[i]->error_ctx); 938 if (dev->vqs[i]->kick) 939 fput(dev->vqs[i]->kick); 940 if (dev->vqs[i]->call_ctx) 941 eventfd_ctx_put(dev->vqs[i]->call_ctx); 942 vhost_vq_reset(dev, dev->vqs[i]); 943 } 944 vhost_dev_free_iovecs(dev); 945 if (dev->log_ctx) 946 eventfd_ctx_put(dev->log_ctx); 947 dev->log_ctx = NULL; 948 /* No one will access memory at this point */ 949 vhost_umem_clean(dev->umem); 950 dev->umem = NULL; 951 vhost_umem_clean(dev->iotlb); 952 dev->iotlb = NULL; 953 vhost_clear_msg(dev); 954 wake_up_interruptible_poll(&dev->wait, EPOLLIN | EPOLLRDNORM); 955 WARN_ON(!llist_empty(&dev->work_list)); 956 if (dev->worker) { 957 kthread_stop(dev->worker); 958 dev->worker = NULL; 959 } 960 if (dev->mm) { 961 #if VHOST_ARCH_CAN_ACCEL_UACCESS 962 mmu_notifier_unregister(&dev->mmu_notifier, dev->mm); 963 #endif 964 mmput(dev->mm); 965 } 966 #if VHOST_ARCH_CAN_ACCEL_UACCESS 967 for (i = 0; i < dev->nvqs; i++) 968 vhost_uninit_vq_maps(dev->vqs[i]); 969 #endif 970 dev->mm = NULL; 971 } 972 EXPORT_SYMBOL_GPL(vhost_dev_cleanup); 973 974 static bool log_access_ok(void __user *log_base, u64 addr, unsigned long sz) 975 { 976 u64 a = addr / VHOST_PAGE_SIZE / 8; 977 978 /* Make sure 64 bit math will not overflow. */ 979 if (a > ULONG_MAX - (unsigned long)log_base || 980 a + (unsigned long)log_base > ULONG_MAX) 981 return false; 982 983 return access_ok(log_base + a, 984 (sz + VHOST_PAGE_SIZE * 8 - 1) / VHOST_PAGE_SIZE / 8); 985 } 986 987 static bool vhost_overflow(u64 uaddr, u64 size) 988 { 989 /* Make sure 64 bit math will not overflow. */ 990 return uaddr > ULONG_MAX || size > ULONG_MAX || uaddr > ULONG_MAX - size; 991 } 992 993 /* Caller should have vq mutex and device mutex. */ 994 static bool vq_memory_access_ok(void __user *log_base, struct vhost_umem *umem, 995 int log_all) 996 { 997 struct vhost_umem_node *node; 998 999 if (!umem) 1000 return false; 1001 1002 list_for_each_entry(node, &umem->umem_list, link) { 1003 unsigned long a = node->userspace_addr; 1004 1005 if (vhost_overflow(node->userspace_addr, node->size)) 1006 return false; 1007 1008 1009 if (!access_ok((void __user *)a, 1010 node->size)) 1011 return false; 1012 else if (log_all && !log_access_ok(log_base, 1013 node->start, 1014 node->size)) 1015 return false; 1016 } 1017 return true; 1018 } 1019 1020 static inline void __user *vhost_vq_meta_fetch(struct vhost_virtqueue *vq, 1021 u64 addr, unsigned int size, 1022 int type) 1023 { 1024 const struct vhost_umem_node *node = vq->meta_iotlb[type]; 1025 1026 if (!node) 1027 return NULL; 1028 1029 return (void *)(uintptr_t)(node->userspace_addr + addr - node->start); 1030 } 1031 1032 /* Can we switch to this memory table? */ 1033 /* Caller should have device mutex but not vq mutex */ 1034 static bool memory_access_ok(struct vhost_dev *d, struct vhost_umem *umem, 1035 int log_all) 1036 { 1037 int i; 1038 1039 for (i = 0; i < d->nvqs; ++i) { 1040 bool ok; 1041 bool log; 1042 1043 mutex_lock(&d->vqs[i]->mutex); 1044 log = log_all || vhost_has_feature(d->vqs[i], VHOST_F_LOG_ALL); 1045 /* If ring is inactive, will check when it's enabled. */ 1046 if (d->vqs[i]->private_data) 1047 ok = vq_memory_access_ok(d->vqs[i]->log_base, 1048 umem, log); 1049 else 1050 ok = true; 1051 mutex_unlock(&d->vqs[i]->mutex); 1052 if (!ok) 1053 return false; 1054 } 1055 return true; 1056 } 1057 1058 static int translate_desc(struct vhost_virtqueue *vq, u64 addr, u32 len, 1059 struct iovec iov[], int iov_size, int access); 1060 1061 static int vhost_copy_to_user(struct vhost_virtqueue *vq, void __user *to, 1062 const void *from, unsigned size) 1063 { 1064 int ret; 1065 1066 if (!vq->iotlb) 1067 return __copy_to_user(to, from, size); 1068 else { 1069 /* This function should be called after iotlb 1070 * prefetch, which means we're sure that all vq 1071 * could be access through iotlb. So -EAGAIN should 1072 * not happen in this case. 1073 */ 1074 struct iov_iter t; 1075 void __user *uaddr = vhost_vq_meta_fetch(vq, 1076 (u64)(uintptr_t)to, size, 1077 VHOST_ADDR_USED); 1078 1079 if (uaddr) 1080 return __copy_to_user(uaddr, from, size); 1081 1082 ret = translate_desc(vq, (u64)(uintptr_t)to, size, vq->iotlb_iov, 1083 ARRAY_SIZE(vq->iotlb_iov), 1084 VHOST_ACCESS_WO); 1085 if (ret < 0) 1086 goto out; 1087 iov_iter_init(&t, WRITE, vq->iotlb_iov, ret, size); 1088 ret = copy_to_iter(from, size, &t); 1089 if (ret == size) 1090 ret = 0; 1091 } 1092 out: 1093 return ret; 1094 } 1095 1096 static int vhost_copy_from_user(struct vhost_virtqueue *vq, void *to, 1097 void __user *from, unsigned size) 1098 { 1099 int ret; 1100 1101 if (!vq->iotlb) 1102 return __copy_from_user(to, from, size); 1103 else { 1104 /* This function should be called after iotlb 1105 * prefetch, which means we're sure that vq 1106 * could be access through iotlb. So -EAGAIN should 1107 * not happen in this case. 1108 */ 1109 void __user *uaddr = vhost_vq_meta_fetch(vq, 1110 (u64)(uintptr_t)from, size, 1111 VHOST_ADDR_DESC); 1112 struct iov_iter f; 1113 1114 if (uaddr) 1115 return __copy_from_user(to, uaddr, size); 1116 1117 ret = translate_desc(vq, (u64)(uintptr_t)from, size, vq->iotlb_iov, 1118 ARRAY_SIZE(vq->iotlb_iov), 1119 VHOST_ACCESS_RO); 1120 if (ret < 0) { 1121 vq_err(vq, "IOTLB translation failure: uaddr " 1122 "%p size 0x%llx\n", from, 1123 (unsigned long long) size); 1124 goto out; 1125 } 1126 iov_iter_init(&f, READ, vq->iotlb_iov, ret, size); 1127 ret = copy_from_iter(to, size, &f); 1128 if (ret == size) 1129 ret = 0; 1130 } 1131 1132 out: 1133 return ret; 1134 } 1135 1136 static void __user *__vhost_get_user_slow(struct vhost_virtqueue *vq, 1137 void __user *addr, unsigned int size, 1138 int type) 1139 { 1140 int ret; 1141 1142 ret = translate_desc(vq, (u64)(uintptr_t)addr, size, vq->iotlb_iov, 1143 ARRAY_SIZE(vq->iotlb_iov), 1144 VHOST_ACCESS_RO); 1145 if (ret < 0) { 1146 vq_err(vq, "IOTLB translation failure: uaddr " 1147 "%p size 0x%llx\n", addr, 1148 (unsigned long long) size); 1149 return NULL; 1150 } 1151 1152 if (ret != 1 || vq->iotlb_iov[0].iov_len != size) { 1153 vq_err(vq, "Non atomic userspace memory access: uaddr " 1154 "%p size 0x%llx\n", addr, 1155 (unsigned long long) size); 1156 return NULL; 1157 } 1158 1159 return vq->iotlb_iov[0].iov_base; 1160 } 1161 1162 /* This function should be called after iotlb 1163 * prefetch, which means we're sure that vq 1164 * could be access through iotlb. So -EAGAIN should 1165 * not happen in this case. 1166 */ 1167 static inline void __user *__vhost_get_user(struct vhost_virtqueue *vq, 1168 void *addr, unsigned int size, 1169 int type) 1170 { 1171 void __user *uaddr = vhost_vq_meta_fetch(vq, 1172 (u64)(uintptr_t)addr, size, type); 1173 if (uaddr) 1174 return uaddr; 1175 1176 return __vhost_get_user_slow(vq, addr, size, type); 1177 } 1178 1179 #define vhost_put_user(vq, x, ptr) \ 1180 ({ \ 1181 int ret = -EFAULT; \ 1182 if (!vq->iotlb) { \ 1183 ret = __put_user(x, ptr); \ 1184 } else { \ 1185 __typeof__(ptr) to = \ 1186 (__typeof__(ptr)) __vhost_get_user(vq, ptr, \ 1187 sizeof(*ptr), VHOST_ADDR_USED); \ 1188 if (to != NULL) \ 1189 ret = __put_user(x, to); \ 1190 else \ 1191 ret = -EFAULT; \ 1192 } \ 1193 ret; \ 1194 }) 1195 1196 static inline int vhost_put_avail_event(struct vhost_virtqueue *vq) 1197 { 1198 #if VHOST_ARCH_CAN_ACCEL_UACCESS 1199 struct vhost_map *map; 1200 struct vring_used *used; 1201 1202 if (!vq->iotlb) { 1203 rcu_read_lock(); 1204 1205 map = rcu_dereference(vq->maps[VHOST_ADDR_USED]); 1206 if (likely(map)) { 1207 used = map->addr; 1208 *((__virtio16 *)&used->ring[vq->num]) = 1209 cpu_to_vhost16(vq, vq->avail_idx); 1210 rcu_read_unlock(); 1211 return 0; 1212 } 1213 1214 rcu_read_unlock(); 1215 } 1216 #endif 1217 1218 return vhost_put_user(vq, cpu_to_vhost16(vq, vq->avail_idx), 1219 vhost_avail_event(vq)); 1220 } 1221 1222 static inline int vhost_put_used(struct vhost_virtqueue *vq, 1223 struct vring_used_elem *head, int idx, 1224 int count) 1225 { 1226 #if VHOST_ARCH_CAN_ACCEL_UACCESS 1227 struct vhost_map *map; 1228 struct vring_used *used; 1229 size_t size; 1230 1231 if (!vq->iotlb) { 1232 rcu_read_lock(); 1233 1234 map = rcu_dereference(vq->maps[VHOST_ADDR_USED]); 1235 if (likely(map)) { 1236 used = map->addr; 1237 size = count * sizeof(*head); 1238 memcpy(used->ring + idx, head, size); 1239 rcu_read_unlock(); 1240 return 0; 1241 } 1242 1243 rcu_read_unlock(); 1244 } 1245 #endif 1246 1247 return vhost_copy_to_user(vq, vq->used->ring + idx, head, 1248 count * sizeof(*head)); 1249 } 1250 1251 static inline int vhost_put_used_flags(struct vhost_virtqueue *vq) 1252 1253 { 1254 #if VHOST_ARCH_CAN_ACCEL_UACCESS 1255 struct vhost_map *map; 1256 struct vring_used *used; 1257 1258 if (!vq->iotlb) { 1259 rcu_read_lock(); 1260 1261 map = rcu_dereference(vq->maps[VHOST_ADDR_USED]); 1262 if (likely(map)) { 1263 used = map->addr; 1264 used->flags = cpu_to_vhost16(vq, vq->used_flags); 1265 rcu_read_unlock(); 1266 return 0; 1267 } 1268 1269 rcu_read_unlock(); 1270 } 1271 #endif 1272 1273 return vhost_put_user(vq, cpu_to_vhost16(vq, vq->used_flags), 1274 &vq->used->flags); 1275 } 1276 1277 static inline int vhost_put_used_idx(struct vhost_virtqueue *vq) 1278 1279 { 1280 #if VHOST_ARCH_CAN_ACCEL_UACCESS 1281 struct vhost_map *map; 1282 struct vring_used *used; 1283 1284 if (!vq->iotlb) { 1285 rcu_read_lock(); 1286 1287 map = rcu_dereference(vq->maps[VHOST_ADDR_USED]); 1288 if (likely(map)) { 1289 used = map->addr; 1290 used->idx = cpu_to_vhost16(vq, vq->last_used_idx); 1291 rcu_read_unlock(); 1292 return 0; 1293 } 1294 1295 rcu_read_unlock(); 1296 } 1297 #endif 1298 1299 return vhost_put_user(vq, cpu_to_vhost16(vq, vq->last_used_idx), 1300 &vq->used->idx); 1301 } 1302 1303 #define vhost_get_user(vq, x, ptr, type) \ 1304 ({ \ 1305 int ret; \ 1306 if (!vq->iotlb) { \ 1307 ret = __get_user(x, ptr); \ 1308 } else { \ 1309 __typeof__(ptr) from = \ 1310 (__typeof__(ptr)) __vhost_get_user(vq, ptr, \ 1311 sizeof(*ptr), \ 1312 type); \ 1313 if (from != NULL) \ 1314 ret = __get_user(x, from); \ 1315 else \ 1316 ret = -EFAULT; \ 1317 } \ 1318 ret; \ 1319 }) 1320 1321 #define vhost_get_avail(vq, x, ptr) \ 1322 vhost_get_user(vq, x, ptr, VHOST_ADDR_AVAIL) 1323 1324 #define vhost_get_used(vq, x, ptr) \ 1325 vhost_get_user(vq, x, ptr, VHOST_ADDR_USED) 1326 1327 static void vhost_dev_lock_vqs(struct vhost_dev *d) 1328 { 1329 int i = 0; 1330 for (i = 0; i < d->nvqs; ++i) 1331 mutex_lock_nested(&d->vqs[i]->mutex, i); 1332 } 1333 1334 static void vhost_dev_unlock_vqs(struct vhost_dev *d) 1335 { 1336 int i = 0; 1337 for (i = 0; i < d->nvqs; ++i) 1338 mutex_unlock(&d->vqs[i]->mutex); 1339 } 1340 1341 static inline int vhost_get_avail_idx(struct vhost_virtqueue *vq, 1342 __virtio16 *idx) 1343 { 1344 #if VHOST_ARCH_CAN_ACCEL_UACCESS 1345 struct vhost_map *map; 1346 struct vring_avail *avail; 1347 1348 if (!vq->iotlb) { 1349 rcu_read_lock(); 1350 1351 map = rcu_dereference(vq->maps[VHOST_ADDR_AVAIL]); 1352 if (likely(map)) { 1353 avail = map->addr; 1354 *idx = avail->idx; 1355 rcu_read_unlock(); 1356 return 0; 1357 } 1358 1359 rcu_read_unlock(); 1360 } 1361 #endif 1362 1363 return vhost_get_avail(vq, *idx, &vq->avail->idx); 1364 } 1365 1366 static inline int vhost_get_avail_head(struct vhost_virtqueue *vq, 1367 __virtio16 *head, int idx) 1368 { 1369 #if VHOST_ARCH_CAN_ACCEL_UACCESS 1370 struct vhost_map *map; 1371 struct vring_avail *avail; 1372 1373 if (!vq->iotlb) { 1374 rcu_read_lock(); 1375 1376 map = rcu_dereference(vq->maps[VHOST_ADDR_AVAIL]); 1377 if (likely(map)) { 1378 avail = map->addr; 1379 *head = avail->ring[idx & (vq->num - 1)]; 1380 rcu_read_unlock(); 1381 return 0; 1382 } 1383 1384 rcu_read_unlock(); 1385 } 1386 #endif 1387 1388 return vhost_get_avail(vq, *head, 1389 &vq->avail->ring[idx & (vq->num - 1)]); 1390 } 1391 1392 static inline int vhost_get_avail_flags(struct vhost_virtqueue *vq, 1393 __virtio16 *flags) 1394 { 1395 #if VHOST_ARCH_CAN_ACCEL_UACCESS 1396 struct vhost_map *map; 1397 struct vring_avail *avail; 1398 1399 if (!vq->iotlb) { 1400 rcu_read_lock(); 1401 1402 map = rcu_dereference(vq->maps[VHOST_ADDR_AVAIL]); 1403 if (likely(map)) { 1404 avail = map->addr; 1405 *flags = avail->flags; 1406 rcu_read_unlock(); 1407 return 0; 1408 } 1409 1410 rcu_read_unlock(); 1411 } 1412 #endif 1413 1414 return vhost_get_avail(vq, *flags, &vq->avail->flags); 1415 } 1416 1417 static inline int vhost_get_used_event(struct vhost_virtqueue *vq, 1418 __virtio16 *event) 1419 { 1420 #if VHOST_ARCH_CAN_ACCEL_UACCESS 1421 struct vhost_map *map; 1422 struct vring_avail *avail; 1423 1424 if (!vq->iotlb) { 1425 rcu_read_lock(); 1426 map = rcu_dereference(vq->maps[VHOST_ADDR_AVAIL]); 1427 if (likely(map)) { 1428 avail = map->addr; 1429 *event = (__virtio16)avail->ring[vq->num]; 1430 rcu_read_unlock(); 1431 return 0; 1432 } 1433 rcu_read_unlock(); 1434 } 1435 #endif 1436 1437 return vhost_get_avail(vq, *event, vhost_used_event(vq)); 1438 } 1439 1440 static inline int vhost_get_used_idx(struct vhost_virtqueue *vq, 1441 __virtio16 *idx) 1442 { 1443 #if VHOST_ARCH_CAN_ACCEL_UACCESS 1444 struct vhost_map *map; 1445 struct vring_used *used; 1446 1447 if (!vq->iotlb) { 1448 rcu_read_lock(); 1449 1450 map = rcu_dereference(vq->maps[VHOST_ADDR_USED]); 1451 if (likely(map)) { 1452 used = map->addr; 1453 *idx = used->idx; 1454 rcu_read_unlock(); 1455 return 0; 1456 } 1457 1458 rcu_read_unlock(); 1459 } 1460 #endif 1461 1462 return vhost_get_used(vq, *idx, &vq->used->idx); 1463 } 1464 1465 static inline int vhost_get_desc(struct vhost_virtqueue *vq, 1466 struct vring_desc *desc, int idx) 1467 { 1468 #if VHOST_ARCH_CAN_ACCEL_UACCESS 1469 struct vhost_map *map; 1470 struct vring_desc *d; 1471 1472 if (!vq->iotlb) { 1473 rcu_read_lock(); 1474 1475 map = rcu_dereference(vq->maps[VHOST_ADDR_DESC]); 1476 if (likely(map)) { 1477 d = map->addr; 1478 *desc = *(d + idx); 1479 rcu_read_unlock(); 1480 return 0; 1481 } 1482 1483 rcu_read_unlock(); 1484 } 1485 #endif 1486 1487 return vhost_copy_from_user(vq, desc, vq->desc + idx, sizeof(*desc)); 1488 } 1489 1490 static int vhost_new_umem_range(struct vhost_umem *umem, 1491 u64 start, u64 size, u64 end, 1492 u64 userspace_addr, int perm) 1493 { 1494 struct vhost_umem_node *tmp, *node; 1495 1496 if (!size) 1497 return -EFAULT; 1498 1499 node = kmalloc(sizeof(*node), GFP_ATOMIC); 1500 if (!node) 1501 return -ENOMEM; 1502 1503 if (umem->numem == max_iotlb_entries) { 1504 tmp = list_first_entry(&umem->umem_list, typeof(*tmp), link); 1505 vhost_umem_free(umem, tmp); 1506 } 1507 1508 node->start = start; 1509 node->size = size; 1510 node->last = end; 1511 node->userspace_addr = userspace_addr; 1512 node->perm = perm; 1513 INIT_LIST_HEAD(&node->link); 1514 list_add_tail(&node->link, &umem->umem_list); 1515 vhost_umem_interval_tree_insert(node, &umem->umem_tree); 1516 umem->numem++; 1517 1518 return 0; 1519 } 1520 1521 static void vhost_del_umem_range(struct vhost_umem *umem, 1522 u64 start, u64 end) 1523 { 1524 struct vhost_umem_node *node; 1525 1526 while ((node = vhost_umem_interval_tree_iter_first(&umem->umem_tree, 1527 start, end))) 1528 vhost_umem_free(umem, node); 1529 } 1530 1531 static void vhost_iotlb_notify_vq(struct vhost_dev *d, 1532 struct vhost_iotlb_msg *msg) 1533 { 1534 struct vhost_msg_node *node, *n; 1535 1536 spin_lock(&d->iotlb_lock); 1537 1538 list_for_each_entry_safe(node, n, &d->pending_list, node) { 1539 struct vhost_iotlb_msg *vq_msg = &node->msg.iotlb; 1540 if (msg->iova <= vq_msg->iova && 1541 msg->iova + msg->size - 1 >= vq_msg->iova && 1542 vq_msg->type == VHOST_IOTLB_MISS) { 1543 vhost_poll_queue(&node->vq->poll); 1544 list_del(&node->node); 1545 kfree(node); 1546 } 1547 } 1548 1549 spin_unlock(&d->iotlb_lock); 1550 } 1551 1552 static bool umem_access_ok(u64 uaddr, u64 size, int access) 1553 { 1554 unsigned long a = uaddr; 1555 1556 /* Make sure 64 bit math will not overflow. */ 1557 if (vhost_overflow(uaddr, size)) 1558 return false; 1559 1560 if ((access & VHOST_ACCESS_RO) && 1561 !access_ok((void __user *)a, size)) 1562 return false; 1563 if ((access & VHOST_ACCESS_WO) && 1564 !access_ok((void __user *)a, size)) 1565 return false; 1566 return true; 1567 } 1568 1569 static int vhost_process_iotlb_msg(struct vhost_dev *dev, 1570 struct vhost_iotlb_msg *msg) 1571 { 1572 int ret = 0; 1573 1574 mutex_lock(&dev->mutex); 1575 vhost_dev_lock_vqs(dev); 1576 switch (msg->type) { 1577 case VHOST_IOTLB_UPDATE: 1578 if (!dev->iotlb) { 1579 ret = -EFAULT; 1580 break; 1581 } 1582 if (!umem_access_ok(msg->uaddr, msg->size, msg->perm)) { 1583 ret = -EFAULT; 1584 break; 1585 } 1586 vhost_vq_meta_reset(dev); 1587 if (vhost_new_umem_range(dev->iotlb, msg->iova, msg->size, 1588 msg->iova + msg->size - 1, 1589 msg->uaddr, msg->perm)) { 1590 ret = -ENOMEM; 1591 break; 1592 } 1593 vhost_iotlb_notify_vq(dev, msg); 1594 break; 1595 case VHOST_IOTLB_INVALIDATE: 1596 if (!dev->iotlb) { 1597 ret = -EFAULT; 1598 break; 1599 } 1600 vhost_vq_meta_reset(dev); 1601 vhost_del_umem_range(dev->iotlb, msg->iova, 1602 msg->iova + msg->size - 1); 1603 break; 1604 default: 1605 ret = -EINVAL; 1606 break; 1607 } 1608 1609 vhost_dev_unlock_vqs(dev); 1610 mutex_unlock(&dev->mutex); 1611 1612 return ret; 1613 } 1614 ssize_t vhost_chr_write_iter(struct vhost_dev *dev, 1615 struct iov_iter *from) 1616 { 1617 struct vhost_iotlb_msg msg; 1618 size_t offset; 1619 int type, ret; 1620 1621 ret = copy_from_iter(&type, sizeof(type), from); 1622 if (ret != sizeof(type)) { 1623 ret = -EINVAL; 1624 goto done; 1625 } 1626 1627 switch (type) { 1628 case VHOST_IOTLB_MSG: 1629 /* There maybe a hole after type for V1 message type, 1630 * so skip it here. 1631 */ 1632 offset = offsetof(struct vhost_msg, iotlb) - sizeof(int); 1633 break; 1634 case VHOST_IOTLB_MSG_V2: 1635 offset = sizeof(__u32); 1636 break; 1637 default: 1638 ret = -EINVAL; 1639 goto done; 1640 } 1641 1642 iov_iter_advance(from, offset); 1643 ret = copy_from_iter(&msg, sizeof(msg), from); 1644 if (ret != sizeof(msg)) { 1645 ret = -EINVAL; 1646 goto done; 1647 } 1648 if (vhost_process_iotlb_msg(dev, &msg)) { 1649 ret = -EFAULT; 1650 goto done; 1651 } 1652 1653 ret = (type == VHOST_IOTLB_MSG) ? sizeof(struct vhost_msg) : 1654 sizeof(struct vhost_msg_v2); 1655 done: 1656 return ret; 1657 } 1658 EXPORT_SYMBOL(vhost_chr_write_iter); 1659 1660 __poll_t vhost_chr_poll(struct file *file, struct vhost_dev *dev, 1661 poll_table *wait) 1662 { 1663 __poll_t mask = 0; 1664 1665 poll_wait(file, &dev->wait, wait); 1666 1667 if (!list_empty(&dev->read_list)) 1668 mask |= EPOLLIN | EPOLLRDNORM; 1669 1670 return mask; 1671 } 1672 EXPORT_SYMBOL(vhost_chr_poll); 1673 1674 ssize_t vhost_chr_read_iter(struct vhost_dev *dev, struct iov_iter *to, 1675 int noblock) 1676 { 1677 DEFINE_WAIT(wait); 1678 struct vhost_msg_node *node; 1679 ssize_t ret = 0; 1680 unsigned size = sizeof(struct vhost_msg); 1681 1682 if (iov_iter_count(to) < size) 1683 return 0; 1684 1685 while (1) { 1686 if (!noblock) 1687 prepare_to_wait(&dev->wait, &wait, 1688 TASK_INTERRUPTIBLE); 1689 1690 node = vhost_dequeue_msg(dev, &dev->read_list); 1691 if (node) 1692 break; 1693 if (noblock) { 1694 ret = -EAGAIN; 1695 break; 1696 } 1697 if (signal_pending(current)) { 1698 ret = -ERESTARTSYS; 1699 break; 1700 } 1701 if (!dev->iotlb) { 1702 ret = -EBADFD; 1703 break; 1704 } 1705 1706 schedule(); 1707 } 1708 1709 if (!noblock) 1710 finish_wait(&dev->wait, &wait); 1711 1712 if (node) { 1713 struct vhost_iotlb_msg *msg; 1714 void *start = &node->msg; 1715 1716 switch (node->msg.type) { 1717 case VHOST_IOTLB_MSG: 1718 size = sizeof(node->msg); 1719 msg = &node->msg.iotlb; 1720 break; 1721 case VHOST_IOTLB_MSG_V2: 1722 size = sizeof(node->msg_v2); 1723 msg = &node->msg_v2.iotlb; 1724 break; 1725 default: 1726 BUG(); 1727 break; 1728 } 1729 1730 ret = copy_to_iter(start, size, to); 1731 if (ret != size || msg->type != VHOST_IOTLB_MISS) { 1732 kfree(node); 1733 return ret; 1734 } 1735 vhost_enqueue_msg(dev, &dev->pending_list, node); 1736 } 1737 1738 return ret; 1739 } 1740 EXPORT_SYMBOL_GPL(vhost_chr_read_iter); 1741 1742 static int vhost_iotlb_miss(struct vhost_virtqueue *vq, u64 iova, int access) 1743 { 1744 struct vhost_dev *dev = vq->dev; 1745 struct vhost_msg_node *node; 1746 struct vhost_iotlb_msg *msg; 1747 bool v2 = vhost_backend_has_feature(vq, VHOST_BACKEND_F_IOTLB_MSG_V2); 1748 1749 node = vhost_new_msg(vq, v2 ? VHOST_IOTLB_MSG_V2 : VHOST_IOTLB_MSG); 1750 if (!node) 1751 return -ENOMEM; 1752 1753 if (v2) { 1754 node->msg_v2.type = VHOST_IOTLB_MSG_V2; 1755 msg = &node->msg_v2.iotlb; 1756 } else { 1757 msg = &node->msg.iotlb; 1758 } 1759 1760 msg->type = VHOST_IOTLB_MISS; 1761 msg->iova = iova; 1762 msg->perm = access; 1763 1764 vhost_enqueue_msg(dev, &dev->read_list, node); 1765 1766 return 0; 1767 } 1768 1769 static bool vq_access_ok(struct vhost_virtqueue *vq, unsigned int num, 1770 struct vring_desc __user *desc, 1771 struct vring_avail __user *avail, 1772 struct vring_used __user *used) 1773 1774 { 1775 return access_ok(desc, vhost_get_desc_size(vq, num)) && 1776 access_ok(avail, vhost_get_avail_size(vq, num)) && 1777 access_ok(used, vhost_get_used_size(vq, num)); 1778 } 1779 1780 static void vhost_vq_meta_update(struct vhost_virtqueue *vq, 1781 const struct vhost_umem_node *node, 1782 int type) 1783 { 1784 int access = (type == VHOST_ADDR_USED) ? 1785 VHOST_ACCESS_WO : VHOST_ACCESS_RO; 1786 1787 if (likely(node->perm & access)) 1788 vq->meta_iotlb[type] = node; 1789 } 1790 1791 static bool iotlb_access_ok(struct vhost_virtqueue *vq, 1792 int access, u64 addr, u64 len, int type) 1793 { 1794 const struct vhost_umem_node *node; 1795 struct vhost_umem *umem = vq->iotlb; 1796 u64 s = 0, size, orig_addr = addr, last = addr + len - 1; 1797 1798 if (vhost_vq_meta_fetch(vq, addr, len, type)) 1799 return true; 1800 1801 while (len > s) { 1802 node = vhost_umem_interval_tree_iter_first(&umem->umem_tree, 1803 addr, 1804 last); 1805 if (node == NULL || node->start > addr) { 1806 vhost_iotlb_miss(vq, addr, access); 1807 return false; 1808 } else if (!(node->perm & access)) { 1809 /* Report the possible access violation by 1810 * request another translation from userspace. 1811 */ 1812 return false; 1813 } 1814 1815 size = node->size - addr + node->start; 1816 1817 if (orig_addr == addr && size >= len) 1818 vhost_vq_meta_update(vq, node, type); 1819 1820 s += size; 1821 addr += size; 1822 } 1823 1824 return true; 1825 } 1826 1827 #if VHOST_ARCH_CAN_ACCEL_UACCESS 1828 static void vhost_vq_map_prefetch(struct vhost_virtqueue *vq) 1829 { 1830 struct vhost_map __rcu *map; 1831 int i; 1832 1833 for (i = 0; i < VHOST_NUM_ADDRS; i++) { 1834 rcu_read_lock(); 1835 map = rcu_dereference(vq->maps[i]); 1836 rcu_read_unlock(); 1837 if (unlikely(!map)) 1838 vhost_map_prefetch(vq, i); 1839 } 1840 } 1841 #endif 1842 1843 int vq_meta_prefetch(struct vhost_virtqueue *vq) 1844 { 1845 unsigned int num = vq->num; 1846 1847 if (!vq->iotlb) { 1848 #if VHOST_ARCH_CAN_ACCEL_UACCESS 1849 vhost_vq_map_prefetch(vq); 1850 #endif 1851 return 1; 1852 } 1853 1854 return iotlb_access_ok(vq, VHOST_ACCESS_RO, (u64)(uintptr_t)vq->desc, 1855 vhost_get_desc_size(vq, num), VHOST_ADDR_DESC) && 1856 iotlb_access_ok(vq, VHOST_ACCESS_RO, (u64)(uintptr_t)vq->avail, 1857 vhost_get_avail_size(vq, num), 1858 VHOST_ADDR_AVAIL) && 1859 iotlb_access_ok(vq, VHOST_ACCESS_WO, (u64)(uintptr_t)vq->used, 1860 vhost_get_used_size(vq, num), VHOST_ADDR_USED); 1861 } 1862 EXPORT_SYMBOL_GPL(vq_meta_prefetch); 1863 1864 /* Can we log writes? */ 1865 /* Caller should have device mutex but not vq mutex */ 1866 bool vhost_log_access_ok(struct vhost_dev *dev) 1867 { 1868 return memory_access_ok(dev, dev->umem, 1); 1869 } 1870 EXPORT_SYMBOL_GPL(vhost_log_access_ok); 1871 1872 /* Verify access for write logging. */ 1873 /* Caller should have vq mutex and device mutex */ 1874 static bool vq_log_access_ok(struct vhost_virtqueue *vq, 1875 void __user *log_base) 1876 { 1877 return vq_memory_access_ok(log_base, vq->umem, 1878 vhost_has_feature(vq, VHOST_F_LOG_ALL)) && 1879 (!vq->log_used || log_access_ok(log_base, vq->log_addr, 1880 vhost_get_used_size(vq, vq->num))); 1881 } 1882 1883 /* Can we start vq? */ 1884 /* Caller should have vq mutex and device mutex */ 1885 bool vhost_vq_access_ok(struct vhost_virtqueue *vq) 1886 { 1887 if (!vq_log_access_ok(vq, vq->log_base)) 1888 return false; 1889 1890 /* Access validation occurs at prefetch time with IOTLB */ 1891 if (vq->iotlb) 1892 return true; 1893 1894 return vq_access_ok(vq, vq->num, vq->desc, vq->avail, vq->used); 1895 } 1896 EXPORT_SYMBOL_GPL(vhost_vq_access_ok); 1897 1898 static struct vhost_umem *vhost_umem_alloc(void) 1899 { 1900 struct vhost_umem *umem = kvzalloc(sizeof(*umem), GFP_KERNEL); 1901 1902 if (!umem) 1903 return NULL; 1904 1905 umem->umem_tree = RB_ROOT_CACHED; 1906 umem->numem = 0; 1907 INIT_LIST_HEAD(&umem->umem_list); 1908 1909 return umem; 1910 } 1911 1912 static long vhost_set_memory(struct vhost_dev *d, struct vhost_memory __user *m) 1913 { 1914 struct vhost_memory mem, *newmem; 1915 struct vhost_memory_region *region; 1916 struct vhost_umem *newumem, *oldumem; 1917 unsigned long size = offsetof(struct vhost_memory, regions); 1918 int i; 1919 1920 if (copy_from_user(&mem, m, size)) 1921 return -EFAULT; 1922 if (mem.padding) 1923 return -EOPNOTSUPP; 1924 if (mem.nregions > max_mem_regions) 1925 return -E2BIG; 1926 newmem = kvzalloc(struct_size(newmem, regions, mem.nregions), 1927 GFP_KERNEL); 1928 if (!newmem) 1929 return -ENOMEM; 1930 1931 memcpy(newmem, &mem, size); 1932 if (copy_from_user(newmem->regions, m->regions, 1933 mem.nregions * sizeof *m->regions)) { 1934 kvfree(newmem); 1935 return -EFAULT; 1936 } 1937 1938 newumem = vhost_umem_alloc(); 1939 if (!newumem) { 1940 kvfree(newmem); 1941 return -ENOMEM; 1942 } 1943 1944 for (region = newmem->regions; 1945 region < newmem->regions + mem.nregions; 1946 region++) { 1947 if (vhost_new_umem_range(newumem, 1948 region->guest_phys_addr, 1949 region->memory_size, 1950 region->guest_phys_addr + 1951 region->memory_size - 1, 1952 region->userspace_addr, 1953 VHOST_ACCESS_RW)) 1954 goto err; 1955 } 1956 1957 if (!memory_access_ok(d, newumem, 0)) 1958 goto err; 1959 1960 oldumem = d->umem; 1961 d->umem = newumem; 1962 1963 /* All memory accesses are done under some VQ mutex. */ 1964 for (i = 0; i < d->nvqs; ++i) { 1965 mutex_lock(&d->vqs[i]->mutex); 1966 d->vqs[i]->umem = newumem; 1967 mutex_unlock(&d->vqs[i]->mutex); 1968 } 1969 1970 kvfree(newmem); 1971 vhost_umem_clean(oldumem); 1972 return 0; 1973 1974 err: 1975 vhost_umem_clean(newumem); 1976 kvfree(newmem); 1977 return -EFAULT; 1978 } 1979 1980 static long vhost_vring_set_num(struct vhost_dev *d, 1981 struct vhost_virtqueue *vq, 1982 void __user *argp) 1983 { 1984 struct vhost_vring_state s; 1985 1986 /* Resizing ring with an active backend? 1987 * You don't want to do that. */ 1988 if (vq->private_data) 1989 return -EBUSY; 1990 1991 if (copy_from_user(&s, argp, sizeof s)) 1992 return -EFAULT; 1993 1994 if (!s.num || s.num > 0xffff || (s.num & (s.num - 1))) 1995 return -EINVAL; 1996 vq->num = s.num; 1997 1998 return 0; 1999 } 2000 2001 static long vhost_vring_set_addr(struct vhost_dev *d, 2002 struct vhost_virtqueue *vq, 2003 void __user *argp) 2004 { 2005 struct vhost_vring_addr a; 2006 2007 if (copy_from_user(&a, argp, sizeof a)) 2008 return -EFAULT; 2009 if (a.flags & ~(0x1 << VHOST_VRING_F_LOG)) 2010 return -EOPNOTSUPP; 2011 2012 /* For 32bit, verify that the top 32bits of the user 2013 data are set to zero. */ 2014 if ((u64)(unsigned long)a.desc_user_addr != a.desc_user_addr || 2015 (u64)(unsigned long)a.used_user_addr != a.used_user_addr || 2016 (u64)(unsigned long)a.avail_user_addr != a.avail_user_addr) 2017 return -EFAULT; 2018 2019 /* Make sure it's safe to cast pointers to vring types. */ 2020 BUILD_BUG_ON(__alignof__ *vq->avail > VRING_AVAIL_ALIGN_SIZE); 2021 BUILD_BUG_ON(__alignof__ *vq->used > VRING_USED_ALIGN_SIZE); 2022 if ((a.avail_user_addr & (VRING_AVAIL_ALIGN_SIZE - 1)) || 2023 (a.used_user_addr & (VRING_USED_ALIGN_SIZE - 1)) || 2024 (a.log_guest_addr & (VRING_USED_ALIGN_SIZE - 1))) 2025 return -EINVAL; 2026 2027 /* We only verify access here if backend is configured. 2028 * If it is not, we don't as size might not have been setup. 2029 * We will verify when backend is configured. */ 2030 if (vq->private_data) { 2031 if (!vq_access_ok(vq, vq->num, 2032 (void __user *)(unsigned long)a.desc_user_addr, 2033 (void __user *)(unsigned long)a.avail_user_addr, 2034 (void __user *)(unsigned long)a.used_user_addr)) 2035 return -EINVAL; 2036 2037 /* Also validate log access for used ring if enabled. */ 2038 if ((a.flags & (0x1 << VHOST_VRING_F_LOG)) && 2039 !log_access_ok(vq->log_base, a.log_guest_addr, 2040 sizeof *vq->used + 2041 vq->num * sizeof *vq->used->ring)) 2042 return -EINVAL; 2043 } 2044 2045 vq->log_used = !!(a.flags & (0x1 << VHOST_VRING_F_LOG)); 2046 vq->desc = (void __user *)(unsigned long)a.desc_user_addr; 2047 vq->avail = (void __user *)(unsigned long)a.avail_user_addr; 2048 vq->log_addr = a.log_guest_addr; 2049 vq->used = (void __user *)(unsigned long)a.used_user_addr; 2050 2051 return 0; 2052 } 2053 2054 static long vhost_vring_set_num_addr(struct vhost_dev *d, 2055 struct vhost_virtqueue *vq, 2056 unsigned int ioctl, 2057 void __user *argp) 2058 { 2059 long r; 2060 2061 mutex_lock(&vq->mutex); 2062 2063 #if VHOST_ARCH_CAN_ACCEL_UACCESS 2064 /* Unregister MMU notifer to allow invalidation callback 2065 * can access vq->uaddrs[] without holding a lock. 2066 */ 2067 if (d->mm) 2068 mmu_notifier_unregister(&d->mmu_notifier, d->mm); 2069 2070 vhost_uninit_vq_maps(vq); 2071 #endif 2072 2073 switch (ioctl) { 2074 case VHOST_SET_VRING_NUM: 2075 r = vhost_vring_set_num(d, vq, argp); 2076 break; 2077 case VHOST_SET_VRING_ADDR: 2078 r = vhost_vring_set_addr(d, vq, argp); 2079 break; 2080 default: 2081 BUG(); 2082 } 2083 2084 #if VHOST_ARCH_CAN_ACCEL_UACCESS 2085 vhost_setup_vq_uaddr(vq); 2086 2087 if (d->mm) 2088 mmu_notifier_register(&d->mmu_notifier, d->mm); 2089 #endif 2090 2091 mutex_unlock(&vq->mutex); 2092 2093 return r; 2094 } 2095 long vhost_vring_ioctl(struct vhost_dev *d, unsigned int ioctl, void __user *argp) 2096 { 2097 struct file *eventfp, *filep = NULL; 2098 bool pollstart = false, pollstop = false; 2099 struct eventfd_ctx *ctx = NULL; 2100 u32 __user *idxp = argp; 2101 struct vhost_virtqueue *vq; 2102 struct vhost_vring_state s; 2103 struct vhost_vring_file f; 2104 u32 idx; 2105 long r; 2106 2107 r = get_user(idx, idxp); 2108 if (r < 0) 2109 return r; 2110 if (idx >= d->nvqs) 2111 return -ENOBUFS; 2112 2113 idx = array_index_nospec(idx, d->nvqs); 2114 vq = d->vqs[idx]; 2115 2116 if (ioctl == VHOST_SET_VRING_NUM || 2117 ioctl == VHOST_SET_VRING_ADDR) { 2118 return vhost_vring_set_num_addr(d, vq, ioctl, argp); 2119 } 2120 2121 mutex_lock(&vq->mutex); 2122 2123 switch (ioctl) { 2124 case VHOST_SET_VRING_BASE: 2125 /* Moving base with an active backend? 2126 * You don't want to do that. */ 2127 if (vq->private_data) { 2128 r = -EBUSY; 2129 break; 2130 } 2131 if (copy_from_user(&s, argp, sizeof s)) { 2132 r = -EFAULT; 2133 break; 2134 } 2135 if (s.num > 0xffff) { 2136 r = -EINVAL; 2137 break; 2138 } 2139 vq->last_avail_idx = s.num; 2140 /* Forget the cached index value. */ 2141 vq->avail_idx = vq->last_avail_idx; 2142 break; 2143 case VHOST_GET_VRING_BASE: 2144 s.index = idx; 2145 s.num = vq->last_avail_idx; 2146 if (copy_to_user(argp, &s, sizeof s)) 2147 r = -EFAULT; 2148 break; 2149 case VHOST_SET_VRING_KICK: 2150 if (copy_from_user(&f, argp, sizeof f)) { 2151 r = -EFAULT; 2152 break; 2153 } 2154 eventfp = f.fd == -1 ? NULL : eventfd_fget(f.fd); 2155 if (IS_ERR(eventfp)) { 2156 r = PTR_ERR(eventfp); 2157 break; 2158 } 2159 if (eventfp != vq->kick) { 2160 pollstop = (filep = vq->kick) != NULL; 2161 pollstart = (vq->kick = eventfp) != NULL; 2162 } else 2163 filep = eventfp; 2164 break; 2165 case VHOST_SET_VRING_CALL: 2166 if (copy_from_user(&f, argp, sizeof f)) { 2167 r = -EFAULT; 2168 break; 2169 } 2170 ctx = f.fd == -1 ? NULL : eventfd_ctx_fdget(f.fd); 2171 if (IS_ERR(ctx)) { 2172 r = PTR_ERR(ctx); 2173 break; 2174 } 2175 swap(ctx, vq->call_ctx); 2176 break; 2177 case VHOST_SET_VRING_ERR: 2178 if (copy_from_user(&f, argp, sizeof f)) { 2179 r = -EFAULT; 2180 break; 2181 } 2182 ctx = f.fd == -1 ? NULL : eventfd_ctx_fdget(f.fd); 2183 if (IS_ERR(ctx)) { 2184 r = PTR_ERR(ctx); 2185 break; 2186 } 2187 swap(ctx, vq->error_ctx); 2188 break; 2189 case VHOST_SET_VRING_ENDIAN: 2190 r = vhost_set_vring_endian(vq, argp); 2191 break; 2192 case VHOST_GET_VRING_ENDIAN: 2193 r = vhost_get_vring_endian(vq, idx, argp); 2194 break; 2195 case VHOST_SET_VRING_BUSYLOOP_TIMEOUT: 2196 if (copy_from_user(&s, argp, sizeof(s))) { 2197 r = -EFAULT; 2198 break; 2199 } 2200 vq->busyloop_timeout = s.num; 2201 break; 2202 case VHOST_GET_VRING_BUSYLOOP_TIMEOUT: 2203 s.index = idx; 2204 s.num = vq->busyloop_timeout; 2205 if (copy_to_user(argp, &s, sizeof(s))) 2206 r = -EFAULT; 2207 break; 2208 default: 2209 r = -ENOIOCTLCMD; 2210 } 2211 2212 if (pollstop && vq->handle_kick) 2213 vhost_poll_stop(&vq->poll); 2214 2215 if (!IS_ERR_OR_NULL(ctx)) 2216 eventfd_ctx_put(ctx); 2217 if (filep) 2218 fput(filep); 2219 2220 if (pollstart && vq->handle_kick) 2221 r = vhost_poll_start(&vq->poll, vq->kick); 2222 2223 mutex_unlock(&vq->mutex); 2224 2225 if (pollstop && vq->handle_kick) 2226 vhost_poll_flush(&vq->poll); 2227 return r; 2228 } 2229 EXPORT_SYMBOL_GPL(vhost_vring_ioctl); 2230 2231 int vhost_init_device_iotlb(struct vhost_dev *d, bool enabled) 2232 { 2233 struct vhost_umem *niotlb, *oiotlb; 2234 int i; 2235 2236 niotlb = vhost_umem_alloc(); 2237 if (!niotlb) 2238 return -ENOMEM; 2239 2240 oiotlb = d->iotlb; 2241 d->iotlb = niotlb; 2242 2243 for (i = 0; i < d->nvqs; ++i) { 2244 struct vhost_virtqueue *vq = d->vqs[i]; 2245 2246 mutex_lock(&vq->mutex); 2247 vq->iotlb = niotlb; 2248 __vhost_vq_meta_reset(vq); 2249 mutex_unlock(&vq->mutex); 2250 } 2251 2252 vhost_umem_clean(oiotlb); 2253 2254 return 0; 2255 } 2256 EXPORT_SYMBOL_GPL(vhost_init_device_iotlb); 2257 2258 /* Caller must have device mutex */ 2259 long vhost_dev_ioctl(struct vhost_dev *d, unsigned int ioctl, void __user *argp) 2260 { 2261 struct eventfd_ctx *ctx; 2262 u64 p; 2263 long r; 2264 int i, fd; 2265 2266 /* If you are not the owner, you can become one */ 2267 if (ioctl == VHOST_SET_OWNER) { 2268 r = vhost_dev_set_owner(d); 2269 goto done; 2270 } 2271 2272 /* You must be the owner to do anything else */ 2273 r = vhost_dev_check_owner(d); 2274 if (r) 2275 goto done; 2276 2277 switch (ioctl) { 2278 case VHOST_SET_MEM_TABLE: 2279 r = vhost_set_memory(d, argp); 2280 break; 2281 case VHOST_SET_LOG_BASE: 2282 if (copy_from_user(&p, argp, sizeof p)) { 2283 r = -EFAULT; 2284 break; 2285 } 2286 if ((u64)(unsigned long)p != p) { 2287 r = -EFAULT; 2288 break; 2289 } 2290 for (i = 0; i < d->nvqs; ++i) { 2291 struct vhost_virtqueue *vq; 2292 void __user *base = (void __user *)(unsigned long)p; 2293 vq = d->vqs[i]; 2294 mutex_lock(&vq->mutex); 2295 /* If ring is inactive, will check when it's enabled. */ 2296 if (vq->private_data && !vq_log_access_ok(vq, base)) 2297 r = -EFAULT; 2298 else 2299 vq->log_base = base; 2300 mutex_unlock(&vq->mutex); 2301 } 2302 break; 2303 case VHOST_SET_LOG_FD: 2304 r = get_user(fd, (int __user *)argp); 2305 if (r < 0) 2306 break; 2307 ctx = fd == -1 ? NULL : eventfd_ctx_fdget(fd); 2308 if (IS_ERR(ctx)) { 2309 r = PTR_ERR(ctx); 2310 break; 2311 } 2312 swap(ctx, d->log_ctx); 2313 for (i = 0; i < d->nvqs; ++i) { 2314 mutex_lock(&d->vqs[i]->mutex); 2315 d->vqs[i]->log_ctx = d->log_ctx; 2316 mutex_unlock(&d->vqs[i]->mutex); 2317 } 2318 if (ctx) 2319 eventfd_ctx_put(ctx); 2320 break; 2321 default: 2322 r = -ENOIOCTLCMD; 2323 break; 2324 } 2325 done: 2326 return r; 2327 } 2328 EXPORT_SYMBOL_GPL(vhost_dev_ioctl); 2329 2330 /* TODO: This is really inefficient. We need something like get_user() 2331 * (instruction directly accesses the data, with an exception table entry 2332 * returning -EFAULT). See Documentation/x86/exception-tables.rst. 2333 */ 2334 static int set_bit_to_user(int nr, void __user *addr) 2335 { 2336 unsigned long log = (unsigned long)addr; 2337 struct page *page; 2338 void *base; 2339 int bit = nr + (log % PAGE_SIZE) * 8; 2340 int r; 2341 2342 r = get_user_pages_fast(log, 1, FOLL_WRITE, &page); 2343 if (r < 0) 2344 return r; 2345 BUG_ON(r != 1); 2346 base = kmap_atomic(page); 2347 set_bit(bit, base); 2348 kunmap_atomic(base); 2349 set_page_dirty_lock(page); 2350 put_page(page); 2351 return 0; 2352 } 2353 2354 static int log_write(void __user *log_base, 2355 u64 write_address, u64 write_length) 2356 { 2357 u64 write_page = write_address / VHOST_PAGE_SIZE; 2358 int r; 2359 2360 if (!write_length) 2361 return 0; 2362 write_length += write_address % VHOST_PAGE_SIZE; 2363 for (;;) { 2364 u64 base = (u64)(unsigned long)log_base; 2365 u64 log = base + write_page / 8; 2366 int bit = write_page % 8; 2367 if ((u64)(unsigned long)log != log) 2368 return -EFAULT; 2369 r = set_bit_to_user(bit, (void __user *)(unsigned long)log); 2370 if (r < 0) 2371 return r; 2372 if (write_length <= VHOST_PAGE_SIZE) 2373 break; 2374 write_length -= VHOST_PAGE_SIZE; 2375 write_page += 1; 2376 } 2377 return r; 2378 } 2379 2380 static int log_write_hva(struct vhost_virtqueue *vq, u64 hva, u64 len) 2381 { 2382 struct vhost_umem *umem = vq->umem; 2383 struct vhost_umem_node *u; 2384 u64 start, end, l, min; 2385 int r; 2386 bool hit = false; 2387 2388 while (len) { 2389 min = len; 2390 /* More than one GPAs can be mapped into a single HVA. So 2391 * iterate all possible umems here to be safe. 2392 */ 2393 list_for_each_entry(u, &umem->umem_list, link) { 2394 if (u->userspace_addr > hva - 1 + len || 2395 u->userspace_addr - 1 + u->size < hva) 2396 continue; 2397 start = max(u->userspace_addr, hva); 2398 end = min(u->userspace_addr - 1 + u->size, 2399 hva - 1 + len); 2400 l = end - start + 1; 2401 r = log_write(vq->log_base, 2402 u->start + start - u->userspace_addr, 2403 l); 2404 if (r < 0) 2405 return r; 2406 hit = true; 2407 min = min(l, min); 2408 } 2409 2410 if (!hit) 2411 return -EFAULT; 2412 2413 len -= min; 2414 hva += min; 2415 } 2416 2417 return 0; 2418 } 2419 2420 static int log_used(struct vhost_virtqueue *vq, u64 used_offset, u64 len) 2421 { 2422 struct iovec iov[64]; 2423 int i, ret; 2424 2425 if (!vq->iotlb) 2426 return log_write(vq->log_base, vq->log_addr + used_offset, len); 2427 2428 ret = translate_desc(vq, (uintptr_t)vq->used + used_offset, 2429 len, iov, 64, VHOST_ACCESS_WO); 2430 if (ret < 0) 2431 return ret; 2432 2433 for (i = 0; i < ret; i++) { 2434 ret = log_write_hva(vq, (uintptr_t)iov[i].iov_base, 2435 iov[i].iov_len); 2436 if (ret) 2437 return ret; 2438 } 2439 2440 return 0; 2441 } 2442 2443 int vhost_log_write(struct vhost_virtqueue *vq, struct vhost_log *log, 2444 unsigned int log_num, u64 len, struct iovec *iov, int count) 2445 { 2446 int i, r; 2447 2448 /* Make sure data written is seen before log. */ 2449 smp_wmb(); 2450 2451 if (vq->iotlb) { 2452 for (i = 0; i < count; i++) { 2453 r = log_write_hva(vq, (uintptr_t)iov[i].iov_base, 2454 iov[i].iov_len); 2455 if (r < 0) 2456 return r; 2457 } 2458 return 0; 2459 } 2460 2461 for (i = 0; i < log_num; ++i) { 2462 u64 l = min(log[i].len, len); 2463 r = log_write(vq->log_base, log[i].addr, l); 2464 if (r < 0) 2465 return r; 2466 len -= l; 2467 if (!len) { 2468 if (vq->log_ctx) 2469 eventfd_signal(vq->log_ctx, 1); 2470 return 0; 2471 } 2472 } 2473 /* Length written exceeds what we have stored. This is a bug. */ 2474 BUG(); 2475 return 0; 2476 } 2477 EXPORT_SYMBOL_GPL(vhost_log_write); 2478 2479 static int vhost_update_used_flags(struct vhost_virtqueue *vq) 2480 { 2481 void __user *used; 2482 if (vhost_put_used_flags(vq)) 2483 return -EFAULT; 2484 if (unlikely(vq->log_used)) { 2485 /* Make sure the flag is seen before log. */ 2486 smp_wmb(); 2487 /* Log used flag write. */ 2488 used = &vq->used->flags; 2489 log_used(vq, (used - (void __user *)vq->used), 2490 sizeof vq->used->flags); 2491 if (vq->log_ctx) 2492 eventfd_signal(vq->log_ctx, 1); 2493 } 2494 return 0; 2495 } 2496 2497 static int vhost_update_avail_event(struct vhost_virtqueue *vq, u16 avail_event) 2498 { 2499 if (vhost_put_avail_event(vq)) 2500 return -EFAULT; 2501 if (unlikely(vq->log_used)) { 2502 void __user *used; 2503 /* Make sure the event is seen before log. */ 2504 smp_wmb(); 2505 /* Log avail event write */ 2506 used = vhost_avail_event(vq); 2507 log_used(vq, (used - (void __user *)vq->used), 2508 sizeof *vhost_avail_event(vq)); 2509 if (vq->log_ctx) 2510 eventfd_signal(vq->log_ctx, 1); 2511 } 2512 return 0; 2513 } 2514 2515 int vhost_vq_init_access(struct vhost_virtqueue *vq) 2516 { 2517 __virtio16 last_used_idx; 2518 int r; 2519 bool is_le = vq->is_le; 2520 2521 if (!vq->private_data) 2522 return 0; 2523 2524 vhost_init_is_le(vq); 2525 2526 r = vhost_update_used_flags(vq); 2527 if (r) 2528 goto err; 2529 vq->signalled_used_valid = false; 2530 if (!vq->iotlb && 2531 !access_ok(&vq->used->idx, sizeof vq->used->idx)) { 2532 r = -EFAULT; 2533 goto err; 2534 } 2535 r = vhost_get_used_idx(vq, &last_used_idx); 2536 if (r) { 2537 vq_err(vq, "Can't access used idx at %p\n", 2538 &vq->used->idx); 2539 goto err; 2540 } 2541 vq->last_used_idx = vhost16_to_cpu(vq, last_used_idx); 2542 return 0; 2543 2544 err: 2545 vq->is_le = is_le; 2546 return r; 2547 } 2548 EXPORT_SYMBOL_GPL(vhost_vq_init_access); 2549 2550 static int translate_desc(struct vhost_virtqueue *vq, u64 addr, u32 len, 2551 struct iovec iov[], int iov_size, int access) 2552 { 2553 const struct vhost_umem_node *node; 2554 struct vhost_dev *dev = vq->dev; 2555 struct vhost_umem *umem = dev->iotlb ? dev->iotlb : dev->umem; 2556 struct iovec *_iov; 2557 u64 s = 0; 2558 int ret = 0; 2559 2560 while ((u64)len > s) { 2561 u64 size; 2562 if (unlikely(ret >= iov_size)) { 2563 ret = -ENOBUFS; 2564 break; 2565 } 2566 2567 node = vhost_umem_interval_tree_iter_first(&umem->umem_tree, 2568 addr, addr + len - 1); 2569 if (node == NULL || node->start > addr) { 2570 if (umem != dev->iotlb) { 2571 ret = -EFAULT; 2572 break; 2573 } 2574 ret = -EAGAIN; 2575 break; 2576 } else if (!(node->perm & access)) { 2577 ret = -EPERM; 2578 break; 2579 } 2580 2581 _iov = iov + ret; 2582 size = node->size - addr + node->start; 2583 _iov->iov_len = min((u64)len - s, size); 2584 _iov->iov_base = (void __user *)(unsigned long) 2585 (node->userspace_addr + addr - node->start); 2586 s += size; 2587 addr += size; 2588 ++ret; 2589 } 2590 2591 if (ret == -EAGAIN) 2592 vhost_iotlb_miss(vq, addr, access); 2593 return ret; 2594 } 2595 2596 /* Each buffer in the virtqueues is actually a chain of descriptors. This 2597 * function returns the next descriptor in the chain, 2598 * or -1U if we're at the end. */ 2599 static unsigned next_desc(struct vhost_virtqueue *vq, struct vring_desc *desc) 2600 { 2601 unsigned int next; 2602 2603 /* If this descriptor says it doesn't chain, we're done. */ 2604 if (!(desc->flags & cpu_to_vhost16(vq, VRING_DESC_F_NEXT))) 2605 return -1U; 2606 2607 /* Check they're not leading us off end of descriptors. */ 2608 next = vhost16_to_cpu(vq, READ_ONCE(desc->next)); 2609 return next; 2610 } 2611 2612 static int get_indirect(struct vhost_virtqueue *vq, 2613 struct iovec iov[], unsigned int iov_size, 2614 unsigned int *out_num, unsigned int *in_num, 2615 struct vhost_log *log, unsigned int *log_num, 2616 struct vring_desc *indirect) 2617 { 2618 struct vring_desc desc; 2619 unsigned int i = 0, count, found = 0; 2620 u32 len = vhost32_to_cpu(vq, indirect->len); 2621 struct iov_iter from; 2622 int ret, access; 2623 2624 /* Sanity check */ 2625 if (unlikely(len % sizeof desc)) { 2626 vq_err(vq, "Invalid length in indirect descriptor: " 2627 "len 0x%llx not multiple of 0x%zx\n", 2628 (unsigned long long)len, 2629 sizeof desc); 2630 return -EINVAL; 2631 } 2632 2633 ret = translate_desc(vq, vhost64_to_cpu(vq, indirect->addr), len, vq->indirect, 2634 UIO_MAXIOV, VHOST_ACCESS_RO); 2635 if (unlikely(ret < 0)) { 2636 if (ret != -EAGAIN) 2637 vq_err(vq, "Translation failure %d in indirect.\n", ret); 2638 return ret; 2639 } 2640 iov_iter_init(&from, READ, vq->indirect, ret, len); 2641 2642 /* We will use the result as an address to read from, so most 2643 * architectures only need a compiler barrier here. */ 2644 read_barrier_depends(); 2645 2646 count = len / sizeof desc; 2647 /* Buffers are chained via a 16 bit next field, so 2648 * we can have at most 2^16 of these. */ 2649 if (unlikely(count > USHRT_MAX + 1)) { 2650 vq_err(vq, "Indirect buffer length too big: %d\n", 2651 indirect->len); 2652 return -E2BIG; 2653 } 2654 2655 do { 2656 unsigned iov_count = *in_num + *out_num; 2657 if (unlikely(++found > count)) { 2658 vq_err(vq, "Loop detected: last one at %u " 2659 "indirect size %u\n", 2660 i, count); 2661 return -EINVAL; 2662 } 2663 if (unlikely(!copy_from_iter_full(&desc, sizeof(desc), &from))) { 2664 vq_err(vq, "Failed indirect descriptor: idx %d, %zx\n", 2665 i, (size_t)vhost64_to_cpu(vq, indirect->addr) + i * sizeof desc); 2666 return -EINVAL; 2667 } 2668 if (unlikely(desc.flags & cpu_to_vhost16(vq, VRING_DESC_F_INDIRECT))) { 2669 vq_err(vq, "Nested indirect descriptor: idx %d, %zx\n", 2670 i, (size_t)vhost64_to_cpu(vq, indirect->addr) + i * sizeof desc); 2671 return -EINVAL; 2672 } 2673 2674 if (desc.flags & cpu_to_vhost16(vq, VRING_DESC_F_WRITE)) 2675 access = VHOST_ACCESS_WO; 2676 else 2677 access = VHOST_ACCESS_RO; 2678 2679 ret = translate_desc(vq, vhost64_to_cpu(vq, desc.addr), 2680 vhost32_to_cpu(vq, desc.len), iov + iov_count, 2681 iov_size - iov_count, access); 2682 if (unlikely(ret < 0)) { 2683 if (ret != -EAGAIN) 2684 vq_err(vq, "Translation failure %d indirect idx %d\n", 2685 ret, i); 2686 return ret; 2687 } 2688 /* If this is an input descriptor, increment that count. */ 2689 if (access == VHOST_ACCESS_WO) { 2690 *in_num += ret; 2691 if (unlikely(log)) { 2692 log[*log_num].addr = vhost64_to_cpu(vq, desc.addr); 2693 log[*log_num].len = vhost32_to_cpu(vq, desc.len); 2694 ++*log_num; 2695 } 2696 } else { 2697 /* If it's an output descriptor, they're all supposed 2698 * to come before any input descriptors. */ 2699 if (unlikely(*in_num)) { 2700 vq_err(vq, "Indirect descriptor " 2701 "has out after in: idx %d\n", i); 2702 return -EINVAL; 2703 } 2704 *out_num += ret; 2705 } 2706 } while ((i = next_desc(vq, &desc)) != -1); 2707 return 0; 2708 } 2709 2710 /* This looks in the virtqueue and for the first available buffer, and converts 2711 * it to an iovec for convenient access. Since descriptors consist of some 2712 * number of output then some number of input descriptors, it's actually two 2713 * iovecs, but we pack them into one and note how many of each there were. 2714 * 2715 * This function returns the descriptor number found, or vq->num (which is 2716 * never a valid descriptor number) if none was found. A negative code is 2717 * returned on error. */ 2718 int vhost_get_vq_desc(struct vhost_virtqueue *vq, 2719 struct iovec iov[], unsigned int iov_size, 2720 unsigned int *out_num, unsigned int *in_num, 2721 struct vhost_log *log, unsigned int *log_num) 2722 { 2723 struct vring_desc desc; 2724 unsigned int i, head, found = 0; 2725 u16 last_avail_idx; 2726 __virtio16 avail_idx; 2727 __virtio16 ring_head; 2728 int ret, access; 2729 2730 /* Check it isn't doing very strange things with descriptor numbers. */ 2731 last_avail_idx = vq->last_avail_idx; 2732 2733 if (vq->avail_idx == vq->last_avail_idx) { 2734 if (unlikely(vhost_get_avail_idx(vq, &avail_idx))) { 2735 vq_err(vq, "Failed to access avail idx at %p\n", 2736 &vq->avail->idx); 2737 return -EFAULT; 2738 } 2739 vq->avail_idx = vhost16_to_cpu(vq, avail_idx); 2740 2741 if (unlikely((u16)(vq->avail_idx - last_avail_idx) > vq->num)) { 2742 vq_err(vq, "Guest moved used index from %u to %u", 2743 last_avail_idx, vq->avail_idx); 2744 return -EFAULT; 2745 } 2746 2747 /* If there's nothing new since last we looked, return 2748 * invalid. 2749 */ 2750 if (vq->avail_idx == last_avail_idx) 2751 return vq->num; 2752 2753 /* Only get avail ring entries after they have been 2754 * exposed by guest. 2755 */ 2756 smp_rmb(); 2757 } 2758 2759 /* Grab the next descriptor number they're advertising, and increment 2760 * the index we've seen. */ 2761 if (unlikely(vhost_get_avail_head(vq, &ring_head, last_avail_idx))) { 2762 vq_err(vq, "Failed to read head: idx %d address %p\n", 2763 last_avail_idx, 2764 &vq->avail->ring[last_avail_idx % vq->num]); 2765 return -EFAULT; 2766 } 2767 2768 head = vhost16_to_cpu(vq, ring_head); 2769 2770 /* If their number is silly, that's an error. */ 2771 if (unlikely(head >= vq->num)) { 2772 vq_err(vq, "Guest says index %u > %u is available", 2773 head, vq->num); 2774 return -EINVAL; 2775 } 2776 2777 /* When we start there are none of either input nor output. */ 2778 *out_num = *in_num = 0; 2779 if (unlikely(log)) 2780 *log_num = 0; 2781 2782 i = head; 2783 do { 2784 unsigned iov_count = *in_num + *out_num; 2785 if (unlikely(i >= vq->num)) { 2786 vq_err(vq, "Desc index is %u > %u, head = %u", 2787 i, vq->num, head); 2788 return -EINVAL; 2789 } 2790 if (unlikely(++found > vq->num)) { 2791 vq_err(vq, "Loop detected: last one at %u " 2792 "vq size %u head %u\n", 2793 i, vq->num, head); 2794 return -EINVAL; 2795 } 2796 ret = vhost_get_desc(vq, &desc, i); 2797 if (unlikely(ret)) { 2798 vq_err(vq, "Failed to get descriptor: idx %d addr %p\n", 2799 i, vq->desc + i); 2800 return -EFAULT; 2801 } 2802 if (desc.flags & cpu_to_vhost16(vq, VRING_DESC_F_INDIRECT)) { 2803 ret = get_indirect(vq, iov, iov_size, 2804 out_num, in_num, 2805 log, log_num, &desc); 2806 if (unlikely(ret < 0)) { 2807 if (ret != -EAGAIN) 2808 vq_err(vq, "Failure detected " 2809 "in indirect descriptor at idx %d\n", i); 2810 return ret; 2811 } 2812 continue; 2813 } 2814 2815 if (desc.flags & cpu_to_vhost16(vq, VRING_DESC_F_WRITE)) 2816 access = VHOST_ACCESS_WO; 2817 else 2818 access = VHOST_ACCESS_RO; 2819 ret = translate_desc(vq, vhost64_to_cpu(vq, desc.addr), 2820 vhost32_to_cpu(vq, desc.len), iov + iov_count, 2821 iov_size - iov_count, access); 2822 if (unlikely(ret < 0)) { 2823 if (ret != -EAGAIN) 2824 vq_err(vq, "Translation failure %d descriptor idx %d\n", 2825 ret, i); 2826 return ret; 2827 } 2828 if (access == VHOST_ACCESS_WO) { 2829 /* If this is an input descriptor, 2830 * increment that count. */ 2831 *in_num += ret; 2832 if (unlikely(log)) { 2833 log[*log_num].addr = vhost64_to_cpu(vq, desc.addr); 2834 log[*log_num].len = vhost32_to_cpu(vq, desc.len); 2835 ++*log_num; 2836 } 2837 } else { 2838 /* If it's an output descriptor, they're all supposed 2839 * to come before any input descriptors. */ 2840 if (unlikely(*in_num)) { 2841 vq_err(vq, "Descriptor has out after in: " 2842 "idx %d\n", i); 2843 return -EINVAL; 2844 } 2845 *out_num += ret; 2846 } 2847 } while ((i = next_desc(vq, &desc)) != -1); 2848 2849 /* On success, increment avail index. */ 2850 vq->last_avail_idx++; 2851 2852 /* Assume notifications from guest are disabled at this point, 2853 * if they aren't we would need to update avail_event index. */ 2854 BUG_ON(!(vq->used_flags & VRING_USED_F_NO_NOTIFY)); 2855 return head; 2856 } 2857 EXPORT_SYMBOL_GPL(vhost_get_vq_desc); 2858 2859 /* Reverse the effect of vhost_get_vq_desc. Useful for error handling. */ 2860 void vhost_discard_vq_desc(struct vhost_virtqueue *vq, int n) 2861 { 2862 vq->last_avail_idx -= n; 2863 } 2864 EXPORT_SYMBOL_GPL(vhost_discard_vq_desc); 2865 2866 /* After we've used one of their buffers, we tell them about it. We'll then 2867 * want to notify the guest, using eventfd. */ 2868 int vhost_add_used(struct vhost_virtqueue *vq, unsigned int head, int len) 2869 { 2870 struct vring_used_elem heads = { 2871 cpu_to_vhost32(vq, head), 2872 cpu_to_vhost32(vq, len) 2873 }; 2874 2875 return vhost_add_used_n(vq, &heads, 1); 2876 } 2877 EXPORT_SYMBOL_GPL(vhost_add_used); 2878 2879 static int __vhost_add_used_n(struct vhost_virtqueue *vq, 2880 struct vring_used_elem *heads, 2881 unsigned count) 2882 { 2883 struct vring_used_elem __user *used; 2884 u16 old, new; 2885 int start; 2886 2887 start = vq->last_used_idx & (vq->num - 1); 2888 used = vq->used->ring + start; 2889 if (vhost_put_used(vq, heads, start, count)) { 2890 vq_err(vq, "Failed to write used"); 2891 return -EFAULT; 2892 } 2893 if (unlikely(vq->log_used)) { 2894 /* Make sure data is seen before log. */ 2895 smp_wmb(); 2896 /* Log used ring entry write. */ 2897 log_used(vq, ((void __user *)used - (void __user *)vq->used), 2898 count * sizeof *used); 2899 } 2900 old = vq->last_used_idx; 2901 new = (vq->last_used_idx += count); 2902 /* If the driver never bothers to signal in a very long while, 2903 * used index might wrap around. If that happens, invalidate 2904 * signalled_used index we stored. TODO: make sure driver 2905 * signals at least once in 2^16 and remove this. */ 2906 if (unlikely((u16)(new - vq->signalled_used) < (u16)(new - old))) 2907 vq->signalled_used_valid = false; 2908 return 0; 2909 } 2910 2911 /* After we've used one of their buffers, we tell them about it. We'll then 2912 * want to notify the guest, using eventfd. */ 2913 int vhost_add_used_n(struct vhost_virtqueue *vq, struct vring_used_elem *heads, 2914 unsigned count) 2915 { 2916 int start, n, r; 2917 2918 start = vq->last_used_idx & (vq->num - 1); 2919 n = vq->num - start; 2920 if (n < count) { 2921 r = __vhost_add_used_n(vq, heads, n); 2922 if (r < 0) 2923 return r; 2924 heads += n; 2925 count -= n; 2926 } 2927 r = __vhost_add_used_n(vq, heads, count); 2928 2929 /* Make sure buffer is written before we update index. */ 2930 smp_wmb(); 2931 if (vhost_put_used_idx(vq)) { 2932 vq_err(vq, "Failed to increment used idx"); 2933 return -EFAULT; 2934 } 2935 if (unlikely(vq->log_used)) { 2936 /* Make sure used idx is seen before log. */ 2937 smp_wmb(); 2938 /* Log used index update. */ 2939 log_used(vq, offsetof(struct vring_used, idx), 2940 sizeof vq->used->idx); 2941 if (vq->log_ctx) 2942 eventfd_signal(vq->log_ctx, 1); 2943 } 2944 return r; 2945 } 2946 EXPORT_SYMBOL_GPL(vhost_add_used_n); 2947 2948 static bool vhost_notify(struct vhost_dev *dev, struct vhost_virtqueue *vq) 2949 { 2950 __u16 old, new; 2951 __virtio16 event; 2952 bool v; 2953 /* Flush out used index updates. This is paired 2954 * with the barrier that the Guest executes when enabling 2955 * interrupts. */ 2956 smp_mb(); 2957 2958 if (vhost_has_feature(vq, VIRTIO_F_NOTIFY_ON_EMPTY) && 2959 unlikely(vq->avail_idx == vq->last_avail_idx)) 2960 return true; 2961 2962 if (!vhost_has_feature(vq, VIRTIO_RING_F_EVENT_IDX)) { 2963 __virtio16 flags; 2964 if (vhost_get_avail_flags(vq, &flags)) { 2965 vq_err(vq, "Failed to get flags"); 2966 return true; 2967 } 2968 return !(flags & cpu_to_vhost16(vq, VRING_AVAIL_F_NO_INTERRUPT)); 2969 } 2970 old = vq->signalled_used; 2971 v = vq->signalled_used_valid; 2972 new = vq->signalled_used = vq->last_used_idx; 2973 vq->signalled_used_valid = true; 2974 2975 if (unlikely(!v)) 2976 return true; 2977 2978 if (vhost_get_used_event(vq, &event)) { 2979 vq_err(vq, "Failed to get used event idx"); 2980 return true; 2981 } 2982 return vring_need_event(vhost16_to_cpu(vq, event), new, old); 2983 } 2984 2985 /* This actually signals the guest, using eventfd. */ 2986 void vhost_signal(struct vhost_dev *dev, struct vhost_virtqueue *vq) 2987 { 2988 /* Signal the Guest tell them we used something up. */ 2989 if (vq->call_ctx && vhost_notify(dev, vq)) 2990 eventfd_signal(vq->call_ctx, 1); 2991 } 2992 EXPORT_SYMBOL_GPL(vhost_signal); 2993 2994 /* And here's the combo meal deal. Supersize me! */ 2995 void vhost_add_used_and_signal(struct vhost_dev *dev, 2996 struct vhost_virtqueue *vq, 2997 unsigned int head, int len) 2998 { 2999 vhost_add_used(vq, head, len); 3000 vhost_signal(dev, vq); 3001 } 3002 EXPORT_SYMBOL_GPL(vhost_add_used_and_signal); 3003 3004 /* multi-buffer version of vhost_add_used_and_signal */ 3005 void vhost_add_used_and_signal_n(struct vhost_dev *dev, 3006 struct vhost_virtqueue *vq, 3007 struct vring_used_elem *heads, unsigned count) 3008 { 3009 vhost_add_used_n(vq, heads, count); 3010 vhost_signal(dev, vq); 3011 } 3012 EXPORT_SYMBOL_GPL(vhost_add_used_and_signal_n); 3013 3014 /* return true if we're sure that avaiable ring is empty */ 3015 bool vhost_vq_avail_empty(struct vhost_dev *dev, struct vhost_virtqueue *vq) 3016 { 3017 __virtio16 avail_idx; 3018 int r; 3019 3020 if (vq->avail_idx != vq->last_avail_idx) 3021 return false; 3022 3023 r = vhost_get_avail_idx(vq, &avail_idx); 3024 if (unlikely(r)) 3025 return false; 3026 vq->avail_idx = vhost16_to_cpu(vq, avail_idx); 3027 3028 return vq->avail_idx == vq->last_avail_idx; 3029 } 3030 EXPORT_SYMBOL_GPL(vhost_vq_avail_empty); 3031 3032 /* OK, now we need to know about added descriptors. */ 3033 bool vhost_enable_notify(struct vhost_dev *dev, struct vhost_virtqueue *vq) 3034 { 3035 __virtio16 avail_idx; 3036 int r; 3037 3038 if (!(vq->used_flags & VRING_USED_F_NO_NOTIFY)) 3039 return false; 3040 vq->used_flags &= ~VRING_USED_F_NO_NOTIFY; 3041 if (!vhost_has_feature(vq, VIRTIO_RING_F_EVENT_IDX)) { 3042 r = vhost_update_used_flags(vq); 3043 if (r) { 3044 vq_err(vq, "Failed to enable notification at %p: %d\n", 3045 &vq->used->flags, r); 3046 return false; 3047 } 3048 } else { 3049 r = vhost_update_avail_event(vq, vq->avail_idx); 3050 if (r) { 3051 vq_err(vq, "Failed to update avail event index at %p: %d\n", 3052 vhost_avail_event(vq), r); 3053 return false; 3054 } 3055 } 3056 /* They could have slipped one in as we were doing that: make 3057 * sure it's written, then check again. */ 3058 smp_mb(); 3059 r = vhost_get_avail_idx(vq, &avail_idx); 3060 if (r) { 3061 vq_err(vq, "Failed to check avail idx at %p: %d\n", 3062 &vq->avail->idx, r); 3063 return false; 3064 } 3065 3066 return vhost16_to_cpu(vq, avail_idx) != vq->avail_idx; 3067 } 3068 EXPORT_SYMBOL_GPL(vhost_enable_notify); 3069 3070 /* We don't need to be notified again. */ 3071 void vhost_disable_notify(struct vhost_dev *dev, struct vhost_virtqueue *vq) 3072 { 3073 int r; 3074 3075 if (vq->used_flags & VRING_USED_F_NO_NOTIFY) 3076 return; 3077 vq->used_flags |= VRING_USED_F_NO_NOTIFY; 3078 if (!vhost_has_feature(vq, VIRTIO_RING_F_EVENT_IDX)) { 3079 r = vhost_update_used_flags(vq); 3080 if (r) 3081 vq_err(vq, "Failed to enable notification at %p: %d\n", 3082 &vq->used->flags, r); 3083 } 3084 } 3085 EXPORT_SYMBOL_GPL(vhost_disable_notify); 3086 3087 /* Create a new message. */ 3088 struct vhost_msg_node *vhost_new_msg(struct vhost_virtqueue *vq, int type) 3089 { 3090 struct vhost_msg_node *node = kmalloc(sizeof *node, GFP_KERNEL); 3091 if (!node) 3092 return NULL; 3093 3094 /* Make sure all padding within the structure is initialized. */ 3095 memset(&node->msg, 0, sizeof node->msg); 3096 node->vq = vq; 3097 node->msg.type = type; 3098 return node; 3099 } 3100 EXPORT_SYMBOL_GPL(vhost_new_msg); 3101 3102 void vhost_enqueue_msg(struct vhost_dev *dev, struct list_head *head, 3103 struct vhost_msg_node *node) 3104 { 3105 spin_lock(&dev->iotlb_lock); 3106 list_add_tail(&node->node, head); 3107 spin_unlock(&dev->iotlb_lock); 3108 3109 wake_up_interruptible_poll(&dev->wait, EPOLLIN | EPOLLRDNORM); 3110 } 3111 EXPORT_SYMBOL_GPL(vhost_enqueue_msg); 3112 3113 struct vhost_msg_node *vhost_dequeue_msg(struct vhost_dev *dev, 3114 struct list_head *head) 3115 { 3116 struct vhost_msg_node *node = NULL; 3117 3118 spin_lock(&dev->iotlb_lock); 3119 if (!list_empty(head)) { 3120 node = list_first_entry(head, struct vhost_msg_node, 3121 node); 3122 list_del(&node->node); 3123 } 3124 spin_unlock(&dev->iotlb_lock); 3125 3126 return node; 3127 } 3128 EXPORT_SYMBOL_GPL(vhost_dequeue_msg); 3129 3130 3131 static int __init vhost_init(void) 3132 { 3133 return 0; 3134 } 3135 3136 static void __exit vhost_exit(void) 3137 { 3138 } 3139 3140 module_init(vhost_init); 3141 module_exit(vhost_exit); 3142 3143 MODULE_VERSION("0.0.1"); 3144 MODULE_LICENSE("GPL v2"); 3145 MODULE_AUTHOR("Michael S. Tsirkin"); 3146 MODULE_DESCRIPTION("Host kernel accelerator for virtio"); 3147