xref: /openbmc/qemu/tests/vhost-user-bridge.c (revision 60e58bd9f08a3b91a35850f7501a0a1bcf912b6f)
1 /*
2  * Vhost User Bridge
3  *
4  * Copyright (c) 2015 Red Hat, Inc.
5  *
6  * Authors:
7  *  Victor Kaplansky <victork@redhat.com>
8  *
9  * This work is licensed under the terms of the GNU GPL, version 2 or
10  * later.  See the COPYING file in the top-level directory.
11  */
12 
13 /*
14  * TODO:
15  *     - main should get parameters from the command line.
16  *     - implement all request handlers. Still not implemented:
17  *          vubr_get_queue_num_exec()
18  *          vubr_send_rarp_exec()
19  *     - test for broken requests and virtqueue.
20  *     - implement features defined by Virtio 1.0 spec.
21  *     - support mergeable buffers and indirect descriptors.
22  *     - implement clean shutdown.
23  *     - implement non-blocking writes to UDP backend.
24  *     - implement polling strategy.
25  *     - implement clean starting/stopping of vq processing
26  *     - implement clean starting/stopping of used and buffers
27  *       dirty page logging.
28  */
29 
30 #define _FILE_OFFSET_BITS 64
31 
32 #include "qemu/osdep.h"
33 #include "qemu/iov.h"
34 #include "standard-headers/linux/virtio_net.h"
35 #include "contrib/libvhost-user/libvhost-user.h"
36 
37 #define VHOST_USER_BRIDGE_DEBUG 1
38 
39 #define DPRINT(...) \
40     do { \
41         if (VHOST_USER_BRIDGE_DEBUG) { \
42             printf(__VA_ARGS__); \
43         } \
44     } while (0)
45 
46 typedef void (*CallbackFunc)(int sock, void *ctx);
47 
48 typedef struct Event {
49     void *ctx;
50     CallbackFunc callback;
51 } Event;
52 
53 typedef struct Dispatcher {
54     int max_sock;
55     fd_set fdset;
56     Event events[FD_SETSIZE];
57 } Dispatcher;
58 
59 typedef struct VubrDev {
60     VuDev vudev;
61     Dispatcher dispatcher;
62     int backend_udp_sock;
63     struct sockaddr_in backend_udp_dest;
64     int hdrlen;
65     int sock;
66     int ready;
67     int quit;
68 } VubrDev;
69 
70 static void
71 vubr_die(const char *s)
72 {
73     perror(s);
74     exit(1);
75 }
76 
77 static int
78 dispatcher_init(Dispatcher *dispr)
79 {
80     FD_ZERO(&dispr->fdset);
81     dispr->max_sock = -1;
82     return 0;
83 }
84 
85 static int
86 dispatcher_add(Dispatcher *dispr, int sock, void *ctx, CallbackFunc cb)
87 {
88     if (sock >= FD_SETSIZE) {
89         fprintf(stderr,
90                 "Error: Failed to add new event. sock %d should be less than %d\n",
91                 sock, FD_SETSIZE);
92         return -1;
93     }
94 
95     dispr->events[sock].ctx = ctx;
96     dispr->events[sock].callback = cb;
97 
98     FD_SET(sock, &dispr->fdset);
99     if (sock > dispr->max_sock) {
100         dispr->max_sock = sock;
101     }
102     DPRINT("Added sock %d for watching. max_sock: %d\n",
103            sock, dispr->max_sock);
104     return 0;
105 }
106 
107 static int
108 dispatcher_remove(Dispatcher *dispr, int sock)
109 {
110     if (sock >= FD_SETSIZE) {
111         fprintf(stderr,
112                 "Error: Failed to remove event. sock %d should be less than %d\n",
113                 sock, FD_SETSIZE);
114         return -1;
115     }
116 
117     FD_CLR(sock, &dispr->fdset);
118     DPRINT("Sock %d removed from dispatcher watch.\n", sock);
119     return 0;
120 }
121 
122 /* timeout in us */
123 static int
124 dispatcher_wait(Dispatcher *dispr, uint32_t timeout)
125 {
126     struct timeval tv;
127     tv.tv_sec = timeout / 1000000;
128     tv.tv_usec = timeout % 1000000;
129 
130     fd_set fdset = dispr->fdset;
131 
132     /* wait until some of sockets become readable. */
133     int rc = select(dispr->max_sock + 1, &fdset, 0, 0, &tv);
134 
135     if (rc == -1) {
136         vubr_die("select");
137     }
138 
139     /* Timeout */
140     if (rc == 0) {
141         return 0;
142     }
143 
144     /* Now call callback for every ready socket. */
145 
146     int sock;
147     for (sock = 0; sock < dispr->max_sock + 1; sock++) {
148         /* The callback on a socket can remove other sockets from the
149          * dispatcher, thus we have to check that the socket is
150          * still not removed from dispatcher's list
151          */
152         if (FD_ISSET(sock, &fdset) && FD_ISSET(sock, &dispr->fdset)) {
153             Event *e = &dispr->events[sock];
154             e->callback(sock, e->ctx);
155         }
156     }
157 
158     return 0;
159 }
160 
161 static void
162 vubr_handle_tx(VuDev *dev, int qidx)
163 {
164     VuVirtq *vq = vu_get_queue(dev, qidx);
165     VubrDev *vubr = container_of(dev, VubrDev, vudev);
166     int hdrlen = vubr->hdrlen;
167     VuVirtqElement *elem = NULL;
168 
169     assert(qidx % 2);
170 
171     for (;;) {
172         ssize_t ret;
173         unsigned int out_num;
174         struct iovec sg[VIRTQUEUE_MAX_SIZE], *out_sg;
175 
176         elem = vu_queue_pop(dev, vq, sizeof(VuVirtqElement));
177         if (!elem) {
178             break;
179         }
180 
181         out_num = elem->out_num;
182         out_sg = elem->out_sg;
183         if (out_num < 1) {
184             fprintf(stderr, "virtio-net header not in first element\n");
185             break;
186         }
187         if (VHOST_USER_BRIDGE_DEBUG) {
188             iov_hexdump(out_sg, out_num, stderr, "TX:", 1024);
189         }
190 
191         if (hdrlen) {
192             unsigned sg_num = iov_copy(sg, ARRAY_SIZE(sg),
193                                        out_sg, out_num,
194                                        hdrlen, -1);
195             out_num = sg_num;
196             out_sg = sg;
197         }
198 
199         struct msghdr msg = {
200             .msg_name = (struct sockaddr *) &vubr->backend_udp_dest,
201             .msg_namelen = sizeof(struct sockaddr_in),
202             .msg_iov = out_sg,
203             .msg_iovlen = out_num,
204         };
205         do {
206             ret = sendmsg(vubr->backend_udp_sock, &msg, 0);
207         } while (ret == -1 && (errno == EAGAIN || errno == EINTR));
208 
209         if (ret == -1) {
210             vubr_die("sendmsg()");
211         }
212 
213         vu_queue_push(dev, vq, elem, 0);
214         vu_queue_notify(dev, vq);
215 
216         free(elem);
217         elem = NULL;
218     }
219 
220     free(elem);
221 }
222 
223 
224 /* this function reverse the effect of iov_discard_front() it must be
225  * called with 'front' being the original struct iovec and 'bytes'
226  * being the number of bytes you shaved off
227  */
228 static void
229 iov_restore_front(struct iovec *front, struct iovec *iov, size_t bytes)
230 {
231     struct iovec *cur;
232 
233     for (cur = front; cur != iov; cur++) {
234         assert(bytes >= cur->iov_len);
235         bytes -= cur->iov_len;
236     }
237 
238     cur->iov_base -= bytes;
239     cur->iov_len += bytes;
240 }
241 
242 static void
243 iov_truncate(struct iovec *iov, unsigned iovc, size_t bytes)
244 {
245     unsigned i;
246 
247     for (i = 0; i < iovc; i++, iov++) {
248         if (bytes < iov->iov_len) {
249             iov->iov_len = bytes;
250             return;
251         }
252 
253         bytes -= iov->iov_len;
254     }
255 
256     assert(!"couldn't truncate iov");
257 }
258 
259 static void
260 vubr_backend_recv_cb(int sock, void *ctx)
261 {
262     VubrDev *vubr = (VubrDev *) ctx;
263     VuDev *dev = &vubr->vudev;
264     VuVirtq *vq = vu_get_queue(dev, 0);
265     VuVirtqElement *elem = NULL;
266     struct iovec mhdr_sg[VIRTQUEUE_MAX_SIZE];
267     struct virtio_net_hdr_mrg_rxbuf mhdr;
268     unsigned mhdr_cnt = 0;
269     int hdrlen = vubr->hdrlen;
270     int i = 0;
271     struct virtio_net_hdr hdr = {
272         .flags = 0,
273         .gso_type = VIRTIO_NET_HDR_GSO_NONE
274     };
275 
276     DPRINT("\n\n   ***   IN UDP RECEIVE CALLBACK    ***\n\n");
277     DPRINT("    hdrlen = %d\n", hdrlen);
278 
279     if (!vu_queue_enabled(dev, vq) ||
280         !vu_queue_started(dev, vq) ||
281         !vu_queue_avail_bytes(dev, vq, hdrlen, 0)) {
282         DPRINT("Got UDP packet, but no available descriptors on RX virtq.\n");
283         return;
284     }
285 
286     while (1) {
287         struct iovec *sg;
288         ssize_t ret, total = 0;
289         unsigned int num;
290 
291         elem = vu_queue_pop(dev, vq, sizeof(VuVirtqElement));
292         if (!elem) {
293             break;
294         }
295 
296         if (elem->in_num < 1) {
297             fprintf(stderr, "virtio-net contains no in buffers\n");
298             break;
299         }
300 
301         sg = elem->in_sg;
302         num = elem->in_num;
303         if (i == 0) {
304             if (hdrlen == 12) {
305                 mhdr_cnt = iov_copy(mhdr_sg, ARRAY_SIZE(mhdr_sg),
306                                     sg, elem->in_num,
307                                     offsetof(typeof(mhdr), num_buffers),
308                                     sizeof(mhdr.num_buffers));
309             }
310             iov_from_buf(sg, elem->in_num, 0, &hdr, sizeof hdr);
311             total += hdrlen;
312             ret = iov_discard_front(&sg, &num, hdrlen);
313             assert(ret == hdrlen);
314         }
315 
316         struct msghdr msg = {
317             .msg_name = (struct sockaddr *) &vubr->backend_udp_dest,
318             .msg_namelen = sizeof(struct sockaddr_in),
319             .msg_iov = sg,
320             .msg_iovlen = elem->in_num,
321             .msg_flags = MSG_DONTWAIT,
322         };
323         do {
324             ret = recvmsg(vubr->backend_udp_sock, &msg, 0);
325         } while (ret == -1 && (errno == EINTR));
326 
327         if (i == 0) {
328             iov_restore_front(elem->in_sg, sg, hdrlen);
329         }
330 
331         if (ret == -1) {
332             if (errno == EWOULDBLOCK) {
333                 vu_queue_rewind(dev, vq, 1);
334                 break;
335             }
336 
337             vubr_die("recvmsg()");
338         }
339 
340         total += ret;
341         iov_truncate(elem->in_sg, elem->in_num, total);
342         vu_queue_fill(dev, vq, elem, total, i++);
343 
344         free(elem);
345         elem = NULL;
346 
347         break;        /* could loop if DONTWAIT worked? */
348     }
349 
350     if (mhdr_cnt) {
351         mhdr.num_buffers = i;
352         iov_from_buf(mhdr_sg, mhdr_cnt,
353                      0,
354                      &mhdr.num_buffers, sizeof mhdr.num_buffers);
355     }
356 
357     vu_queue_flush(dev, vq, i);
358     vu_queue_notify(dev, vq);
359 
360     free(elem);
361 }
362 
363 static void
364 vubr_receive_cb(int sock, void *ctx)
365 {
366     VubrDev *vubr = (VubrDev *)ctx;
367 
368     if (!vu_dispatch(&vubr->vudev)) {
369         fprintf(stderr, "Error while dispatching\n");
370     }
371 }
372 
373 typedef struct WatchData {
374     VuDev *dev;
375     vu_watch_cb cb;
376     void *data;
377 } WatchData;
378 
379 static void
380 watch_cb(int sock, void *ctx)
381 {
382     struct WatchData *wd = ctx;
383 
384     wd->cb(wd->dev, VU_WATCH_IN, wd->data);
385 }
386 
387 static void
388 vubr_set_watch(VuDev *dev, int fd, int condition,
389                vu_watch_cb cb, void *data)
390 {
391     VubrDev *vubr = container_of(dev, VubrDev, vudev);
392     static WatchData watches[FD_SETSIZE];
393     struct WatchData *wd = &watches[fd];
394 
395     wd->cb = cb;
396     wd->data = data;
397     wd->dev = dev;
398     dispatcher_add(&vubr->dispatcher, fd, wd, watch_cb);
399 }
400 
401 static void
402 vubr_remove_watch(VuDev *dev, int fd)
403 {
404     VubrDev *vubr = container_of(dev, VubrDev, vudev);
405 
406     dispatcher_remove(&vubr->dispatcher, fd);
407 }
408 
409 static int
410 vubr_send_rarp_exec(VuDev *dev, VhostUserMsg *vmsg)
411 {
412     DPRINT("Function %s() not implemented yet.\n", __func__);
413     return 0;
414 }
415 
416 static int
417 vubr_process_msg(VuDev *dev, VhostUserMsg *vmsg, int *do_reply)
418 {
419     switch (vmsg->request) {
420     case VHOST_USER_SEND_RARP:
421         *do_reply = vubr_send_rarp_exec(dev, vmsg);
422         return 1;
423     default:
424         /* let the library handle the rest */
425         return 0;
426     }
427 
428     return 0;
429 }
430 
431 static void
432 vubr_set_features(VuDev *dev, uint64_t features)
433 {
434     VubrDev *vubr = container_of(dev, VubrDev, vudev);
435 
436     if ((features & (1ULL << VIRTIO_F_VERSION_1)) ||
437         (features & (1ULL << VIRTIO_NET_F_MRG_RXBUF))) {
438         vubr->hdrlen = 12;
439     } else {
440         vubr->hdrlen = 10;
441     }
442 }
443 
444 static uint64_t
445 vubr_get_features(VuDev *dev)
446 {
447     return 1ULL << VIRTIO_NET_F_GUEST_ANNOUNCE |
448         1ULL << VIRTIO_NET_F_MRG_RXBUF;
449 }
450 
451 static void
452 vubr_queue_set_started(VuDev *dev, int qidx, bool started)
453 {
454     VuVirtq *vq = vu_get_queue(dev, qidx);
455 
456     if (qidx % 2 == 1) {
457         vu_set_queue_handler(dev, vq, started ? vubr_handle_tx : NULL);
458     }
459 }
460 
461 static void
462 vubr_panic(VuDev *dev, const char *msg)
463 {
464     VubrDev *vubr = container_of(dev, VubrDev, vudev);
465 
466     fprintf(stderr, "PANIC: %s\n", msg);
467 
468     dispatcher_remove(&vubr->dispatcher, dev->sock);
469     vubr->quit = 1;
470 }
471 
472 static bool
473 vubr_queue_is_processed_in_order(VuDev *dev, int qidx)
474 {
475     return true;
476 }
477 
478 static const VuDevIface vuiface = {
479     .get_features = vubr_get_features,
480     .set_features = vubr_set_features,
481     .process_msg = vubr_process_msg,
482     .queue_set_started = vubr_queue_set_started,
483     .queue_is_processed_in_order = vubr_queue_is_processed_in_order,
484 };
485 
486 static void
487 vubr_accept_cb(int sock, void *ctx)
488 {
489     VubrDev *dev = (VubrDev *)ctx;
490     int conn_fd;
491     struct sockaddr_un un;
492     socklen_t len = sizeof(un);
493 
494     conn_fd = accept(sock, (struct sockaddr *) &un, &len);
495     if (conn_fd == -1) {
496         vubr_die("accept()");
497     }
498     DPRINT("Got connection from remote peer on sock %d\n", conn_fd);
499 
500     vu_init(&dev->vudev,
501             conn_fd,
502             vubr_panic,
503             vubr_set_watch,
504             vubr_remove_watch,
505             &vuiface);
506 
507     dispatcher_add(&dev->dispatcher, conn_fd, ctx, vubr_receive_cb);
508     dispatcher_remove(&dev->dispatcher, sock);
509 }
510 
511 static VubrDev *
512 vubr_new(const char *path, bool client)
513 {
514     VubrDev *dev = (VubrDev *) calloc(1, sizeof(VubrDev));
515     struct sockaddr_un un;
516     CallbackFunc cb;
517     size_t len;
518 
519     /* Get a UNIX socket. */
520     dev->sock = socket(AF_UNIX, SOCK_STREAM, 0);
521     if (dev->sock == -1) {
522         vubr_die("socket");
523     }
524 
525     un.sun_family = AF_UNIX;
526     strcpy(un.sun_path, path);
527     len = sizeof(un.sun_family) + strlen(path);
528 
529     if (!client) {
530         unlink(path);
531 
532         if (bind(dev->sock, (struct sockaddr *) &un, len) == -1) {
533             vubr_die("bind");
534         }
535 
536         if (listen(dev->sock, 1) == -1) {
537             vubr_die("listen");
538         }
539         cb = vubr_accept_cb;
540 
541         DPRINT("Waiting for connections on UNIX socket %s ...\n", path);
542     } else {
543         if (connect(dev->sock, (struct sockaddr *)&un, len) == -1) {
544             vubr_die("connect");
545         }
546         vu_init(&dev->vudev,
547                 dev->sock,
548                 vubr_panic,
549                 vubr_set_watch,
550                 vubr_remove_watch,
551                 &vuiface);
552         cb = vubr_receive_cb;
553     }
554 
555     dispatcher_init(&dev->dispatcher);
556 
557     dispatcher_add(&dev->dispatcher, dev->sock, (void *)dev, cb);
558 
559     return dev;
560 }
561 
562 static void
563 vubr_set_host(struct sockaddr_in *saddr, const char *host)
564 {
565     if (isdigit(host[0])) {
566         if (!inet_aton(host, &saddr->sin_addr)) {
567             fprintf(stderr, "inet_aton() failed.\n");
568             exit(1);
569         }
570     } else {
571         struct hostent *he = gethostbyname(host);
572 
573         if (!he) {
574             fprintf(stderr, "gethostbyname() failed.\n");
575             exit(1);
576         }
577         saddr->sin_addr = *(struct in_addr *)he->h_addr;
578     }
579 }
580 
581 static void
582 vubr_backend_udp_setup(VubrDev *dev,
583                        const char *local_host,
584                        const char *local_port,
585                        const char *remote_host,
586                        const char *remote_port)
587 {
588     int sock;
589     const char *r;
590 
591     int lport, rport;
592 
593     lport = strtol(local_port, (char **)&r, 0);
594     if (r == local_port) {
595         fprintf(stderr, "lport parsing failed.\n");
596         exit(1);
597     }
598 
599     rport = strtol(remote_port, (char **)&r, 0);
600     if (r == remote_port) {
601         fprintf(stderr, "rport parsing failed.\n");
602         exit(1);
603     }
604 
605     struct sockaddr_in si_local = {
606         .sin_family = AF_INET,
607         .sin_port = htons(lport),
608     };
609 
610     vubr_set_host(&si_local, local_host);
611 
612     /* setup destination for sends */
613     dev->backend_udp_dest = (struct sockaddr_in) {
614         .sin_family = AF_INET,
615         .sin_port = htons(rport),
616     };
617     vubr_set_host(&dev->backend_udp_dest, remote_host);
618 
619     sock = socket(AF_INET, SOCK_DGRAM, IPPROTO_UDP);
620     if (sock == -1) {
621         vubr_die("socket");
622     }
623 
624     if (bind(sock, (struct sockaddr *)&si_local, sizeof(si_local)) == -1) {
625         vubr_die("bind");
626     }
627 
628     dev->backend_udp_sock = sock;
629     dispatcher_add(&dev->dispatcher, sock, dev, vubr_backend_recv_cb);
630     DPRINT("Waiting for data from udp backend on %s:%d...\n",
631            local_host, lport);
632 }
633 
634 static void
635 vubr_run(VubrDev *dev)
636 {
637     while (!dev->quit) {
638         /* timeout 200ms */
639         dispatcher_wait(&dev->dispatcher, 200000);
640         /* Here one can try polling strategy. */
641     }
642 }
643 
644 static int
645 vubr_parse_host_port(const char **host, const char **port, const char *buf)
646 {
647     char *p = strchr(buf, ':');
648 
649     if (!p) {
650         return -1;
651     }
652     *p = '\0';
653     *host = strdup(buf);
654     *port = strdup(p + 1);
655     return 0;
656 }
657 
658 #define DEFAULT_UD_SOCKET "/tmp/vubr.sock"
659 #define DEFAULT_LHOST "127.0.0.1"
660 #define DEFAULT_LPORT "4444"
661 #define DEFAULT_RHOST "127.0.0.1"
662 #define DEFAULT_RPORT "5555"
663 
664 static const char *ud_socket_path = DEFAULT_UD_SOCKET;
665 static const char *lhost = DEFAULT_LHOST;
666 static const char *lport = DEFAULT_LPORT;
667 static const char *rhost = DEFAULT_RHOST;
668 static const char *rport = DEFAULT_RPORT;
669 
670 int
671 main(int argc, char *argv[])
672 {
673     VubrDev *dev;
674     int opt;
675     bool client = false;
676 
677     while ((opt = getopt(argc, argv, "l:r:u:c")) != -1) {
678 
679         switch (opt) {
680         case 'l':
681             if (vubr_parse_host_port(&lhost, &lport, optarg) < 0) {
682                 goto out;
683             }
684             break;
685         case 'r':
686             if (vubr_parse_host_port(&rhost, &rport, optarg) < 0) {
687                 goto out;
688             }
689             break;
690         case 'u':
691             ud_socket_path = strdup(optarg);
692             break;
693         case 'c':
694             client = true;
695             break;
696         default:
697             goto out;
698         }
699     }
700 
701     DPRINT("ud socket: %s (%s)\n", ud_socket_path,
702            client ? "client" : "server");
703     DPRINT("local:     %s:%s\n", lhost, lport);
704     DPRINT("remote:    %s:%s\n", rhost, rport);
705 
706     dev = vubr_new(ud_socket_path, client);
707     if (!dev) {
708         return 1;
709     }
710 
711     vubr_backend_udp_setup(dev, lhost, lport, rhost, rport);
712     vubr_run(dev);
713 
714     vu_deinit(&dev->vudev);
715 
716     return 0;
717 
718 out:
719     fprintf(stderr, "Usage: %s ", argv[0]);
720     fprintf(stderr, "[-c] [-u ud_socket_path] [-l lhost:lport] [-r rhost:rport]\n");
721     fprintf(stderr, "\t-u path to unix doman socket. default: %s\n",
722             DEFAULT_UD_SOCKET);
723     fprintf(stderr, "\t-l local host and port. default: %s:%s\n",
724             DEFAULT_LHOST, DEFAULT_LPORT);
725     fprintf(stderr, "\t-r remote host and port. default: %s:%s\n",
726             DEFAULT_RHOST, DEFAULT_RPORT);
727     fprintf(stderr, "\t-c client mode\n");
728 
729     return 1;
730 }
731