1 // SPDX-License-Identifier: GPL-2.0-only 2 /* 3 * virtio transport for vsock 4 * 5 * Copyright (C) 2013-2015 Red Hat, Inc. 6 * Author: Asias He <asias@redhat.com> 7 * Stefan Hajnoczi <stefanha@redhat.com> 8 * 9 * Some of the code is take from Gerd Hoffmann <kraxel@redhat.com>'s 10 * early virtio-vsock proof-of-concept bits. 11 */ 12 #include <linux/spinlock.h> 13 #include <linux/module.h> 14 #include <linux/list.h> 15 #include <linux/atomic.h> 16 #include <linux/virtio.h> 17 #include <linux/virtio_ids.h> 18 #include <linux/virtio_config.h> 19 #include <linux/virtio_vsock.h> 20 #include <net/sock.h> 21 #include <linux/mutex.h> 22 #include <net/af_vsock.h> 23 24 static struct workqueue_struct *virtio_vsock_workqueue; 25 static struct virtio_vsock __rcu *the_virtio_vsock; 26 static DEFINE_MUTEX(the_virtio_vsock_mutex); /* protects the_virtio_vsock */ 27 static struct virtio_transport virtio_transport; /* forward declaration */ 28 29 struct virtio_vsock { 30 struct virtio_device *vdev; 31 struct virtqueue *vqs[VSOCK_VQ_MAX]; 32 33 /* Virtqueue processing is deferred to a workqueue */ 34 struct work_struct tx_work; 35 struct work_struct rx_work; 36 struct work_struct event_work; 37 38 /* The following fields are protected by tx_lock. vqs[VSOCK_VQ_TX] 39 * must be accessed with tx_lock held. 40 */ 41 struct mutex tx_lock; 42 bool tx_run; 43 44 struct work_struct send_pkt_work; 45 spinlock_t send_pkt_list_lock; 46 struct list_head send_pkt_list; 47 48 atomic_t queued_replies; 49 50 /* The following fields are protected by rx_lock. vqs[VSOCK_VQ_RX] 51 * must be accessed with rx_lock held. 52 */ 53 struct mutex rx_lock; 54 bool rx_run; 55 int rx_buf_nr; 56 int rx_buf_max_nr; 57 58 /* The following fields are protected by event_lock. 59 * vqs[VSOCK_VQ_EVENT] must be accessed with event_lock held. 60 */ 61 struct mutex event_lock; 62 bool event_run; 63 struct virtio_vsock_event event_list[8]; 64 65 u32 guest_cid; 66 bool seqpacket_allow; 67 }; 68 69 static u32 virtio_transport_get_local_cid(void) 70 { 71 struct virtio_vsock *vsock; 72 u32 ret; 73 74 rcu_read_lock(); 75 vsock = rcu_dereference(the_virtio_vsock); 76 if (!vsock) { 77 ret = VMADDR_CID_ANY; 78 goto out_rcu; 79 } 80 81 ret = vsock->guest_cid; 82 out_rcu: 83 rcu_read_unlock(); 84 return ret; 85 } 86 87 static void 88 virtio_transport_send_pkt_work(struct work_struct *work) 89 { 90 struct virtio_vsock *vsock = 91 container_of(work, struct virtio_vsock, send_pkt_work); 92 struct virtqueue *vq; 93 bool added = false; 94 bool restart_rx = false; 95 96 mutex_lock(&vsock->tx_lock); 97 98 if (!vsock->tx_run) 99 goto out; 100 101 vq = vsock->vqs[VSOCK_VQ_TX]; 102 103 for (;;) { 104 struct virtio_vsock_pkt *pkt; 105 struct scatterlist hdr, buf, *sgs[2]; 106 int ret, in_sg = 0, out_sg = 0; 107 bool reply; 108 109 spin_lock_bh(&vsock->send_pkt_list_lock); 110 if (list_empty(&vsock->send_pkt_list)) { 111 spin_unlock_bh(&vsock->send_pkt_list_lock); 112 break; 113 } 114 115 pkt = list_first_entry(&vsock->send_pkt_list, 116 struct virtio_vsock_pkt, list); 117 list_del_init(&pkt->list); 118 spin_unlock_bh(&vsock->send_pkt_list_lock); 119 120 virtio_transport_deliver_tap_pkt(pkt); 121 122 reply = pkt->reply; 123 124 sg_init_one(&hdr, &pkt->hdr, sizeof(pkt->hdr)); 125 sgs[out_sg++] = &hdr; 126 if (pkt->buf) { 127 sg_init_one(&buf, pkt->buf, pkt->len); 128 sgs[out_sg++] = &buf; 129 } 130 131 ret = virtqueue_add_sgs(vq, sgs, out_sg, in_sg, pkt, GFP_KERNEL); 132 /* Usually this means that there is no more space available in 133 * the vq 134 */ 135 if (ret < 0) { 136 spin_lock_bh(&vsock->send_pkt_list_lock); 137 list_add(&pkt->list, &vsock->send_pkt_list); 138 spin_unlock_bh(&vsock->send_pkt_list_lock); 139 break; 140 } 141 142 if (reply) { 143 struct virtqueue *rx_vq = vsock->vqs[VSOCK_VQ_RX]; 144 int val; 145 146 val = atomic_dec_return(&vsock->queued_replies); 147 148 /* Do we now have resources to resume rx processing? */ 149 if (val + 1 == virtqueue_get_vring_size(rx_vq)) 150 restart_rx = true; 151 } 152 153 added = true; 154 } 155 156 if (added) 157 virtqueue_kick(vq); 158 159 out: 160 mutex_unlock(&vsock->tx_lock); 161 162 if (restart_rx) 163 queue_work(virtio_vsock_workqueue, &vsock->rx_work); 164 } 165 166 static int 167 virtio_transport_send_pkt(struct virtio_vsock_pkt *pkt) 168 { 169 struct virtio_vsock *vsock; 170 int len = pkt->len; 171 172 rcu_read_lock(); 173 vsock = rcu_dereference(the_virtio_vsock); 174 if (!vsock) { 175 virtio_transport_free_pkt(pkt); 176 len = -ENODEV; 177 goto out_rcu; 178 } 179 180 if (le64_to_cpu(pkt->hdr.dst_cid) == vsock->guest_cid) { 181 virtio_transport_free_pkt(pkt); 182 len = -ENODEV; 183 goto out_rcu; 184 } 185 186 if (pkt->reply) 187 atomic_inc(&vsock->queued_replies); 188 189 spin_lock_bh(&vsock->send_pkt_list_lock); 190 list_add_tail(&pkt->list, &vsock->send_pkt_list); 191 spin_unlock_bh(&vsock->send_pkt_list_lock); 192 193 queue_work(virtio_vsock_workqueue, &vsock->send_pkt_work); 194 195 out_rcu: 196 rcu_read_unlock(); 197 return len; 198 } 199 200 static int 201 virtio_transport_cancel_pkt(struct vsock_sock *vsk) 202 { 203 struct virtio_vsock *vsock; 204 struct virtio_vsock_pkt *pkt, *n; 205 int cnt = 0, ret; 206 LIST_HEAD(freeme); 207 208 rcu_read_lock(); 209 vsock = rcu_dereference(the_virtio_vsock); 210 if (!vsock) { 211 ret = -ENODEV; 212 goto out_rcu; 213 } 214 215 spin_lock_bh(&vsock->send_pkt_list_lock); 216 list_for_each_entry_safe(pkt, n, &vsock->send_pkt_list, list) { 217 if (pkt->vsk != vsk) 218 continue; 219 list_move(&pkt->list, &freeme); 220 } 221 spin_unlock_bh(&vsock->send_pkt_list_lock); 222 223 list_for_each_entry_safe(pkt, n, &freeme, list) { 224 if (pkt->reply) 225 cnt++; 226 list_del(&pkt->list); 227 virtio_transport_free_pkt(pkt); 228 } 229 230 if (cnt) { 231 struct virtqueue *rx_vq = vsock->vqs[VSOCK_VQ_RX]; 232 int new_cnt; 233 234 new_cnt = atomic_sub_return(cnt, &vsock->queued_replies); 235 if (new_cnt + cnt >= virtqueue_get_vring_size(rx_vq) && 236 new_cnt < virtqueue_get_vring_size(rx_vq)) 237 queue_work(virtio_vsock_workqueue, &vsock->rx_work); 238 } 239 240 ret = 0; 241 242 out_rcu: 243 rcu_read_unlock(); 244 return ret; 245 } 246 247 static void virtio_vsock_rx_fill(struct virtio_vsock *vsock) 248 { 249 int buf_len = VIRTIO_VSOCK_DEFAULT_RX_BUF_SIZE; 250 struct virtio_vsock_pkt *pkt; 251 struct scatterlist hdr, buf, *sgs[2]; 252 struct virtqueue *vq; 253 int ret; 254 255 vq = vsock->vqs[VSOCK_VQ_RX]; 256 257 do { 258 pkt = kzalloc(sizeof(*pkt), GFP_KERNEL); 259 if (!pkt) 260 break; 261 262 pkt->buf = kmalloc(buf_len, GFP_KERNEL); 263 if (!pkt->buf) { 264 virtio_transport_free_pkt(pkt); 265 break; 266 } 267 268 pkt->buf_len = buf_len; 269 pkt->len = buf_len; 270 271 sg_init_one(&hdr, &pkt->hdr, sizeof(pkt->hdr)); 272 sgs[0] = &hdr; 273 274 sg_init_one(&buf, pkt->buf, buf_len); 275 sgs[1] = &buf; 276 ret = virtqueue_add_sgs(vq, sgs, 0, 2, pkt, GFP_KERNEL); 277 if (ret) { 278 virtio_transport_free_pkt(pkt); 279 break; 280 } 281 vsock->rx_buf_nr++; 282 } while (vq->num_free); 283 if (vsock->rx_buf_nr > vsock->rx_buf_max_nr) 284 vsock->rx_buf_max_nr = vsock->rx_buf_nr; 285 virtqueue_kick(vq); 286 } 287 288 static void virtio_transport_tx_work(struct work_struct *work) 289 { 290 struct virtio_vsock *vsock = 291 container_of(work, struct virtio_vsock, tx_work); 292 struct virtqueue *vq; 293 bool added = false; 294 295 vq = vsock->vqs[VSOCK_VQ_TX]; 296 mutex_lock(&vsock->tx_lock); 297 298 if (!vsock->tx_run) 299 goto out; 300 301 do { 302 struct virtio_vsock_pkt *pkt; 303 unsigned int len; 304 305 virtqueue_disable_cb(vq); 306 while ((pkt = virtqueue_get_buf(vq, &len)) != NULL) { 307 virtio_transport_free_pkt(pkt); 308 added = true; 309 } 310 } while (!virtqueue_enable_cb(vq)); 311 312 out: 313 mutex_unlock(&vsock->tx_lock); 314 315 if (added) 316 queue_work(virtio_vsock_workqueue, &vsock->send_pkt_work); 317 } 318 319 /* Is there space left for replies to rx packets? */ 320 static bool virtio_transport_more_replies(struct virtio_vsock *vsock) 321 { 322 struct virtqueue *vq = vsock->vqs[VSOCK_VQ_RX]; 323 int val; 324 325 smp_rmb(); /* paired with atomic_inc() and atomic_dec_return() */ 326 val = atomic_read(&vsock->queued_replies); 327 328 return val < virtqueue_get_vring_size(vq); 329 } 330 331 /* event_lock must be held */ 332 static int virtio_vsock_event_fill_one(struct virtio_vsock *vsock, 333 struct virtio_vsock_event *event) 334 { 335 struct scatterlist sg; 336 struct virtqueue *vq; 337 338 vq = vsock->vqs[VSOCK_VQ_EVENT]; 339 340 sg_init_one(&sg, event, sizeof(*event)); 341 342 return virtqueue_add_inbuf(vq, &sg, 1, event, GFP_KERNEL); 343 } 344 345 /* event_lock must be held */ 346 static void virtio_vsock_event_fill(struct virtio_vsock *vsock) 347 { 348 size_t i; 349 350 for (i = 0; i < ARRAY_SIZE(vsock->event_list); i++) { 351 struct virtio_vsock_event *event = &vsock->event_list[i]; 352 353 virtio_vsock_event_fill_one(vsock, event); 354 } 355 356 virtqueue_kick(vsock->vqs[VSOCK_VQ_EVENT]); 357 } 358 359 static void virtio_vsock_reset_sock(struct sock *sk) 360 { 361 /* vmci_transport.c doesn't take sk_lock here either. At least we're 362 * under vsock_table_lock so the sock cannot disappear while we're 363 * executing. 364 */ 365 366 sk->sk_state = TCP_CLOSE; 367 sk->sk_err = ECONNRESET; 368 sk_error_report(sk); 369 } 370 371 static void virtio_vsock_update_guest_cid(struct virtio_vsock *vsock) 372 { 373 struct virtio_device *vdev = vsock->vdev; 374 __le64 guest_cid; 375 376 vdev->config->get(vdev, offsetof(struct virtio_vsock_config, guest_cid), 377 &guest_cid, sizeof(guest_cid)); 378 vsock->guest_cid = le64_to_cpu(guest_cid); 379 } 380 381 /* event_lock must be held */ 382 static void virtio_vsock_event_handle(struct virtio_vsock *vsock, 383 struct virtio_vsock_event *event) 384 { 385 switch (le32_to_cpu(event->id)) { 386 case VIRTIO_VSOCK_EVENT_TRANSPORT_RESET: 387 virtio_vsock_update_guest_cid(vsock); 388 vsock_for_each_connected_socket(&virtio_transport.transport, 389 virtio_vsock_reset_sock); 390 break; 391 } 392 } 393 394 static void virtio_transport_event_work(struct work_struct *work) 395 { 396 struct virtio_vsock *vsock = 397 container_of(work, struct virtio_vsock, event_work); 398 struct virtqueue *vq; 399 400 vq = vsock->vqs[VSOCK_VQ_EVENT]; 401 402 mutex_lock(&vsock->event_lock); 403 404 if (!vsock->event_run) 405 goto out; 406 407 do { 408 struct virtio_vsock_event *event; 409 unsigned int len; 410 411 virtqueue_disable_cb(vq); 412 while ((event = virtqueue_get_buf(vq, &len)) != NULL) { 413 if (len == sizeof(*event)) 414 virtio_vsock_event_handle(vsock, event); 415 416 virtio_vsock_event_fill_one(vsock, event); 417 } 418 } while (!virtqueue_enable_cb(vq)); 419 420 virtqueue_kick(vsock->vqs[VSOCK_VQ_EVENT]); 421 out: 422 mutex_unlock(&vsock->event_lock); 423 } 424 425 static void virtio_vsock_event_done(struct virtqueue *vq) 426 { 427 struct virtio_vsock *vsock = vq->vdev->priv; 428 429 if (!vsock) 430 return; 431 queue_work(virtio_vsock_workqueue, &vsock->event_work); 432 } 433 434 static void virtio_vsock_tx_done(struct virtqueue *vq) 435 { 436 struct virtio_vsock *vsock = vq->vdev->priv; 437 438 if (!vsock) 439 return; 440 queue_work(virtio_vsock_workqueue, &vsock->tx_work); 441 } 442 443 static void virtio_vsock_rx_done(struct virtqueue *vq) 444 { 445 struct virtio_vsock *vsock = vq->vdev->priv; 446 447 if (!vsock) 448 return; 449 queue_work(virtio_vsock_workqueue, &vsock->rx_work); 450 } 451 452 static bool virtio_transport_seqpacket_allow(u32 remote_cid); 453 454 static struct virtio_transport virtio_transport = { 455 .transport = { 456 .module = THIS_MODULE, 457 458 .get_local_cid = virtio_transport_get_local_cid, 459 460 .init = virtio_transport_do_socket_init, 461 .destruct = virtio_transport_destruct, 462 .release = virtio_transport_release, 463 .connect = virtio_transport_connect, 464 .shutdown = virtio_transport_shutdown, 465 .cancel_pkt = virtio_transport_cancel_pkt, 466 467 .dgram_bind = virtio_transport_dgram_bind, 468 .dgram_dequeue = virtio_transport_dgram_dequeue, 469 .dgram_enqueue = virtio_transport_dgram_enqueue, 470 .dgram_allow = virtio_transport_dgram_allow, 471 472 .stream_dequeue = virtio_transport_stream_dequeue, 473 .stream_enqueue = virtio_transport_stream_enqueue, 474 .stream_has_data = virtio_transport_stream_has_data, 475 .stream_has_space = virtio_transport_stream_has_space, 476 .stream_rcvhiwat = virtio_transport_stream_rcvhiwat, 477 .stream_is_active = virtio_transport_stream_is_active, 478 .stream_allow = virtio_transport_stream_allow, 479 480 .seqpacket_dequeue = virtio_transport_seqpacket_dequeue, 481 .seqpacket_enqueue = virtio_transport_seqpacket_enqueue, 482 .seqpacket_allow = virtio_transport_seqpacket_allow, 483 .seqpacket_has_data = virtio_transport_seqpacket_has_data, 484 485 .notify_poll_in = virtio_transport_notify_poll_in, 486 .notify_poll_out = virtio_transport_notify_poll_out, 487 .notify_recv_init = virtio_transport_notify_recv_init, 488 .notify_recv_pre_block = virtio_transport_notify_recv_pre_block, 489 .notify_recv_pre_dequeue = virtio_transport_notify_recv_pre_dequeue, 490 .notify_recv_post_dequeue = virtio_transport_notify_recv_post_dequeue, 491 .notify_send_init = virtio_transport_notify_send_init, 492 .notify_send_pre_block = virtio_transport_notify_send_pre_block, 493 .notify_send_pre_enqueue = virtio_transport_notify_send_pre_enqueue, 494 .notify_send_post_enqueue = virtio_transport_notify_send_post_enqueue, 495 .notify_buffer_size = virtio_transport_notify_buffer_size, 496 }, 497 498 .send_pkt = virtio_transport_send_pkt, 499 }; 500 501 static bool virtio_transport_seqpacket_allow(u32 remote_cid) 502 { 503 struct virtio_vsock *vsock; 504 bool seqpacket_allow; 505 506 seqpacket_allow = false; 507 rcu_read_lock(); 508 vsock = rcu_dereference(the_virtio_vsock); 509 if (vsock) 510 seqpacket_allow = vsock->seqpacket_allow; 511 rcu_read_unlock(); 512 513 return seqpacket_allow; 514 } 515 516 static void virtio_transport_rx_work(struct work_struct *work) 517 { 518 struct virtio_vsock *vsock = 519 container_of(work, struct virtio_vsock, rx_work); 520 struct virtqueue *vq; 521 522 vq = vsock->vqs[VSOCK_VQ_RX]; 523 524 mutex_lock(&vsock->rx_lock); 525 526 if (!vsock->rx_run) 527 goto out; 528 529 do { 530 virtqueue_disable_cb(vq); 531 for (;;) { 532 struct virtio_vsock_pkt *pkt; 533 unsigned int len; 534 535 if (!virtio_transport_more_replies(vsock)) { 536 /* Stop rx until the device processes already 537 * pending replies. Leave rx virtqueue 538 * callbacks disabled. 539 */ 540 goto out; 541 } 542 543 pkt = virtqueue_get_buf(vq, &len); 544 if (!pkt) { 545 break; 546 } 547 548 vsock->rx_buf_nr--; 549 550 /* Drop short/long packets */ 551 if (unlikely(len < sizeof(pkt->hdr) || 552 len > sizeof(pkt->hdr) + pkt->len)) { 553 virtio_transport_free_pkt(pkt); 554 continue; 555 } 556 557 pkt->len = len - sizeof(pkt->hdr); 558 virtio_transport_deliver_tap_pkt(pkt); 559 virtio_transport_recv_pkt(&virtio_transport, pkt); 560 } 561 } while (!virtqueue_enable_cb(vq)); 562 563 out: 564 if (vsock->rx_buf_nr < vsock->rx_buf_max_nr / 2) 565 virtio_vsock_rx_fill(vsock); 566 mutex_unlock(&vsock->rx_lock); 567 } 568 569 static int virtio_vsock_vqs_init(struct virtio_vsock *vsock) 570 { 571 struct virtio_device *vdev = vsock->vdev; 572 static const char * const names[] = { 573 "rx", 574 "tx", 575 "event", 576 }; 577 vq_callback_t *callbacks[] = { 578 virtio_vsock_rx_done, 579 virtio_vsock_tx_done, 580 virtio_vsock_event_done, 581 }; 582 int ret; 583 584 ret = virtio_find_vqs(vdev, VSOCK_VQ_MAX, vsock->vqs, callbacks, names, 585 NULL); 586 if (ret < 0) 587 return ret; 588 589 virtio_vsock_update_guest_cid(vsock); 590 591 virtio_device_ready(vdev); 592 593 mutex_lock(&vsock->tx_lock); 594 vsock->tx_run = true; 595 mutex_unlock(&vsock->tx_lock); 596 597 mutex_lock(&vsock->rx_lock); 598 virtio_vsock_rx_fill(vsock); 599 vsock->rx_run = true; 600 mutex_unlock(&vsock->rx_lock); 601 602 mutex_lock(&vsock->event_lock); 603 virtio_vsock_event_fill(vsock); 604 vsock->event_run = true; 605 mutex_unlock(&vsock->event_lock); 606 607 return 0; 608 } 609 610 static void virtio_vsock_vqs_del(struct virtio_vsock *vsock) 611 { 612 struct virtio_device *vdev = vsock->vdev; 613 struct virtio_vsock_pkt *pkt; 614 615 /* Reset all connected sockets when the VQs disappear */ 616 vsock_for_each_connected_socket(&virtio_transport.transport, 617 virtio_vsock_reset_sock); 618 619 /* Stop all work handlers to make sure no one is accessing the device, 620 * so we can safely call virtio_reset_device(). 621 */ 622 mutex_lock(&vsock->rx_lock); 623 vsock->rx_run = false; 624 mutex_unlock(&vsock->rx_lock); 625 626 mutex_lock(&vsock->tx_lock); 627 vsock->tx_run = false; 628 mutex_unlock(&vsock->tx_lock); 629 630 mutex_lock(&vsock->event_lock); 631 vsock->event_run = false; 632 mutex_unlock(&vsock->event_lock); 633 634 /* Flush all device writes and interrupts, device will not use any 635 * more buffers. 636 */ 637 virtio_reset_device(vdev); 638 639 mutex_lock(&vsock->rx_lock); 640 while ((pkt = virtqueue_detach_unused_buf(vsock->vqs[VSOCK_VQ_RX]))) 641 virtio_transport_free_pkt(pkt); 642 mutex_unlock(&vsock->rx_lock); 643 644 mutex_lock(&vsock->tx_lock); 645 while ((pkt = virtqueue_detach_unused_buf(vsock->vqs[VSOCK_VQ_TX]))) 646 virtio_transport_free_pkt(pkt); 647 mutex_unlock(&vsock->tx_lock); 648 649 spin_lock_bh(&vsock->send_pkt_list_lock); 650 while (!list_empty(&vsock->send_pkt_list)) { 651 pkt = list_first_entry(&vsock->send_pkt_list, 652 struct virtio_vsock_pkt, list); 653 list_del(&pkt->list); 654 virtio_transport_free_pkt(pkt); 655 } 656 spin_unlock_bh(&vsock->send_pkt_list_lock); 657 658 /* Delete virtqueues and flush outstanding callbacks if any */ 659 vdev->config->del_vqs(vdev); 660 } 661 662 static int virtio_vsock_probe(struct virtio_device *vdev) 663 { 664 struct virtio_vsock *vsock = NULL; 665 int ret; 666 667 ret = mutex_lock_interruptible(&the_virtio_vsock_mutex); 668 if (ret) 669 return ret; 670 671 /* Only one virtio-vsock device per guest is supported */ 672 if (rcu_dereference_protected(the_virtio_vsock, 673 lockdep_is_held(&the_virtio_vsock_mutex))) { 674 ret = -EBUSY; 675 goto out; 676 } 677 678 vsock = kzalloc(sizeof(*vsock), GFP_KERNEL); 679 if (!vsock) { 680 ret = -ENOMEM; 681 goto out; 682 } 683 684 vsock->vdev = vdev; 685 686 vsock->rx_buf_nr = 0; 687 vsock->rx_buf_max_nr = 0; 688 atomic_set(&vsock->queued_replies, 0); 689 690 mutex_init(&vsock->tx_lock); 691 mutex_init(&vsock->rx_lock); 692 mutex_init(&vsock->event_lock); 693 spin_lock_init(&vsock->send_pkt_list_lock); 694 INIT_LIST_HEAD(&vsock->send_pkt_list); 695 INIT_WORK(&vsock->rx_work, virtio_transport_rx_work); 696 INIT_WORK(&vsock->tx_work, virtio_transport_tx_work); 697 INIT_WORK(&vsock->event_work, virtio_transport_event_work); 698 INIT_WORK(&vsock->send_pkt_work, virtio_transport_send_pkt_work); 699 700 if (virtio_has_feature(vdev, VIRTIO_VSOCK_F_SEQPACKET)) 701 vsock->seqpacket_allow = true; 702 703 vdev->priv = vsock; 704 705 ret = virtio_vsock_vqs_init(vsock); 706 if (ret < 0) 707 goto out; 708 709 rcu_assign_pointer(the_virtio_vsock, vsock); 710 711 mutex_unlock(&the_virtio_vsock_mutex); 712 713 return 0; 714 715 out: 716 kfree(vsock); 717 mutex_unlock(&the_virtio_vsock_mutex); 718 return ret; 719 } 720 721 static void virtio_vsock_remove(struct virtio_device *vdev) 722 { 723 struct virtio_vsock *vsock = vdev->priv; 724 725 mutex_lock(&the_virtio_vsock_mutex); 726 727 vdev->priv = NULL; 728 rcu_assign_pointer(the_virtio_vsock, NULL); 729 synchronize_rcu(); 730 731 virtio_vsock_vqs_del(vsock); 732 733 /* Other works can be queued before 'config->del_vqs()', so we flush 734 * all works before to free the vsock object to avoid use after free. 735 */ 736 flush_work(&vsock->rx_work); 737 flush_work(&vsock->tx_work); 738 flush_work(&vsock->event_work); 739 flush_work(&vsock->send_pkt_work); 740 741 mutex_unlock(&the_virtio_vsock_mutex); 742 743 kfree(vsock); 744 } 745 746 #ifdef CONFIG_PM_SLEEP 747 static int virtio_vsock_freeze(struct virtio_device *vdev) 748 { 749 struct virtio_vsock *vsock = vdev->priv; 750 751 mutex_lock(&the_virtio_vsock_mutex); 752 753 rcu_assign_pointer(the_virtio_vsock, NULL); 754 synchronize_rcu(); 755 756 virtio_vsock_vqs_del(vsock); 757 758 mutex_unlock(&the_virtio_vsock_mutex); 759 760 return 0; 761 } 762 763 static int virtio_vsock_restore(struct virtio_device *vdev) 764 { 765 struct virtio_vsock *vsock = vdev->priv; 766 int ret; 767 768 mutex_lock(&the_virtio_vsock_mutex); 769 770 /* Only one virtio-vsock device per guest is supported */ 771 if (rcu_dereference_protected(the_virtio_vsock, 772 lockdep_is_held(&the_virtio_vsock_mutex))) { 773 ret = -EBUSY; 774 goto out; 775 } 776 777 ret = virtio_vsock_vqs_init(vsock); 778 if (ret < 0) 779 goto out; 780 781 rcu_assign_pointer(the_virtio_vsock, vsock); 782 783 out: 784 mutex_unlock(&the_virtio_vsock_mutex); 785 return ret; 786 } 787 #endif /* CONFIG_PM_SLEEP */ 788 789 static struct virtio_device_id id_table[] = { 790 { VIRTIO_ID_VSOCK, VIRTIO_DEV_ANY_ID }, 791 { 0 }, 792 }; 793 794 static unsigned int features[] = { 795 VIRTIO_VSOCK_F_SEQPACKET 796 }; 797 798 static struct virtio_driver virtio_vsock_driver = { 799 .feature_table = features, 800 .feature_table_size = ARRAY_SIZE(features), 801 .driver.name = KBUILD_MODNAME, 802 .driver.owner = THIS_MODULE, 803 .id_table = id_table, 804 .probe = virtio_vsock_probe, 805 .remove = virtio_vsock_remove, 806 #ifdef CONFIG_PM_SLEEP 807 .freeze = virtio_vsock_freeze, 808 .restore = virtio_vsock_restore, 809 #endif 810 }; 811 812 static int __init virtio_vsock_init(void) 813 { 814 int ret; 815 816 virtio_vsock_workqueue = alloc_workqueue("virtio_vsock", 0, 0); 817 if (!virtio_vsock_workqueue) 818 return -ENOMEM; 819 820 ret = vsock_core_register(&virtio_transport.transport, 821 VSOCK_TRANSPORT_F_G2H); 822 if (ret) 823 goto out_wq; 824 825 ret = register_virtio_driver(&virtio_vsock_driver); 826 if (ret) 827 goto out_vci; 828 829 return 0; 830 831 out_vci: 832 vsock_core_unregister(&virtio_transport.transport); 833 out_wq: 834 destroy_workqueue(virtio_vsock_workqueue); 835 return ret; 836 } 837 838 static void __exit virtio_vsock_exit(void) 839 { 840 unregister_virtio_driver(&virtio_vsock_driver); 841 vsock_core_unregister(&virtio_transport.transport); 842 destroy_workqueue(virtio_vsock_workqueue); 843 } 844 845 module_init(virtio_vsock_init); 846 module_exit(virtio_vsock_exit); 847 MODULE_LICENSE("GPL v2"); 848 MODULE_AUTHOR("Asias He"); 849 MODULE_DESCRIPTION("virtio transport for vsock"); 850 MODULE_DEVICE_TABLE(virtio, id_table); 851