1 // SPDX-License-Identifier: GPL-2.0-only
2 /*
3 * virtio transport for vsock
4 *
5 * Copyright (C) 2013-2015 Red Hat, Inc.
6 * Author: Asias He <asias@redhat.com>
7 * Stefan Hajnoczi <stefanha@redhat.com>
8 *
9 * Some of the code is take from Gerd Hoffmann <kraxel@redhat.com>'s
10 * early virtio-vsock proof-of-concept bits.
11 */
12 #include <linux/spinlock.h>
13 #include <linux/module.h>
14 #include <linux/list.h>
15 #include <linux/atomic.h>
16 #include <linux/virtio.h>
17 #include <linux/virtio_ids.h>
18 #include <linux/virtio_config.h>
19 #include <linux/virtio_vsock.h>
20 #include <net/sock.h>
21 #include <linux/mutex.h>
22 #include <net/af_vsock.h>
23
24 static struct workqueue_struct *virtio_vsock_workqueue;
25 static struct virtio_vsock __rcu *the_virtio_vsock;
26 static DEFINE_MUTEX(the_virtio_vsock_mutex); /* protects the_virtio_vsock */
27 static struct virtio_transport virtio_transport; /* forward declaration */
28
29 struct virtio_vsock {
30 struct virtio_device *vdev;
31 struct virtqueue *vqs[VSOCK_VQ_MAX];
32
33 /* Virtqueue processing is deferred to a workqueue */
34 struct work_struct tx_work;
35 struct work_struct rx_work;
36 struct work_struct event_work;
37
38 /* The following fields are protected by tx_lock. vqs[VSOCK_VQ_TX]
39 * must be accessed with tx_lock held.
40 */
41 struct mutex tx_lock;
42 bool tx_run;
43
44 struct work_struct send_pkt_work;
45 struct sk_buff_head send_pkt_queue;
46
47 atomic_t queued_replies;
48
49 /* The following fields are protected by rx_lock. vqs[VSOCK_VQ_RX]
50 * must be accessed with rx_lock held.
51 */
52 struct mutex rx_lock;
53 bool rx_run;
54 int rx_buf_nr;
55 int rx_buf_max_nr;
56
57 /* The following fields are protected by event_lock.
58 * vqs[VSOCK_VQ_EVENT] must be accessed with event_lock held.
59 */
60 struct mutex event_lock;
61 bool event_run;
62 struct virtio_vsock_event event_list[8];
63
64 u32 guest_cid;
65 bool seqpacket_allow;
66 };
67
virtio_transport_get_local_cid(void)68 static u32 virtio_transport_get_local_cid(void)
69 {
70 struct virtio_vsock *vsock;
71 u32 ret;
72
73 rcu_read_lock();
74 vsock = rcu_dereference(the_virtio_vsock);
75 if (!vsock) {
76 ret = VMADDR_CID_ANY;
77 goto out_rcu;
78 }
79
80 ret = vsock->guest_cid;
81 out_rcu:
82 rcu_read_unlock();
83 return ret;
84 }
85
86 static void
virtio_transport_send_pkt_work(struct work_struct * work)87 virtio_transport_send_pkt_work(struct work_struct *work)
88 {
89 struct virtio_vsock *vsock =
90 container_of(work, struct virtio_vsock, send_pkt_work);
91 struct virtqueue *vq;
92 bool added = false;
93 bool restart_rx = false;
94
95 mutex_lock(&vsock->tx_lock);
96
97 if (!vsock->tx_run)
98 goto out;
99
100 vq = vsock->vqs[VSOCK_VQ_TX];
101
102 for (;;) {
103 struct scatterlist hdr, buf, *sgs[2];
104 int ret, in_sg = 0, out_sg = 0;
105 struct sk_buff *skb;
106 bool reply;
107
108 skb = virtio_vsock_skb_dequeue(&vsock->send_pkt_queue);
109 if (!skb)
110 break;
111
112 reply = virtio_vsock_skb_reply(skb);
113
114 sg_init_one(&hdr, virtio_vsock_hdr(skb), sizeof(*virtio_vsock_hdr(skb)));
115 sgs[out_sg++] = &hdr;
116 if (skb->len > 0) {
117 sg_init_one(&buf, skb->data, skb->len);
118 sgs[out_sg++] = &buf;
119 }
120
121 ret = virtqueue_add_sgs(vq, sgs, out_sg, in_sg, skb, GFP_KERNEL);
122 /* Usually this means that there is no more space available in
123 * the vq
124 */
125 if (ret < 0) {
126 virtio_vsock_skb_queue_head(&vsock->send_pkt_queue, skb);
127 break;
128 }
129
130 virtio_transport_deliver_tap_pkt(skb);
131
132 if (reply) {
133 struct virtqueue *rx_vq = vsock->vqs[VSOCK_VQ_RX];
134 int val;
135
136 val = atomic_dec_return(&vsock->queued_replies);
137
138 /* Do we now have resources to resume rx processing? */
139 if (val + 1 == virtqueue_get_vring_size(rx_vq))
140 restart_rx = true;
141 }
142
143 added = true;
144 }
145
146 if (added)
147 virtqueue_kick(vq);
148
149 out:
150 mutex_unlock(&vsock->tx_lock);
151
152 if (restart_rx)
153 queue_work(virtio_vsock_workqueue, &vsock->rx_work);
154 }
155
156 static int
virtio_transport_send_pkt(struct sk_buff * skb)157 virtio_transport_send_pkt(struct sk_buff *skb)
158 {
159 struct virtio_vsock_hdr *hdr;
160 struct virtio_vsock *vsock;
161 int len = skb->len;
162
163 hdr = virtio_vsock_hdr(skb);
164
165 rcu_read_lock();
166 vsock = rcu_dereference(the_virtio_vsock);
167 if (!vsock) {
168 kfree_skb(skb);
169 len = -ENODEV;
170 goto out_rcu;
171 }
172
173 if (le64_to_cpu(hdr->dst_cid) == vsock->guest_cid) {
174 kfree_skb(skb);
175 len = -ENODEV;
176 goto out_rcu;
177 }
178
179 if (virtio_vsock_skb_reply(skb))
180 atomic_inc(&vsock->queued_replies);
181
182 virtio_vsock_skb_queue_tail(&vsock->send_pkt_queue, skb);
183 queue_work(virtio_vsock_workqueue, &vsock->send_pkt_work);
184
185 out_rcu:
186 rcu_read_unlock();
187 return len;
188 }
189
190 static int
virtio_transport_cancel_pkt(struct vsock_sock * vsk)191 virtio_transport_cancel_pkt(struct vsock_sock *vsk)
192 {
193 struct virtio_vsock *vsock;
194 int cnt = 0, ret;
195
196 rcu_read_lock();
197 vsock = rcu_dereference(the_virtio_vsock);
198 if (!vsock) {
199 ret = -ENODEV;
200 goto out_rcu;
201 }
202
203 cnt = virtio_transport_purge_skbs(vsk, &vsock->send_pkt_queue);
204
205 if (cnt) {
206 struct virtqueue *rx_vq = vsock->vqs[VSOCK_VQ_RX];
207 int new_cnt;
208
209 new_cnt = atomic_sub_return(cnt, &vsock->queued_replies);
210 if (new_cnt + cnt >= virtqueue_get_vring_size(rx_vq) &&
211 new_cnt < virtqueue_get_vring_size(rx_vq))
212 queue_work(virtio_vsock_workqueue, &vsock->rx_work);
213 }
214
215 ret = 0;
216
217 out_rcu:
218 rcu_read_unlock();
219 return ret;
220 }
221
virtio_vsock_rx_fill(struct virtio_vsock * vsock)222 static void virtio_vsock_rx_fill(struct virtio_vsock *vsock)
223 {
224 int total_len = VIRTIO_VSOCK_DEFAULT_RX_BUF_SIZE + VIRTIO_VSOCK_SKB_HEADROOM;
225 struct scatterlist pkt, *p;
226 struct virtqueue *vq;
227 struct sk_buff *skb;
228 int ret;
229
230 vq = vsock->vqs[VSOCK_VQ_RX];
231
232 do {
233 skb = virtio_vsock_alloc_skb(total_len, GFP_KERNEL);
234 if (!skb)
235 break;
236
237 memset(skb->head, 0, VIRTIO_VSOCK_SKB_HEADROOM);
238 sg_init_one(&pkt, virtio_vsock_hdr(skb), total_len);
239 p = &pkt;
240 ret = virtqueue_add_sgs(vq, &p, 0, 1, skb, GFP_KERNEL);
241 if (ret < 0) {
242 kfree_skb(skb);
243 break;
244 }
245
246 vsock->rx_buf_nr++;
247 } while (vq->num_free);
248 if (vsock->rx_buf_nr > vsock->rx_buf_max_nr)
249 vsock->rx_buf_max_nr = vsock->rx_buf_nr;
250 virtqueue_kick(vq);
251 }
252
virtio_transport_tx_work(struct work_struct * work)253 static void virtio_transport_tx_work(struct work_struct *work)
254 {
255 struct virtio_vsock *vsock =
256 container_of(work, struct virtio_vsock, tx_work);
257 struct virtqueue *vq;
258 bool added = false;
259
260 vq = vsock->vqs[VSOCK_VQ_TX];
261 mutex_lock(&vsock->tx_lock);
262
263 if (!vsock->tx_run)
264 goto out;
265
266 do {
267 struct sk_buff *skb;
268 unsigned int len;
269
270 virtqueue_disable_cb(vq);
271 while ((skb = virtqueue_get_buf(vq, &len)) != NULL) {
272 consume_skb(skb);
273 added = true;
274 }
275 } while (!virtqueue_enable_cb(vq));
276
277 out:
278 mutex_unlock(&vsock->tx_lock);
279
280 if (added)
281 queue_work(virtio_vsock_workqueue, &vsock->send_pkt_work);
282 }
283
284 /* Is there space left for replies to rx packets? */
virtio_transport_more_replies(struct virtio_vsock * vsock)285 static bool virtio_transport_more_replies(struct virtio_vsock *vsock)
286 {
287 struct virtqueue *vq = vsock->vqs[VSOCK_VQ_RX];
288 int val;
289
290 smp_rmb(); /* paired with atomic_inc() and atomic_dec_return() */
291 val = atomic_read(&vsock->queued_replies);
292
293 return val < virtqueue_get_vring_size(vq);
294 }
295
296 /* event_lock must be held */
virtio_vsock_event_fill_one(struct virtio_vsock * vsock,struct virtio_vsock_event * event)297 static int virtio_vsock_event_fill_one(struct virtio_vsock *vsock,
298 struct virtio_vsock_event *event)
299 {
300 struct scatterlist sg;
301 struct virtqueue *vq;
302
303 vq = vsock->vqs[VSOCK_VQ_EVENT];
304
305 sg_init_one(&sg, event, sizeof(*event));
306
307 return virtqueue_add_inbuf(vq, &sg, 1, event, GFP_KERNEL);
308 }
309
310 /* event_lock must be held */
virtio_vsock_event_fill(struct virtio_vsock * vsock)311 static void virtio_vsock_event_fill(struct virtio_vsock *vsock)
312 {
313 size_t i;
314
315 for (i = 0; i < ARRAY_SIZE(vsock->event_list); i++) {
316 struct virtio_vsock_event *event = &vsock->event_list[i];
317
318 virtio_vsock_event_fill_one(vsock, event);
319 }
320
321 virtqueue_kick(vsock->vqs[VSOCK_VQ_EVENT]);
322 }
323
virtio_vsock_reset_sock(struct sock * sk)324 static void virtio_vsock_reset_sock(struct sock *sk)
325 {
326 /* vmci_transport.c doesn't take sk_lock here either. At least we're
327 * under vsock_table_lock so the sock cannot disappear while we're
328 * executing.
329 */
330
331 sk->sk_state = TCP_CLOSE;
332 sk->sk_err = ECONNRESET;
333 sk_error_report(sk);
334 }
335
virtio_vsock_update_guest_cid(struct virtio_vsock * vsock)336 static void virtio_vsock_update_guest_cid(struct virtio_vsock *vsock)
337 {
338 struct virtio_device *vdev = vsock->vdev;
339 __le64 guest_cid;
340
341 vdev->config->get(vdev, offsetof(struct virtio_vsock_config, guest_cid),
342 &guest_cid, sizeof(guest_cid));
343 vsock->guest_cid = le64_to_cpu(guest_cid);
344 }
345
346 /* event_lock must be held */
virtio_vsock_event_handle(struct virtio_vsock * vsock,struct virtio_vsock_event * event)347 static void virtio_vsock_event_handle(struct virtio_vsock *vsock,
348 struct virtio_vsock_event *event)
349 {
350 switch (le32_to_cpu(event->id)) {
351 case VIRTIO_VSOCK_EVENT_TRANSPORT_RESET:
352 virtio_vsock_update_guest_cid(vsock);
353 vsock_for_each_connected_socket(&virtio_transport.transport,
354 virtio_vsock_reset_sock);
355 break;
356 }
357 }
358
virtio_transport_event_work(struct work_struct * work)359 static void virtio_transport_event_work(struct work_struct *work)
360 {
361 struct virtio_vsock *vsock =
362 container_of(work, struct virtio_vsock, event_work);
363 struct virtqueue *vq;
364
365 vq = vsock->vqs[VSOCK_VQ_EVENT];
366
367 mutex_lock(&vsock->event_lock);
368
369 if (!vsock->event_run)
370 goto out;
371
372 do {
373 struct virtio_vsock_event *event;
374 unsigned int len;
375
376 virtqueue_disable_cb(vq);
377 while ((event = virtqueue_get_buf(vq, &len)) != NULL) {
378 if (len == sizeof(*event))
379 virtio_vsock_event_handle(vsock, event);
380
381 virtio_vsock_event_fill_one(vsock, event);
382 }
383 } while (!virtqueue_enable_cb(vq));
384
385 virtqueue_kick(vsock->vqs[VSOCK_VQ_EVENT]);
386 out:
387 mutex_unlock(&vsock->event_lock);
388 }
389
virtio_vsock_event_done(struct virtqueue * vq)390 static void virtio_vsock_event_done(struct virtqueue *vq)
391 {
392 struct virtio_vsock *vsock = vq->vdev->priv;
393
394 if (!vsock)
395 return;
396 queue_work(virtio_vsock_workqueue, &vsock->event_work);
397 }
398
virtio_vsock_tx_done(struct virtqueue * vq)399 static void virtio_vsock_tx_done(struct virtqueue *vq)
400 {
401 struct virtio_vsock *vsock = vq->vdev->priv;
402
403 if (!vsock)
404 return;
405 queue_work(virtio_vsock_workqueue, &vsock->tx_work);
406 }
407
virtio_vsock_rx_done(struct virtqueue * vq)408 static void virtio_vsock_rx_done(struct virtqueue *vq)
409 {
410 struct virtio_vsock *vsock = vq->vdev->priv;
411
412 if (!vsock)
413 return;
414 queue_work(virtio_vsock_workqueue, &vsock->rx_work);
415 }
416
417 static bool virtio_transport_seqpacket_allow(u32 remote_cid);
418
419 static struct virtio_transport virtio_transport = {
420 .transport = {
421 .module = THIS_MODULE,
422
423 .get_local_cid = virtio_transport_get_local_cid,
424
425 .init = virtio_transport_do_socket_init,
426 .destruct = virtio_transport_destruct,
427 .release = virtio_transport_release,
428 .connect = virtio_transport_connect,
429 .shutdown = virtio_transport_shutdown,
430 .cancel_pkt = virtio_transport_cancel_pkt,
431
432 .dgram_bind = virtio_transport_dgram_bind,
433 .dgram_dequeue = virtio_transport_dgram_dequeue,
434 .dgram_enqueue = virtio_transport_dgram_enqueue,
435 .dgram_allow = virtio_transport_dgram_allow,
436
437 .stream_dequeue = virtio_transport_stream_dequeue,
438 .stream_enqueue = virtio_transport_stream_enqueue,
439 .stream_has_data = virtio_transport_stream_has_data,
440 .stream_has_space = virtio_transport_stream_has_space,
441 .stream_rcvhiwat = virtio_transport_stream_rcvhiwat,
442 .stream_is_active = virtio_transport_stream_is_active,
443 .stream_allow = virtio_transport_stream_allow,
444
445 .seqpacket_dequeue = virtio_transport_seqpacket_dequeue,
446 .seqpacket_enqueue = virtio_transport_seqpacket_enqueue,
447 .seqpacket_allow = virtio_transport_seqpacket_allow,
448 .seqpacket_has_data = virtio_transport_seqpacket_has_data,
449
450 .notify_poll_in = virtio_transport_notify_poll_in,
451 .notify_poll_out = virtio_transport_notify_poll_out,
452 .notify_recv_init = virtio_transport_notify_recv_init,
453 .notify_recv_pre_block = virtio_transport_notify_recv_pre_block,
454 .notify_recv_pre_dequeue = virtio_transport_notify_recv_pre_dequeue,
455 .notify_recv_post_dequeue = virtio_transport_notify_recv_post_dequeue,
456 .notify_send_init = virtio_transport_notify_send_init,
457 .notify_send_pre_block = virtio_transport_notify_send_pre_block,
458 .notify_send_pre_enqueue = virtio_transport_notify_send_pre_enqueue,
459 .notify_send_post_enqueue = virtio_transport_notify_send_post_enqueue,
460 .notify_buffer_size = virtio_transport_notify_buffer_size,
461 .notify_set_rcvlowat = virtio_transport_notify_set_rcvlowat,
462
463 .read_skb = virtio_transport_read_skb,
464 },
465
466 .send_pkt = virtio_transport_send_pkt,
467 };
468
virtio_transport_seqpacket_allow(u32 remote_cid)469 static bool virtio_transport_seqpacket_allow(u32 remote_cid)
470 {
471 struct virtio_vsock *vsock;
472 bool seqpacket_allow;
473
474 seqpacket_allow = false;
475 rcu_read_lock();
476 vsock = rcu_dereference(the_virtio_vsock);
477 if (vsock)
478 seqpacket_allow = vsock->seqpacket_allow;
479 rcu_read_unlock();
480
481 return seqpacket_allow;
482 }
483
virtio_transport_rx_work(struct work_struct * work)484 static void virtio_transport_rx_work(struct work_struct *work)
485 {
486 struct virtio_vsock *vsock =
487 container_of(work, struct virtio_vsock, rx_work);
488 struct virtqueue *vq;
489
490 vq = vsock->vqs[VSOCK_VQ_RX];
491
492 mutex_lock(&vsock->rx_lock);
493
494 if (!vsock->rx_run)
495 goto out;
496
497 do {
498 virtqueue_disable_cb(vq);
499 for (;;) {
500 struct sk_buff *skb;
501 unsigned int len;
502
503 if (!virtio_transport_more_replies(vsock)) {
504 /* Stop rx until the device processes already
505 * pending replies. Leave rx virtqueue
506 * callbacks disabled.
507 */
508 goto out;
509 }
510
511 skb = virtqueue_get_buf(vq, &len);
512 if (!skb)
513 break;
514
515 vsock->rx_buf_nr--;
516
517 /* Drop short/long packets */
518 if (unlikely(len < sizeof(struct virtio_vsock_hdr) ||
519 len > virtio_vsock_skb_len(skb))) {
520 kfree_skb(skb);
521 continue;
522 }
523
524 virtio_vsock_skb_rx_put(skb);
525 virtio_transport_deliver_tap_pkt(skb);
526 virtio_transport_recv_pkt(&virtio_transport, skb);
527 }
528 } while (!virtqueue_enable_cb(vq));
529
530 out:
531 if (vsock->rx_buf_nr < vsock->rx_buf_max_nr / 2)
532 virtio_vsock_rx_fill(vsock);
533 mutex_unlock(&vsock->rx_lock);
534 }
535
virtio_vsock_vqs_init(struct virtio_vsock * vsock)536 static int virtio_vsock_vqs_init(struct virtio_vsock *vsock)
537 {
538 struct virtio_device *vdev = vsock->vdev;
539 static const char * const names[] = {
540 "rx",
541 "tx",
542 "event",
543 };
544 vq_callback_t *callbacks[] = {
545 virtio_vsock_rx_done,
546 virtio_vsock_tx_done,
547 virtio_vsock_event_done,
548 };
549 int ret;
550
551 ret = virtio_find_vqs(vdev, VSOCK_VQ_MAX, vsock->vqs, callbacks, names,
552 NULL);
553 if (ret < 0)
554 return ret;
555
556 virtio_vsock_update_guest_cid(vsock);
557
558 virtio_device_ready(vdev);
559
560 return 0;
561 }
562
virtio_vsock_vqs_start(struct virtio_vsock * vsock)563 static void virtio_vsock_vqs_start(struct virtio_vsock *vsock)
564 {
565 mutex_lock(&vsock->tx_lock);
566 vsock->tx_run = true;
567 mutex_unlock(&vsock->tx_lock);
568
569 mutex_lock(&vsock->rx_lock);
570 virtio_vsock_rx_fill(vsock);
571 vsock->rx_run = true;
572 mutex_unlock(&vsock->rx_lock);
573
574 mutex_lock(&vsock->event_lock);
575 virtio_vsock_event_fill(vsock);
576 vsock->event_run = true;
577 mutex_unlock(&vsock->event_lock);
578
579 /* virtio_transport_send_pkt() can queue packets once
580 * the_virtio_vsock is set, but they won't be processed until
581 * vsock->tx_run is set to true. We queue vsock->send_pkt_work
582 * when initialization finishes to send those packets queued
583 * earlier.
584 * We don't need to queue the other workers (rx, event) because
585 * as long as we don't fill the queues with empty buffers, the
586 * host can't send us any notification.
587 */
588 queue_work(virtio_vsock_workqueue, &vsock->send_pkt_work);
589 }
590
virtio_vsock_vqs_del(struct virtio_vsock * vsock)591 static void virtio_vsock_vqs_del(struct virtio_vsock *vsock)
592 {
593 struct virtio_device *vdev = vsock->vdev;
594 struct sk_buff *skb;
595
596 /* Reset all connected sockets when the VQs disappear */
597 vsock_for_each_connected_socket(&virtio_transport.transport,
598 virtio_vsock_reset_sock);
599
600 /* Stop all work handlers to make sure no one is accessing the device,
601 * so we can safely call virtio_reset_device().
602 */
603 mutex_lock(&vsock->rx_lock);
604 vsock->rx_run = false;
605 mutex_unlock(&vsock->rx_lock);
606
607 mutex_lock(&vsock->tx_lock);
608 vsock->tx_run = false;
609 mutex_unlock(&vsock->tx_lock);
610
611 mutex_lock(&vsock->event_lock);
612 vsock->event_run = false;
613 mutex_unlock(&vsock->event_lock);
614
615 /* Flush all device writes and interrupts, device will not use any
616 * more buffers.
617 */
618 virtio_reset_device(vdev);
619
620 mutex_lock(&vsock->rx_lock);
621 while ((skb = virtqueue_detach_unused_buf(vsock->vqs[VSOCK_VQ_RX])))
622 kfree_skb(skb);
623 mutex_unlock(&vsock->rx_lock);
624
625 mutex_lock(&vsock->tx_lock);
626 while ((skb = virtqueue_detach_unused_buf(vsock->vqs[VSOCK_VQ_TX])))
627 kfree_skb(skb);
628 mutex_unlock(&vsock->tx_lock);
629
630 virtio_vsock_skb_queue_purge(&vsock->send_pkt_queue);
631
632 /* Delete virtqueues and flush outstanding callbacks if any */
633 vdev->config->del_vqs(vdev);
634 }
635
virtio_vsock_probe(struct virtio_device * vdev)636 static int virtio_vsock_probe(struct virtio_device *vdev)
637 {
638 struct virtio_vsock *vsock = NULL;
639 int ret;
640
641 ret = mutex_lock_interruptible(&the_virtio_vsock_mutex);
642 if (ret)
643 return ret;
644
645 /* Only one virtio-vsock device per guest is supported */
646 if (rcu_dereference_protected(the_virtio_vsock,
647 lockdep_is_held(&the_virtio_vsock_mutex))) {
648 ret = -EBUSY;
649 goto out;
650 }
651
652 vsock = kzalloc(sizeof(*vsock), GFP_KERNEL);
653 if (!vsock) {
654 ret = -ENOMEM;
655 goto out;
656 }
657
658 vsock->vdev = vdev;
659
660 vsock->rx_buf_nr = 0;
661 vsock->rx_buf_max_nr = 0;
662 atomic_set(&vsock->queued_replies, 0);
663
664 mutex_init(&vsock->tx_lock);
665 mutex_init(&vsock->rx_lock);
666 mutex_init(&vsock->event_lock);
667 skb_queue_head_init(&vsock->send_pkt_queue);
668 INIT_WORK(&vsock->rx_work, virtio_transport_rx_work);
669 INIT_WORK(&vsock->tx_work, virtio_transport_tx_work);
670 INIT_WORK(&vsock->event_work, virtio_transport_event_work);
671 INIT_WORK(&vsock->send_pkt_work, virtio_transport_send_pkt_work);
672
673 if (virtio_has_feature(vdev, VIRTIO_VSOCK_F_SEQPACKET))
674 vsock->seqpacket_allow = true;
675
676 vdev->priv = vsock;
677
678 ret = virtio_vsock_vqs_init(vsock);
679 if (ret < 0)
680 goto out;
681
682 rcu_assign_pointer(the_virtio_vsock, vsock);
683 virtio_vsock_vqs_start(vsock);
684
685 mutex_unlock(&the_virtio_vsock_mutex);
686
687 return 0;
688
689 out:
690 kfree(vsock);
691 mutex_unlock(&the_virtio_vsock_mutex);
692 return ret;
693 }
694
virtio_vsock_remove(struct virtio_device * vdev)695 static void virtio_vsock_remove(struct virtio_device *vdev)
696 {
697 struct virtio_vsock *vsock = vdev->priv;
698
699 mutex_lock(&the_virtio_vsock_mutex);
700
701 vdev->priv = NULL;
702 rcu_assign_pointer(the_virtio_vsock, NULL);
703 synchronize_rcu();
704
705 virtio_vsock_vqs_del(vsock);
706
707 /* Other works can be queued before 'config->del_vqs()', so we flush
708 * all works before to free the vsock object to avoid use after free.
709 */
710 flush_work(&vsock->rx_work);
711 flush_work(&vsock->tx_work);
712 flush_work(&vsock->event_work);
713 flush_work(&vsock->send_pkt_work);
714
715 mutex_unlock(&the_virtio_vsock_mutex);
716
717 kfree(vsock);
718 }
719
720 #ifdef CONFIG_PM_SLEEP
virtio_vsock_freeze(struct virtio_device * vdev)721 static int virtio_vsock_freeze(struct virtio_device *vdev)
722 {
723 struct virtio_vsock *vsock = vdev->priv;
724
725 mutex_lock(&the_virtio_vsock_mutex);
726
727 rcu_assign_pointer(the_virtio_vsock, NULL);
728 synchronize_rcu();
729
730 virtio_vsock_vqs_del(vsock);
731
732 mutex_unlock(&the_virtio_vsock_mutex);
733
734 return 0;
735 }
736
virtio_vsock_restore(struct virtio_device * vdev)737 static int virtio_vsock_restore(struct virtio_device *vdev)
738 {
739 struct virtio_vsock *vsock = vdev->priv;
740 int ret;
741
742 mutex_lock(&the_virtio_vsock_mutex);
743
744 /* Only one virtio-vsock device per guest is supported */
745 if (rcu_dereference_protected(the_virtio_vsock,
746 lockdep_is_held(&the_virtio_vsock_mutex))) {
747 ret = -EBUSY;
748 goto out;
749 }
750
751 ret = virtio_vsock_vqs_init(vsock);
752 if (ret < 0)
753 goto out;
754
755 rcu_assign_pointer(the_virtio_vsock, vsock);
756 virtio_vsock_vqs_start(vsock);
757
758 out:
759 mutex_unlock(&the_virtio_vsock_mutex);
760 return ret;
761 }
762 #endif /* CONFIG_PM_SLEEP */
763
764 static struct virtio_device_id id_table[] = {
765 { VIRTIO_ID_VSOCK, VIRTIO_DEV_ANY_ID },
766 { 0 },
767 };
768
769 static unsigned int features[] = {
770 VIRTIO_VSOCK_F_SEQPACKET
771 };
772
773 static struct virtio_driver virtio_vsock_driver = {
774 .feature_table = features,
775 .feature_table_size = ARRAY_SIZE(features),
776 .driver.name = KBUILD_MODNAME,
777 .driver.owner = THIS_MODULE,
778 .id_table = id_table,
779 .probe = virtio_vsock_probe,
780 .remove = virtio_vsock_remove,
781 #ifdef CONFIG_PM_SLEEP
782 .freeze = virtio_vsock_freeze,
783 .restore = virtio_vsock_restore,
784 #endif
785 };
786
virtio_vsock_init(void)787 static int __init virtio_vsock_init(void)
788 {
789 int ret;
790
791 virtio_vsock_workqueue = alloc_workqueue("virtio_vsock", 0, 0);
792 if (!virtio_vsock_workqueue)
793 return -ENOMEM;
794
795 ret = vsock_core_register(&virtio_transport.transport,
796 VSOCK_TRANSPORT_F_G2H);
797 if (ret)
798 goto out_wq;
799
800 ret = register_virtio_driver(&virtio_vsock_driver);
801 if (ret)
802 goto out_vci;
803
804 return 0;
805
806 out_vci:
807 vsock_core_unregister(&virtio_transport.transport);
808 out_wq:
809 destroy_workqueue(virtio_vsock_workqueue);
810 return ret;
811 }
812
virtio_vsock_exit(void)813 static void __exit virtio_vsock_exit(void)
814 {
815 unregister_virtio_driver(&virtio_vsock_driver);
816 vsock_core_unregister(&virtio_transport.transport);
817 destroy_workqueue(virtio_vsock_workqueue);
818 }
819
820 module_init(virtio_vsock_init);
821 module_exit(virtio_vsock_exit);
822 MODULE_LICENSE("GPL v2");
823 MODULE_AUTHOR("Asias He");
824 MODULE_DESCRIPTION("virtio transport for vsock");
825 MODULE_DEVICE_TABLE(virtio, id_table);
826