1 // SPDX-License-Identifier: GPL-2.0-only 2 /* Copyright (C) 2009 Red Hat, Inc. 3 * Copyright (C) 2006 Rusty Russell IBM Corporation 4 * 5 * Author: Michael S. Tsirkin <mst@redhat.com> 6 * 7 * Inspiration, some code, and most witty comments come from 8 * Documentation/virtual/lguest/lguest.c, by Rusty Russell 9 * 10 * Generic code for virtio server in host kernel. 11 */ 12 13 #include <linux/eventfd.h> 14 #include <linux/vhost.h> 15 #include <linux/uio.h> 16 #include <linux/mm.h> 17 #include <linux/miscdevice.h> 18 #include <linux/mutex.h> 19 #include <linux/poll.h> 20 #include <linux/file.h> 21 #include <linux/highmem.h> 22 #include <linux/slab.h> 23 #include <linux/vmalloc.h> 24 #include <linux/kthread.h> 25 #include <linux/module.h> 26 #include <linux/sort.h> 27 #include <linux/sched/mm.h> 28 #include <linux/sched/signal.h> 29 #include <linux/sched/vhost_task.h> 30 #include <linux/interval_tree_generic.h> 31 #include <linux/nospec.h> 32 #include <linux/kcov.h> 33 34 #include "vhost.h" 35 36 static ushort max_mem_regions = 64; 37 module_param(max_mem_regions, ushort, 0444); 38 MODULE_PARM_DESC(max_mem_regions, 39 "Maximum number of memory regions in memory map. (default: 64)"); 40 static int max_iotlb_entries = 2048; 41 module_param(max_iotlb_entries, int, 0444); 42 MODULE_PARM_DESC(max_iotlb_entries, 43 "Maximum number of iotlb entries. (default: 2048)"); 44 45 enum { 46 VHOST_MEMORY_F_LOG = 0x1, 47 }; 48 49 #define vhost_used_event(vq) ((__virtio16 __user *)&vq->avail->ring[vq->num]) 50 #define vhost_avail_event(vq) ((__virtio16 __user *)&vq->used->ring[vq->num]) 51 52 #ifdef CONFIG_VHOST_CROSS_ENDIAN_LEGACY 53 static void vhost_disable_cross_endian(struct vhost_virtqueue *vq) 54 { 55 vq->user_be = !virtio_legacy_is_little_endian(); 56 } 57 58 static void vhost_enable_cross_endian_big(struct vhost_virtqueue *vq) 59 { 60 vq->user_be = true; 61 } 62 63 static void vhost_enable_cross_endian_little(struct vhost_virtqueue *vq) 64 { 65 vq->user_be = false; 66 } 67 68 static long vhost_set_vring_endian(struct vhost_virtqueue *vq, int __user *argp) 69 { 70 struct vhost_vring_state s; 71 72 if (vq->private_data) 73 return -EBUSY; 74 75 if (copy_from_user(&s, argp, sizeof(s))) 76 return -EFAULT; 77 78 if (s.num != VHOST_VRING_LITTLE_ENDIAN && 79 s.num != VHOST_VRING_BIG_ENDIAN) 80 return -EINVAL; 81 82 if (s.num == VHOST_VRING_BIG_ENDIAN) 83 vhost_enable_cross_endian_big(vq); 84 else 85 vhost_enable_cross_endian_little(vq); 86 87 return 0; 88 } 89 90 static long vhost_get_vring_endian(struct vhost_virtqueue *vq, u32 idx, 91 int __user *argp) 92 { 93 struct vhost_vring_state s = { 94 .index = idx, 95 .num = vq->user_be 96 }; 97 98 if (copy_to_user(argp, &s, sizeof(s))) 99 return -EFAULT; 100 101 return 0; 102 } 103 104 static void vhost_init_is_le(struct vhost_virtqueue *vq) 105 { 106 /* Note for legacy virtio: user_be is initialized at reset time 107 * according to the host endianness. If userspace does not set an 108 * explicit endianness, the default behavior is native endian, as 109 * expected by legacy virtio. 110 */ 111 vq->is_le = vhost_has_feature(vq, VIRTIO_F_VERSION_1) || !vq->user_be; 112 } 113 #else 114 static void vhost_disable_cross_endian(struct vhost_virtqueue *vq) 115 { 116 } 117 118 static long vhost_set_vring_endian(struct vhost_virtqueue *vq, int __user *argp) 119 { 120 return -ENOIOCTLCMD; 121 } 122 123 static long vhost_get_vring_endian(struct vhost_virtqueue *vq, u32 idx, 124 int __user *argp) 125 { 126 return -ENOIOCTLCMD; 127 } 128 129 static void vhost_init_is_le(struct vhost_virtqueue *vq) 130 { 131 vq->is_le = vhost_has_feature(vq, VIRTIO_F_VERSION_1) 132 || virtio_legacy_is_little_endian(); 133 } 134 #endif /* CONFIG_VHOST_CROSS_ENDIAN_LEGACY */ 135 136 static void vhost_reset_is_le(struct vhost_virtqueue *vq) 137 { 138 vhost_init_is_le(vq); 139 } 140 141 struct vhost_flush_struct { 142 struct vhost_work work; 143 struct completion wait_event; 144 }; 145 146 static void vhost_flush_work(struct vhost_work *work) 147 { 148 struct vhost_flush_struct *s; 149 150 s = container_of(work, struct vhost_flush_struct, work); 151 complete(&s->wait_event); 152 } 153 154 static void vhost_poll_func(struct file *file, wait_queue_head_t *wqh, 155 poll_table *pt) 156 { 157 struct vhost_poll *poll; 158 159 poll = container_of(pt, struct vhost_poll, table); 160 poll->wqh = wqh; 161 add_wait_queue(wqh, &poll->wait); 162 } 163 164 static int vhost_poll_wakeup(wait_queue_entry_t *wait, unsigned mode, int sync, 165 void *key) 166 { 167 struct vhost_poll *poll = container_of(wait, struct vhost_poll, wait); 168 struct vhost_work *work = &poll->work; 169 170 if (!(key_to_poll(key) & poll->mask)) 171 return 0; 172 173 if (!poll->dev->use_worker) 174 work->fn(work); 175 else 176 vhost_poll_queue(poll); 177 178 return 0; 179 } 180 181 void vhost_work_init(struct vhost_work *work, vhost_work_fn_t fn) 182 { 183 clear_bit(VHOST_WORK_QUEUED, &work->flags); 184 work->fn = fn; 185 } 186 EXPORT_SYMBOL_GPL(vhost_work_init); 187 188 /* Init poll structure */ 189 void vhost_poll_init(struct vhost_poll *poll, vhost_work_fn_t fn, 190 __poll_t mask, struct vhost_dev *dev, 191 struct vhost_virtqueue *vq) 192 { 193 init_waitqueue_func_entry(&poll->wait, vhost_poll_wakeup); 194 init_poll_funcptr(&poll->table, vhost_poll_func); 195 poll->mask = mask; 196 poll->dev = dev; 197 poll->wqh = NULL; 198 poll->vq = vq; 199 200 vhost_work_init(&poll->work, fn); 201 } 202 EXPORT_SYMBOL_GPL(vhost_poll_init); 203 204 /* Start polling a file. We add ourselves to file's wait queue. The caller must 205 * keep a reference to a file until after vhost_poll_stop is called. */ 206 int vhost_poll_start(struct vhost_poll *poll, struct file *file) 207 { 208 __poll_t mask; 209 210 if (poll->wqh) 211 return 0; 212 213 mask = vfs_poll(file, &poll->table); 214 if (mask) 215 vhost_poll_wakeup(&poll->wait, 0, 0, poll_to_key(mask)); 216 if (mask & EPOLLERR) { 217 vhost_poll_stop(poll); 218 return -EINVAL; 219 } 220 221 return 0; 222 } 223 EXPORT_SYMBOL_GPL(vhost_poll_start); 224 225 /* Stop polling a file. After this function returns, it becomes safe to drop the 226 * file reference. You must also flush afterwards. */ 227 void vhost_poll_stop(struct vhost_poll *poll) 228 { 229 if (poll->wqh) { 230 remove_wait_queue(poll->wqh, &poll->wait); 231 poll->wqh = NULL; 232 } 233 } 234 EXPORT_SYMBOL_GPL(vhost_poll_stop); 235 236 static void vhost_worker_queue(struct vhost_worker *worker, 237 struct vhost_work *work) 238 { 239 if (!test_and_set_bit(VHOST_WORK_QUEUED, &work->flags)) { 240 /* We can only add the work to the list after we're 241 * sure it was not in the list. 242 * test_and_set_bit() implies a memory barrier. 243 */ 244 llist_add(&work->node, &worker->work_list); 245 vhost_task_wake(worker->vtsk); 246 } 247 } 248 249 bool vhost_vq_work_queue(struct vhost_virtqueue *vq, struct vhost_work *work) 250 { 251 struct vhost_worker *worker; 252 bool queued = false; 253 254 rcu_read_lock(); 255 worker = rcu_dereference(vq->worker); 256 if (worker) { 257 queued = true; 258 vhost_worker_queue(worker, work); 259 } 260 rcu_read_unlock(); 261 262 return queued; 263 } 264 EXPORT_SYMBOL_GPL(vhost_vq_work_queue); 265 266 void vhost_vq_flush(struct vhost_virtqueue *vq) 267 { 268 struct vhost_flush_struct flush; 269 270 init_completion(&flush.wait_event); 271 vhost_work_init(&flush.work, vhost_flush_work); 272 273 if (vhost_vq_work_queue(vq, &flush.work)) 274 wait_for_completion(&flush.wait_event); 275 } 276 EXPORT_SYMBOL_GPL(vhost_vq_flush); 277 278 /** 279 * __vhost_worker_flush - flush a worker 280 * @worker: worker to flush 281 * 282 * The worker's flush_mutex must be held. 283 */ 284 static void __vhost_worker_flush(struct vhost_worker *worker) 285 { 286 struct vhost_flush_struct flush; 287 288 if (!worker->attachment_cnt || worker->killed) 289 return; 290 291 init_completion(&flush.wait_event); 292 vhost_work_init(&flush.work, vhost_flush_work); 293 294 vhost_worker_queue(worker, &flush.work); 295 /* 296 * Drop mutex in case our worker is killed and it needs to take the 297 * mutex to force cleanup. 298 */ 299 mutex_unlock(&worker->mutex); 300 wait_for_completion(&flush.wait_event); 301 mutex_lock(&worker->mutex); 302 } 303 304 static void vhost_worker_flush(struct vhost_worker *worker) 305 { 306 mutex_lock(&worker->mutex); 307 __vhost_worker_flush(worker); 308 mutex_unlock(&worker->mutex); 309 } 310 311 void vhost_dev_flush(struct vhost_dev *dev) 312 { 313 struct vhost_worker *worker; 314 unsigned long i; 315 316 xa_for_each(&dev->worker_xa, i, worker) 317 vhost_worker_flush(worker); 318 } 319 EXPORT_SYMBOL_GPL(vhost_dev_flush); 320 321 /* A lockless hint for busy polling code to exit the loop */ 322 bool vhost_vq_has_work(struct vhost_virtqueue *vq) 323 { 324 struct vhost_worker *worker; 325 bool has_work = false; 326 327 rcu_read_lock(); 328 worker = rcu_dereference(vq->worker); 329 if (worker && !llist_empty(&worker->work_list)) 330 has_work = true; 331 rcu_read_unlock(); 332 333 return has_work; 334 } 335 EXPORT_SYMBOL_GPL(vhost_vq_has_work); 336 337 void vhost_poll_queue(struct vhost_poll *poll) 338 { 339 vhost_vq_work_queue(poll->vq, &poll->work); 340 } 341 EXPORT_SYMBOL_GPL(vhost_poll_queue); 342 343 static void __vhost_vq_meta_reset(struct vhost_virtqueue *vq) 344 { 345 int j; 346 347 for (j = 0; j < VHOST_NUM_ADDRS; j++) 348 vq->meta_iotlb[j] = NULL; 349 } 350 351 static void vhost_vq_meta_reset(struct vhost_dev *d) 352 { 353 int i; 354 355 for (i = 0; i < d->nvqs; ++i) 356 __vhost_vq_meta_reset(d->vqs[i]); 357 } 358 359 static void vhost_vring_call_reset(struct vhost_vring_call *call_ctx) 360 { 361 call_ctx->ctx = NULL; 362 memset(&call_ctx->producer, 0x0, sizeof(struct irq_bypass_producer)); 363 } 364 365 bool vhost_vq_is_setup(struct vhost_virtqueue *vq) 366 { 367 return vq->avail && vq->desc && vq->used && vhost_vq_access_ok(vq); 368 } 369 EXPORT_SYMBOL_GPL(vhost_vq_is_setup); 370 371 static void vhost_vq_reset(struct vhost_dev *dev, 372 struct vhost_virtqueue *vq) 373 { 374 vq->num = 1; 375 vq->desc = NULL; 376 vq->avail = NULL; 377 vq->used = NULL; 378 vq->last_avail_idx = 0; 379 vq->avail_idx = 0; 380 vq->last_used_idx = 0; 381 vq->signalled_used = 0; 382 vq->signalled_used_valid = false; 383 vq->used_flags = 0; 384 vq->log_used = false; 385 vq->log_addr = -1ull; 386 vq->private_data = NULL; 387 vq->acked_features = 0; 388 vq->acked_backend_features = 0; 389 vq->log_base = NULL; 390 vq->error_ctx = NULL; 391 vq->kick = NULL; 392 vq->log_ctx = NULL; 393 vhost_disable_cross_endian(vq); 394 vhost_reset_is_le(vq); 395 vq->busyloop_timeout = 0; 396 vq->umem = NULL; 397 vq->iotlb = NULL; 398 rcu_assign_pointer(vq->worker, NULL); 399 vhost_vring_call_reset(&vq->call_ctx); 400 __vhost_vq_meta_reset(vq); 401 } 402 403 static bool vhost_run_work_list(void *data) 404 { 405 struct vhost_worker *worker = data; 406 struct vhost_work *work, *work_next; 407 struct llist_node *node; 408 409 node = llist_del_all(&worker->work_list); 410 if (node) { 411 __set_current_state(TASK_RUNNING); 412 413 node = llist_reverse_order(node); 414 /* make sure flag is seen after deletion */ 415 smp_wmb(); 416 llist_for_each_entry_safe(work, work_next, node, node) { 417 clear_bit(VHOST_WORK_QUEUED, &work->flags); 418 kcov_remote_start_common(worker->kcov_handle); 419 work->fn(work); 420 kcov_remote_stop(); 421 cond_resched(); 422 } 423 } 424 425 return !!node; 426 } 427 428 static void vhost_worker_killed(void *data) 429 { 430 struct vhost_worker *worker = data; 431 struct vhost_dev *dev = worker->dev; 432 struct vhost_virtqueue *vq; 433 int i, attach_cnt = 0; 434 435 mutex_lock(&worker->mutex); 436 worker->killed = true; 437 438 for (i = 0; i < dev->nvqs; i++) { 439 vq = dev->vqs[i]; 440 441 mutex_lock(&vq->mutex); 442 if (worker == 443 rcu_dereference_check(vq->worker, 444 lockdep_is_held(&vq->mutex))) { 445 rcu_assign_pointer(vq->worker, NULL); 446 attach_cnt++; 447 } 448 mutex_unlock(&vq->mutex); 449 } 450 451 worker->attachment_cnt -= attach_cnt; 452 if (attach_cnt) 453 synchronize_rcu(); 454 /* 455 * Finish vhost_worker_flush calls and any other works that snuck in 456 * before the synchronize_rcu. 457 */ 458 vhost_run_work_list(worker); 459 mutex_unlock(&worker->mutex); 460 } 461 462 static void vhost_vq_free_iovecs(struct vhost_virtqueue *vq) 463 { 464 kfree(vq->indirect); 465 vq->indirect = NULL; 466 kfree(vq->log); 467 vq->log = NULL; 468 kfree(vq->heads); 469 vq->heads = NULL; 470 } 471 472 /* Helper to allocate iovec buffers for all vqs. */ 473 static long vhost_dev_alloc_iovecs(struct vhost_dev *dev) 474 { 475 struct vhost_virtqueue *vq; 476 int i; 477 478 for (i = 0; i < dev->nvqs; ++i) { 479 vq = dev->vqs[i]; 480 vq->indirect = kmalloc_array(UIO_MAXIOV, 481 sizeof(*vq->indirect), 482 GFP_KERNEL); 483 vq->log = kmalloc_array(dev->iov_limit, sizeof(*vq->log), 484 GFP_KERNEL); 485 vq->heads = kmalloc_array(dev->iov_limit, sizeof(*vq->heads), 486 GFP_KERNEL); 487 if (!vq->indirect || !vq->log || !vq->heads) 488 goto err_nomem; 489 } 490 return 0; 491 492 err_nomem: 493 for (; i >= 0; --i) 494 vhost_vq_free_iovecs(dev->vqs[i]); 495 return -ENOMEM; 496 } 497 498 static void vhost_dev_free_iovecs(struct vhost_dev *dev) 499 { 500 int i; 501 502 for (i = 0; i < dev->nvqs; ++i) 503 vhost_vq_free_iovecs(dev->vqs[i]); 504 } 505 506 bool vhost_exceeds_weight(struct vhost_virtqueue *vq, 507 int pkts, int total_len) 508 { 509 struct vhost_dev *dev = vq->dev; 510 511 if ((dev->byte_weight && total_len >= dev->byte_weight) || 512 pkts >= dev->weight) { 513 vhost_poll_queue(&vq->poll); 514 return true; 515 } 516 517 return false; 518 } 519 EXPORT_SYMBOL_GPL(vhost_exceeds_weight); 520 521 static size_t vhost_get_avail_size(struct vhost_virtqueue *vq, 522 unsigned int num) 523 { 524 size_t event __maybe_unused = 525 vhost_has_feature(vq, VIRTIO_RING_F_EVENT_IDX) ? 2 : 0; 526 527 return size_add(struct_size(vq->avail, ring, num), event); 528 } 529 530 static size_t vhost_get_used_size(struct vhost_virtqueue *vq, 531 unsigned int num) 532 { 533 size_t event __maybe_unused = 534 vhost_has_feature(vq, VIRTIO_RING_F_EVENT_IDX) ? 2 : 0; 535 536 return size_add(struct_size(vq->used, ring, num), event); 537 } 538 539 static size_t vhost_get_desc_size(struct vhost_virtqueue *vq, 540 unsigned int num) 541 { 542 return sizeof(*vq->desc) * num; 543 } 544 545 void vhost_dev_init(struct vhost_dev *dev, 546 struct vhost_virtqueue **vqs, int nvqs, 547 int iov_limit, int weight, int byte_weight, 548 bool use_worker, 549 int (*msg_handler)(struct vhost_dev *dev, u32 asid, 550 struct vhost_iotlb_msg *msg)) 551 { 552 struct vhost_virtqueue *vq; 553 int i; 554 555 dev->vqs = vqs; 556 dev->nvqs = nvqs; 557 mutex_init(&dev->mutex); 558 dev->log_ctx = NULL; 559 dev->umem = NULL; 560 dev->iotlb = NULL; 561 dev->mm = NULL; 562 dev->iov_limit = iov_limit; 563 dev->weight = weight; 564 dev->byte_weight = byte_weight; 565 dev->use_worker = use_worker; 566 dev->msg_handler = msg_handler; 567 init_waitqueue_head(&dev->wait); 568 INIT_LIST_HEAD(&dev->read_list); 569 INIT_LIST_HEAD(&dev->pending_list); 570 spin_lock_init(&dev->iotlb_lock); 571 xa_init_flags(&dev->worker_xa, XA_FLAGS_ALLOC); 572 573 for (i = 0; i < dev->nvqs; ++i) { 574 vq = dev->vqs[i]; 575 vq->log = NULL; 576 vq->indirect = NULL; 577 vq->heads = NULL; 578 vq->dev = dev; 579 mutex_init(&vq->mutex); 580 vhost_vq_reset(dev, vq); 581 if (vq->handle_kick) 582 vhost_poll_init(&vq->poll, vq->handle_kick, 583 EPOLLIN, dev, vq); 584 } 585 } 586 EXPORT_SYMBOL_GPL(vhost_dev_init); 587 588 /* Caller should have device mutex */ 589 long vhost_dev_check_owner(struct vhost_dev *dev) 590 { 591 /* Are you the owner? If not, I don't think you mean to do that */ 592 return dev->mm == current->mm ? 0 : -EPERM; 593 } 594 EXPORT_SYMBOL_GPL(vhost_dev_check_owner); 595 596 /* Caller should have device mutex */ 597 bool vhost_dev_has_owner(struct vhost_dev *dev) 598 { 599 return dev->mm; 600 } 601 EXPORT_SYMBOL_GPL(vhost_dev_has_owner); 602 603 static void vhost_attach_mm(struct vhost_dev *dev) 604 { 605 /* No owner, become one */ 606 if (dev->use_worker) { 607 dev->mm = get_task_mm(current); 608 } else { 609 /* vDPA device does not use worker thead, so there's 610 * no need to hold the address space for mm. This help 611 * to avoid deadlock in the case of mmap() which may 612 * held the refcnt of the file and depends on release 613 * method to remove vma. 614 */ 615 dev->mm = current->mm; 616 mmgrab(dev->mm); 617 } 618 } 619 620 static void vhost_detach_mm(struct vhost_dev *dev) 621 { 622 if (!dev->mm) 623 return; 624 625 if (dev->use_worker) 626 mmput(dev->mm); 627 else 628 mmdrop(dev->mm); 629 630 dev->mm = NULL; 631 } 632 633 static void vhost_worker_destroy(struct vhost_dev *dev, 634 struct vhost_worker *worker) 635 { 636 if (!worker) 637 return; 638 639 WARN_ON(!llist_empty(&worker->work_list)); 640 xa_erase(&dev->worker_xa, worker->id); 641 vhost_task_stop(worker->vtsk); 642 kfree(worker); 643 } 644 645 static void vhost_workers_free(struct vhost_dev *dev) 646 { 647 struct vhost_worker *worker; 648 unsigned long i; 649 650 if (!dev->use_worker) 651 return; 652 653 for (i = 0; i < dev->nvqs; i++) 654 rcu_assign_pointer(dev->vqs[i]->worker, NULL); 655 /* 656 * Free the default worker we created and cleanup workers userspace 657 * created but couldn't clean up (it forgot or crashed). 658 */ 659 xa_for_each(&dev->worker_xa, i, worker) 660 vhost_worker_destroy(dev, worker); 661 xa_destroy(&dev->worker_xa); 662 } 663 664 static struct vhost_worker *vhost_worker_create(struct vhost_dev *dev) 665 { 666 struct vhost_worker *worker; 667 struct vhost_task *vtsk; 668 char name[TASK_COMM_LEN]; 669 int ret; 670 u32 id; 671 672 worker = kzalloc(sizeof(*worker), GFP_KERNEL_ACCOUNT); 673 if (!worker) 674 return NULL; 675 676 worker->dev = dev; 677 snprintf(name, sizeof(name), "vhost-%d", current->pid); 678 679 vtsk = vhost_task_create(vhost_run_work_list, vhost_worker_killed, 680 worker, name); 681 if (!vtsk) 682 goto free_worker; 683 684 mutex_init(&worker->mutex); 685 init_llist_head(&worker->work_list); 686 worker->kcov_handle = kcov_common_handle(); 687 worker->vtsk = vtsk; 688 689 vhost_task_start(vtsk); 690 691 ret = xa_alloc(&dev->worker_xa, &id, worker, xa_limit_32b, GFP_KERNEL); 692 if (ret < 0) 693 goto stop_worker; 694 worker->id = id; 695 696 return worker; 697 698 stop_worker: 699 vhost_task_stop(vtsk); 700 free_worker: 701 kfree(worker); 702 return NULL; 703 } 704 705 /* Caller must have device mutex */ 706 static void __vhost_vq_attach_worker(struct vhost_virtqueue *vq, 707 struct vhost_worker *worker) 708 { 709 struct vhost_worker *old_worker; 710 711 mutex_lock(&worker->mutex); 712 if (worker->killed) { 713 mutex_unlock(&worker->mutex); 714 return; 715 } 716 717 mutex_lock(&vq->mutex); 718 719 old_worker = rcu_dereference_check(vq->worker, 720 lockdep_is_held(&vq->mutex)); 721 rcu_assign_pointer(vq->worker, worker); 722 worker->attachment_cnt++; 723 724 if (!old_worker) { 725 mutex_unlock(&vq->mutex); 726 mutex_unlock(&worker->mutex); 727 return; 728 } 729 mutex_unlock(&vq->mutex); 730 mutex_unlock(&worker->mutex); 731 732 /* 733 * Take the worker mutex to make sure we see the work queued from 734 * device wide flushes which doesn't use RCU for execution. 735 */ 736 mutex_lock(&old_worker->mutex); 737 if (old_worker->killed) { 738 mutex_unlock(&old_worker->mutex); 739 return; 740 } 741 742 /* 743 * We don't want to call synchronize_rcu for every vq during setup 744 * because it will slow down VM startup. If we haven't done 745 * VHOST_SET_VRING_KICK and not done the driver specific 746 * SET_ENDPOINT/RUNNUNG then we can skip the sync since there will 747 * not be any works queued for scsi and net. 748 */ 749 mutex_lock(&vq->mutex); 750 if (!vhost_vq_get_backend(vq) && !vq->kick) { 751 mutex_unlock(&vq->mutex); 752 753 old_worker->attachment_cnt--; 754 mutex_unlock(&old_worker->mutex); 755 /* 756 * vsock can queue anytime after VHOST_VSOCK_SET_GUEST_CID. 757 * Warn if it adds support for multiple workers but forgets to 758 * handle the early queueing case. 759 */ 760 WARN_ON(!old_worker->attachment_cnt && 761 !llist_empty(&old_worker->work_list)); 762 return; 763 } 764 mutex_unlock(&vq->mutex); 765 766 /* Make sure new vq queue/flush/poll calls see the new worker */ 767 synchronize_rcu(); 768 /* Make sure whatever was queued gets run */ 769 __vhost_worker_flush(old_worker); 770 old_worker->attachment_cnt--; 771 mutex_unlock(&old_worker->mutex); 772 } 773 774 /* Caller must have device mutex */ 775 static int vhost_vq_attach_worker(struct vhost_virtqueue *vq, 776 struct vhost_vring_worker *info) 777 { 778 unsigned long index = info->worker_id; 779 struct vhost_dev *dev = vq->dev; 780 struct vhost_worker *worker; 781 782 if (!dev->use_worker) 783 return -EINVAL; 784 785 worker = xa_find(&dev->worker_xa, &index, UINT_MAX, XA_PRESENT); 786 if (!worker || worker->id != info->worker_id) 787 return -ENODEV; 788 789 __vhost_vq_attach_worker(vq, worker); 790 return 0; 791 } 792 793 /* Caller must have device mutex */ 794 static int vhost_new_worker(struct vhost_dev *dev, 795 struct vhost_worker_state *info) 796 { 797 struct vhost_worker *worker; 798 799 worker = vhost_worker_create(dev); 800 if (!worker) 801 return -ENOMEM; 802 803 info->worker_id = worker->id; 804 return 0; 805 } 806 807 /* Caller must have device mutex */ 808 static int vhost_free_worker(struct vhost_dev *dev, 809 struct vhost_worker_state *info) 810 { 811 unsigned long index = info->worker_id; 812 struct vhost_worker *worker; 813 814 worker = xa_find(&dev->worker_xa, &index, UINT_MAX, XA_PRESENT); 815 if (!worker || worker->id != info->worker_id) 816 return -ENODEV; 817 818 mutex_lock(&worker->mutex); 819 if (worker->attachment_cnt || worker->killed) { 820 mutex_unlock(&worker->mutex); 821 return -EBUSY; 822 } 823 /* 824 * A flush might have raced and snuck in before attachment_cnt was set 825 * to zero. Make sure flushes are flushed from the queue before 826 * freeing. 827 */ 828 __vhost_worker_flush(worker); 829 mutex_unlock(&worker->mutex); 830 831 vhost_worker_destroy(dev, worker); 832 return 0; 833 } 834 835 static int vhost_get_vq_from_user(struct vhost_dev *dev, void __user *argp, 836 struct vhost_virtqueue **vq, u32 *id) 837 { 838 u32 __user *idxp = argp; 839 u32 idx; 840 long r; 841 842 r = get_user(idx, idxp); 843 if (r < 0) 844 return r; 845 846 if (idx >= dev->nvqs) 847 return -ENOBUFS; 848 849 idx = array_index_nospec(idx, dev->nvqs); 850 851 *vq = dev->vqs[idx]; 852 *id = idx; 853 return 0; 854 } 855 856 /* Caller must have device mutex */ 857 long vhost_worker_ioctl(struct vhost_dev *dev, unsigned int ioctl, 858 void __user *argp) 859 { 860 struct vhost_vring_worker ring_worker; 861 struct vhost_worker_state state; 862 struct vhost_worker *worker; 863 struct vhost_virtqueue *vq; 864 long ret; 865 u32 idx; 866 867 if (!dev->use_worker) 868 return -EINVAL; 869 870 if (!vhost_dev_has_owner(dev)) 871 return -EINVAL; 872 873 ret = vhost_dev_check_owner(dev); 874 if (ret) 875 return ret; 876 877 switch (ioctl) { 878 /* dev worker ioctls */ 879 case VHOST_NEW_WORKER: 880 ret = vhost_new_worker(dev, &state); 881 if (!ret && copy_to_user(argp, &state, sizeof(state))) 882 ret = -EFAULT; 883 return ret; 884 case VHOST_FREE_WORKER: 885 if (copy_from_user(&state, argp, sizeof(state))) 886 return -EFAULT; 887 return vhost_free_worker(dev, &state); 888 /* vring worker ioctls */ 889 case VHOST_ATTACH_VRING_WORKER: 890 case VHOST_GET_VRING_WORKER: 891 break; 892 default: 893 return -ENOIOCTLCMD; 894 } 895 896 ret = vhost_get_vq_from_user(dev, argp, &vq, &idx); 897 if (ret) 898 return ret; 899 900 switch (ioctl) { 901 case VHOST_ATTACH_VRING_WORKER: 902 if (copy_from_user(&ring_worker, argp, sizeof(ring_worker))) { 903 ret = -EFAULT; 904 break; 905 } 906 907 ret = vhost_vq_attach_worker(vq, &ring_worker); 908 break; 909 case VHOST_GET_VRING_WORKER: 910 worker = rcu_dereference_check(vq->worker, 911 lockdep_is_held(&dev->mutex)); 912 if (!worker) { 913 ret = -EINVAL; 914 break; 915 } 916 917 ring_worker.index = idx; 918 ring_worker.worker_id = worker->id; 919 920 if (copy_to_user(argp, &ring_worker, sizeof(ring_worker))) 921 ret = -EFAULT; 922 break; 923 default: 924 ret = -ENOIOCTLCMD; 925 break; 926 } 927 928 return ret; 929 } 930 EXPORT_SYMBOL_GPL(vhost_worker_ioctl); 931 932 /* Caller should have device mutex */ 933 long vhost_dev_set_owner(struct vhost_dev *dev) 934 { 935 struct vhost_worker *worker; 936 int err, i; 937 938 /* Is there an owner already? */ 939 if (vhost_dev_has_owner(dev)) { 940 err = -EBUSY; 941 goto err_mm; 942 } 943 944 vhost_attach_mm(dev); 945 946 err = vhost_dev_alloc_iovecs(dev); 947 if (err) 948 goto err_iovecs; 949 950 if (dev->use_worker) { 951 /* 952 * This should be done last, because vsock can queue work 953 * before VHOST_SET_OWNER so it simplifies the failure path 954 * below since we don't have to worry about vsock queueing 955 * while we free the worker. 956 */ 957 worker = vhost_worker_create(dev); 958 if (!worker) { 959 err = -ENOMEM; 960 goto err_worker; 961 } 962 963 for (i = 0; i < dev->nvqs; i++) 964 __vhost_vq_attach_worker(dev->vqs[i], worker); 965 } 966 967 return 0; 968 969 err_worker: 970 vhost_dev_free_iovecs(dev); 971 err_iovecs: 972 vhost_detach_mm(dev); 973 err_mm: 974 return err; 975 } 976 EXPORT_SYMBOL_GPL(vhost_dev_set_owner); 977 978 static struct vhost_iotlb *iotlb_alloc(void) 979 { 980 return vhost_iotlb_alloc(max_iotlb_entries, 981 VHOST_IOTLB_FLAG_RETIRE); 982 } 983 984 struct vhost_iotlb *vhost_dev_reset_owner_prepare(void) 985 { 986 return iotlb_alloc(); 987 } 988 EXPORT_SYMBOL_GPL(vhost_dev_reset_owner_prepare); 989 990 /* Caller should have device mutex */ 991 void vhost_dev_reset_owner(struct vhost_dev *dev, struct vhost_iotlb *umem) 992 { 993 int i; 994 995 vhost_dev_cleanup(dev); 996 997 dev->umem = umem; 998 /* We don't need VQ locks below since vhost_dev_cleanup makes sure 999 * VQs aren't running. 1000 */ 1001 for (i = 0; i < dev->nvqs; ++i) 1002 dev->vqs[i]->umem = umem; 1003 } 1004 EXPORT_SYMBOL_GPL(vhost_dev_reset_owner); 1005 1006 void vhost_dev_stop(struct vhost_dev *dev) 1007 { 1008 int i; 1009 1010 for (i = 0; i < dev->nvqs; ++i) { 1011 if (dev->vqs[i]->kick && dev->vqs[i]->handle_kick) 1012 vhost_poll_stop(&dev->vqs[i]->poll); 1013 } 1014 1015 vhost_dev_flush(dev); 1016 } 1017 EXPORT_SYMBOL_GPL(vhost_dev_stop); 1018 1019 void vhost_clear_msg(struct vhost_dev *dev) 1020 { 1021 struct vhost_msg_node *node, *n; 1022 1023 spin_lock(&dev->iotlb_lock); 1024 1025 list_for_each_entry_safe(node, n, &dev->read_list, node) { 1026 list_del(&node->node); 1027 kfree(node); 1028 } 1029 1030 list_for_each_entry_safe(node, n, &dev->pending_list, node) { 1031 list_del(&node->node); 1032 kfree(node); 1033 } 1034 1035 spin_unlock(&dev->iotlb_lock); 1036 } 1037 EXPORT_SYMBOL_GPL(vhost_clear_msg); 1038 1039 void vhost_dev_cleanup(struct vhost_dev *dev) 1040 { 1041 int i; 1042 1043 for (i = 0; i < dev->nvqs; ++i) { 1044 if (dev->vqs[i]->error_ctx) 1045 eventfd_ctx_put(dev->vqs[i]->error_ctx); 1046 if (dev->vqs[i]->kick) 1047 fput(dev->vqs[i]->kick); 1048 if (dev->vqs[i]->call_ctx.ctx) 1049 eventfd_ctx_put(dev->vqs[i]->call_ctx.ctx); 1050 vhost_vq_reset(dev, dev->vqs[i]); 1051 } 1052 vhost_dev_free_iovecs(dev); 1053 if (dev->log_ctx) 1054 eventfd_ctx_put(dev->log_ctx); 1055 dev->log_ctx = NULL; 1056 /* No one will access memory at this point */ 1057 vhost_iotlb_free(dev->umem); 1058 dev->umem = NULL; 1059 vhost_iotlb_free(dev->iotlb); 1060 dev->iotlb = NULL; 1061 vhost_clear_msg(dev); 1062 wake_up_interruptible_poll(&dev->wait, EPOLLIN | EPOLLRDNORM); 1063 vhost_workers_free(dev); 1064 vhost_detach_mm(dev); 1065 } 1066 EXPORT_SYMBOL_GPL(vhost_dev_cleanup); 1067 1068 static bool log_access_ok(void __user *log_base, u64 addr, unsigned long sz) 1069 { 1070 u64 a = addr / VHOST_PAGE_SIZE / 8; 1071 1072 /* Make sure 64 bit math will not overflow. */ 1073 if (a > ULONG_MAX - (unsigned long)log_base || 1074 a + (unsigned long)log_base > ULONG_MAX) 1075 return false; 1076 1077 return access_ok(log_base + a, 1078 (sz + VHOST_PAGE_SIZE * 8 - 1) / VHOST_PAGE_SIZE / 8); 1079 } 1080 1081 /* Make sure 64 bit math will not overflow. */ 1082 static bool vhost_overflow(u64 uaddr, u64 size) 1083 { 1084 if (uaddr > ULONG_MAX || size > ULONG_MAX) 1085 return true; 1086 1087 if (!size) 1088 return false; 1089 1090 return uaddr > ULONG_MAX - size + 1; 1091 } 1092 1093 /* Caller should have vq mutex and device mutex. */ 1094 static bool vq_memory_access_ok(void __user *log_base, struct vhost_iotlb *umem, 1095 int log_all) 1096 { 1097 struct vhost_iotlb_map *map; 1098 1099 if (!umem) 1100 return false; 1101 1102 list_for_each_entry(map, &umem->list, link) { 1103 unsigned long a = map->addr; 1104 1105 if (vhost_overflow(map->addr, map->size)) 1106 return false; 1107 1108 1109 if (!access_ok((void __user *)a, map->size)) 1110 return false; 1111 else if (log_all && !log_access_ok(log_base, 1112 map->start, 1113 map->size)) 1114 return false; 1115 } 1116 return true; 1117 } 1118 1119 static inline void __user *vhost_vq_meta_fetch(struct vhost_virtqueue *vq, 1120 u64 addr, unsigned int size, 1121 int type) 1122 { 1123 const struct vhost_iotlb_map *map = vq->meta_iotlb[type]; 1124 1125 if (!map) 1126 return NULL; 1127 1128 return (void __user *)(uintptr_t)(map->addr + addr - map->start); 1129 } 1130 1131 /* Can we switch to this memory table? */ 1132 /* Caller should have device mutex but not vq mutex */ 1133 static bool memory_access_ok(struct vhost_dev *d, struct vhost_iotlb *umem, 1134 int log_all) 1135 { 1136 int i; 1137 1138 for (i = 0; i < d->nvqs; ++i) { 1139 bool ok; 1140 bool log; 1141 1142 mutex_lock(&d->vqs[i]->mutex); 1143 log = log_all || vhost_has_feature(d->vqs[i], VHOST_F_LOG_ALL); 1144 /* If ring is inactive, will check when it's enabled. */ 1145 if (d->vqs[i]->private_data) 1146 ok = vq_memory_access_ok(d->vqs[i]->log_base, 1147 umem, log); 1148 else 1149 ok = true; 1150 mutex_unlock(&d->vqs[i]->mutex); 1151 if (!ok) 1152 return false; 1153 } 1154 return true; 1155 } 1156 1157 static int translate_desc(struct vhost_virtqueue *vq, u64 addr, u32 len, 1158 struct iovec iov[], int iov_size, int access); 1159 1160 static int vhost_copy_to_user(struct vhost_virtqueue *vq, void __user *to, 1161 const void *from, unsigned size) 1162 { 1163 int ret; 1164 1165 if (!vq->iotlb) 1166 return __copy_to_user(to, from, size); 1167 else { 1168 /* This function should be called after iotlb 1169 * prefetch, which means we're sure that all vq 1170 * could be access through iotlb. So -EAGAIN should 1171 * not happen in this case. 1172 */ 1173 struct iov_iter t; 1174 void __user *uaddr = vhost_vq_meta_fetch(vq, 1175 (u64)(uintptr_t)to, size, 1176 VHOST_ADDR_USED); 1177 1178 if (uaddr) 1179 return __copy_to_user(uaddr, from, size); 1180 1181 ret = translate_desc(vq, (u64)(uintptr_t)to, size, vq->iotlb_iov, 1182 ARRAY_SIZE(vq->iotlb_iov), 1183 VHOST_ACCESS_WO); 1184 if (ret < 0) 1185 goto out; 1186 iov_iter_init(&t, ITER_DEST, vq->iotlb_iov, ret, size); 1187 ret = copy_to_iter(from, size, &t); 1188 if (ret == size) 1189 ret = 0; 1190 } 1191 out: 1192 return ret; 1193 } 1194 1195 static int vhost_copy_from_user(struct vhost_virtqueue *vq, void *to, 1196 void __user *from, unsigned size) 1197 { 1198 int ret; 1199 1200 if (!vq->iotlb) 1201 return __copy_from_user(to, from, size); 1202 else { 1203 /* This function should be called after iotlb 1204 * prefetch, which means we're sure that vq 1205 * could be access through iotlb. So -EAGAIN should 1206 * not happen in this case. 1207 */ 1208 void __user *uaddr = vhost_vq_meta_fetch(vq, 1209 (u64)(uintptr_t)from, size, 1210 VHOST_ADDR_DESC); 1211 struct iov_iter f; 1212 1213 if (uaddr) 1214 return __copy_from_user(to, uaddr, size); 1215 1216 ret = translate_desc(vq, (u64)(uintptr_t)from, size, vq->iotlb_iov, 1217 ARRAY_SIZE(vq->iotlb_iov), 1218 VHOST_ACCESS_RO); 1219 if (ret < 0) { 1220 vq_err(vq, "IOTLB translation failure: uaddr " 1221 "%p size 0x%llx\n", from, 1222 (unsigned long long) size); 1223 goto out; 1224 } 1225 iov_iter_init(&f, ITER_SOURCE, vq->iotlb_iov, ret, size); 1226 ret = copy_from_iter(to, size, &f); 1227 if (ret == size) 1228 ret = 0; 1229 } 1230 1231 out: 1232 return ret; 1233 } 1234 1235 static void __user *__vhost_get_user_slow(struct vhost_virtqueue *vq, 1236 void __user *addr, unsigned int size, 1237 int type) 1238 { 1239 int ret; 1240 1241 ret = translate_desc(vq, (u64)(uintptr_t)addr, size, vq->iotlb_iov, 1242 ARRAY_SIZE(vq->iotlb_iov), 1243 VHOST_ACCESS_RO); 1244 if (ret < 0) { 1245 vq_err(vq, "IOTLB translation failure: uaddr " 1246 "%p size 0x%llx\n", addr, 1247 (unsigned long long) size); 1248 return NULL; 1249 } 1250 1251 if (ret != 1 || vq->iotlb_iov[0].iov_len != size) { 1252 vq_err(vq, "Non atomic userspace memory access: uaddr " 1253 "%p size 0x%llx\n", addr, 1254 (unsigned long long) size); 1255 return NULL; 1256 } 1257 1258 return vq->iotlb_iov[0].iov_base; 1259 } 1260 1261 /* This function should be called after iotlb 1262 * prefetch, which means we're sure that vq 1263 * could be access through iotlb. So -EAGAIN should 1264 * not happen in this case. 1265 */ 1266 static inline void __user *__vhost_get_user(struct vhost_virtqueue *vq, 1267 void __user *addr, unsigned int size, 1268 int type) 1269 { 1270 void __user *uaddr = vhost_vq_meta_fetch(vq, 1271 (u64)(uintptr_t)addr, size, type); 1272 if (uaddr) 1273 return uaddr; 1274 1275 return __vhost_get_user_slow(vq, addr, size, type); 1276 } 1277 1278 #define vhost_put_user(vq, x, ptr) \ 1279 ({ \ 1280 int ret; \ 1281 if (!vq->iotlb) { \ 1282 ret = __put_user(x, ptr); \ 1283 } else { \ 1284 __typeof__(ptr) to = \ 1285 (__typeof__(ptr)) __vhost_get_user(vq, ptr, \ 1286 sizeof(*ptr), VHOST_ADDR_USED); \ 1287 if (to != NULL) \ 1288 ret = __put_user(x, to); \ 1289 else \ 1290 ret = -EFAULT; \ 1291 } \ 1292 ret; \ 1293 }) 1294 1295 static inline int vhost_put_avail_event(struct vhost_virtqueue *vq) 1296 { 1297 return vhost_put_user(vq, cpu_to_vhost16(vq, vq->avail_idx), 1298 vhost_avail_event(vq)); 1299 } 1300 1301 static inline int vhost_put_used(struct vhost_virtqueue *vq, 1302 struct vring_used_elem *head, int idx, 1303 int count) 1304 { 1305 return vhost_copy_to_user(vq, vq->used->ring + idx, head, 1306 count * sizeof(*head)); 1307 } 1308 1309 static inline int vhost_put_used_flags(struct vhost_virtqueue *vq) 1310 1311 { 1312 return vhost_put_user(vq, cpu_to_vhost16(vq, vq->used_flags), 1313 &vq->used->flags); 1314 } 1315 1316 static inline int vhost_put_used_idx(struct vhost_virtqueue *vq) 1317 1318 { 1319 return vhost_put_user(vq, cpu_to_vhost16(vq, vq->last_used_idx), 1320 &vq->used->idx); 1321 } 1322 1323 #define vhost_get_user(vq, x, ptr, type) \ 1324 ({ \ 1325 int ret; \ 1326 if (!vq->iotlb) { \ 1327 ret = __get_user(x, ptr); \ 1328 } else { \ 1329 __typeof__(ptr) from = \ 1330 (__typeof__(ptr)) __vhost_get_user(vq, ptr, \ 1331 sizeof(*ptr), \ 1332 type); \ 1333 if (from != NULL) \ 1334 ret = __get_user(x, from); \ 1335 else \ 1336 ret = -EFAULT; \ 1337 } \ 1338 ret; \ 1339 }) 1340 1341 #define vhost_get_avail(vq, x, ptr) \ 1342 vhost_get_user(vq, x, ptr, VHOST_ADDR_AVAIL) 1343 1344 #define vhost_get_used(vq, x, ptr) \ 1345 vhost_get_user(vq, x, ptr, VHOST_ADDR_USED) 1346 1347 static void vhost_dev_lock_vqs(struct vhost_dev *d) 1348 { 1349 int i = 0; 1350 for (i = 0; i < d->nvqs; ++i) 1351 mutex_lock_nested(&d->vqs[i]->mutex, i); 1352 } 1353 1354 static void vhost_dev_unlock_vqs(struct vhost_dev *d) 1355 { 1356 int i = 0; 1357 for (i = 0; i < d->nvqs; ++i) 1358 mutex_unlock(&d->vqs[i]->mutex); 1359 } 1360 1361 static inline int vhost_get_avail_idx(struct vhost_virtqueue *vq, 1362 __virtio16 *idx) 1363 { 1364 return vhost_get_avail(vq, *idx, &vq->avail->idx); 1365 } 1366 1367 static inline int vhost_get_avail_head(struct vhost_virtqueue *vq, 1368 __virtio16 *head, int idx) 1369 { 1370 return vhost_get_avail(vq, *head, 1371 &vq->avail->ring[idx & (vq->num - 1)]); 1372 } 1373 1374 static inline int vhost_get_avail_flags(struct vhost_virtqueue *vq, 1375 __virtio16 *flags) 1376 { 1377 return vhost_get_avail(vq, *flags, &vq->avail->flags); 1378 } 1379 1380 static inline int vhost_get_used_event(struct vhost_virtqueue *vq, 1381 __virtio16 *event) 1382 { 1383 return vhost_get_avail(vq, *event, vhost_used_event(vq)); 1384 } 1385 1386 static inline int vhost_get_used_idx(struct vhost_virtqueue *vq, 1387 __virtio16 *idx) 1388 { 1389 return vhost_get_used(vq, *idx, &vq->used->idx); 1390 } 1391 1392 static inline int vhost_get_desc(struct vhost_virtqueue *vq, 1393 struct vring_desc *desc, int idx) 1394 { 1395 return vhost_copy_from_user(vq, desc, vq->desc + idx, sizeof(*desc)); 1396 } 1397 1398 static void vhost_iotlb_notify_vq(struct vhost_dev *d, 1399 struct vhost_iotlb_msg *msg) 1400 { 1401 struct vhost_msg_node *node, *n; 1402 1403 spin_lock(&d->iotlb_lock); 1404 1405 list_for_each_entry_safe(node, n, &d->pending_list, node) { 1406 struct vhost_iotlb_msg *vq_msg = &node->msg.iotlb; 1407 if (msg->iova <= vq_msg->iova && 1408 msg->iova + msg->size - 1 >= vq_msg->iova && 1409 vq_msg->type == VHOST_IOTLB_MISS) { 1410 vhost_poll_queue(&node->vq->poll); 1411 list_del(&node->node); 1412 kfree(node); 1413 } 1414 } 1415 1416 spin_unlock(&d->iotlb_lock); 1417 } 1418 1419 static bool umem_access_ok(u64 uaddr, u64 size, int access) 1420 { 1421 unsigned long a = uaddr; 1422 1423 /* Make sure 64 bit math will not overflow. */ 1424 if (vhost_overflow(uaddr, size)) 1425 return false; 1426 1427 if ((access & VHOST_ACCESS_RO) && 1428 !access_ok((void __user *)a, size)) 1429 return false; 1430 if ((access & VHOST_ACCESS_WO) && 1431 !access_ok((void __user *)a, size)) 1432 return false; 1433 return true; 1434 } 1435 1436 static int vhost_process_iotlb_msg(struct vhost_dev *dev, u32 asid, 1437 struct vhost_iotlb_msg *msg) 1438 { 1439 int ret = 0; 1440 1441 if (asid != 0) 1442 return -EINVAL; 1443 1444 mutex_lock(&dev->mutex); 1445 vhost_dev_lock_vqs(dev); 1446 switch (msg->type) { 1447 case VHOST_IOTLB_UPDATE: 1448 if (!dev->iotlb) { 1449 ret = -EFAULT; 1450 break; 1451 } 1452 if (!umem_access_ok(msg->uaddr, msg->size, msg->perm)) { 1453 ret = -EFAULT; 1454 break; 1455 } 1456 vhost_vq_meta_reset(dev); 1457 if (vhost_iotlb_add_range(dev->iotlb, msg->iova, 1458 msg->iova + msg->size - 1, 1459 msg->uaddr, msg->perm)) { 1460 ret = -ENOMEM; 1461 break; 1462 } 1463 vhost_iotlb_notify_vq(dev, msg); 1464 break; 1465 case VHOST_IOTLB_INVALIDATE: 1466 if (!dev->iotlb) { 1467 ret = -EFAULT; 1468 break; 1469 } 1470 vhost_vq_meta_reset(dev); 1471 vhost_iotlb_del_range(dev->iotlb, msg->iova, 1472 msg->iova + msg->size - 1); 1473 break; 1474 default: 1475 ret = -EINVAL; 1476 break; 1477 } 1478 1479 vhost_dev_unlock_vqs(dev); 1480 mutex_unlock(&dev->mutex); 1481 1482 return ret; 1483 } 1484 ssize_t vhost_chr_write_iter(struct vhost_dev *dev, 1485 struct iov_iter *from) 1486 { 1487 struct vhost_iotlb_msg msg; 1488 size_t offset; 1489 int type, ret; 1490 u32 asid = 0; 1491 1492 ret = copy_from_iter(&type, sizeof(type), from); 1493 if (ret != sizeof(type)) { 1494 ret = -EINVAL; 1495 goto done; 1496 } 1497 1498 switch (type) { 1499 case VHOST_IOTLB_MSG: 1500 /* There maybe a hole after type for V1 message type, 1501 * so skip it here. 1502 */ 1503 offset = offsetof(struct vhost_msg, iotlb) - sizeof(int); 1504 break; 1505 case VHOST_IOTLB_MSG_V2: 1506 if (vhost_backend_has_feature(dev->vqs[0], 1507 VHOST_BACKEND_F_IOTLB_ASID)) { 1508 ret = copy_from_iter(&asid, sizeof(asid), from); 1509 if (ret != sizeof(asid)) { 1510 ret = -EINVAL; 1511 goto done; 1512 } 1513 offset = 0; 1514 } else 1515 offset = sizeof(__u32); 1516 break; 1517 default: 1518 ret = -EINVAL; 1519 goto done; 1520 } 1521 1522 iov_iter_advance(from, offset); 1523 ret = copy_from_iter(&msg, sizeof(msg), from); 1524 if (ret != sizeof(msg)) { 1525 ret = -EINVAL; 1526 goto done; 1527 } 1528 1529 if (msg.type == VHOST_IOTLB_UPDATE && msg.size == 0) { 1530 ret = -EINVAL; 1531 goto done; 1532 } 1533 1534 if (dev->msg_handler) 1535 ret = dev->msg_handler(dev, asid, &msg); 1536 else 1537 ret = vhost_process_iotlb_msg(dev, asid, &msg); 1538 if (ret) { 1539 ret = -EFAULT; 1540 goto done; 1541 } 1542 1543 ret = (type == VHOST_IOTLB_MSG) ? sizeof(struct vhost_msg) : 1544 sizeof(struct vhost_msg_v2); 1545 done: 1546 return ret; 1547 } 1548 EXPORT_SYMBOL(vhost_chr_write_iter); 1549 1550 __poll_t vhost_chr_poll(struct file *file, struct vhost_dev *dev, 1551 poll_table *wait) 1552 { 1553 __poll_t mask = 0; 1554 1555 poll_wait(file, &dev->wait, wait); 1556 1557 if (!list_empty(&dev->read_list)) 1558 mask |= EPOLLIN | EPOLLRDNORM; 1559 1560 return mask; 1561 } 1562 EXPORT_SYMBOL(vhost_chr_poll); 1563 1564 ssize_t vhost_chr_read_iter(struct vhost_dev *dev, struct iov_iter *to, 1565 int noblock) 1566 { 1567 DEFINE_WAIT(wait); 1568 struct vhost_msg_node *node; 1569 ssize_t ret = 0; 1570 unsigned size = sizeof(struct vhost_msg); 1571 1572 if (iov_iter_count(to) < size) 1573 return 0; 1574 1575 while (1) { 1576 if (!noblock) 1577 prepare_to_wait(&dev->wait, &wait, 1578 TASK_INTERRUPTIBLE); 1579 1580 node = vhost_dequeue_msg(dev, &dev->read_list); 1581 if (node) 1582 break; 1583 if (noblock) { 1584 ret = -EAGAIN; 1585 break; 1586 } 1587 if (signal_pending(current)) { 1588 ret = -ERESTARTSYS; 1589 break; 1590 } 1591 if (!dev->iotlb) { 1592 ret = -EBADFD; 1593 break; 1594 } 1595 1596 schedule(); 1597 } 1598 1599 if (!noblock) 1600 finish_wait(&dev->wait, &wait); 1601 1602 if (node) { 1603 struct vhost_iotlb_msg *msg; 1604 void *start = &node->msg; 1605 1606 switch (node->msg.type) { 1607 case VHOST_IOTLB_MSG: 1608 size = sizeof(node->msg); 1609 msg = &node->msg.iotlb; 1610 break; 1611 case VHOST_IOTLB_MSG_V2: 1612 size = sizeof(node->msg_v2); 1613 msg = &node->msg_v2.iotlb; 1614 break; 1615 default: 1616 BUG(); 1617 break; 1618 } 1619 1620 ret = copy_to_iter(start, size, to); 1621 if (ret != size || msg->type != VHOST_IOTLB_MISS) { 1622 kfree(node); 1623 return ret; 1624 } 1625 vhost_enqueue_msg(dev, &dev->pending_list, node); 1626 } 1627 1628 return ret; 1629 } 1630 EXPORT_SYMBOL_GPL(vhost_chr_read_iter); 1631 1632 static int vhost_iotlb_miss(struct vhost_virtqueue *vq, u64 iova, int access) 1633 { 1634 struct vhost_dev *dev = vq->dev; 1635 struct vhost_msg_node *node; 1636 struct vhost_iotlb_msg *msg; 1637 bool v2 = vhost_backend_has_feature(vq, VHOST_BACKEND_F_IOTLB_MSG_V2); 1638 1639 node = vhost_new_msg(vq, v2 ? VHOST_IOTLB_MSG_V2 : VHOST_IOTLB_MSG); 1640 if (!node) 1641 return -ENOMEM; 1642 1643 if (v2) { 1644 node->msg_v2.type = VHOST_IOTLB_MSG_V2; 1645 msg = &node->msg_v2.iotlb; 1646 } else { 1647 msg = &node->msg.iotlb; 1648 } 1649 1650 msg->type = VHOST_IOTLB_MISS; 1651 msg->iova = iova; 1652 msg->perm = access; 1653 1654 vhost_enqueue_msg(dev, &dev->read_list, node); 1655 1656 return 0; 1657 } 1658 1659 static bool vq_access_ok(struct vhost_virtqueue *vq, unsigned int num, 1660 vring_desc_t __user *desc, 1661 vring_avail_t __user *avail, 1662 vring_used_t __user *used) 1663 1664 { 1665 /* If an IOTLB device is present, the vring addresses are 1666 * GIOVAs. Access validation occurs at prefetch time. */ 1667 if (vq->iotlb) 1668 return true; 1669 1670 return access_ok(desc, vhost_get_desc_size(vq, num)) && 1671 access_ok(avail, vhost_get_avail_size(vq, num)) && 1672 access_ok(used, vhost_get_used_size(vq, num)); 1673 } 1674 1675 static void vhost_vq_meta_update(struct vhost_virtqueue *vq, 1676 const struct vhost_iotlb_map *map, 1677 int type) 1678 { 1679 int access = (type == VHOST_ADDR_USED) ? 1680 VHOST_ACCESS_WO : VHOST_ACCESS_RO; 1681 1682 if (likely(map->perm & access)) 1683 vq->meta_iotlb[type] = map; 1684 } 1685 1686 static bool iotlb_access_ok(struct vhost_virtqueue *vq, 1687 int access, u64 addr, u64 len, int type) 1688 { 1689 const struct vhost_iotlb_map *map; 1690 struct vhost_iotlb *umem = vq->iotlb; 1691 u64 s = 0, size, orig_addr = addr, last = addr + len - 1; 1692 1693 if (vhost_vq_meta_fetch(vq, addr, len, type)) 1694 return true; 1695 1696 while (len > s) { 1697 map = vhost_iotlb_itree_first(umem, addr, last); 1698 if (map == NULL || map->start > addr) { 1699 vhost_iotlb_miss(vq, addr, access); 1700 return false; 1701 } else if (!(map->perm & access)) { 1702 /* Report the possible access violation by 1703 * request another translation from userspace. 1704 */ 1705 return false; 1706 } 1707 1708 size = map->size - addr + map->start; 1709 1710 if (orig_addr == addr && size >= len) 1711 vhost_vq_meta_update(vq, map, type); 1712 1713 s += size; 1714 addr += size; 1715 } 1716 1717 return true; 1718 } 1719 1720 int vq_meta_prefetch(struct vhost_virtqueue *vq) 1721 { 1722 unsigned int num = vq->num; 1723 1724 if (!vq->iotlb) 1725 return 1; 1726 1727 return iotlb_access_ok(vq, VHOST_MAP_RO, (u64)(uintptr_t)vq->desc, 1728 vhost_get_desc_size(vq, num), VHOST_ADDR_DESC) && 1729 iotlb_access_ok(vq, VHOST_MAP_RO, (u64)(uintptr_t)vq->avail, 1730 vhost_get_avail_size(vq, num), 1731 VHOST_ADDR_AVAIL) && 1732 iotlb_access_ok(vq, VHOST_MAP_WO, (u64)(uintptr_t)vq->used, 1733 vhost_get_used_size(vq, num), VHOST_ADDR_USED); 1734 } 1735 EXPORT_SYMBOL_GPL(vq_meta_prefetch); 1736 1737 /* Can we log writes? */ 1738 /* Caller should have device mutex but not vq mutex */ 1739 bool vhost_log_access_ok(struct vhost_dev *dev) 1740 { 1741 return memory_access_ok(dev, dev->umem, 1); 1742 } 1743 EXPORT_SYMBOL_GPL(vhost_log_access_ok); 1744 1745 static bool vq_log_used_access_ok(struct vhost_virtqueue *vq, 1746 void __user *log_base, 1747 bool log_used, 1748 u64 log_addr) 1749 { 1750 /* If an IOTLB device is present, log_addr is a GIOVA that 1751 * will never be logged by log_used(). */ 1752 if (vq->iotlb) 1753 return true; 1754 1755 return !log_used || log_access_ok(log_base, log_addr, 1756 vhost_get_used_size(vq, vq->num)); 1757 } 1758 1759 /* Verify access for write logging. */ 1760 /* Caller should have vq mutex and device mutex */ 1761 static bool vq_log_access_ok(struct vhost_virtqueue *vq, 1762 void __user *log_base) 1763 { 1764 return vq_memory_access_ok(log_base, vq->umem, 1765 vhost_has_feature(vq, VHOST_F_LOG_ALL)) && 1766 vq_log_used_access_ok(vq, log_base, vq->log_used, vq->log_addr); 1767 } 1768 1769 /* Can we start vq? */ 1770 /* Caller should have vq mutex and device mutex */ 1771 bool vhost_vq_access_ok(struct vhost_virtqueue *vq) 1772 { 1773 if (!vq_log_access_ok(vq, vq->log_base)) 1774 return false; 1775 1776 return vq_access_ok(vq, vq->num, vq->desc, vq->avail, vq->used); 1777 } 1778 EXPORT_SYMBOL_GPL(vhost_vq_access_ok); 1779 1780 static long vhost_set_memory(struct vhost_dev *d, struct vhost_memory __user *m) 1781 { 1782 struct vhost_memory mem, *newmem; 1783 struct vhost_memory_region *region; 1784 struct vhost_iotlb *newumem, *oldumem; 1785 unsigned long size = offsetof(struct vhost_memory, regions); 1786 int i; 1787 1788 if (copy_from_user(&mem, m, size)) 1789 return -EFAULT; 1790 if (mem.padding) 1791 return -EOPNOTSUPP; 1792 if (mem.nregions > max_mem_regions) 1793 return -E2BIG; 1794 newmem = kvzalloc(struct_size(newmem, regions, mem.nregions), 1795 GFP_KERNEL); 1796 if (!newmem) 1797 return -ENOMEM; 1798 1799 memcpy(newmem, &mem, size); 1800 if (copy_from_user(newmem->regions, m->regions, 1801 flex_array_size(newmem, regions, mem.nregions))) { 1802 kvfree(newmem); 1803 return -EFAULT; 1804 } 1805 1806 newumem = iotlb_alloc(); 1807 if (!newumem) { 1808 kvfree(newmem); 1809 return -ENOMEM; 1810 } 1811 1812 for (region = newmem->regions; 1813 region < newmem->regions + mem.nregions; 1814 region++) { 1815 if (vhost_iotlb_add_range(newumem, 1816 region->guest_phys_addr, 1817 region->guest_phys_addr + 1818 region->memory_size - 1, 1819 region->userspace_addr, 1820 VHOST_MAP_RW)) 1821 goto err; 1822 } 1823 1824 if (!memory_access_ok(d, newumem, 0)) 1825 goto err; 1826 1827 oldumem = d->umem; 1828 d->umem = newumem; 1829 1830 /* All memory accesses are done under some VQ mutex. */ 1831 for (i = 0; i < d->nvqs; ++i) { 1832 mutex_lock(&d->vqs[i]->mutex); 1833 d->vqs[i]->umem = newumem; 1834 mutex_unlock(&d->vqs[i]->mutex); 1835 } 1836 1837 kvfree(newmem); 1838 vhost_iotlb_free(oldumem); 1839 return 0; 1840 1841 err: 1842 vhost_iotlb_free(newumem); 1843 kvfree(newmem); 1844 return -EFAULT; 1845 } 1846 1847 static long vhost_vring_set_num(struct vhost_dev *d, 1848 struct vhost_virtqueue *vq, 1849 void __user *argp) 1850 { 1851 struct vhost_vring_state s; 1852 1853 /* Resizing ring with an active backend? 1854 * You don't want to do that. */ 1855 if (vq->private_data) 1856 return -EBUSY; 1857 1858 if (copy_from_user(&s, argp, sizeof s)) 1859 return -EFAULT; 1860 1861 if (!s.num || s.num > 0xffff || (s.num & (s.num - 1))) 1862 return -EINVAL; 1863 vq->num = s.num; 1864 1865 return 0; 1866 } 1867 1868 static long vhost_vring_set_addr(struct vhost_dev *d, 1869 struct vhost_virtqueue *vq, 1870 void __user *argp) 1871 { 1872 struct vhost_vring_addr a; 1873 1874 if (copy_from_user(&a, argp, sizeof a)) 1875 return -EFAULT; 1876 if (a.flags & ~(0x1 << VHOST_VRING_F_LOG)) 1877 return -EOPNOTSUPP; 1878 1879 /* For 32bit, verify that the top 32bits of the user 1880 data are set to zero. */ 1881 if ((u64)(unsigned long)a.desc_user_addr != a.desc_user_addr || 1882 (u64)(unsigned long)a.used_user_addr != a.used_user_addr || 1883 (u64)(unsigned long)a.avail_user_addr != a.avail_user_addr) 1884 return -EFAULT; 1885 1886 /* Make sure it's safe to cast pointers to vring types. */ 1887 BUILD_BUG_ON(__alignof__ *vq->avail > VRING_AVAIL_ALIGN_SIZE); 1888 BUILD_BUG_ON(__alignof__ *vq->used > VRING_USED_ALIGN_SIZE); 1889 if ((a.avail_user_addr & (VRING_AVAIL_ALIGN_SIZE - 1)) || 1890 (a.used_user_addr & (VRING_USED_ALIGN_SIZE - 1)) || 1891 (a.log_guest_addr & (VRING_USED_ALIGN_SIZE - 1))) 1892 return -EINVAL; 1893 1894 /* We only verify access here if backend is configured. 1895 * If it is not, we don't as size might not have been setup. 1896 * We will verify when backend is configured. */ 1897 if (vq->private_data) { 1898 if (!vq_access_ok(vq, vq->num, 1899 (void __user *)(unsigned long)a.desc_user_addr, 1900 (void __user *)(unsigned long)a.avail_user_addr, 1901 (void __user *)(unsigned long)a.used_user_addr)) 1902 return -EINVAL; 1903 1904 /* Also validate log access for used ring if enabled. */ 1905 if (!vq_log_used_access_ok(vq, vq->log_base, 1906 a.flags & (0x1 << VHOST_VRING_F_LOG), 1907 a.log_guest_addr)) 1908 return -EINVAL; 1909 } 1910 1911 vq->log_used = !!(a.flags & (0x1 << VHOST_VRING_F_LOG)); 1912 vq->desc = (void __user *)(unsigned long)a.desc_user_addr; 1913 vq->avail = (void __user *)(unsigned long)a.avail_user_addr; 1914 vq->log_addr = a.log_guest_addr; 1915 vq->used = (void __user *)(unsigned long)a.used_user_addr; 1916 1917 return 0; 1918 } 1919 1920 static long vhost_vring_set_num_addr(struct vhost_dev *d, 1921 struct vhost_virtqueue *vq, 1922 unsigned int ioctl, 1923 void __user *argp) 1924 { 1925 long r; 1926 1927 mutex_lock(&vq->mutex); 1928 1929 switch (ioctl) { 1930 case VHOST_SET_VRING_NUM: 1931 r = vhost_vring_set_num(d, vq, argp); 1932 break; 1933 case VHOST_SET_VRING_ADDR: 1934 r = vhost_vring_set_addr(d, vq, argp); 1935 break; 1936 default: 1937 BUG(); 1938 } 1939 1940 mutex_unlock(&vq->mutex); 1941 1942 return r; 1943 } 1944 long vhost_vring_ioctl(struct vhost_dev *d, unsigned int ioctl, void __user *argp) 1945 { 1946 struct file *eventfp, *filep = NULL; 1947 bool pollstart = false, pollstop = false; 1948 struct eventfd_ctx *ctx = NULL; 1949 struct vhost_virtqueue *vq; 1950 struct vhost_vring_state s; 1951 struct vhost_vring_file f; 1952 u32 idx; 1953 long r; 1954 1955 r = vhost_get_vq_from_user(d, argp, &vq, &idx); 1956 if (r < 0) 1957 return r; 1958 1959 if (ioctl == VHOST_SET_VRING_NUM || 1960 ioctl == VHOST_SET_VRING_ADDR) { 1961 return vhost_vring_set_num_addr(d, vq, ioctl, argp); 1962 } 1963 1964 mutex_lock(&vq->mutex); 1965 1966 switch (ioctl) { 1967 case VHOST_SET_VRING_BASE: 1968 /* Moving base with an active backend? 1969 * You don't want to do that. */ 1970 if (vq->private_data) { 1971 r = -EBUSY; 1972 break; 1973 } 1974 if (copy_from_user(&s, argp, sizeof s)) { 1975 r = -EFAULT; 1976 break; 1977 } 1978 if (vhost_has_feature(vq, VIRTIO_F_RING_PACKED)) { 1979 vq->last_avail_idx = s.num & 0xffff; 1980 vq->last_used_idx = (s.num >> 16) & 0xffff; 1981 } else { 1982 if (s.num > 0xffff) { 1983 r = -EINVAL; 1984 break; 1985 } 1986 vq->last_avail_idx = s.num; 1987 } 1988 /* Forget the cached index value. */ 1989 vq->avail_idx = vq->last_avail_idx; 1990 break; 1991 case VHOST_GET_VRING_BASE: 1992 s.index = idx; 1993 if (vhost_has_feature(vq, VIRTIO_F_RING_PACKED)) 1994 s.num = (u32)vq->last_avail_idx | ((u32)vq->last_used_idx << 16); 1995 else 1996 s.num = vq->last_avail_idx; 1997 if (copy_to_user(argp, &s, sizeof s)) 1998 r = -EFAULT; 1999 break; 2000 case VHOST_SET_VRING_KICK: 2001 if (copy_from_user(&f, argp, sizeof f)) { 2002 r = -EFAULT; 2003 break; 2004 } 2005 eventfp = f.fd == VHOST_FILE_UNBIND ? NULL : eventfd_fget(f.fd); 2006 if (IS_ERR(eventfp)) { 2007 r = PTR_ERR(eventfp); 2008 break; 2009 } 2010 if (eventfp != vq->kick) { 2011 pollstop = (filep = vq->kick) != NULL; 2012 pollstart = (vq->kick = eventfp) != NULL; 2013 } else 2014 filep = eventfp; 2015 break; 2016 case VHOST_SET_VRING_CALL: 2017 if (copy_from_user(&f, argp, sizeof f)) { 2018 r = -EFAULT; 2019 break; 2020 } 2021 ctx = f.fd == VHOST_FILE_UNBIND ? NULL : eventfd_ctx_fdget(f.fd); 2022 if (IS_ERR(ctx)) { 2023 r = PTR_ERR(ctx); 2024 break; 2025 } 2026 2027 swap(ctx, vq->call_ctx.ctx); 2028 break; 2029 case VHOST_SET_VRING_ERR: 2030 if (copy_from_user(&f, argp, sizeof f)) { 2031 r = -EFAULT; 2032 break; 2033 } 2034 ctx = f.fd == VHOST_FILE_UNBIND ? NULL : eventfd_ctx_fdget(f.fd); 2035 if (IS_ERR(ctx)) { 2036 r = PTR_ERR(ctx); 2037 break; 2038 } 2039 swap(ctx, vq->error_ctx); 2040 break; 2041 case VHOST_SET_VRING_ENDIAN: 2042 r = vhost_set_vring_endian(vq, argp); 2043 break; 2044 case VHOST_GET_VRING_ENDIAN: 2045 r = vhost_get_vring_endian(vq, idx, argp); 2046 break; 2047 case VHOST_SET_VRING_BUSYLOOP_TIMEOUT: 2048 if (copy_from_user(&s, argp, sizeof(s))) { 2049 r = -EFAULT; 2050 break; 2051 } 2052 vq->busyloop_timeout = s.num; 2053 break; 2054 case VHOST_GET_VRING_BUSYLOOP_TIMEOUT: 2055 s.index = idx; 2056 s.num = vq->busyloop_timeout; 2057 if (copy_to_user(argp, &s, sizeof(s))) 2058 r = -EFAULT; 2059 break; 2060 default: 2061 r = -ENOIOCTLCMD; 2062 } 2063 2064 if (pollstop && vq->handle_kick) 2065 vhost_poll_stop(&vq->poll); 2066 2067 if (!IS_ERR_OR_NULL(ctx)) 2068 eventfd_ctx_put(ctx); 2069 if (filep) 2070 fput(filep); 2071 2072 if (pollstart && vq->handle_kick) 2073 r = vhost_poll_start(&vq->poll, vq->kick); 2074 2075 mutex_unlock(&vq->mutex); 2076 2077 if (pollstop && vq->handle_kick) 2078 vhost_dev_flush(vq->poll.dev); 2079 return r; 2080 } 2081 EXPORT_SYMBOL_GPL(vhost_vring_ioctl); 2082 2083 int vhost_init_device_iotlb(struct vhost_dev *d) 2084 { 2085 struct vhost_iotlb *niotlb, *oiotlb; 2086 int i; 2087 2088 niotlb = iotlb_alloc(); 2089 if (!niotlb) 2090 return -ENOMEM; 2091 2092 oiotlb = d->iotlb; 2093 d->iotlb = niotlb; 2094 2095 for (i = 0; i < d->nvqs; ++i) { 2096 struct vhost_virtqueue *vq = d->vqs[i]; 2097 2098 mutex_lock(&vq->mutex); 2099 vq->iotlb = niotlb; 2100 __vhost_vq_meta_reset(vq); 2101 mutex_unlock(&vq->mutex); 2102 } 2103 2104 vhost_iotlb_free(oiotlb); 2105 2106 return 0; 2107 } 2108 EXPORT_SYMBOL_GPL(vhost_init_device_iotlb); 2109 2110 /* Caller must have device mutex */ 2111 long vhost_dev_ioctl(struct vhost_dev *d, unsigned int ioctl, void __user *argp) 2112 { 2113 struct eventfd_ctx *ctx; 2114 u64 p; 2115 long r; 2116 int i, fd; 2117 2118 /* If you are not the owner, you can become one */ 2119 if (ioctl == VHOST_SET_OWNER) { 2120 r = vhost_dev_set_owner(d); 2121 goto done; 2122 } 2123 2124 /* You must be the owner to do anything else */ 2125 r = vhost_dev_check_owner(d); 2126 if (r) 2127 goto done; 2128 2129 switch (ioctl) { 2130 case VHOST_SET_MEM_TABLE: 2131 r = vhost_set_memory(d, argp); 2132 break; 2133 case VHOST_SET_LOG_BASE: 2134 if (copy_from_user(&p, argp, sizeof p)) { 2135 r = -EFAULT; 2136 break; 2137 } 2138 if ((u64)(unsigned long)p != p) { 2139 r = -EFAULT; 2140 break; 2141 } 2142 for (i = 0; i < d->nvqs; ++i) { 2143 struct vhost_virtqueue *vq; 2144 void __user *base = (void __user *)(unsigned long)p; 2145 vq = d->vqs[i]; 2146 mutex_lock(&vq->mutex); 2147 /* If ring is inactive, will check when it's enabled. */ 2148 if (vq->private_data && !vq_log_access_ok(vq, base)) 2149 r = -EFAULT; 2150 else 2151 vq->log_base = base; 2152 mutex_unlock(&vq->mutex); 2153 } 2154 break; 2155 case VHOST_SET_LOG_FD: 2156 r = get_user(fd, (int __user *)argp); 2157 if (r < 0) 2158 break; 2159 ctx = fd == VHOST_FILE_UNBIND ? NULL : eventfd_ctx_fdget(fd); 2160 if (IS_ERR(ctx)) { 2161 r = PTR_ERR(ctx); 2162 break; 2163 } 2164 swap(ctx, d->log_ctx); 2165 for (i = 0; i < d->nvqs; ++i) { 2166 mutex_lock(&d->vqs[i]->mutex); 2167 d->vqs[i]->log_ctx = d->log_ctx; 2168 mutex_unlock(&d->vqs[i]->mutex); 2169 } 2170 if (ctx) 2171 eventfd_ctx_put(ctx); 2172 break; 2173 default: 2174 r = -ENOIOCTLCMD; 2175 break; 2176 } 2177 done: 2178 return r; 2179 } 2180 EXPORT_SYMBOL_GPL(vhost_dev_ioctl); 2181 2182 /* TODO: This is really inefficient. We need something like get_user() 2183 * (instruction directly accesses the data, with an exception table entry 2184 * returning -EFAULT). See Documentation/arch/x86/exception-tables.rst. 2185 */ 2186 static int set_bit_to_user(int nr, void __user *addr) 2187 { 2188 unsigned long log = (unsigned long)addr; 2189 struct page *page; 2190 void *base; 2191 int bit = nr + (log % PAGE_SIZE) * 8; 2192 int r; 2193 2194 r = pin_user_pages_fast(log, 1, FOLL_WRITE, &page); 2195 if (r < 0) 2196 return r; 2197 BUG_ON(r != 1); 2198 base = kmap_atomic(page); 2199 set_bit(bit, base); 2200 kunmap_atomic(base); 2201 unpin_user_pages_dirty_lock(&page, 1, true); 2202 return 0; 2203 } 2204 2205 static int log_write(void __user *log_base, 2206 u64 write_address, u64 write_length) 2207 { 2208 u64 write_page = write_address / VHOST_PAGE_SIZE; 2209 int r; 2210 2211 if (!write_length) 2212 return 0; 2213 write_length += write_address % VHOST_PAGE_SIZE; 2214 for (;;) { 2215 u64 base = (u64)(unsigned long)log_base; 2216 u64 log = base + write_page / 8; 2217 int bit = write_page % 8; 2218 if ((u64)(unsigned long)log != log) 2219 return -EFAULT; 2220 r = set_bit_to_user(bit, (void __user *)(unsigned long)log); 2221 if (r < 0) 2222 return r; 2223 if (write_length <= VHOST_PAGE_SIZE) 2224 break; 2225 write_length -= VHOST_PAGE_SIZE; 2226 write_page += 1; 2227 } 2228 return r; 2229 } 2230 2231 static int log_write_hva(struct vhost_virtqueue *vq, u64 hva, u64 len) 2232 { 2233 struct vhost_iotlb *umem = vq->umem; 2234 struct vhost_iotlb_map *u; 2235 u64 start, end, l, min; 2236 int r; 2237 bool hit = false; 2238 2239 while (len) { 2240 min = len; 2241 /* More than one GPAs can be mapped into a single HVA. So 2242 * iterate all possible umems here to be safe. 2243 */ 2244 list_for_each_entry(u, &umem->list, link) { 2245 if (u->addr > hva - 1 + len || 2246 u->addr - 1 + u->size < hva) 2247 continue; 2248 start = max(u->addr, hva); 2249 end = min(u->addr - 1 + u->size, hva - 1 + len); 2250 l = end - start + 1; 2251 r = log_write(vq->log_base, 2252 u->start + start - u->addr, 2253 l); 2254 if (r < 0) 2255 return r; 2256 hit = true; 2257 min = min(l, min); 2258 } 2259 2260 if (!hit) 2261 return -EFAULT; 2262 2263 len -= min; 2264 hva += min; 2265 } 2266 2267 return 0; 2268 } 2269 2270 static int log_used(struct vhost_virtqueue *vq, u64 used_offset, u64 len) 2271 { 2272 struct iovec *iov = vq->log_iov; 2273 int i, ret; 2274 2275 if (!vq->iotlb) 2276 return log_write(vq->log_base, vq->log_addr + used_offset, len); 2277 2278 ret = translate_desc(vq, (uintptr_t)vq->used + used_offset, 2279 len, iov, 64, VHOST_ACCESS_WO); 2280 if (ret < 0) 2281 return ret; 2282 2283 for (i = 0; i < ret; i++) { 2284 ret = log_write_hva(vq, (uintptr_t)iov[i].iov_base, 2285 iov[i].iov_len); 2286 if (ret) 2287 return ret; 2288 } 2289 2290 return 0; 2291 } 2292 2293 int vhost_log_write(struct vhost_virtqueue *vq, struct vhost_log *log, 2294 unsigned int log_num, u64 len, struct iovec *iov, int count) 2295 { 2296 int i, r; 2297 2298 /* Make sure data written is seen before log. */ 2299 smp_wmb(); 2300 2301 if (vq->iotlb) { 2302 for (i = 0; i < count; i++) { 2303 r = log_write_hva(vq, (uintptr_t)iov[i].iov_base, 2304 iov[i].iov_len); 2305 if (r < 0) 2306 return r; 2307 } 2308 return 0; 2309 } 2310 2311 for (i = 0; i < log_num; ++i) { 2312 u64 l = min(log[i].len, len); 2313 r = log_write(vq->log_base, log[i].addr, l); 2314 if (r < 0) 2315 return r; 2316 len -= l; 2317 if (!len) { 2318 if (vq->log_ctx) 2319 eventfd_signal(vq->log_ctx, 1); 2320 return 0; 2321 } 2322 } 2323 /* Length written exceeds what we have stored. This is a bug. */ 2324 BUG(); 2325 return 0; 2326 } 2327 EXPORT_SYMBOL_GPL(vhost_log_write); 2328 2329 static int vhost_update_used_flags(struct vhost_virtqueue *vq) 2330 { 2331 void __user *used; 2332 if (vhost_put_used_flags(vq)) 2333 return -EFAULT; 2334 if (unlikely(vq->log_used)) { 2335 /* Make sure the flag is seen before log. */ 2336 smp_wmb(); 2337 /* Log used flag write. */ 2338 used = &vq->used->flags; 2339 log_used(vq, (used - (void __user *)vq->used), 2340 sizeof vq->used->flags); 2341 if (vq->log_ctx) 2342 eventfd_signal(vq->log_ctx, 1); 2343 } 2344 return 0; 2345 } 2346 2347 static int vhost_update_avail_event(struct vhost_virtqueue *vq) 2348 { 2349 if (vhost_put_avail_event(vq)) 2350 return -EFAULT; 2351 if (unlikely(vq->log_used)) { 2352 void __user *used; 2353 /* Make sure the event is seen before log. */ 2354 smp_wmb(); 2355 /* Log avail event write */ 2356 used = vhost_avail_event(vq); 2357 log_used(vq, (used - (void __user *)vq->used), 2358 sizeof *vhost_avail_event(vq)); 2359 if (vq->log_ctx) 2360 eventfd_signal(vq->log_ctx, 1); 2361 } 2362 return 0; 2363 } 2364 2365 int vhost_vq_init_access(struct vhost_virtqueue *vq) 2366 { 2367 __virtio16 last_used_idx; 2368 int r; 2369 bool is_le = vq->is_le; 2370 2371 if (!vq->private_data) 2372 return 0; 2373 2374 vhost_init_is_le(vq); 2375 2376 r = vhost_update_used_flags(vq); 2377 if (r) 2378 goto err; 2379 vq->signalled_used_valid = false; 2380 if (!vq->iotlb && 2381 !access_ok(&vq->used->idx, sizeof vq->used->idx)) { 2382 r = -EFAULT; 2383 goto err; 2384 } 2385 r = vhost_get_used_idx(vq, &last_used_idx); 2386 if (r) { 2387 vq_err(vq, "Can't access used idx at %p\n", 2388 &vq->used->idx); 2389 goto err; 2390 } 2391 vq->last_used_idx = vhost16_to_cpu(vq, last_used_idx); 2392 return 0; 2393 2394 err: 2395 vq->is_le = is_le; 2396 return r; 2397 } 2398 EXPORT_SYMBOL_GPL(vhost_vq_init_access); 2399 2400 static int translate_desc(struct vhost_virtqueue *vq, u64 addr, u32 len, 2401 struct iovec iov[], int iov_size, int access) 2402 { 2403 const struct vhost_iotlb_map *map; 2404 struct vhost_dev *dev = vq->dev; 2405 struct vhost_iotlb *umem = dev->iotlb ? dev->iotlb : dev->umem; 2406 struct iovec *_iov; 2407 u64 s = 0, last = addr + len - 1; 2408 int ret = 0; 2409 2410 while ((u64)len > s) { 2411 u64 size; 2412 if (unlikely(ret >= iov_size)) { 2413 ret = -ENOBUFS; 2414 break; 2415 } 2416 2417 map = vhost_iotlb_itree_first(umem, addr, last); 2418 if (map == NULL || map->start > addr) { 2419 if (umem != dev->iotlb) { 2420 ret = -EFAULT; 2421 break; 2422 } 2423 ret = -EAGAIN; 2424 break; 2425 } else if (!(map->perm & access)) { 2426 ret = -EPERM; 2427 break; 2428 } 2429 2430 _iov = iov + ret; 2431 size = map->size - addr + map->start; 2432 _iov->iov_len = min((u64)len - s, size); 2433 _iov->iov_base = (void __user *)(unsigned long) 2434 (map->addr + addr - map->start); 2435 s += size; 2436 addr += size; 2437 ++ret; 2438 } 2439 2440 if (ret == -EAGAIN) 2441 vhost_iotlb_miss(vq, addr, access); 2442 return ret; 2443 } 2444 2445 /* Each buffer in the virtqueues is actually a chain of descriptors. This 2446 * function returns the next descriptor in the chain, 2447 * or -1U if we're at the end. */ 2448 static unsigned next_desc(struct vhost_virtqueue *vq, struct vring_desc *desc) 2449 { 2450 unsigned int next; 2451 2452 /* If this descriptor says it doesn't chain, we're done. */ 2453 if (!(desc->flags & cpu_to_vhost16(vq, VRING_DESC_F_NEXT))) 2454 return -1U; 2455 2456 /* Check they're not leading us off end of descriptors. */ 2457 next = vhost16_to_cpu(vq, READ_ONCE(desc->next)); 2458 return next; 2459 } 2460 2461 static int get_indirect(struct vhost_virtqueue *vq, 2462 struct iovec iov[], unsigned int iov_size, 2463 unsigned int *out_num, unsigned int *in_num, 2464 struct vhost_log *log, unsigned int *log_num, 2465 struct vring_desc *indirect) 2466 { 2467 struct vring_desc desc; 2468 unsigned int i = 0, count, found = 0; 2469 u32 len = vhost32_to_cpu(vq, indirect->len); 2470 struct iov_iter from; 2471 int ret, access; 2472 2473 /* Sanity check */ 2474 if (unlikely(len % sizeof desc)) { 2475 vq_err(vq, "Invalid length in indirect descriptor: " 2476 "len 0x%llx not multiple of 0x%zx\n", 2477 (unsigned long long)len, 2478 sizeof desc); 2479 return -EINVAL; 2480 } 2481 2482 ret = translate_desc(vq, vhost64_to_cpu(vq, indirect->addr), len, vq->indirect, 2483 UIO_MAXIOV, VHOST_ACCESS_RO); 2484 if (unlikely(ret < 0)) { 2485 if (ret != -EAGAIN) 2486 vq_err(vq, "Translation failure %d in indirect.\n", ret); 2487 return ret; 2488 } 2489 iov_iter_init(&from, ITER_SOURCE, vq->indirect, ret, len); 2490 count = len / sizeof desc; 2491 /* Buffers are chained via a 16 bit next field, so 2492 * we can have at most 2^16 of these. */ 2493 if (unlikely(count > USHRT_MAX + 1)) { 2494 vq_err(vq, "Indirect buffer length too big: %d\n", 2495 indirect->len); 2496 return -E2BIG; 2497 } 2498 2499 do { 2500 unsigned iov_count = *in_num + *out_num; 2501 if (unlikely(++found > count)) { 2502 vq_err(vq, "Loop detected: last one at %u " 2503 "indirect size %u\n", 2504 i, count); 2505 return -EINVAL; 2506 } 2507 if (unlikely(!copy_from_iter_full(&desc, sizeof(desc), &from))) { 2508 vq_err(vq, "Failed indirect descriptor: idx %d, %zx\n", 2509 i, (size_t)vhost64_to_cpu(vq, indirect->addr) + i * sizeof desc); 2510 return -EINVAL; 2511 } 2512 if (unlikely(desc.flags & cpu_to_vhost16(vq, VRING_DESC_F_INDIRECT))) { 2513 vq_err(vq, "Nested indirect descriptor: idx %d, %zx\n", 2514 i, (size_t)vhost64_to_cpu(vq, indirect->addr) + i * sizeof desc); 2515 return -EINVAL; 2516 } 2517 2518 if (desc.flags & cpu_to_vhost16(vq, VRING_DESC_F_WRITE)) 2519 access = VHOST_ACCESS_WO; 2520 else 2521 access = VHOST_ACCESS_RO; 2522 2523 ret = translate_desc(vq, vhost64_to_cpu(vq, desc.addr), 2524 vhost32_to_cpu(vq, desc.len), iov + iov_count, 2525 iov_size - iov_count, access); 2526 if (unlikely(ret < 0)) { 2527 if (ret != -EAGAIN) 2528 vq_err(vq, "Translation failure %d indirect idx %d\n", 2529 ret, i); 2530 return ret; 2531 } 2532 /* If this is an input descriptor, increment that count. */ 2533 if (access == VHOST_ACCESS_WO) { 2534 *in_num += ret; 2535 if (unlikely(log && ret)) { 2536 log[*log_num].addr = vhost64_to_cpu(vq, desc.addr); 2537 log[*log_num].len = vhost32_to_cpu(vq, desc.len); 2538 ++*log_num; 2539 } 2540 } else { 2541 /* If it's an output descriptor, they're all supposed 2542 * to come before any input descriptors. */ 2543 if (unlikely(*in_num)) { 2544 vq_err(vq, "Indirect descriptor " 2545 "has out after in: idx %d\n", i); 2546 return -EINVAL; 2547 } 2548 *out_num += ret; 2549 } 2550 } while ((i = next_desc(vq, &desc)) != -1); 2551 return 0; 2552 } 2553 2554 /* This looks in the virtqueue and for the first available buffer, and converts 2555 * it to an iovec for convenient access. Since descriptors consist of some 2556 * number of output then some number of input descriptors, it's actually two 2557 * iovecs, but we pack them into one and note how many of each there were. 2558 * 2559 * This function returns the descriptor number found, or vq->num (which is 2560 * never a valid descriptor number) if none was found. A negative code is 2561 * returned on error. */ 2562 int vhost_get_vq_desc(struct vhost_virtqueue *vq, 2563 struct iovec iov[], unsigned int iov_size, 2564 unsigned int *out_num, unsigned int *in_num, 2565 struct vhost_log *log, unsigned int *log_num) 2566 { 2567 struct vring_desc desc; 2568 unsigned int i, head, found = 0; 2569 u16 last_avail_idx; 2570 __virtio16 avail_idx; 2571 __virtio16 ring_head; 2572 int ret, access; 2573 2574 /* Check it isn't doing very strange things with descriptor numbers. */ 2575 last_avail_idx = vq->last_avail_idx; 2576 2577 if (vq->avail_idx == vq->last_avail_idx) { 2578 if (unlikely(vhost_get_avail_idx(vq, &avail_idx))) { 2579 vq_err(vq, "Failed to access avail idx at %p\n", 2580 &vq->avail->idx); 2581 return -EFAULT; 2582 } 2583 vq->avail_idx = vhost16_to_cpu(vq, avail_idx); 2584 2585 if (unlikely((u16)(vq->avail_idx - last_avail_idx) > vq->num)) { 2586 vq_err(vq, "Guest moved used index from %u to %u", 2587 last_avail_idx, vq->avail_idx); 2588 return -EFAULT; 2589 } 2590 2591 /* If there's nothing new since last we looked, return 2592 * invalid. 2593 */ 2594 if (vq->avail_idx == last_avail_idx) 2595 return vq->num; 2596 2597 /* Only get avail ring entries after they have been 2598 * exposed by guest. 2599 */ 2600 smp_rmb(); 2601 } 2602 2603 /* Grab the next descriptor number they're advertising, and increment 2604 * the index we've seen. */ 2605 if (unlikely(vhost_get_avail_head(vq, &ring_head, last_avail_idx))) { 2606 vq_err(vq, "Failed to read head: idx %d address %p\n", 2607 last_avail_idx, 2608 &vq->avail->ring[last_avail_idx % vq->num]); 2609 return -EFAULT; 2610 } 2611 2612 head = vhost16_to_cpu(vq, ring_head); 2613 2614 /* If their number is silly, that's an error. */ 2615 if (unlikely(head >= vq->num)) { 2616 vq_err(vq, "Guest says index %u > %u is available", 2617 head, vq->num); 2618 return -EINVAL; 2619 } 2620 2621 /* When we start there are none of either input nor output. */ 2622 *out_num = *in_num = 0; 2623 if (unlikely(log)) 2624 *log_num = 0; 2625 2626 i = head; 2627 do { 2628 unsigned iov_count = *in_num + *out_num; 2629 if (unlikely(i >= vq->num)) { 2630 vq_err(vq, "Desc index is %u > %u, head = %u", 2631 i, vq->num, head); 2632 return -EINVAL; 2633 } 2634 if (unlikely(++found > vq->num)) { 2635 vq_err(vq, "Loop detected: last one at %u " 2636 "vq size %u head %u\n", 2637 i, vq->num, head); 2638 return -EINVAL; 2639 } 2640 ret = vhost_get_desc(vq, &desc, i); 2641 if (unlikely(ret)) { 2642 vq_err(vq, "Failed to get descriptor: idx %d addr %p\n", 2643 i, vq->desc + i); 2644 return -EFAULT; 2645 } 2646 if (desc.flags & cpu_to_vhost16(vq, VRING_DESC_F_INDIRECT)) { 2647 ret = get_indirect(vq, iov, iov_size, 2648 out_num, in_num, 2649 log, log_num, &desc); 2650 if (unlikely(ret < 0)) { 2651 if (ret != -EAGAIN) 2652 vq_err(vq, "Failure detected " 2653 "in indirect descriptor at idx %d\n", i); 2654 return ret; 2655 } 2656 continue; 2657 } 2658 2659 if (desc.flags & cpu_to_vhost16(vq, VRING_DESC_F_WRITE)) 2660 access = VHOST_ACCESS_WO; 2661 else 2662 access = VHOST_ACCESS_RO; 2663 ret = translate_desc(vq, vhost64_to_cpu(vq, desc.addr), 2664 vhost32_to_cpu(vq, desc.len), iov + iov_count, 2665 iov_size - iov_count, access); 2666 if (unlikely(ret < 0)) { 2667 if (ret != -EAGAIN) 2668 vq_err(vq, "Translation failure %d descriptor idx %d\n", 2669 ret, i); 2670 return ret; 2671 } 2672 if (access == VHOST_ACCESS_WO) { 2673 /* If this is an input descriptor, 2674 * increment that count. */ 2675 *in_num += ret; 2676 if (unlikely(log && ret)) { 2677 log[*log_num].addr = vhost64_to_cpu(vq, desc.addr); 2678 log[*log_num].len = vhost32_to_cpu(vq, desc.len); 2679 ++*log_num; 2680 } 2681 } else { 2682 /* If it's an output descriptor, they're all supposed 2683 * to come before any input descriptors. */ 2684 if (unlikely(*in_num)) { 2685 vq_err(vq, "Descriptor has out after in: " 2686 "idx %d\n", i); 2687 return -EINVAL; 2688 } 2689 *out_num += ret; 2690 } 2691 } while ((i = next_desc(vq, &desc)) != -1); 2692 2693 /* On success, increment avail index. */ 2694 vq->last_avail_idx++; 2695 2696 /* Assume notifications from guest are disabled at this point, 2697 * if they aren't we would need to update avail_event index. */ 2698 BUG_ON(!(vq->used_flags & VRING_USED_F_NO_NOTIFY)); 2699 return head; 2700 } 2701 EXPORT_SYMBOL_GPL(vhost_get_vq_desc); 2702 2703 /* Reverse the effect of vhost_get_vq_desc. Useful for error handling. */ 2704 void vhost_discard_vq_desc(struct vhost_virtqueue *vq, int n) 2705 { 2706 vq->last_avail_idx -= n; 2707 } 2708 EXPORT_SYMBOL_GPL(vhost_discard_vq_desc); 2709 2710 /* After we've used one of their buffers, we tell them about it. We'll then 2711 * want to notify the guest, using eventfd. */ 2712 int vhost_add_used(struct vhost_virtqueue *vq, unsigned int head, int len) 2713 { 2714 struct vring_used_elem heads = { 2715 cpu_to_vhost32(vq, head), 2716 cpu_to_vhost32(vq, len) 2717 }; 2718 2719 return vhost_add_used_n(vq, &heads, 1); 2720 } 2721 EXPORT_SYMBOL_GPL(vhost_add_used); 2722 2723 static int __vhost_add_used_n(struct vhost_virtqueue *vq, 2724 struct vring_used_elem *heads, 2725 unsigned count) 2726 { 2727 vring_used_elem_t __user *used; 2728 u16 old, new; 2729 int start; 2730 2731 start = vq->last_used_idx & (vq->num - 1); 2732 used = vq->used->ring + start; 2733 if (vhost_put_used(vq, heads, start, count)) { 2734 vq_err(vq, "Failed to write used"); 2735 return -EFAULT; 2736 } 2737 if (unlikely(vq->log_used)) { 2738 /* Make sure data is seen before log. */ 2739 smp_wmb(); 2740 /* Log used ring entry write. */ 2741 log_used(vq, ((void __user *)used - (void __user *)vq->used), 2742 count * sizeof *used); 2743 } 2744 old = vq->last_used_idx; 2745 new = (vq->last_used_idx += count); 2746 /* If the driver never bothers to signal in a very long while, 2747 * used index might wrap around. If that happens, invalidate 2748 * signalled_used index we stored. TODO: make sure driver 2749 * signals at least once in 2^16 and remove this. */ 2750 if (unlikely((u16)(new - vq->signalled_used) < (u16)(new - old))) 2751 vq->signalled_used_valid = false; 2752 return 0; 2753 } 2754 2755 /* After we've used one of their buffers, we tell them about it. We'll then 2756 * want to notify the guest, using eventfd. */ 2757 int vhost_add_used_n(struct vhost_virtqueue *vq, struct vring_used_elem *heads, 2758 unsigned count) 2759 { 2760 int start, n, r; 2761 2762 start = vq->last_used_idx & (vq->num - 1); 2763 n = vq->num - start; 2764 if (n < count) { 2765 r = __vhost_add_used_n(vq, heads, n); 2766 if (r < 0) 2767 return r; 2768 heads += n; 2769 count -= n; 2770 } 2771 r = __vhost_add_used_n(vq, heads, count); 2772 2773 /* Make sure buffer is written before we update index. */ 2774 smp_wmb(); 2775 if (vhost_put_used_idx(vq)) { 2776 vq_err(vq, "Failed to increment used idx"); 2777 return -EFAULT; 2778 } 2779 if (unlikely(vq->log_used)) { 2780 /* Make sure used idx is seen before log. */ 2781 smp_wmb(); 2782 /* Log used index update. */ 2783 log_used(vq, offsetof(struct vring_used, idx), 2784 sizeof vq->used->idx); 2785 if (vq->log_ctx) 2786 eventfd_signal(vq->log_ctx, 1); 2787 } 2788 return r; 2789 } 2790 EXPORT_SYMBOL_GPL(vhost_add_used_n); 2791 2792 static bool vhost_notify(struct vhost_dev *dev, struct vhost_virtqueue *vq) 2793 { 2794 __u16 old, new; 2795 __virtio16 event; 2796 bool v; 2797 /* Flush out used index updates. This is paired 2798 * with the barrier that the Guest executes when enabling 2799 * interrupts. */ 2800 smp_mb(); 2801 2802 if (vhost_has_feature(vq, VIRTIO_F_NOTIFY_ON_EMPTY) && 2803 unlikely(vq->avail_idx == vq->last_avail_idx)) 2804 return true; 2805 2806 if (!vhost_has_feature(vq, VIRTIO_RING_F_EVENT_IDX)) { 2807 __virtio16 flags; 2808 if (vhost_get_avail_flags(vq, &flags)) { 2809 vq_err(vq, "Failed to get flags"); 2810 return true; 2811 } 2812 return !(flags & cpu_to_vhost16(vq, VRING_AVAIL_F_NO_INTERRUPT)); 2813 } 2814 old = vq->signalled_used; 2815 v = vq->signalled_used_valid; 2816 new = vq->signalled_used = vq->last_used_idx; 2817 vq->signalled_used_valid = true; 2818 2819 if (unlikely(!v)) 2820 return true; 2821 2822 if (vhost_get_used_event(vq, &event)) { 2823 vq_err(vq, "Failed to get used event idx"); 2824 return true; 2825 } 2826 return vring_need_event(vhost16_to_cpu(vq, event), new, old); 2827 } 2828 2829 /* This actually signals the guest, using eventfd. */ 2830 void vhost_signal(struct vhost_dev *dev, struct vhost_virtqueue *vq) 2831 { 2832 /* Signal the Guest tell them we used something up. */ 2833 if (vq->call_ctx.ctx && vhost_notify(dev, vq)) 2834 eventfd_signal(vq->call_ctx.ctx, 1); 2835 } 2836 EXPORT_SYMBOL_GPL(vhost_signal); 2837 2838 /* And here's the combo meal deal. Supersize me! */ 2839 void vhost_add_used_and_signal(struct vhost_dev *dev, 2840 struct vhost_virtqueue *vq, 2841 unsigned int head, int len) 2842 { 2843 vhost_add_used(vq, head, len); 2844 vhost_signal(dev, vq); 2845 } 2846 EXPORT_SYMBOL_GPL(vhost_add_used_and_signal); 2847 2848 /* multi-buffer version of vhost_add_used_and_signal */ 2849 void vhost_add_used_and_signal_n(struct vhost_dev *dev, 2850 struct vhost_virtqueue *vq, 2851 struct vring_used_elem *heads, unsigned count) 2852 { 2853 vhost_add_used_n(vq, heads, count); 2854 vhost_signal(dev, vq); 2855 } 2856 EXPORT_SYMBOL_GPL(vhost_add_used_and_signal_n); 2857 2858 /* return true if we're sure that avaiable ring is empty */ 2859 bool vhost_vq_avail_empty(struct vhost_dev *dev, struct vhost_virtqueue *vq) 2860 { 2861 __virtio16 avail_idx; 2862 int r; 2863 2864 if (vq->avail_idx != vq->last_avail_idx) 2865 return false; 2866 2867 r = vhost_get_avail_idx(vq, &avail_idx); 2868 if (unlikely(r)) 2869 return false; 2870 2871 vq->avail_idx = vhost16_to_cpu(vq, avail_idx); 2872 if (vq->avail_idx != vq->last_avail_idx) { 2873 /* Since we have updated avail_idx, the following 2874 * call to vhost_get_vq_desc() will read available 2875 * ring entries. Make sure that read happens after 2876 * the avail_idx read. 2877 */ 2878 smp_rmb(); 2879 return false; 2880 } 2881 2882 return true; 2883 } 2884 EXPORT_SYMBOL_GPL(vhost_vq_avail_empty); 2885 2886 /* OK, now we need to know about added descriptors. */ 2887 bool vhost_enable_notify(struct vhost_dev *dev, struct vhost_virtqueue *vq) 2888 { 2889 __virtio16 avail_idx; 2890 int r; 2891 2892 if (!(vq->used_flags & VRING_USED_F_NO_NOTIFY)) 2893 return false; 2894 vq->used_flags &= ~VRING_USED_F_NO_NOTIFY; 2895 if (!vhost_has_feature(vq, VIRTIO_RING_F_EVENT_IDX)) { 2896 r = vhost_update_used_flags(vq); 2897 if (r) { 2898 vq_err(vq, "Failed to enable notification at %p: %d\n", 2899 &vq->used->flags, r); 2900 return false; 2901 } 2902 } else { 2903 r = vhost_update_avail_event(vq); 2904 if (r) { 2905 vq_err(vq, "Failed to update avail event index at %p: %d\n", 2906 vhost_avail_event(vq), r); 2907 return false; 2908 } 2909 } 2910 /* They could have slipped one in as we were doing that: make 2911 * sure it's written, then check again. */ 2912 smp_mb(); 2913 r = vhost_get_avail_idx(vq, &avail_idx); 2914 if (r) { 2915 vq_err(vq, "Failed to check avail idx at %p: %d\n", 2916 &vq->avail->idx, r); 2917 return false; 2918 } 2919 2920 vq->avail_idx = vhost16_to_cpu(vq, avail_idx); 2921 if (vq->avail_idx != vq->last_avail_idx) { 2922 /* Since we have updated avail_idx, the following 2923 * call to vhost_get_vq_desc() will read available 2924 * ring entries. Make sure that read happens after 2925 * the avail_idx read. 2926 */ 2927 smp_rmb(); 2928 return true; 2929 } 2930 2931 return false; 2932 } 2933 EXPORT_SYMBOL_GPL(vhost_enable_notify); 2934 2935 /* We don't need to be notified again. */ 2936 void vhost_disable_notify(struct vhost_dev *dev, struct vhost_virtqueue *vq) 2937 { 2938 int r; 2939 2940 if (vq->used_flags & VRING_USED_F_NO_NOTIFY) 2941 return; 2942 vq->used_flags |= VRING_USED_F_NO_NOTIFY; 2943 if (!vhost_has_feature(vq, VIRTIO_RING_F_EVENT_IDX)) { 2944 r = vhost_update_used_flags(vq); 2945 if (r) 2946 vq_err(vq, "Failed to disable notification at %p: %d\n", 2947 &vq->used->flags, r); 2948 } 2949 } 2950 EXPORT_SYMBOL_GPL(vhost_disable_notify); 2951 2952 /* Create a new message. */ 2953 struct vhost_msg_node *vhost_new_msg(struct vhost_virtqueue *vq, int type) 2954 { 2955 /* Make sure all padding within the structure is initialized. */ 2956 struct vhost_msg_node *node = kzalloc(sizeof(*node), GFP_KERNEL); 2957 if (!node) 2958 return NULL; 2959 2960 node->vq = vq; 2961 node->msg.type = type; 2962 return node; 2963 } 2964 EXPORT_SYMBOL_GPL(vhost_new_msg); 2965 2966 void vhost_enqueue_msg(struct vhost_dev *dev, struct list_head *head, 2967 struct vhost_msg_node *node) 2968 { 2969 spin_lock(&dev->iotlb_lock); 2970 list_add_tail(&node->node, head); 2971 spin_unlock(&dev->iotlb_lock); 2972 2973 wake_up_interruptible_poll(&dev->wait, EPOLLIN | EPOLLRDNORM); 2974 } 2975 EXPORT_SYMBOL_GPL(vhost_enqueue_msg); 2976 2977 struct vhost_msg_node *vhost_dequeue_msg(struct vhost_dev *dev, 2978 struct list_head *head) 2979 { 2980 struct vhost_msg_node *node = NULL; 2981 2982 spin_lock(&dev->iotlb_lock); 2983 if (!list_empty(head)) { 2984 node = list_first_entry(head, struct vhost_msg_node, 2985 node); 2986 list_del(&node->node); 2987 } 2988 spin_unlock(&dev->iotlb_lock); 2989 2990 return node; 2991 } 2992 EXPORT_SYMBOL_GPL(vhost_dequeue_msg); 2993 2994 void vhost_set_backend_features(struct vhost_dev *dev, u64 features) 2995 { 2996 struct vhost_virtqueue *vq; 2997 int i; 2998 2999 mutex_lock(&dev->mutex); 3000 for (i = 0; i < dev->nvqs; ++i) { 3001 vq = dev->vqs[i]; 3002 mutex_lock(&vq->mutex); 3003 vq->acked_backend_features = features; 3004 mutex_unlock(&vq->mutex); 3005 } 3006 mutex_unlock(&dev->mutex); 3007 } 3008 EXPORT_SYMBOL_GPL(vhost_set_backend_features); 3009 3010 static int __init vhost_init(void) 3011 { 3012 return 0; 3013 } 3014 3015 static void __exit vhost_exit(void) 3016 { 3017 } 3018 3019 module_init(vhost_init); 3020 module_exit(vhost_exit); 3021 3022 MODULE_VERSION("0.0.1"); 3023 MODULE_LICENSE("GPL v2"); 3024 MODULE_AUTHOR("Michael S. Tsirkin"); 3025 MODULE_DESCRIPTION("Host kernel accelerator for virtio"); 3026