1 // SPDX-License-Identifier: GPL-2.0-only 2 /* 3 * vhost transport for vsock 4 * 5 * Copyright (C) 2013-2015 Red Hat, Inc. 6 * Author: Asias He <asias@redhat.com> 7 * Stefan Hajnoczi <stefanha@redhat.com> 8 */ 9 #include <linux/miscdevice.h> 10 #include <linux/atomic.h> 11 #include <linux/module.h> 12 #include <linux/mutex.h> 13 #include <linux/vmalloc.h> 14 #include <net/sock.h> 15 #include <linux/virtio_vsock.h> 16 #include <linux/vhost.h> 17 #include <linux/hashtable.h> 18 19 #include <net/af_vsock.h> 20 #include "vhost.h" 21 22 #define VHOST_VSOCK_DEFAULT_HOST_CID 2 23 /* Max number of bytes transferred before requeueing the job. 24 * Using this limit prevents one virtqueue from starving others. */ 25 #define VHOST_VSOCK_WEIGHT 0x80000 26 /* Max number of packets transferred before requeueing the job. 27 * Using this limit prevents one virtqueue from starving others with 28 * small pkts. 29 */ 30 #define VHOST_VSOCK_PKT_WEIGHT 256 31 32 enum { 33 VHOST_VSOCK_FEATURES = VHOST_FEATURES, 34 }; 35 36 /* Used to track all the vhost_vsock instances on the system. */ 37 static DEFINE_MUTEX(vhost_vsock_mutex); 38 static DEFINE_READ_MOSTLY_HASHTABLE(vhost_vsock_hash, 8); 39 40 struct vhost_vsock { 41 struct vhost_dev dev; 42 struct vhost_virtqueue vqs[2]; 43 44 /* Link to global vhost_vsock_hash, writes use vhost_vsock_mutex */ 45 struct hlist_node hash; 46 47 struct vhost_work send_pkt_work; 48 spinlock_t send_pkt_list_lock; 49 struct list_head send_pkt_list; /* host->guest pending packets */ 50 51 atomic_t queued_replies; 52 53 u32 guest_cid; 54 }; 55 56 static u32 vhost_transport_get_local_cid(void) 57 { 58 return VHOST_VSOCK_DEFAULT_HOST_CID; 59 } 60 61 /* Callers that dereference the return value must hold vhost_vsock_mutex or the 62 * RCU read lock. 63 */ 64 static struct vhost_vsock *vhost_vsock_get(u32 guest_cid) 65 { 66 struct vhost_vsock *vsock; 67 68 hash_for_each_possible_rcu(vhost_vsock_hash, vsock, hash, guest_cid) { 69 u32 other_cid = vsock->guest_cid; 70 71 /* Skip instances that have no CID yet */ 72 if (other_cid == 0) 73 continue; 74 75 if (other_cid == guest_cid) 76 return vsock; 77 78 } 79 80 return NULL; 81 } 82 83 static void 84 vhost_transport_do_send_pkt(struct vhost_vsock *vsock, 85 struct vhost_virtqueue *vq) 86 { 87 struct vhost_virtqueue *tx_vq = &vsock->vqs[VSOCK_VQ_TX]; 88 int pkts = 0, total_len = 0; 89 bool added = false; 90 bool restart_tx = false; 91 92 mutex_lock(&vq->mutex); 93 94 if (!vhost_vq_get_backend(vq)) 95 goto out; 96 97 /* Avoid further vmexits, we're already processing the virtqueue */ 98 vhost_disable_notify(&vsock->dev, vq); 99 100 do { 101 struct virtio_vsock_pkt *pkt; 102 struct iov_iter iov_iter; 103 unsigned out, in; 104 size_t nbytes; 105 size_t iov_len, payload_len; 106 int head; 107 108 spin_lock_bh(&vsock->send_pkt_list_lock); 109 if (list_empty(&vsock->send_pkt_list)) { 110 spin_unlock_bh(&vsock->send_pkt_list_lock); 111 vhost_enable_notify(&vsock->dev, vq); 112 break; 113 } 114 115 pkt = list_first_entry(&vsock->send_pkt_list, 116 struct virtio_vsock_pkt, list); 117 list_del_init(&pkt->list); 118 spin_unlock_bh(&vsock->send_pkt_list_lock); 119 120 head = vhost_get_vq_desc(vq, vq->iov, ARRAY_SIZE(vq->iov), 121 &out, &in, NULL, NULL); 122 if (head < 0) { 123 spin_lock_bh(&vsock->send_pkt_list_lock); 124 list_add(&pkt->list, &vsock->send_pkt_list); 125 spin_unlock_bh(&vsock->send_pkt_list_lock); 126 break; 127 } 128 129 if (head == vq->num) { 130 spin_lock_bh(&vsock->send_pkt_list_lock); 131 list_add(&pkt->list, &vsock->send_pkt_list); 132 spin_unlock_bh(&vsock->send_pkt_list_lock); 133 134 /* We cannot finish yet if more buffers snuck in while 135 * re-enabling notify. 136 */ 137 if (unlikely(vhost_enable_notify(&vsock->dev, vq))) { 138 vhost_disable_notify(&vsock->dev, vq); 139 continue; 140 } 141 break; 142 } 143 144 if (out) { 145 virtio_transport_free_pkt(pkt); 146 vq_err(vq, "Expected 0 output buffers, got %u\n", out); 147 break; 148 } 149 150 iov_len = iov_length(&vq->iov[out], in); 151 if (iov_len < sizeof(pkt->hdr)) { 152 virtio_transport_free_pkt(pkt); 153 vq_err(vq, "Buffer len [%zu] too small\n", iov_len); 154 break; 155 } 156 157 iov_iter_init(&iov_iter, READ, &vq->iov[out], in, iov_len); 158 payload_len = pkt->len - pkt->off; 159 160 /* If the packet is greater than the space available in the 161 * buffer, we split it using multiple buffers. 162 */ 163 if (payload_len > iov_len - sizeof(pkt->hdr)) 164 payload_len = iov_len - sizeof(pkt->hdr); 165 166 /* Set the correct length in the header */ 167 pkt->hdr.len = cpu_to_le32(payload_len); 168 169 nbytes = copy_to_iter(&pkt->hdr, sizeof(pkt->hdr), &iov_iter); 170 if (nbytes != sizeof(pkt->hdr)) { 171 virtio_transport_free_pkt(pkt); 172 vq_err(vq, "Faulted on copying pkt hdr\n"); 173 break; 174 } 175 176 nbytes = copy_to_iter(pkt->buf + pkt->off, payload_len, 177 &iov_iter); 178 if (nbytes != payload_len) { 179 virtio_transport_free_pkt(pkt); 180 vq_err(vq, "Faulted on copying pkt buf\n"); 181 break; 182 } 183 184 /* Deliver to monitoring devices all packets that we 185 * will transmit. 186 */ 187 virtio_transport_deliver_tap_pkt(pkt); 188 189 vhost_add_used(vq, head, sizeof(pkt->hdr) + payload_len); 190 added = true; 191 192 pkt->off += payload_len; 193 total_len += payload_len; 194 195 /* If we didn't send all the payload we can requeue the packet 196 * to send it with the next available buffer. 197 */ 198 if (pkt->off < pkt->len) { 199 /* We are queueing the same virtio_vsock_pkt to handle 200 * the remaining bytes, and we want to deliver it 201 * to monitoring devices in the next iteration. 202 */ 203 pkt->tap_delivered = false; 204 205 spin_lock_bh(&vsock->send_pkt_list_lock); 206 list_add(&pkt->list, &vsock->send_pkt_list); 207 spin_unlock_bh(&vsock->send_pkt_list_lock); 208 } else { 209 if (pkt->reply) { 210 int val; 211 212 val = atomic_dec_return(&vsock->queued_replies); 213 214 /* Do we have resources to resume tx 215 * processing? 216 */ 217 if (val + 1 == tx_vq->num) 218 restart_tx = true; 219 } 220 221 virtio_transport_free_pkt(pkt); 222 } 223 } while(likely(!vhost_exceeds_weight(vq, ++pkts, total_len))); 224 if (added) 225 vhost_signal(&vsock->dev, vq); 226 227 out: 228 mutex_unlock(&vq->mutex); 229 230 if (restart_tx) 231 vhost_poll_queue(&tx_vq->poll); 232 } 233 234 static void vhost_transport_send_pkt_work(struct vhost_work *work) 235 { 236 struct vhost_virtqueue *vq; 237 struct vhost_vsock *vsock; 238 239 vsock = container_of(work, struct vhost_vsock, send_pkt_work); 240 vq = &vsock->vqs[VSOCK_VQ_RX]; 241 242 vhost_transport_do_send_pkt(vsock, vq); 243 } 244 245 static int 246 vhost_transport_send_pkt(struct virtio_vsock_pkt *pkt) 247 { 248 struct vhost_vsock *vsock; 249 int len = pkt->len; 250 251 rcu_read_lock(); 252 253 /* Find the vhost_vsock according to guest context id */ 254 vsock = vhost_vsock_get(le64_to_cpu(pkt->hdr.dst_cid)); 255 if (!vsock) { 256 rcu_read_unlock(); 257 virtio_transport_free_pkt(pkt); 258 return -ENODEV; 259 } 260 261 if (pkt->reply) 262 atomic_inc(&vsock->queued_replies); 263 264 spin_lock_bh(&vsock->send_pkt_list_lock); 265 list_add_tail(&pkt->list, &vsock->send_pkt_list); 266 spin_unlock_bh(&vsock->send_pkt_list_lock); 267 268 vhost_work_queue(&vsock->dev, &vsock->send_pkt_work); 269 270 rcu_read_unlock(); 271 return len; 272 } 273 274 static int 275 vhost_transport_cancel_pkt(struct vsock_sock *vsk) 276 { 277 struct vhost_vsock *vsock; 278 struct virtio_vsock_pkt *pkt, *n; 279 int cnt = 0; 280 int ret = -ENODEV; 281 LIST_HEAD(freeme); 282 283 rcu_read_lock(); 284 285 /* Find the vhost_vsock according to guest context id */ 286 vsock = vhost_vsock_get(vsk->remote_addr.svm_cid); 287 if (!vsock) 288 goto out; 289 290 spin_lock_bh(&vsock->send_pkt_list_lock); 291 list_for_each_entry_safe(pkt, n, &vsock->send_pkt_list, list) { 292 if (pkt->vsk != vsk) 293 continue; 294 list_move(&pkt->list, &freeme); 295 } 296 spin_unlock_bh(&vsock->send_pkt_list_lock); 297 298 list_for_each_entry_safe(pkt, n, &freeme, list) { 299 if (pkt->reply) 300 cnt++; 301 list_del(&pkt->list); 302 virtio_transport_free_pkt(pkt); 303 } 304 305 if (cnt) { 306 struct vhost_virtqueue *tx_vq = &vsock->vqs[VSOCK_VQ_TX]; 307 int new_cnt; 308 309 new_cnt = atomic_sub_return(cnt, &vsock->queued_replies); 310 if (new_cnt + cnt >= tx_vq->num && new_cnt < tx_vq->num) 311 vhost_poll_queue(&tx_vq->poll); 312 } 313 314 ret = 0; 315 out: 316 rcu_read_unlock(); 317 return ret; 318 } 319 320 static struct virtio_vsock_pkt * 321 vhost_vsock_alloc_pkt(struct vhost_virtqueue *vq, 322 unsigned int out, unsigned int in) 323 { 324 struct virtio_vsock_pkt *pkt; 325 struct iov_iter iov_iter; 326 size_t nbytes; 327 size_t len; 328 329 if (in != 0) { 330 vq_err(vq, "Expected 0 input buffers, got %u\n", in); 331 return NULL; 332 } 333 334 pkt = kzalloc(sizeof(*pkt), GFP_KERNEL); 335 if (!pkt) 336 return NULL; 337 338 len = iov_length(vq->iov, out); 339 iov_iter_init(&iov_iter, WRITE, vq->iov, out, len); 340 341 nbytes = copy_from_iter(&pkt->hdr, sizeof(pkt->hdr), &iov_iter); 342 if (nbytes != sizeof(pkt->hdr)) { 343 vq_err(vq, "Expected %zu bytes for pkt->hdr, got %zu bytes\n", 344 sizeof(pkt->hdr), nbytes); 345 kfree(pkt); 346 return NULL; 347 } 348 349 if (le16_to_cpu(pkt->hdr.type) == VIRTIO_VSOCK_TYPE_STREAM) 350 pkt->len = le32_to_cpu(pkt->hdr.len); 351 352 /* No payload */ 353 if (!pkt->len) 354 return pkt; 355 356 /* The pkt is too big */ 357 if (pkt->len > VIRTIO_VSOCK_MAX_PKT_BUF_SIZE) { 358 kfree(pkt); 359 return NULL; 360 } 361 362 pkt->buf = kmalloc(pkt->len, GFP_KERNEL); 363 if (!pkt->buf) { 364 kfree(pkt); 365 return NULL; 366 } 367 368 pkt->buf_len = pkt->len; 369 370 nbytes = copy_from_iter(pkt->buf, pkt->len, &iov_iter); 371 if (nbytes != pkt->len) { 372 vq_err(vq, "Expected %u byte payload, got %zu bytes\n", 373 pkt->len, nbytes); 374 virtio_transport_free_pkt(pkt); 375 return NULL; 376 } 377 378 return pkt; 379 } 380 381 /* Is there space left for replies to rx packets? */ 382 static bool vhost_vsock_more_replies(struct vhost_vsock *vsock) 383 { 384 struct vhost_virtqueue *vq = &vsock->vqs[VSOCK_VQ_TX]; 385 int val; 386 387 smp_rmb(); /* paired with atomic_inc() and atomic_dec_return() */ 388 val = atomic_read(&vsock->queued_replies); 389 390 return val < vq->num; 391 } 392 393 static struct virtio_transport vhost_transport = { 394 .transport = { 395 .module = THIS_MODULE, 396 397 .get_local_cid = vhost_transport_get_local_cid, 398 399 .init = virtio_transport_do_socket_init, 400 .destruct = virtio_transport_destruct, 401 .release = virtio_transport_release, 402 .connect = virtio_transport_connect, 403 .shutdown = virtio_transport_shutdown, 404 .cancel_pkt = vhost_transport_cancel_pkt, 405 406 .dgram_enqueue = virtio_transport_dgram_enqueue, 407 .dgram_dequeue = virtio_transport_dgram_dequeue, 408 .dgram_bind = virtio_transport_dgram_bind, 409 .dgram_allow = virtio_transport_dgram_allow, 410 411 .stream_enqueue = virtio_transport_stream_enqueue, 412 .stream_dequeue = virtio_transport_stream_dequeue, 413 .stream_has_data = virtio_transport_stream_has_data, 414 .stream_has_space = virtio_transport_stream_has_space, 415 .stream_rcvhiwat = virtio_transport_stream_rcvhiwat, 416 .stream_is_active = virtio_transport_stream_is_active, 417 .stream_allow = virtio_transport_stream_allow, 418 419 .notify_poll_in = virtio_transport_notify_poll_in, 420 .notify_poll_out = virtio_transport_notify_poll_out, 421 .notify_recv_init = virtio_transport_notify_recv_init, 422 .notify_recv_pre_block = virtio_transport_notify_recv_pre_block, 423 .notify_recv_pre_dequeue = virtio_transport_notify_recv_pre_dequeue, 424 .notify_recv_post_dequeue = virtio_transport_notify_recv_post_dequeue, 425 .notify_send_init = virtio_transport_notify_send_init, 426 .notify_send_pre_block = virtio_transport_notify_send_pre_block, 427 .notify_send_pre_enqueue = virtio_transport_notify_send_pre_enqueue, 428 .notify_send_post_enqueue = virtio_transport_notify_send_post_enqueue, 429 .notify_buffer_size = virtio_transport_notify_buffer_size, 430 431 }, 432 433 .send_pkt = vhost_transport_send_pkt, 434 }; 435 436 static void vhost_vsock_handle_tx_kick(struct vhost_work *work) 437 { 438 struct vhost_virtqueue *vq = container_of(work, struct vhost_virtqueue, 439 poll.work); 440 struct vhost_vsock *vsock = container_of(vq->dev, struct vhost_vsock, 441 dev); 442 struct virtio_vsock_pkt *pkt; 443 int head, pkts = 0, total_len = 0; 444 unsigned int out, in; 445 bool added = false; 446 447 mutex_lock(&vq->mutex); 448 449 if (!vhost_vq_get_backend(vq)) 450 goto out; 451 452 vhost_disable_notify(&vsock->dev, vq); 453 do { 454 u32 len; 455 456 if (!vhost_vsock_more_replies(vsock)) { 457 /* Stop tx until the device processes already 458 * pending replies. Leave tx virtqueue 459 * callbacks disabled. 460 */ 461 goto no_more_replies; 462 } 463 464 head = vhost_get_vq_desc(vq, vq->iov, ARRAY_SIZE(vq->iov), 465 &out, &in, NULL, NULL); 466 if (head < 0) 467 break; 468 469 if (head == vq->num) { 470 if (unlikely(vhost_enable_notify(&vsock->dev, vq))) { 471 vhost_disable_notify(&vsock->dev, vq); 472 continue; 473 } 474 break; 475 } 476 477 pkt = vhost_vsock_alloc_pkt(vq, out, in); 478 if (!pkt) { 479 vq_err(vq, "Faulted on pkt\n"); 480 continue; 481 } 482 483 len = pkt->len; 484 485 /* Deliver to monitoring devices all received packets */ 486 virtio_transport_deliver_tap_pkt(pkt); 487 488 /* Only accept correctly addressed packets */ 489 if (le64_to_cpu(pkt->hdr.src_cid) == vsock->guest_cid && 490 le64_to_cpu(pkt->hdr.dst_cid) == 491 vhost_transport_get_local_cid()) 492 virtio_transport_recv_pkt(&vhost_transport, pkt); 493 else 494 virtio_transport_free_pkt(pkt); 495 496 len += sizeof(pkt->hdr); 497 vhost_add_used(vq, head, len); 498 total_len += len; 499 added = true; 500 } while(likely(!vhost_exceeds_weight(vq, ++pkts, total_len))); 501 502 no_more_replies: 503 if (added) 504 vhost_signal(&vsock->dev, vq); 505 506 out: 507 mutex_unlock(&vq->mutex); 508 } 509 510 static void vhost_vsock_handle_rx_kick(struct vhost_work *work) 511 { 512 struct vhost_virtqueue *vq = container_of(work, struct vhost_virtqueue, 513 poll.work); 514 struct vhost_vsock *vsock = container_of(vq->dev, struct vhost_vsock, 515 dev); 516 517 vhost_transport_do_send_pkt(vsock, vq); 518 } 519 520 static int vhost_vsock_start(struct vhost_vsock *vsock) 521 { 522 struct vhost_virtqueue *vq; 523 size_t i; 524 int ret; 525 526 mutex_lock(&vsock->dev.mutex); 527 528 ret = vhost_dev_check_owner(&vsock->dev); 529 if (ret) 530 goto err; 531 532 for (i = 0; i < ARRAY_SIZE(vsock->vqs); i++) { 533 vq = &vsock->vqs[i]; 534 535 mutex_lock(&vq->mutex); 536 537 if (!vhost_vq_access_ok(vq)) { 538 ret = -EFAULT; 539 goto err_vq; 540 } 541 542 if (!vhost_vq_get_backend(vq)) { 543 vhost_vq_set_backend(vq, vsock); 544 ret = vhost_vq_init_access(vq); 545 if (ret) 546 goto err_vq; 547 } 548 549 mutex_unlock(&vq->mutex); 550 } 551 552 /* Some packets may have been queued before the device was started, 553 * let's kick the send worker to send them. 554 */ 555 vhost_work_queue(&vsock->dev, &vsock->send_pkt_work); 556 557 mutex_unlock(&vsock->dev.mutex); 558 return 0; 559 560 err_vq: 561 vhost_vq_set_backend(vq, NULL); 562 mutex_unlock(&vq->mutex); 563 564 for (i = 0; i < ARRAY_SIZE(vsock->vqs); i++) { 565 vq = &vsock->vqs[i]; 566 567 mutex_lock(&vq->mutex); 568 vhost_vq_set_backend(vq, NULL); 569 mutex_unlock(&vq->mutex); 570 } 571 err: 572 mutex_unlock(&vsock->dev.mutex); 573 return ret; 574 } 575 576 static int vhost_vsock_stop(struct vhost_vsock *vsock) 577 { 578 size_t i; 579 int ret; 580 581 mutex_lock(&vsock->dev.mutex); 582 583 ret = vhost_dev_check_owner(&vsock->dev); 584 if (ret) 585 goto err; 586 587 for (i = 0; i < ARRAY_SIZE(vsock->vqs); i++) { 588 struct vhost_virtqueue *vq = &vsock->vqs[i]; 589 590 mutex_lock(&vq->mutex); 591 vhost_vq_set_backend(vq, NULL); 592 mutex_unlock(&vq->mutex); 593 } 594 595 err: 596 mutex_unlock(&vsock->dev.mutex); 597 return ret; 598 } 599 600 static void vhost_vsock_free(struct vhost_vsock *vsock) 601 { 602 kvfree(vsock); 603 } 604 605 static int vhost_vsock_dev_open(struct inode *inode, struct file *file) 606 { 607 struct vhost_virtqueue **vqs; 608 struct vhost_vsock *vsock; 609 int ret; 610 611 /* This struct is large and allocation could fail, fall back to vmalloc 612 * if there is no other way. 613 */ 614 vsock = kvmalloc(sizeof(*vsock), GFP_KERNEL | __GFP_RETRY_MAYFAIL); 615 if (!vsock) 616 return -ENOMEM; 617 618 vqs = kmalloc_array(ARRAY_SIZE(vsock->vqs), sizeof(*vqs), GFP_KERNEL); 619 if (!vqs) { 620 ret = -ENOMEM; 621 goto out; 622 } 623 624 vsock->guest_cid = 0; /* no CID assigned yet */ 625 626 atomic_set(&vsock->queued_replies, 0); 627 628 vqs[VSOCK_VQ_TX] = &vsock->vqs[VSOCK_VQ_TX]; 629 vqs[VSOCK_VQ_RX] = &vsock->vqs[VSOCK_VQ_RX]; 630 vsock->vqs[VSOCK_VQ_TX].handle_kick = vhost_vsock_handle_tx_kick; 631 vsock->vqs[VSOCK_VQ_RX].handle_kick = vhost_vsock_handle_rx_kick; 632 633 vhost_dev_init(&vsock->dev, vqs, ARRAY_SIZE(vsock->vqs), 634 UIO_MAXIOV, VHOST_VSOCK_PKT_WEIGHT, 635 VHOST_VSOCK_WEIGHT, NULL); 636 637 file->private_data = vsock; 638 spin_lock_init(&vsock->send_pkt_list_lock); 639 INIT_LIST_HEAD(&vsock->send_pkt_list); 640 vhost_work_init(&vsock->send_pkt_work, vhost_transport_send_pkt_work); 641 return 0; 642 643 out: 644 vhost_vsock_free(vsock); 645 return ret; 646 } 647 648 static void vhost_vsock_flush(struct vhost_vsock *vsock) 649 { 650 int i; 651 652 for (i = 0; i < ARRAY_SIZE(vsock->vqs); i++) 653 if (vsock->vqs[i].handle_kick) 654 vhost_poll_flush(&vsock->vqs[i].poll); 655 vhost_work_flush(&vsock->dev, &vsock->send_pkt_work); 656 } 657 658 static void vhost_vsock_reset_orphans(struct sock *sk) 659 { 660 struct vsock_sock *vsk = vsock_sk(sk); 661 662 /* vmci_transport.c doesn't take sk_lock here either. At least we're 663 * under vsock_table_lock so the sock cannot disappear while we're 664 * executing. 665 */ 666 667 /* If the peer is still valid, no need to reset connection */ 668 if (vhost_vsock_get(vsk->remote_addr.svm_cid)) 669 return; 670 671 /* If the close timeout is pending, let it expire. This avoids races 672 * with the timeout callback. 673 */ 674 if (vsk->close_work_scheduled) 675 return; 676 677 sock_set_flag(sk, SOCK_DONE); 678 vsk->peer_shutdown = SHUTDOWN_MASK; 679 sk->sk_state = SS_UNCONNECTED; 680 sk->sk_err = ECONNRESET; 681 sk->sk_error_report(sk); 682 } 683 684 static int vhost_vsock_dev_release(struct inode *inode, struct file *file) 685 { 686 struct vhost_vsock *vsock = file->private_data; 687 688 mutex_lock(&vhost_vsock_mutex); 689 if (vsock->guest_cid) 690 hash_del_rcu(&vsock->hash); 691 mutex_unlock(&vhost_vsock_mutex); 692 693 /* Wait for other CPUs to finish using vsock */ 694 synchronize_rcu(); 695 696 /* Iterating over all connections for all CIDs to find orphans is 697 * inefficient. Room for improvement here. */ 698 vsock_for_each_connected_socket(vhost_vsock_reset_orphans); 699 700 vhost_vsock_stop(vsock); 701 vhost_vsock_flush(vsock); 702 vhost_dev_stop(&vsock->dev); 703 704 spin_lock_bh(&vsock->send_pkt_list_lock); 705 while (!list_empty(&vsock->send_pkt_list)) { 706 struct virtio_vsock_pkt *pkt; 707 708 pkt = list_first_entry(&vsock->send_pkt_list, 709 struct virtio_vsock_pkt, list); 710 list_del_init(&pkt->list); 711 virtio_transport_free_pkt(pkt); 712 } 713 spin_unlock_bh(&vsock->send_pkt_list_lock); 714 715 vhost_dev_cleanup(&vsock->dev); 716 kfree(vsock->dev.vqs); 717 vhost_vsock_free(vsock); 718 return 0; 719 } 720 721 static int vhost_vsock_set_cid(struct vhost_vsock *vsock, u64 guest_cid) 722 { 723 struct vhost_vsock *other; 724 725 /* Refuse reserved CIDs */ 726 if (guest_cid <= VMADDR_CID_HOST || 727 guest_cid == U32_MAX) 728 return -EINVAL; 729 730 /* 64-bit CIDs are not yet supported */ 731 if (guest_cid > U32_MAX) 732 return -EINVAL; 733 734 /* Refuse if CID is assigned to the guest->host transport (i.e. nested 735 * VM), to make the loopback work. 736 */ 737 if (vsock_find_cid(guest_cid)) 738 return -EADDRINUSE; 739 740 /* Refuse if CID is already in use */ 741 mutex_lock(&vhost_vsock_mutex); 742 other = vhost_vsock_get(guest_cid); 743 if (other && other != vsock) { 744 mutex_unlock(&vhost_vsock_mutex); 745 return -EADDRINUSE; 746 } 747 748 if (vsock->guest_cid) 749 hash_del_rcu(&vsock->hash); 750 751 vsock->guest_cid = guest_cid; 752 hash_add_rcu(vhost_vsock_hash, &vsock->hash, vsock->guest_cid); 753 mutex_unlock(&vhost_vsock_mutex); 754 755 return 0; 756 } 757 758 static int vhost_vsock_set_features(struct vhost_vsock *vsock, u64 features) 759 { 760 struct vhost_virtqueue *vq; 761 int i; 762 763 if (features & ~VHOST_VSOCK_FEATURES) 764 return -EOPNOTSUPP; 765 766 mutex_lock(&vsock->dev.mutex); 767 if ((features & (1 << VHOST_F_LOG_ALL)) && 768 !vhost_log_access_ok(&vsock->dev)) { 769 mutex_unlock(&vsock->dev.mutex); 770 return -EFAULT; 771 } 772 773 for (i = 0; i < ARRAY_SIZE(vsock->vqs); i++) { 774 vq = &vsock->vqs[i]; 775 mutex_lock(&vq->mutex); 776 vq->acked_features = features; 777 mutex_unlock(&vq->mutex); 778 } 779 mutex_unlock(&vsock->dev.mutex); 780 return 0; 781 } 782 783 static long vhost_vsock_dev_ioctl(struct file *f, unsigned int ioctl, 784 unsigned long arg) 785 { 786 struct vhost_vsock *vsock = f->private_data; 787 void __user *argp = (void __user *)arg; 788 u64 guest_cid; 789 u64 features; 790 int start; 791 int r; 792 793 switch (ioctl) { 794 case VHOST_VSOCK_SET_GUEST_CID: 795 if (copy_from_user(&guest_cid, argp, sizeof(guest_cid))) 796 return -EFAULT; 797 return vhost_vsock_set_cid(vsock, guest_cid); 798 case VHOST_VSOCK_SET_RUNNING: 799 if (copy_from_user(&start, argp, sizeof(start))) 800 return -EFAULT; 801 if (start) 802 return vhost_vsock_start(vsock); 803 else 804 return vhost_vsock_stop(vsock); 805 case VHOST_GET_FEATURES: 806 features = VHOST_VSOCK_FEATURES; 807 if (copy_to_user(argp, &features, sizeof(features))) 808 return -EFAULT; 809 return 0; 810 case VHOST_SET_FEATURES: 811 if (copy_from_user(&features, argp, sizeof(features))) 812 return -EFAULT; 813 return vhost_vsock_set_features(vsock, features); 814 default: 815 mutex_lock(&vsock->dev.mutex); 816 r = vhost_dev_ioctl(&vsock->dev, ioctl, argp); 817 if (r == -ENOIOCTLCMD) 818 r = vhost_vring_ioctl(&vsock->dev, ioctl, argp); 819 else 820 vhost_vsock_flush(vsock); 821 mutex_unlock(&vsock->dev.mutex); 822 return r; 823 } 824 } 825 826 static const struct file_operations vhost_vsock_fops = { 827 .owner = THIS_MODULE, 828 .open = vhost_vsock_dev_open, 829 .release = vhost_vsock_dev_release, 830 .llseek = noop_llseek, 831 .unlocked_ioctl = vhost_vsock_dev_ioctl, 832 .compat_ioctl = compat_ptr_ioctl, 833 }; 834 835 static struct miscdevice vhost_vsock_misc = { 836 .minor = VHOST_VSOCK_MINOR, 837 .name = "vhost-vsock", 838 .fops = &vhost_vsock_fops, 839 }; 840 841 static int __init vhost_vsock_init(void) 842 { 843 int ret; 844 845 ret = vsock_core_register(&vhost_transport.transport, 846 VSOCK_TRANSPORT_F_H2G); 847 if (ret < 0) 848 return ret; 849 return misc_register(&vhost_vsock_misc); 850 }; 851 852 static void __exit vhost_vsock_exit(void) 853 { 854 misc_deregister(&vhost_vsock_misc); 855 vsock_core_unregister(&vhost_transport.transport); 856 }; 857 858 module_init(vhost_vsock_init); 859 module_exit(vhost_vsock_exit); 860 MODULE_LICENSE("GPL v2"); 861 MODULE_AUTHOR("Asias He"); 862 MODULE_DESCRIPTION("vhost transport for vsock "); 863 MODULE_ALIAS_MISCDEV(VHOST_VSOCK_MINOR); 864 MODULE_ALIAS("devname:vhost-vsock"); 865