1 /* Copyright (C) 2009 Red Hat, Inc. 2 * Copyright (C) 2006 Rusty Russell IBM Corporation 3 * 4 * Author: Michael S. Tsirkin <mst@redhat.com> 5 * 6 * Inspiration, some code, and most witty comments come from 7 * Documentation/virtual/lguest/lguest.c, by Rusty Russell 8 * 9 * This work is licensed under the terms of the GNU GPL, version 2. 10 * 11 * Generic code for virtio server in host kernel. 12 */ 13 14 #include <linux/eventfd.h> 15 #include <linux/vhost.h> 16 #include <linux/virtio_net.h> 17 #include <linux/mm.h> 18 #include <linux/mmu_context.h> 19 #include <linux/miscdevice.h> 20 #include <linux/mutex.h> 21 #include <linux/rcupdate.h> 22 #include <linux/poll.h> 23 #include <linux/file.h> 24 #include <linux/highmem.h> 25 #include <linux/slab.h> 26 #include <linux/kthread.h> 27 #include <linux/cgroup.h> 28 29 #include <linux/net.h> 30 #include <linux/if_packet.h> 31 #include <linux/if_arp.h> 32 33 #include "vhost.h" 34 35 enum { 36 VHOST_MEMORY_MAX_NREGIONS = 64, 37 VHOST_MEMORY_F_LOG = 0x1, 38 }; 39 40 static unsigned vhost_zcopy_mask __read_mostly; 41 42 #define vhost_used_event(vq) ((u16 __user *)&vq->avail->ring[vq->num]) 43 #define vhost_avail_event(vq) ((u16 __user *)&vq->used->ring[vq->num]) 44 45 static void vhost_poll_func(struct file *file, wait_queue_head_t *wqh, 46 poll_table *pt) 47 { 48 struct vhost_poll *poll; 49 50 poll = container_of(pt, struct vhost_poll, table); 51 poll->wqh = wqh; 52 add_wait_queue(wqh, &poll->wait); 53 } 54 55 static int vhost_poll_wakeup(wait_queue_t *wait, unsigned mode, int sync, 56 void *key) 57 { 58 struct vhost_poll *poll = container_of(wait, struct vhost_poll, wait); 59 60 if (!((unsigned long)key & poll->mask)) 61 return 0; 62 63 vhost_poll_queue(poll); 64 return 0; 65 } 66 67 void vhost_work_init(struct vhost_work *work, vhost_work_fn_t fn) 68 { 69 INIT_LIST_HEAD(&work->node); 70 work->fn = fn; 71 init_waitqueue_head(&work->done); 72 work->flushing = 0; 73 work->queue_seq = work->done_seq = 0; 74 } 75 76 /* Init poll structure */ 77 void vhost_poll_init(struct vhost_poll *poll, vhost_work_fn_t fn, 78 unsigned long mask, struct vhost_dev *dev) 79 { 80 init_waitqueue_func_entry(&poll->wait, vhost_poll_wakeup); 81 init_poll_funcptr(&poll->table, vhost_poll_func); 82 poll->mask = mask; 83 poll->dev = dev; 84 85 vhost_work_init(&poll->work, fn); 86 } 87 88 /* Start polling a file. We add ourselves to file's wait queue. The caller must 89 * keep a reference to a file until after vhost_poll_stop is called. */ 90 void vhost_poll_start(struct vhost_poll *poll, struct file *file) 91 { 92 unsigned long mask; 93 94 mask = file->f_op->poll(file, &poll->table); 95 if (mask) 96 vhost_poll_wakeup(&poll->wait, 0, 0, (void *)mask); 97 } 98 99 /* Stop polling a file. After this function returns, it becomes safe to drop the 100 * file reference. You must also flush afterwards. */ 101 void vhost_poll_stop(struct vhost_poll *poll) 102 { 103 remove_wait_queue(poll->wqh, &poll->wait); 104 } 105 106 static bool vhost_work_seq_done(struct vhost_dev *dev, struct vhost_work *work, 107 unsigned seq) 108 { 109 int left; 110 111 spin_lock_irq(&dev->work_lock); 112 left = seq - work->done_seq; 113 spin_unlock_irq(&dev->work_lock); 114 return left <= 0; 115 } 116 117 static void vhost_work_flush(struct vhost_dev *dev, struct vhost_work *work) 118 { 119 unsigned seq; 120 int flushing; 121 122 spin_lock_irq(&dev->work_lock); 123 seq = work->queue_seq; 124 work->flushing++; 125 spin_unlock_irq(&dev->work_lock); 126 wait_event(work->done, vhost_work_seq_done(dev, work, seq)); 127 spin_lock_irq(&dev->work_lock); 128 flushing = --work->flushing; 129 spin_unlock_irq(&dev->work_lock); 130 BUG_ON(flushing < 0); 131 } 132 133 /* Flush any work that has been scheduled. When calling this, don't hold any 134 * locks that are also used by the callback. */ 135 void vhost_poll_flush(struct vhost_poll *poll) 136 { 137 vhost_work_flush(poll->dev, &poll->work); 138 } 139 140 void vhost_work_queue(struct vhost_dev *dev, struct vhost_work *work) 141 { 142 unsigned long flags; 143 144 spin_lock_irqsave(&dev->work_lock, flags); 145 if (list_empty(&work->node)) { 146 list_add_tail(&work->node, &dev->work_list); 147 work->queue_seq++; 148 wake_up_process(dev->worker); 149 } 150 spin_unlock_irqrestore(&dev->work_lock, flags); 151 } 152 153 void vhost_poll_queue(struct vhost_poll *poll) 154 { 155 vhost_work_queue(poll->dev, &poll->work); 156 } 157 158 static void vhost_vq_reset(struct vhost_dev *dev, 159 struct vhost_virtqueue *vq) 160 { 161 vq->num = 1; 162 vq->desc = NULL; 163 vq->avail = NULL; 164 vq->used = NULL; 165 vq->last_avail_idx = 0; 166 vq->avail_idx = 0; 167 vq->last_used_idx = 0; 168 vq->signalled_used = 0; 169 vq->signalled_used_valid = false; 170 vq->used_flags = 0; 171 vq->log_used = false; 172 vq->log_addr = -1ull; 173 vq->vhost_hlen = 0; 174 vq->sock_hlen = 0; 175 vq->private_data = NULL; 176 vq->log_base = NULL; 177 vq->error_ctx = NULL; 178 vq->error = NULL; 179 vq->kick = NULL; 180 vq->call_ctx = NULL; 181 vq->call = NULL; 182 vq->log_ctx = NULL; 183 vq->upend_idx = 0; 184 vq->done_idx = 0; 185 vq->ubufs = NULL; 186 } 187 188 static int vhost_worker(void *data) 189 { 190 struct vhost_dev *dev = data; 191 struct vhost_work *work = NULL; 192 unsigned uninitialized_var(seq); 193 mm_segment_t oldfs = get_fs(); 194 195 set_fs(USER_DS); 196 use_mm(dev->mm); 197 198 for (;;) { 199 /* mb paired w/ kthread_stop */ 200 set_current_state(TASK_INTERRUPTIBLE); 201 202 spin_lock_irq(&dev->work_lock); 203 if (work) { 204 work->done_seq = seq; 205 if (work->flushing) 206 wake_up_all(&work->done); 207 } 208 209 if (kthread_should_stop()) { 210 spin_unlock_irq(&dev->work_lock); 211 __set_current_state(TASK_RUNNING); 212 break; 213 } 214 if (!list_empty(&dev->work_list)) { 215 work = list_first_entry(&dev->work_list, 216 struct vhost_work, node); 217 list_del_init(&work->node); 218 seq = work->queue_seq; 219 } else 220 work = NULL; 221 spin_unlock_irq(&dev->work_lock); 222 223 if (work) { 224 __set_current_state(TASK_RUNNING); 225 work->fn(work); 226 if (need_resched()) 227 schedule(); 228 } else 229 schedule(); 230 231 } 232 unuse_mm(dev->mm); 233 set_fs(oldfs); 234 return 0; 235 } 236 237 static void vhost_vq_free_iovecs(struct vhost_virtqueue *vq) 238 { 239 kfree(vq->indirect); 240 vq->indirect = NULL; 241 kfree(vq->log); 242 vq->log = NULL; 243 kfree(vq->heads); 244 vq->heads = NULL; 245 kfree(vq->ubuf_info); 246 vq->ubuf_info = NULL; 247 } 248 249 void vhost_enable_zcopy(int vq) 250 { 251 vhost_zcopy_mask |= 0x1 << vq; 252 } 253 254 /* Helper to allocate iovec buffers for all vqs. */ 255 static long vhost_dev_alloc_iovecs(struct vhost_dev *dev) 256 { 257 int i; 258 bool zcopy; 259 260 for (i = 0; i < dev->nvqs; ++i) { 261 dev->vqs[i].indirect = kmalloc(sizeof *dev->vqs[i].indirect * 262 UIO_MAXIOV, GFP_KERNEL); 263 dev->vqs[i].log = kmalloc(sizeof *dev->vqs[i].log * UIO_MAXIOV, 264 GFP_KERNEL); 265 dev->vqs[i].heads = kmalloc(sizeof *dev->vqs[i].heads * 266 UIO_MAXIOV, GFP_KERNEL); 267 zcopy = vhost_zcopy_mask & (0x1 << i); 268 if (zcopy) 269 dev->vqs[i].ubuf_info = 270 kmalloc(sizeof *dev->vqs[i].ubuf_info * 271 UIO_MAXIOV, GFP_KERNEL); 272 if (!dev->vqs[i].indirect || !dev->vqs[i].log || 273 !dev->vqs[i].heads || 274 (zcopy && !dev->vqs[i].ubuf_info)) 275 goto err_nomem; 276 } 277 return 0; 278 279 err_nomem: 280 for (; i >= 0; --i) 281 vhost_vq_free_iovecs(&dev->vqs[i]); 282 return -ENOMEM; 283 } 284 285 static void vhost_dev_free_iovecs(struct vhost_dev *dev) 286 { 287 int i; 288 289 for (i = 0; i < dev->nvqs; ++i) 290 vhost_vq_free_iovecs(&dev->vqs[i]); 291 } 292 293 long vhost_dev_init(struct vhost_dev *dev, 294 struct vhost_virtqueue *vqs, int nvqs) 295 { 296 int i; 297 298 dev->vqs = vqs; 299 dev->nvqs = nvqs; 300 mutex_init(&dev->mutex); 301 dev->log_ctx = NULL; 302 dev->log_file = NULL; 303 dev->memory = NULL; 304 dev->mm = NULL; 305 spin_lock_init(&dev->work_lock); 306 INIT_LIST_HEAD(&dev->work_list); 307 dev->worker = NULL; 308 309 for (i = 0; i < dev->nvqs; ++i) { 310 dev->vqs[i].log = NULL; 311 dev->vqs[i].indirect = NULL; 312 dev->vqs[i].heads = NULL; 313 dev->vqs[i].ubuf_info = NULL; 314 dev->vqs[i].dev = dev; 315 mutex_init(&dev->vqs[i].mutex); 316 vhost_vq_reset(dev, dev->vqs + i); 317 if (dev->vqs[i].handle_kick) 318 vhost_poll_init(&dev->vqs[i].poll, 319 dev->vqs[i].handle_kick, POLLIN, dev); 320 } 321 322 return 0; 323 } 324 325 /* Caller should have device mutex */ 326 long vhost_dev_check_owner(struct vhost_dev *dev) 327 { 328 /* Are you the owner? If not, I don't think you mean to do that */ 329 return dev->mm == current->mm ? 0 : -EPERM; 330 } 331 332 struct vhost_attach_cgroups_struct { 333 struct vhost_work work; 334 struct task_struct *owner; 335 int ret; 336 }; 337 338 static void vhost_attach_cgroups_work(struct vhost_work *work) 339 { 340 struct vhost_attach_cgroups_struct *s; 341 342 s = container_of(work, struct vhost_attach_cgroups_struct, work); 343 s->ret = cgroup_attach_task_all(s->owner, current); 344 } 345 346 static int vhost_attach_cgroups(struct vhost_dev *dev) 347 { 348 struct vhost_attach_cgroups_struct attach; 349 350 attach.owner = current; 351 vhost_work_init(&attach.work, vhost_attach_cgroups_work); 352 vhost_work_queue(dev, &attach.work); 353 vhost_work_flush(dev, &attach.work); 354 return attach.ret; 355 } 356 357 /* Caller should have device mutex */ 358 static long vhost_dev_set_owner(struct vhost_dev *dev) 359 { 360 struct task_struct *worker; 361 int err; 362 363 /* Is there an owner already? */ 364 if (dev->mm) { 365 err = -EBUSY; 366 goto err_mm; 367 } 368 369 /* No owner, become one */ 370 dev->mm = get_task_mm(current); 371 worker = kthread_create(vhost_worker, dev, "vhost-%d", current->pid); 372 if (IS_ERR(worker)) { 373 err = PTR_ERR(worker); 374 goto err_worker; 375 } 376 377 dev->worker = worker; 378 wake_up_process(worker); /* avoid contributing to loadavg */ 379 380 err = vhost_attach_cgroups(dev); 381 if (err) 382 goto err_cgroup; 383 384 err = vhost_dev_alloc_iovecs(dev); 385 if (err) 386 goto err_cgroup; 387 388 return 0; 389 err_cgroup: 390 kthread_stop(worker); 391 dev->worker = NULL; 392 err_worker: 393 if (dev->mm) 394 mmput(dev->mm); 395 dev->mm = NULL; 396 err_mm: 397 return err; 398 } 399 400 /* Caller should have device mutex */ 401 long vhost_dev_reset_owner(struct vhost_dev *dev) 402 { 403 struct vhost_memory *memory; 404 405 /* Restore memory to default empty mapping. */ 406 memory = kmalloc(offsetof(struct vhost_memory, regions), GFP_KERNEL); 407 if (!memory) 408 return -ENOMEM; 409 410 vhost_dev_cleanup(dev, true); 411 412 memory->nregions = 0; 413 RCU_INIT_POINTER(dev->memory, memory); 414 return 0; 415 } 416 417 /* In case of DMA done not in order in lower device driver for some reason. 418 * upend_idx is used to track end of used idx, done_idx is used to track head 419 * of used idx. Once lower device DMA done contiguously, we will signal KVM 420 * guest used idx. 421 */ 422 int vhost_zerocopy_signal_used(struct vhost_virtqueue *vq) 423 { 424 int i; 425 int j = 0; 426 427 for (i = vq->done_idx; i != vq->upend_idx; i = (i + 1) % UIO_MAXIOV) { 428 if ((vq->heads[i].len == VHOST_DMA_DONE_LEN)) { 429 vq->heads[i].len = VHOST_DMA_CLEAR_LEN; 430 vhost_add_used_and_signal(vq->dev, vq, 431 vq->heads[i].id, 0); 432 ++j; 433 } else 434 break; 435 } 436 if (j) 437 vq->done_idx = i; 438 return j; 439 } 440 441 /* Caller should have device mutex if and only if locked is set */ 442 void vhost_dev_cleanup(struct vhost_dev *dev, bool locked) 443 { 444 int i; 445 446 for (i = 0; i < dev->nvqs; ++i) { 447 if (dev->vqs[i].kick && dev->vqs[i].handle_kick) { 448 vhost_poll_stop(&dev->vqs[i].poll); 449 vhost_poll_flush(&dev->vqs[i].poll); 450 } 451 /* Wait for all lower device DMAs done. */ 452 if (dev->vqs[i].ubufs) 453 vhost_ubuf_put_and_wait(dev->vqs[i].ubufs); 454 455 /* Signal guest as appropriate. */ 456 vhost_zerocopy_signal_used(&dev->vqs[i]); 457 458 if (dev->vqs[i].error_ctx) 459 eventfd_ctx_put(dev->vqs[i].error_ctx); 460 if (dev->vqs[i].error) 461 fput(dev->vqs[i].error); 462 if (dev->vqs[i].kick) 463 fput(dev->vqs[i].kick); 464 if (dev->vqs[i].call_ctx) 465 eventfd_ctx_put(dev->vqs[i].call_ctx); 466 if (dev->vqs[i].call) 467 fput(dev->vqs[i].call); 468 vhost_vq_reset(dev, dev->vqs + i); 469 } 470 vhost_dev_free_iovecs(dev); 471 if (dev->log_ctx) 472 eventfd_ctx_put(dev->log_ctx); 473 dev->log_ctx = NULL; 474 if (dev->log_file) 475 fput(dev->log_file); 476 dev->log_file = NULL; 477 /* No one will access memory at this point */ 478 kfree(rcu_dereference_protected(dev->memory, 479 locked == 480 lockdep_is_held(&dev->mutex))); 481 RCU_INIT_POINTER(dev->memory, NULL); 482 WARN_ON(!list_empty(&dev->work_list)); 483 if (dev->worker) { 484 kthread_stop(dev->worker); 485 dev->worker = NULL; 486 } 487 if (dev->mm) 488 mmput(dev->mm); 489 dev->mm = NULL; 490 } 491 492 static int log_access_ok(void __user *log_base, u64 addr, unsigned long sz) 493 { 494 u64 a = addr / VHOST_PAGE_SIZE / 8; 495 496 /* Make sure 64 bit math will not overflow. */ 497 if (a > ULONG_MAX - (unsigned long)log_base || 498 a + (unsigned long)log_base > ULONG_MAX) 499 return 0; 500 501 return access_ok(VERIFY_WRITE, log_base + a, 502 (sz + VHOST_PAGE_SIZE * 8 - 1) / VHOST_PAGE_SIZE / 8); 503 } 504 505 /* Caller should have vq mutex and device mutex. */ 506 static int vq_memory_access_ok(void __user *log_base, struct vhost_memory *mem, 507 int log_all) 508 { 509 int i; 510 511 if (!mem) 512 return 0; 513 514 for (i = 0; i < mem->nregions; ++i) { 515 struct vhost_memory_region *m = mem->regions + i; 516 unsigned long a = m->userspace_addr; 517 if (m->memory_size > ULONG_MAX) 518 return 0; 519 else if (!access_ok(VERIFY_WRITE, (void __user *)a, 520 m->memory_size)) 521 return 0; 522 else if (log_all && !log_access_ok(log_base, 523 m->guest_phys_addr, 524 m->memory_size)) 525 return 0; 526 } 527 return 1; 528 } 529 530 /* Can we switch to this memory table? */ 531 /* Caller should have device mutex but not vq mutex */ 532 static int memory_access_ok(struct vhost_dev *d, struct vhost_memory *mem, 533 int log_all) 534 { 535 int i; 536 537 for (i = 0; i < d->nvqs; ++i) { 538 int ok; 539 mutex_lock(&d->vqs[i].mutex); 540 /* If ring is inactive, will check when it's enabled. */ 541 if (d->vqs[i].private_data) 542 ok = vq_memory_access_ok(d->vqs[i].log_base, mem, 543 log_all); 544 else 545 ok = 1; 546 mutex_unlock(&d->vqs[i].mutex); 547 if (!ok) 548 return 0; 549 } 550 return 1; 551 } 552 553 static int vq_access_ok(struct vhost_dev *d, unsigned int num, 554 struct vring_desc __user *desc, 555 struct vring_avail __user *avail, 556 struct vring_used __user *used) 557 { 558 size_t s = vhost_has_feature(d, VIRTIO_RING_F_EVENT_IDX) ? 2 : 0; 559 return access_ok(VERIFY_READ, desc, num * sizeof *desc) && 560 access_ok(VERIFY_READ, avail, 561 sizeof *avail + num * sizeof *avail->ring + s) && 562 access_ok(VERIFY_WRITE, used, 563 sizeof *used + num * sizeof *used->ring + s); 564 } 565 566 /* Can we log writes? */ 567 /* Caller should have device mutex but not vq mutex */ 568 int vhost_log_access_ok(struct vhost_dev *dev) 569 { 570 struct vhost_memory *mp; 571 572 mp = rcu_dereference_protected(dev->memory, 573 lockdep_is_held(&dev->mutex)); 574 return memory_access_ok(dev, mp, 1); 575 } 576 577 /* Verify access for write logging. */ 578 /* Caller should have vq mutex and device mutex */ 579 static int vq_log_access_ok(struct vhost_dev *d, struct vhost_virtqueue *vq, 580 void __user *log_base) 581 { 582 struct vhost_memory *mp; 583 size_t s = vhost_has_feature(d, VIRTIO_RING_F_EVENT_IDX) ? 2 : 0; 584 585 mp = rcu_dereference_protected(vq->dev->memory, 586 lockdep_is_held(&vq->mutex)); 587 return vq_memory_access_ok(log_base, mp, 588 vhost_has_feature(vq->dev, VHOST_F_LOG_ALL)) && 589 (!vq->log_used || log_access_ok(log_base, vq->log_addr, 590 sizeof *vq->used + 591 vq->num * sizeof *vq->used->ring + s)); 592 } 593 594 /* Can we start vq? */ 595 /* Caller should have vq mutex and device mutex */ 596 int vhost_vq_access_ok(struct vhost_virtqueue *vq) 597 { 598 return vq_access_ok(vq->dev, vq->num, vq->desc, vq->avail, vq->used) && 599 vq_log_access_ok(vq->dev, vq, vq->log_base); 600 } 601 602 static long vhost_set_memory(struct vhost_dev *d, struct vhost_memory __user *m) 603 { 604 struct vhost_memory mem, *newmem, *oldmem; 605 unsigned long size = offsetof(struct vhost_memory, regions); 606 607 if (copy_from_user(&mem, m, size)) 608 return -EFAULT; 609 if (mem.padding) 610 return -EOPNOTSUPP; 611 if (mem.nregions > VHOST_MEMORY_MAX_NREGIONS) 612 return -E2BIG; 613 newmem = kmalloc(size + mem.nregions * sizeof *m->regions, GFP_KERNEL); 614 if (!newmem) 615 return -ENOMEM; 616 617 memcpy(newmem, &mem, size); 618 if (copy_from_user(newmem->regions, m->regions, 619 mem.nregions * sizeof *m->regions)) { 620 kfree(newmem); 621 return -EFAULT; 622 } 623 624 if (!memory_access_ok(d, newmem, 625 vhost_has_feature(d, VHOST_F_LOG_ALL))) { 626 kfree(newmem); 627 return -EFAULT; 628 } 629 oldmem = rcu_dereference_protected(d->memory, 630 lockdep_is_held(&d->mutex)); 631 rcu_assign_pointer(d->memory, newmem); 632 synchronize_rcu(); 633 kfree(oldmem); 634 return 0; 635 } 636 637 static long vhost_set_vring(struct vhost_dev *d, int ioctl, void __user *argp) 638 { 639 struct file *eventfp, *filep = NULL, 640 *pollstart = NULL, *pollstop = NULL; 641 struct eventfd_ctx *ctx = NULL; 642 u32 __user *idxp = argp; 643 struct vhost_virtqueue *vq; 644 struct vhost_vring_state s; 645 struct vhost_vring_file f; 646 struct vhost_vring_addr a; 647 u32 idx; 648 long r; 649 650 r = get_user(idx, idxp); 651 if (r < 0) 652 return r; 653 if (idx >= d->nvqs) 654 return -ENOBUFS; 655 656 vq = d->vqs + idx; 657 658 mutex_lock(&vq->mutex); 659 660 switch (ioctl) { 661 case VHOST_SET_VRING_NUM: 662 /* Resizing ring with an active backend? 663 * You don't want to do that. */ 664 if (vq->private_data) { 665 r = -EBUSY; 666 break; 667 } 668 if (copy_from_user(&s, argp, sizeof s)) { 669 r = -EFAULT; 670 break; 671 } 672 if (!s.num || s.num > 0xffff || (s.num & (s.num - 1))) { 673 r = -EINVAL; 674 break; 675 } 676 vq->num = s.num; 677 break; 678 case VHOST_SET_VRING_BASE: 679 /* Moving base with an active backend? 680 * You don't want to do that. */ 681 if (vq->private_data) { 682 r = -EBUSY; 683 break; 684 } 685 if (copy_from_user(&s, argp, sizeof s)) { 686 r = -EFAULT; 687 break; 688 } 689 if (s.num > 0xffff) { 690 r = -EINVAL; 691 break; 692 } 693 vq->last_avail_idx = s.num; 694 /* Forget the cached index value. */ 695 vq->avail_idx = vq->last_avail_idx; 696 break; 697 case VHOST_GET_VRING_BASE: 698 s.index = idx; 699 s.num = vq->last_avail_idx; 700 if (copy_to_user(argp, &s, sizeof s)) 701 r = -EFAULT; 702 break; 703 case VHOST_SET_VRING_ADDR: 704 if (copy_from_user(&a, argp, sizeof a)) { 705 r = -EFAULT; 706 break; 707 } 708 if (a.flags & ~(0x1 << VHOST_VRING_F_LOG)) { 709 r = -EOPNOTSUPP; 710 break; 711 } 712 /* For 32bit, verify that the top 32bits of the user 713 data are set to zero. */ 714 if ((u64)(unsigned long)a.desc_user_addr != a.desc_user_addr || 715 (u64)(unsigned long)a.used_user_addr != a.used_user_addr || 716 (u64)(unsigned long)a.avail_user_addr != a.avail_user_addr) { 717 r = -EFAULT; 718 break; 719 } 720 if ((a.avail_user_addr & (sizeof *vq->avail->ring - 1)) || 721 (a.used_user_addr & (sizeof *vq->used->ring - 1)) || 722 (a.log_guest_addr & (sizeof *vq->used->ring - 1))) { 723 r = -EINVAL; 724 break; 725 } 726 727 /* We only verify access here if backend is configured. 728 * If it is not, we don't as size might not have been setup. 729 * We will verify when backend is configured. */ 730 if (vq->private_data) { 731 if (!vq_access_ok(d, vq->num, 732 (void __user *)(unsigned long)a.desc_user_addr, 733 (void __user *)(unsigned long)a.avail_user_addr, 734 (void __user *)(unsigned long)a.used_user_addr)) { 735 r = -EINVAL; 736 break; 737 } 738 739 /* Also validate log access for used ring if enabled. */ 740 if ((a.flags & (0x1 << VHOST_VRING_F_LOG)) && 741 !log_access_ok(vq->log_base, a.log_guest_addr, 742 sizeof *vq->used + 743 vq->num * sizeof *vq->used->ring)) { 744 r = -EINVAL; 745 break; 746 } 747 } 748 749 vq->log_used = !!(a.flags & (0x1 << VHOST_VRING_F_LOG)); 750 vq->desc = (void __user *)(unsigned long)a.desc_user_addr; 751 vq->avail = (void __user *)(unsigned long)a.avail_user_addr; 752 vq->log_addr = a.log_guest_addr; 753 vq->used = (void __user *)(unsigned long)a.used_user_addr; 754 break; 755 case VHOST_SET_VRING_KICK: 756 if (copy_from_user(&f, argp, sizeof f)) { 757 r = -EFAULT; 758 break; 759 } 760 eventfp = f.fd == -1 ? NULL : eventfd_fget(f.fd); 761 if (IS_ERR(eventfp)) { 762 r = PTR_ERR(eventfp); 763 break; 764 } 765 if (eventfp != vq->kick) { 766 pollstop = filep = vq->kick; 767 pollstart = vq->kick = eventfp; 768 } else 769 filep = eventfp; 770 break; 771 case VHOST_SET_VRING_CALL: 772 if (copy_from_user(&f, argp, sizeof f)) { 773 r = -EFAULT; 774 break; 775 } 776 eventfp = f.fd == -1 ? NULL : eventfd_fget(f.fd); 777 if (IS_ERR(eventfp)) { 778 r = PTR_ERR(eventfp); 779 break; 780 } 781 if (eventfp != vq->call) { 782 filep = vq->call; 783 ctx = vq->call_ctx; 784 vq->call = eventfp; 785 vq->call_ctx = eventfp ? 786 eventfd_ctx_fileget(eventfp) : NULL; 787 } else 788 filep = eventfp; 789 break; 790 case VHOST_SET_VRING_ERR: 791 if (copy_from_user(&f, argp, sizeof f)) { 792 r = -EFAULT; 793 break; 794 } 795 eventfp = f.fd == -1 ? NULL : eventfd_fget(f.fd); 796 if (IS_ERR(eventfp)) { 797 r = PTR_ERR(eventfp); 798 break; 799 } 800 if (eventfp != vq->error) { 801 filep = vq->error; 802 vq->error = eventfp; 803 ctx = vq->error_ctx; 804 vq->error_ctx = eventfp ? 805 eventfd_ctx_fileget(eventfp) : NULL; 806 } else 807 filep = eventfp; 808 break; 809 default: 810 r = -ENOIOCTLCMD; 811 } 812 813 if (pollstop && vq->handle_kick) 814 vhost_poll_stop(&vq->poll); 815 816 if (ctx) 817 eventfd_ctx_put(ctx); 818 if (filep) 819 fput(filep); 820 821 if (pollstart && vq->handle_kick) 822 vhost_poll_start(&vq->poll, vq->kick); 823 824 mutex_unlock(&vq->mutex); 825 826 if (pollstop && vq->handle_kick) 827 vhost_poll_flush(&vq->poll); 828 return r; 829 } 830 831 /* Caller must have device mutex */ 832 long vhost_dev_ioctl(struct vhost_dev *d, unsigned int ioctl, unsigned long arg) 833 { 834 void __user *argp = (void __user *)arg; 835 struct file *eventfp, *filep = NULL; 836 struct eventfd_ctx *ctx = NULL; 837 u64 p; 838 long r; 839 int i, fd; 840 841 /* If you are not the owner, you can become one */ 842 if (ioctl == VHOST_SET_OWNER) { 843 r = vhost_dev_set_owner(d); 844 goto done; 845 } 846 847 /* You must be the owner to do anything else */ 848 r = vhost_dev_check_owner(d); 849 if (r) 850 goto done; 851 852 switch (ioctl) { 853 case VHOST_SET_MEM_TABLE: 854 r = vhost_set_memory(d, argp); 855 break; 856 case VHOST_SET_LOG_BASE: 857 if (copy_from_user(&p, argp, sizeof p)) { 858 r = -EFAULT; 859 break; 860 } 861 if ((u64)(unsigned long)p != p) { 862 r = -EFAULT; 863 break; 864 } 865 for (i = 0; i < d->nvqs; ++i) { 866 struct vhost_virtqueue *vq; 867 void __user *base = (void __user *)(unsigned long)p; 868 vq = d->vqs + i; 869 mutex_lock(&vq->mutex); 870 /* If ring is inactive, will check when it's enabled. */ 871 if (vq->private_data && !vq_log_access_ok(d, vq, base)) 872 r = -EFAULT; 873 else 874 vq->log_base = base; 875 mutex_unlock(&vq->mutex); 876 } 877 break; 878 case VHOST_SET_LOG_FD: 879 r = get_user(fd, (int __user *)argp); 880 if (r < 0) 881 break; 882 eventfp = fd == -1 ? NULL : eventfd_fget(fd); 883 if (IS_ERR(eventfp)) { 884 r = PTR_ERR(eventfp); 885 break; 886 } 887 if (eventfp != d->log_file) { 888 filep = d->log_file; 889 ctx = d->log_ctx; 890 d->log_ctx = eventfp ? 891 eventfd_ctx_fileget(eventfp) : NULL; 892 } else 893 filep = eventfp; 894 for (i = 0; i < d->nvqs; ++i) { 895 mutex_lock(&d->vqs[i].mutex); 896 d->vqs[i].log_ctx = d->log_ctx; 897 mutex_unlock(&d->vqs[i].mutex); 898 } 899 if (ctx) 900 eventfd_ctx_put(ctx); 901 if (filep) 902 fput(filep); 903 break; 904 default: 905 r = vhost_set_vring(d, ioctl, argp); 906 break; 907 } 908 done: 909 return r; 910 } 911 912 static const struct vhost_memory_region *find_region(struct vhost_memory *mem, 913 __u64 addr, __u32 len) 914 { 915 struct vhost_memory_region *reg; 916 int i; 917 918 /* linear search is not brilliant, but we really have on the order of 6 919 * regions in practice */ 920 for (i = 0; i < mem->nregions; ++i) { 921 reg = mem->regions + i; 922 if (reg->guest_phys_addr <= addr && 923 reg->guest_phys_addr + reg->memory_size - 1 >= addr) 924 return reg; 925 } 926 return NULL; 927 } 928 929 /* TODO: This is really inefficient. We need something like get_user() 930 * (instruction directly accesses the data, with an exception table entry 931 * returning -EFAULT). See Documentation/x86/exception-tables.txt. 932 */ 933 static int set_bit_to_user(int nr, void __user *addr) 934 { 935 unsigned long log = (unsigned long)addr; 936 struct page *page; 937 void *base; 938 int bit = nr + (log % PAGE_SIZE) * 8; 939 int r; 940 941 r = get_user_pages_fast(log, 1, 1, &page); 942 if (r < 0) 943 return r; 944 BUG_ON(r != 1); 945 base = kmap_atomic(page); 946 set_bit(bit, base); 947 kunmap_atomic(base); 948 set_page_dirty_lock(page); 949 put_page(page); 950 return 0; 951 } 952 953 static int log_write(void __user *log_base, 954 u64 write_address, u64 write_length) 955 { 956 u64 write_page = write_address / VHOST_PAGE_SIZE; 957 int r; 958 959 if (!write_length) 960 return 0; 961 write_length += write_address % VHOST_PAGE_SIZE; 962 for (;;) { 963 u64 base = (u64)(unsigned long)log_base; 964 u64 log = base + write_page / 8; 965 int bit = write_page % 8; 966 if ((u64)(unsigned long)log != log) 967 return -EFAULT; 968 r = set_bit_to_user(bit, (void __user *)(unsigned long)log); 969 if (r < 0) 970 return r; 971 if (write_length <= VHOST_PAGE_SIZE) 972 break; 973 write_length -= VHOST_PAGE_SIZE; 974 write_page += 1; 975 } 976 return r; 977 } 978 979 int vhost_log_write(struct vhost_virtqueue *vq, struct vhost_log *log, 980 unsigned int log_num, u64 len) 981 { 982 int i, r; 983 984 /* Make sure data written is seen before log. */ 985 smp_wmb(); 986 for (i = 0; i < log_num; ++i) { 987 u64 l = min(log[i].len, len); 988 r = log_write(vq->log_base, log[i].addr, l); 989 if (r < 0) 990 return r; 991 len -= l; 992 if (!len) { 993 if (vq->log_ctx) 994 eventfd_signal(vq->log_ctx, 1); 995 return 0; 996 } 997 } 998 /* Length written exceeds what we have stored. This is a bug. */ 999 BUG(); 1000 return 0; 1001 } 1002 1003 static int vhost_update_used_flags(struct vhost_virtqueue *vq) 1004 { 1005 void __user *used; 1006 if (__put_user(vq->used_flags, &vq->used->flags) < 0) 1007 return -EFAULT; 1008 if (unlikely(vq->log_used)) { 1009 /* Make sure the flag is seen before log. */ 1010 smp_wmb(); 1011 /* Log used flag write. */ 1012 used = &vq->used->flags; 1013 log_write(vq->log_base, vq->log_addr + 1014 (used - (void __user *)vq->used), 1015 sizeof vq->used->flags); 1016 if (vq->log_ctx) 1017 eventfd_signal(vq->log_ctx, 1); 1018 } 1019 return 0; 1020 } 1021 1022 static int vhost_update_avail_event(struct vhost_virtqueue *vq, u16 avail_event) 1023 { 1024 if (__put_user(vq->avail_idx, vhost_avail_event(vq))) 1025 return -EFAULT; 1026 if (unlikely(vq->log_used)) { 1027 void __user *used; 1028 /* Make sure the event is seen before log. */ 1029 smp_wmb(); 1030 /* Log avail event write */ 1031 used = vhost_avail_event(vq); 1032 log_write(vq->log_base, vq->log_addr + 1033 (used - (void __user *)vq->used), 1034 sizeof *vhost_avail_event(vq)); 1035 if (vq->log_ctx) 1036 eventfd_signal(vq->log_ctx, 1); 1037 } 1038 return 0; 1039 } 1040 1041 int vhost_init_used(struct vhost_virtqueue *vq) 1042 { 1043 int r; 1044 if (!vq->private_data) 1045 return 0; 1046 1047 r = vhost_update_used_flags(vq); 1048 if (r) 1049 return r; 1050 vq->signalled_used_valid = false; 1051 return get_user(vq->last_used_idx, &vq->used->idx); 1052 } 1053 1054 static int translate_desc(struct vhost_dev *dev, u64 addr, u32 len, 1055 struct iovec iov[], int iov_size) 1056 { 1057 const struct vhost_memory_region *reg; 1058 struct vhost_memory *mem; 1059 struct iovec *_iov; 1060 u64 s = 0; 1061 int ret = 0; 1062 1063 rcu_read_lock(); 1064 1065 mem = rcu_dereference(dev->memory); 1066 while ((u64)len > s) { 1067 u64 size; 1068 if (unlikely(ret >= iov_size)) { 1069 ret = -ENOBUFS; 1070 break; 1071 } 1072 reg = find_region(mem, addr, len); 1073 if (unlikely(!reg)) { 1074 ret = -EFAULT; 1075 break; 1076 } 1077 _iov = iov + ret; 1078 size = reg->memory_size - addr + reg->guest_phys_addr; 1079 _iov->iov_len = min((u64)len, size); 1080 _iov->iov_base = (void __user *)(unsigned long) 1081 (reg->userspace_addr + addr - reg->guest_phys_addr); 1082 s += size; 1083 addr += size; 1084 ++ret; 1085 } 1086 1087 rcu_read_unlock(); 1088 return ret; 1089 } 1090 1091 /* Each buffer in the virtqueues is actually a chain of descriptors. This 1092 * function returns the next descriptor in the chain, 1093 * or -1U if we're at the end. */ 1094 static unsigned next_desc(struct vring_desc *desc) 1095 { 1096 unsigned int next; 1097 1098 /* If this descriptor says it doesn't chain, we're done. */ 1099 if (!(desc->flags & VRING_DESC_F_NEXT)) 1100 return -1U; 1101 1102 /* Check they're not leading us off end of descriptors. */ 1103 next = desc->next; 1104 /* Make sure compiler knows to grab that: we don't want it changing! */ 1105 /* We will use the result as an index in an array, so most 1106 * architectures only need a compiler barrier here. */ 1107 read_barrier_depends(); 1108 1109 return next; 1110 } 1111 1112 static int get_indirect(struct vhost_dev *dev, struct vhost_virtqueue *vq, 1113 struct iovec iov[], unsigned int iov_size, 1114 unsigned int *out_num, unsigned int *in_num, 1115 struct vhost_log *log, unsigned int *log_num, 1116 struct vring_desc *indirect) 1117 { 1118 struct vring_desc desc; 1119 unsigned int i = 0, count, found = 0; 1120 int ret; 1121 1122 /* Sanity check */ 1123 if (unlikely(indirect->len % sizeof desc)) { 1124 vq_err(vq, "Invalid length in indirect descriptor: " 1125 "len 0x%llx not multiple of 0x%zx\n", 1126 (unsigned long long)indirect->len, 1127 sizeof desc); 1128 return -EINVAL; 1129 } 1130 1131 ret = translate_desc(dev, indirect->addr, indirect->len, vq->indirect, 1132 UIO_MAXIOV); 1133 if (unlikely(ret < 0)) { 1134 vq_err(vq, "Translation failure %d in indirect.\n", ret); 1135 return ret; 1136 } 1137 1138 /* We will use the result as an address to read from, so most 1139 * architectures only need a compiler barrier here. */ 1140 read_barrier_depends(); 1141 1142 count = indirect->len / sizeof desc; 1143 /* Buffers are chained via a 16 bit next field, so 1144 * we can have at most 2^16 of these. */ 1145 if (unlikely(count > USHRT_MAX + 1)) { 1146 vq_err(vq, "Indirect buffer length too big: %d\n", 1147 indirect->len); 1148 return -E2BIG; 1149 } 1150 1151 do { 1152 unsigned iov_count = *in_num + *out_num; 1153 if (unlikely(++found > count)) { 1154 vq_err(vq, "Loop detected: last one at %u " 1155 "indirect size %u\n", 1156 i, count); 1157 return -EINVAL; 1158 } 1159 if (unlikely(memcpy_fromiovec((unsigned char *)&desc, 1160 vq->indirect, sizeof desc))) { 1161 vq_err(vq, "Failed indirect descriptor: idx %d, %zx\n", 1162 i, (size_t)indirect->addr + i * sizeof desc); 1163 return -EINVAL; 1164 } 1165 if (unlikely(desc.flags & VRING_DESC_F_INDIRECT)) { 1166 vq_err(vq, "Nested indirect descriptor: idx %d, %zx\n", 1167 i, (size_t)indirect->addr + i * sizeof desc); 1168 return -EINVAL; 1169 } 1170 1171 ret = translate_desc(dev, desc.addr, desc.len, iov + iov_count, 1172 iov_size - iov_count); 1173 if (unlikely(ret < 0)) { 1174 vq_err(vq, "Translation failure %d indirect idx %d\n", 1175 ret, i); 1176 return ret; 1177 } 1178 /* If this is an input descriptor, increment that count. */ 1179 if (desc.flags & VRING_DESC_F_WRITE) { 1180 *in_num += ret; 1181 if (unlikely(log)) { 1182 log[*log_num].addr = desc.addr; 1183 log[*log_num].len = desc.len; 1184 ++*log_num; 1185 } 1186 } else { 1187 /* If it's an output descriptor, they're all supposed 1188 * to come before any input descriptors. */ 1189 if (unlikely(*in_num)) { 1190 vq_err(vq, "Indirect descriptor " 1191 "has out after in: idx %d\n", i); 1192 return -EINVAL; 1193 } 1194 *out_num += ret; 1195 } 1196 } while ((i = next_desc(&desc)) != -1); 1197 return 0; 1198 } 1199 1200 /* This looks in the virtqueue and for the first available buffer, and converts 1201 * it to an iovec for convenient access. Since descriptors consist of some 1202 * number of output then some number of input descriptors, it's actually two 1203 * iovecs, but we pack them into one and note how many of each there were. 1204 * 1205 * This function returns the descriptor number found, or vq->num (which is 1206 * never a valid descriptor number) if none was found. A negative code is 1207 * returned on error. */ 1208 int vhost_get_vq_desc(struct vhost_dev *dev, struct vhost_virtqueue *vq, 1209 struct iovec iov[], unsigned int iov_size, 1210 unsigned int *out_num, unsigned int *in_num, 1211 struct vhost_log *log, unsigned int *log_num) 1212 { 1213 struct vring_desc desc; 1214 unsigned int i, head, found = 0; 1215 u16 last_avail_idx; 1216 int ret; 1217 1218 /* Check it isn't doing very strange things with descriptor numbers. */ 1219 last_avail_idx = vq->last_avail_idx; 1220 if (unlikely(__get_user(vq->avail_idx, &vq->avail->idx))) { 1221 vq_err(vq, "Failed to access avail idx at %p\n", 1222 &vq->avail->idx); 1223 return -EFAULT; 1224 } 1225 1226 if (unlikely((u16)(vq->avail_idx - last_avail_idx) > vq->num)) { 1227 vq_err(vq, "Guest moved used index from %u to %u", 1228 last_avail_idx, vq->avail_idx); 1229 return -EFAULT; 1230 } 1231 1232 /* If there's nothing new since last we looked, return invalid. */ 1233 if (vq->avail_idx == last_avail_idx) 1234 return vq->num; 1235 1236 /* Only get avail ring entries after they have been exposed by guest. */ 1237 smp_rmb(); 1238 1239 /* Grab the next descriptor number they're advertising, and increment 1240 * the index we've seen. */ 1241 if (unlikely(__get_user(head, 1242 &vq->avail->ring[last_avail_idx % vq->num]))) { 1243 vq_err(vq, "Failed to read head: idx %d address %p\n", 1244 last_avail_idx, 1245 &vq->avail->ring[last_avail_idx % vq->num]); 1246 return -EFAULT; 1247 } 1248 1249 /* If their number is silly, that's an error. */ 1250 if (unlikely(head >= vq->num)) { 1251 vq_err(vq, "Guest says index %u > %u is available", 1252 head, vq->num); 1253 return -EINVAL; 1254 } 1255 1256 /* When we start there are none of either input nor output. */ 1257 *out_num = *in_num = 0; 1258 if (unlikely(log)) 1259 *log_num = 0; 1260 1261 i = head; 1262 do { 1263 unsigned iov_count = *in_num + *out_num; 1264 if (unlikely(i >= vq->num)) { 1265 vq_err(vq, "Desc index is %u > %u, head = %u", 1266 i, vq->num, head); 1267 return -EINVAL; 1268 } 1269 if (unlikely(++found > vq->num)) { 1270 vq_err(vq, "Loop detected: last one at %u " 1271 "vq size %u head %u\n", 1272 i, vq->num, head); 1273 return -EINVAL; 1274 } 1275 ret = __copy_from_user(&desc, vq->desc + i, sizeof desc); 1276 if (unlikely(ret)) { 1277 vq_err(vq, "Failed to get descriptor: idx %d addr %p\n", 1278 i, vq->desc + i); 1279 return -EFAULT; 1280 } 1281 if (desc.flags & VRING_DESC_F_INDIRECT) { 1282 ret = get_indirect(dev, vq, iov, iov_size, 1283 out_num, in_num, 1284 log, log_num, &desc); 1285 if (unlikely(ret < 0)) { 1286 vq_err(vq, "Failure detected " 1287 "in indirect descriptor at idx %d\n", i); 1288 return ret; 1289 } 1290 continue; 1291 } 1292 1293 ret = translate_desc(dev, desc.addr, desc.len, iov + iov_count, 1294 iov_size - iov_count); 1295 if (unlikely(ret < 0)) { 1296 vq_err(vq, "Translation failure %d descriptor idx %d\n", 1297 ret, i); 1298 return ret; 1299 } 1300 if (desc.flags & VRING_DESC_F_WRITE) { 1301 /* If this is an input descriptor, 1302 * increment that count. */ 1303 *in_num += ret; 1304 if (unlikely(log)) { 1305 log[*log_num].addr = desc.addr; 1306 log[*log_num].len = desc.len; 1307 ++*log_num; 1308 } 1309 } else { 1310 /* If it's an output descriptor, they're all supposed 1311 * to come before any input descriptors. */ 1312 if (unlikely(*in_num)) { 1313 vq_err(vq, "Descriptor has out after in: " 1314 "idx %d\n", i); 1315 return -EINVAL; 1316 } 1317 *out_num += ret; 1318 } 1319 } while ((i = next_desc(&desc)) != -1); 1320 1321 /* On success, increment avail index. */ 1322 vq->last_avail_idx++; 1323 1324 /* Assume notifications from guest are disabled at this point, 1325 * if they aren't we would need to update avail_event index. */ 1326 BUG_ON(!(vq->used_flags & VRING_USED_F_NO_NOTIFY)); 1327 return head; 1328 } 1329 1330 /* Reverse the effect of vhost_get_vq_desc. Useful for error handling. */ 1331 void vhost_discard_vq_desc(struct vhost_virtqueue *vq, int n) 1332 { 1333 vq->last_avail_idx -= n; 1334 } 1335 1336 /* After we've used one of their buffers, we tell them about it. We'll then 1337 * want to notify the guest, using eventfd. */ 1338 int vhost_add_used(struct vhost_virtqueue *vq, unsigned int head, int len) 1339 { 1340 struct vring_used_elem __user *used; 1341 1342 /* The virtqueue contains a ring of used buffers. Get a pointer to the 1343 * next entry in that used ring. */ 1344 used = &vq->used->ring[vq->last_used_idx % vq->num]; 1345 if (__put_user(head, &used->id)) { 1346 vq_err(vq, "Failed to write used id"); 1347 return -EFAULT; 1348 } 1349 if (__put_user(len, &used->len)) { 1350 vq_err(vq, "Failed to write used len"); 1351 return -EFAULT; 1352 } 1353 /* Make sure buffer is written before we update index. */ 1354 smp_wmb(); 1355 if (__put_user(vq->last_used_idx + 1, &vq->used->idx)) { 1356 vq_err(vq, "Failed to increment used idx"); 1357 return -EFAULT; 1358 } 1359 if (unlikely(vq->log_used)) { 1360 /* Make sure data is seen before log. */ 1361 smp_wmb(); 1362 /* Log used ring entry write. */ 1363 log_write(vq->log_base, 1364 vq->log_addr + 1365 ((void __user *)used - (void __user *)vq->used), 1366 sizeof *used); 1367 /* Log used index update. */ 1368 log_write(vq->log_base, 1369 vq->log_addr + offsetof(struct vring_used, idx), 1370 sizeof vq->used->idx); 1371 if (vq->log_ctx) 1372 eventfd_signal(vq->log_ctx, 1); 1373 } 1374 vq->last_used_idx++; 1375 /* If the driver never bothers to signal in a very long while, 1376 * used index might wrap around. If that happens, invalidate 1377 * signalled_used index we stored. TODO: make sure driver 1378 * signals at least once in 2^16 and remove this. */ 1379 if (unlikely(vq->last_used_idx == vq->signalled_used)) 1380 vq->signalled_used_valid = false; 1381 return 0; 1382 } 1383 1384 static int __vhost_add_used_n(struct vhost_virtqueue *vq, 1385 struct vring_used_elem *heads, 1386 unsigned count) 1387 { 1388 struct vring_used_elem __user *used; 1389 u16 old, new; 1390 int start; 1391 1392 start = vq->last_used_idx % vq->num; 1393 used = vq->used->ring + start; 1394 if (__copy_to_user(used, heads, count * sizeof *used)) { 1395 vq_err(vq, "Failed to write used"); 1396 return -EFAULT; 1397 } 1398 if (unlikely(vq->log_used)) { 1399 /* Make sure data is seen before log. */ 1400 smp_wmb(); 1401 /* Log used ring entry write. */ 1402 log_write(vq->log_base, 1403 vq->log_addr + 1404 ((void __user *)used - (void __user *)vq->used), 1405 count * sizeof *used); 1406 } 1407 old = vq->last_used_idx; 1408 new = (vq->last_used_idx += count); 1409 /* If the driver never bothers to signal in a very long while, 1410 * used index might wrap around. If that happens, invalidate 1411 * signalled_used index we stored. TODO: make sure driver 1412 * signals at least once in 2^16 and remove this. */ 1413 if (unlikely((u16)(new - vq->signalled_used) < (u16)(new - old))) 1414 vq->signalled_used_valid = false; 1415 return 0; 1416 } 1417 1418 /* After we've used one of their buffers, we tell them about it. We'll then 1419 * want to notify the guest, using eventfd. */ 1420 int vhost_add_used_n(struct vhost_virtqueue *vq, struct vring_used_elem *heads, 1421 unsigned count) 1422 { 1423 int start, n, r; 1424 1425 start = vq->last_used_idx % vq->num; 1426 n = vq->num - start; 1427 if (n < count) { 1428 r = __vhost_add_used_n(vq, heads, n); 1429 if (r < 0) 1430 return r; 1431 heads += n; 1432 count -= n; 1433 } 1434 r = __vhost_add_used_n(vq, heads, count); 1435 1436 /* Make sure buffer is written before we update index. */ 1437 smp_wmb(); 1438 if (put_user(vq->last_used_idx, &vq->used->idx)) { 1439 vq_err(vq, "Failed to increment used idx"); 1440 return -EFAULT; 1441 } 1442 if (unlikely(vq->log_used)) { 1443 /* Log used index update. */ 1444 log_write(vq->log_base, 1445 vq->log_addr + offsetof(struct vring_used, idx), 1446 sizeof vq->used->idx); 1447 if (vq->log_ctx) 1448 eventfd_signal(vq->log_ctx, 1); 1449 } 1450 return r; 1451 } 1452 1453 static bool vhost_notify(struct vhost_dev *dev, struct vhost_virtqueue *vq) 1454 { 1455 __u16 old, new, event; 1456 bool v; 1457 /* Flush out used index updates. This is paired 1458 * with the barrier that the Guest executes when enabling 1459 * interrupts. */ 1460 smp_mb(); 1461 1462 if (vhost_has_feature(dev, VIRTIO_F_NOTIFY_ON_EMPTY) && 1463 unlikely(vq->avail_idx == vq->last_avail_idx)) 1464 return true; 1465 1466 if (!vhost_has_feature(dev, VIRTIO_RING_F_EVENT_IDX)) { 1467 __u16 flags; 1468 if (__get_user(flags, &vq->avail->flags)) { 1469 vq_err(vq, "Failed to get flags"); 1470 return true; 1471 } 1472 return !(flags & VRING_AVAIL_F_NO_INTERRUPT); 1473 } 1474 old = vq->signalled_used; 1475 v = vq->signalled_used_valid; 1476 new = vq->signalled_used = vq->last_used_idx; 1477 vq->signalled_used_valid = true; 1478 1479 if (unlikely(!v)) 1480 return true; 1481 1482 if (get_user(event, vhost_used_event(vq))) { 1483 vq_err(vq, "Failed to get used event idx"); 1484 return true; 1485 } 1486 return vring_need_event(event, new, old); 1487 } 1488 1489 /* This actually signals the guest, using eventfd. */ 1490 void vhost_signal(struct vhost_dev *dev, struct vhost_virtqueue *vq) 1491 { 1492 /* Signal the Guest tell them we used something up. */ 1493 if (vq->call_ctx && vhost_notify(dev, vq)) 1494 eventfd_signal(vq->call_ctx, 1); 1495 } 1496 1497 /* And here's the combo meal deal. Supersize me! */ 1498 void vhost_add_used_and_signal(struct vhost_dev *dev, 1499 struct vhost_virtqueue *vq, 1500 unsigned int head, int len) 1501 { 1502 vhost_add_used(vq, head, len); 1503 vhost_signal(dev, vq); 1504 } 1505 1506 /* multi-buffer version of vhost_add_used_and_signal */ 1507 void vhost_add_used_and_signal_n(struct vhost_dev *dev, 1508 struct vhost_virtqueue *vq, 1509 struct vring_used_elem *heads, unsigned count) 1510 { 1511 vhost_add_used_n(vq, heads, count); 1512 vhost_signal(dev, vq); 1513 } 1514 1515 /* OK, now we need to know about added descriptors. */ 1516 bool vhost_enable_notify(struct vhost_dev *dev, struct vhost_virtqueue *vq) 1517 { 1518 u16 avail_idx; 1519 int r; 1520 1521 if (!(vq->used_flags & VRING_USED_F_NO_NOTIFY)) 1522 return false; 1523 vq->used_flags &= ~VRING_USED_F_NO_NOTIFY; 1524 if (!vhost_has_feature(dev, VIRTIO_RING_F_EVENT_IDX)) { 1525 r = vhost_update_used_flags(vq); 1526 if (r) { 1527 vq_err(vq, "Failed to enable notification at %p: %d\n", 1528 &vq->used->flags, r); 1529 return false; 1530 } 1531 } else { 1532 r = vhost_update_avail_event(vq, vq->avail_idx); 1533 if (r) { 1534 vq_err(vq, "Failed to update avail event index at %p: %d\n", 1535 vhost_avail_event(vq), r); 1536 return false; 1537 } 1538 } 1539 /* They could have slipped one in as we were doing that: make 1540 * sure it's written, then check again. */ 1541 smp_mb(); 1542 r = __get_user(avail_idx, &vq->avail->idx); 1543 if (r) { 1544 vq_err(vq, "Failed to check avail idx at %p: %d\n", 1545 &vq->avail->idx, r); 1546 return false; 1547 } 1548 1549 return avail_idx != vq->avail_idx; 1550 } 1551 1552 /* We don't need to be notified again. */ 1553 void vhost_disable_notify(struct vhost_dev *dev, struct vhost_virtqueue *vq) 1554 { 1555 int r; 1556 1557 if (vq->used_flags & VRING_USED_F_NO_NOTIFY) 1558 return; 1559 vq->used_flags |= VRING_USED_F_NO_NOTIFY; 1560 if (!vhost_has_feature(dev, VIRTIO_RING_F_EVENT_IDX)) { 1561 r = vhost_update_used_flags(vq); 1562 if (r) 1563 vq_err(vq, "Failed to enable notification at %p: %d\n", 1564 &vq->used->flags, r); 1565 } 1566 } 1567 1568 static void vhost_zerocopy_done_signal(struct kref *kref) 1569 { 1570 struct vhost_ubuf_ref *ubufs = container_of(kref, struct vhost_ubuf_ref, 1571 kref); 1572 wake_up(&ubufs->wait); 1573 } 1574 1575 struct vhost_ubuf_ref *vhost_ubuf_alloc(struct vhost_virtqueue *vq, 1576 bool zcopy) 1577 { 1578 struct vhost_ubuf_ref *ubufs; 1579 /* No zero copy backend? Nothing to count. */ 1580 if (!zcopy) 1581 return NULL; 1582 ubufs = kmalloc(sizeof *ubufs, GFP_KERNEL); 1583 if (!ubufs) 1584 return ERR_PTR(-ENOMEM); 1585 kref_init(&ubufs->kref); 1586 init_waitqueue_head(&ubufs->wait); 1587 ubufs->vq = vq; 1588 return ubufs; 1589 } 1590 1591 void vhost_ubuf_put(struct vhost_ubuf_ref *ubufs) 1592 { 1593 kref_put(&ubufs->kref, vhost_zerocopy_done_signal); 1594 } 1595 1596 void vhost_ubuf_put_and_wait(struct vhost_ubuf_ref *ubufs) 1597 { 1598 kref_put(&ubufs->kref, vhost_zerocopy_done_signal); 1599 wait_event(ubufs->wait, !atomic_read(&ubufs->kref.refcount)); 1600 kfree(ubufs); 1601 } 1602 1603 void vhost_zerocopy_callback(struct ubuf_info *ubuf) 1604 { 1605 struct vhost_ubuf_ref *ubufs = ubuf->ctx; 1606 struct vhost_virtqueue *vq = ubufs->vq; 1607 1608 vhost_poll_queue(&vq->poll); 1609 /* set len = 1 to mark this desc buffers done DMA */ 1610 vq->heads[ubuf->desc].len = VHOST_DMA_DONE_LEN; 1611 kref_put(&ubufs->kref, vhost_zerocopy_done_signal); 1612 } 1613