xref: /openbmc/qemu/net/l2tpv3.c (revision 8e6fe6b8bab4716b4adf99a9ab52eaa82464b37e)
1 /*
2  * QEMU System Emulator
3  *
4  * Copyright (c) 2003-2008 Fabrice Bellard
5  * Copyright (c) 2012-2014 Cisco Systems
6  *
7  * Permission is hereby granted, free of charge, to any person obtaining a copy
8  * of this software and associated documentation files (the "Software"), to deal
9  * in the Software without restriction, including without limitation the rights
10  * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11  * copies of the Software, and to permit persons to whom the Software is
12  * furnished to do so, subject to the following conditions:
13  *
14  * The above copyright notice and this permission notice shall be included in
15  * all copies or substantial portions of the Software.
16  *
17  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
20  * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22  * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
23  * THE SOFTWARE.
24  */
25 
26 #include "qemu/osdep.h"
27 #include <linux/ip.h>
28 #include <netdb.h>
29 #include "net/net.h"
30 #include "clients.h"
31 #include "qapi/error.h"
32 #include "qemu/error-report.h"
33 #include "qemu/option.h"
34 #include "qemu/sockets.h"
35 #include "qemu/iov.h"
36 #include "qemu/main-loop.h"
37 
38 
39 /* The buffer size needs to be investigated for optimum numbers and
40  * optimum means of paging in on different systems. This size is
41  * chosen to be sufficient to accommodate one packet with some headers
42  */
43 
44 #define BUFFER_ALIGN sysconf(_SC_PAGESIZE)
45 #define BUFFER_SIZE 2048
46 #define IOVSIZE 2
47 #define MAX_L2TPV3_MSGCNT 64
48 #define MAX_L2TPV3_IOVCNT (MAX_L2TPV3_MSGCNT * IOVSIZE)
49 
50 /* Header set to 0x30000 signifies a data packet */
51 
52 #define L2TPV3_DATA_PACKET 0x30000
53 
54 /* IANA-assigned IP protocol ID for L2TPv3 */
55 
56 #ifndef IPPROTO_L2TP
57 #define IPPROTO_L2TP 0x73
58 #endif
59 
60 typedef struct NetL2TPV3State {
61     NetClientState nc;
62     int fd;
63 
64     /*
65      * these are used for xmit - that happens packet a time
66      * and for first sign of life packet (easier to parse that once)
67      */
68 
69     uint8_t *header_buf;
70     struct iovec *vec;
71 
72     /*
73      * these are used for receive - try to "eat" up to 32 packets at a time
74      */
75 
76     struct mmsghdr *msgvec;
77 
78     /*
79      * peer address
80      */
81 
82     struct sockaddr_storage *dgram_dst;
83     uint32_t dst_size;
84 
85     /*
86      * L2TPv3 parameters
87      */
88 
89     uint64_t rx_cookie;
90     uint64_t tx_cookie;
91     uint32_t rx_session;
92     uint32_t tx_session;
93     uint32_t header_size;
94     uint32_t counter;
95 
96     /*
97     * DOS avoidance in error handling
98     */
99 
100     bool header_mismatch;
101 
102     /*
103      * Ring buffer handling
104      */
105 
106     int queue_head;
107     int queue_tail;
108     int queue_depth;
109 
110     /*
111      * Precomputed offsets
112      */
113 
114     uint32_t offset;
115     uint32_t cookie_offset;
116     uint32_t counter_offset;
117     uint32_t session_offset;
118 
119     /* Poll Control */
120 
121     bool read_poll;
122     bool write_poll;
123 
124     /* Flags */
125 
126     bool ipv6;
127     bool udp;
128     bool has_counter;
129     bool pin_counter;
130     bool cookie;
131     bool cookie_is_64;
132 
133 } NetL2TPV3State;
134 
135 static void net_l2tpv3_send(void *opaque);
136 static void l2tpv3_writable(void *opaque);
137 
138 static void l2tpv3_update_fd_handler(NetL2TPV3State *s)
139 {
140     qemu_set_fd_handler(s->fd,
141                         s->read_poll ? net_l2tpv3_send : NULL,
142                         s->write_poll ? l2tpv3_writable : NULL,
143                         s);
144 }
145 
146 static void l2tpv3_read_poll(NetL2TPV3State *s, bool enable)
147 {
148     if (s->read_poll != enable) {
149         s->read_poll = enable;
150         l2tpv3_update_fd_handler(s);
151     }
152 }
153 
154 static void l2tpv3_write_poll(NetL2TPV3State *s, bool enable)
155 {
156     if (s->write_poll != enable) {
157         s->write_poll = enable;
158         l2tpv3_update_fd_handler(s);
159     }
160 }
161 
162 static void l2tpv3_writable(void *opaque)
163 {
164     NetL2TPV3State *s = opaque;
165     l2tpv3_write_poll(s, false);
166     qemu_flush_queued_packets(&s->nc);
167 }
168 
169 static void l2tpv3_send_completed(NetClientState *nc, ssize_t len)
170 {
171     NetL2TPV3State *s = DO_UPCAST(NetL2TPV3State, nc, nc);
172     l2tpv3_read_poll(s, true);
173 }
174 
175 static void l2tpv3_poll(NetClientState *nc, bool enable)
176 {
177     NetL2TPV3State *s = DO_UPCAST(NetL2TPV3State, nc, nc);
178     l2tpv3_write_poll(s, enable);
179     l2tpv3_read_poll(s, enable);
180 }
181 
182 static void l2tpv3_form_header(NetL2TPV3State *s)
183 {
184     uint32_t *counter;
185 
186     if (s->udp) {
187         stl_be_p((uint32_t *) s->header_buf, L2TPV3_DATA_PACKET);
188     }
189     stl_be_p(
190             (uint32_t *) (s->header_buf + s->session_offset),
191             s->tx_session
192         );
193     if (s->cookie) {
194         if (s->cookie_is_64) {
195             stq_be_p(
196                 (uint64_t *)(s->header_buf + s->cookie_offset),
197                 s->tx_cookie
198             );
199         } else {
200             stl_be_p(
201                 (uint32_t *) (s->header_buf + s->cookie_offset),
202                 s->tx_cookie
203             );
204         }
205     }
206     if (s->has_counter) {
207         counter = (uint32_t *)(s->header_buf + s->counter_offset);
208         if (s->pin_counter) {
209             *counter = 0;
210         } else {
211             stl_be_p(counter, ++s->counter);
212         }
213     }
214 }
215 
216 static ssize_t net_l2tpv3_receive_dgram_iov(NetClientState *nc,
217                     const struct iovec *iov,
218                     int iovcnt)
219 {
220     NetL2TPV3State *s = DO_UPCAST(NetL2TPV3State, nc, nc);
221 
222     struct msghdr message;
223     int ret;
224 
225     if (iovcnt > MAX_L2TPV3_IOVCNT - 1) {
226         error_report(
227             "iovec too long %d > %d, change l2tpv3.h",
228             iovcnt, MAX_L2TPV3_IOVCNT
229         );
230         return -1;
231     }
232     l2tpv3_form_header(s);
233     memcpy(s->vec + 1, iov, iovcnt * sizeof(struct iovec));
234     s->vec->iov_base = s->header_buf;
235     s->vec->iov_len = s->offset;
236     message.msg_name = s->dgram_dst;
237     message.msg_namelen = s->dst_size;
238     message.msg_iov = s->vec;
239     message.msg_iovlen = iovcnt + 1;
240     message.msg_control = NULL;
241     message.msg_controllen = 0;
242     message.msg_flags = 0;
243     do {
244         ret = sendmsg(s->fd, &message, 0);
245     } while ((ret == -1) && (errno == EINTR));
246     if (ret > 0) {
247         ret -= s->offset;
248     } else if (ret == 0) {
249         /* belt and braces - should not occur on DGRAM
250         * we should get an error and never a 0 send
251         */
252         ret = iov_size(iov, iovcnt);
253     } else {
254         /* signal upper layer that socket buffer is full */
255         ret = -errno;
256         if (ret == -EAGAIN || ret == -ENOBUFS) {
257             l2tpv3_write_poll(s, true);
258             ret = 0;
259         }
260     }
261     return ret;
262 }
263 
264 static ssize_t net_l2tpv3_receive_dgram(NetClientState *nc,
265                     const uint8_t *buf,
266                     size_t size)
267 {
268     NetL2TPV3State *s = DO_UPCAST(NetL2TPV3State, nc, nc);
269 
270     struct iovec *vec;
271     struct msghdr message;
272     ssize_t ret = 0;
273 
274     l2tpv3_form_header(s);
275     vec = s->vec;
276     vec->iov_base = s->header_buf;
277     vec->iov_len = s->offset;
278     vec++;
279     vec->iov_base = (void *) buf;
280     vec->iov_len = size;
281     message.msg_name = s->dgram_dst;
282     message.msg_namelen = s->dst_size;
283     message.msg_iov = s->vec;
284     message.msg_iovlen = 2;
285     message.msg_control = NULL;
286     message.msg_controllen = 0;
287     message.msg_flags = 0;
288     do {
289         ret = sendmsg(s->fd, &message, 0);
290     } while ((ret == -1) && (errno == EINTR));
291     if (ret > 0) {
292         ret -= s->offset;
293     } else if (ret == 0) {
294         /* belt and braces - should not occur on DGRAM
295         * we should get an error and never a 0 send
296         */
297         ret = size;
298     } else {
299         ret = -errno;
300         if (ret == -EAGAIN || ret == -ENOBUFS) {
301             /* signal upper layer that socket buffer is full */
302             l2tpv3_write_poll(s, true);
303             ret = 0;
304         }
305     }
306     return ret;
307 }
308 
309 static int l2tpv3_verify_header(NetL2TPV3State *s, uint8_t *buf)
310 {
311 
312     uint32_t *session;
313     uint64_t cookie;
314 
315     if ((!s->udp) && (!s->ipv6)) {
316         buf += sizeof(struct iphdr) /* fix for ipv4 raw */;
317     }
318 
319     /* we do not do a strict check for "data" packets as per
320     * the RFC spec because the pure IP spec does not have
321     * that anyway.
322     */
323 
324     if (s->cookie) {
325         if (s->cookie_is_64) {
326             cookie = ldq_be_p(buf + s->cookie_offset);
327         } else {
328             cookie = ldl_be_p(buf + s->cookie_offset) & 0xffffffffULL;
329         }
330         if (cookie != s->rx_cookie) {
331             if (!s->header_mismatch) {
332                 error_report("unknown cookie id");
333             }
334             return -1;
335         }
336     }
337     session = (uint32_t *) (buf + s->session_offset);
338     if (ldl_be_p(session) != s->rx_session) {
339         if (!s->header_mismatch) {
340             error_report("session mismatch");
341         }
342         return -1;
343     }
344     return 0;
345 }
346 
347 static void net_l2tpv3_process_queue(NetL2TPV3State *s)
348 {
349     int size = 0;
350     struct iovec *vec;
351     bool bad_read;
352     int data_size;
353     struct mmsghdr *msgvec;
354 
355     /* go into ring mode only if there is a "pending" tail */
356     if (s->queue_depth > 0) {
357         do {
358             msgvec = s->msgvec + s->queue_tail;
359             if (msgvec->msg_len > 0) {
360                 data_size = msgvec->msg_len - s->header_size;
361                 vec = msgvec->msg_hdr.msg_iov;
362                 if ((data_size > 0) &&
363                     (l2tpv3_verify_header(s, vec->iov_base) == 0)) {
364                     vec++;
365                     /* Use the legacy delivery for now, we will
366                      * switch to using our own ring as a queueing mechanism
367                      * at a later date
368                      */
369                     size = qemu_send_packet_async(
370                             &s->nc,
371                             vec->iov_base,
372                             data_size,
373                             l2tpv3_send_completed
374                         );
375                     if (size == 0) {
376                         l2tpv3_read_poll(s, false);
377                     }
378                     bad_read = false;
379                 } else {
380                     bad_read = true;
381                     if (!s->header_mismatch) {
382                         /* report error only once */
383                         error_report("l2tpv3 header verification failed");
384                         s->header_mismatch = true;
385                     }
386                 }
387             } else {
388                 bad_read = true;
389             }
390             s->queue_tail = (s->queue_tail + 1) % MAX_L2TPV3_MSGCNT;
391             s->queue_depth--;
392         } while (
393                 (s->queue_depth > 0) &&
394                  qemu_can_send_packet(&s->nc) &&
395                 ((size > 0) || bad_read)
396             );
397     }
398 }
399 
400 static void net_l2tpv3_send(void *opaque)
401 {
402     NetL2TPV3State *s = opaque;
403     int target_count, count;
404     struct mmsghdr *msgvec;
405 
406     /* go into ring mode only if there is a "pending" tail */
407 
408     if (s->queue_depth) {
409 
410         /* The ring buffer we use has variable intake
411          * count of how much we can read varies - adjust accordingly
412          */
413 
414         target_count = MAX_L2TPV3_MSGCNT - s->queue_depth;
415 
416         /* Ensure we do not overrun the ring when we have
417          * a lot of enqueued packets
418          */
419 
420         if (s->queue_head + target_count > MAX_L2TPV3_MSGCNT) {
421             target_count = MAX_L2TPV3_MSGCNT - s->queue_head;
422         }
423     } else {
424 
425         /* we do not have any pending packets - we can use
426         * the whole message vector linearly instead of using
427         * it as a ring
428         */
429 
430         s->queue_head = 0;
431         s->queue_tail = 0;
432         target_count = MAX_L2TPV3_MSGCNT;
433     }
434 
435     msgvec = s->msgvec + s->queue_head;
436     if (target_count > 0) {
437         do {
438             count = recvmmsg(
439                 s->fd,
440                 msgvec,
441                 target_count, MSG_DONTWAIT, NULL);
442         } while ((count == -1) && (errno == EINTR));
443         if (count < 0) {
444             /* Recv error - we still need to flush packets here,
445              * (re)set queue head to current position
446              */
447             count = 0;
448         }
449         s->queue_head = (s->queue_head + count) % MAX_L2TPV3_MSGCNT;
450         s->queue_depth += count;
451     }
452     net_l2tpv3_process_queue(s);
453 }
454 
455 static void destroy_vector(struct mmsghdr *msgvec, int count, int iovcount)
456 {
457     int i, j;
458     struct iovec *iov;
459     struct mmsghdr *cleanup = msgvec;
460     if (cleanup) {
461         for (i = 0; i < count; i++) {
462             if (cleanup->msg_hdr.msg_iov) {
463                 iov = cleanup->msg_hdr.msg_iov;
464                 for (j = 0; j < iovcount; j++) {
465                     g_free(iov->iov_base);
466                     iov++;
467                 }
468                 g_free(cleanup->msg_hdr.msg_iov);
469             }
470             cleanup++;
471         }
472         g_free(msgvec);
473     }
474 }
475 
476 static struct mmsghdr *build_l2tpv3_vector(NetL2TPV3State *s, int count)
477 {
478     int i;
479     struct iovec *iov;
480     struct mmsghdr *msgvec, *result;
481 
482     msgvec = g_new(struct mmsghdr, count);
483     result = msgvec;
484     for (i = 0; i < count ; i++) {
485         msgvec->msg_hdr.msg_name = NULL;
486         msgvec->msg_hdr.msg_namelen = 0;
487         iov =  g_new(struct iovec, IOVSIZE);
488         msgvec->msg_hdr.msg_iov = iov;
489         iov->iov_base = g_malloc(s->header_size);
490         iov->iov_len = s->header_size;
491         iov++ ;
492         iov->iov_base = qemu_memalign(BUFFER_ALIGN, BUFFER_SIZE);
493         iov->iov_len = BUFFER_SIZE;
494         msgvec->msg_hdr.msg_iovlen = 2;
495         msgvec->msg_hdr.msg_control = NULL;
496         msgvec->msg_hdr.msg_controllen = 0;
497         msgvec->msg_hdr.msg_flags = 0;
498         msgvec++;
499     }
500     return result;
501 }
502 
503 static void net_l2tpv3_cleanup(NetClientState *nc)
504 {
505     NetL2TPV3State *s = DO_UPCAST(NetL2TPV3State, nc, nc);
506     qemu_purge_queued_packets(nc);
507     l2tpv3_read_poll(s, false);
508     l2tpv3_write_poll(s, false);
509     if (s->fd >= 0) {
510         close(s->fd);
511     }
512     destroy_vector(s->msgvec, MAX_L2TPV3_MSGCNT, IOVSIZE);
513     g_free(s->vec);
514     g_free(s->header_buf);
515     g_free(s->dgram_dst);
516 }
517 
518 static NetClientInfo net_l2tpv3_info = {
519     .type = NET_CLIENT_DRIVER_L2TPV3,
520     .size = sizeof(NetL2TPV3State),
521     .receive = net_l2tpv3_receive_dgram,
522     .receive_iov = net_l2tpv3_receive_dgram_iov,
523     .poll = l2tpv3_poll,
524     .cleanup = net_l2tpv3_cleanup,
525 };
526 
527 int net_init_l2tpv3(const Netdev *netdev,
528                     const char *name,
529                     NetClientState *peer, Error **errp)
530 {
531     const NetdevL2TPv3Options *l2tpv3;
532     NetL2TPV3State *s;
533     NetClientState *nc;
534     int fd = -1, gairet;
535     struct addrinfo hints;
536     struct addrinfo *result = NULL;
537     char *srcport, *dstport;
538 
539     nc = qemu_new_net_client(&net_l2tpv3_info, peer, "l2tpv3", name);
540 
541     s = DO_UPCAST(NetL2TPV3State, nc, nc);
542 
543     s->queue_head = 0;
544     s->queue_tail = 0;
545     s->header_mismatch = false;
546 
547     assert(netdev->type == NET_CLIENT_DRIVER_L2TPV3);
548     l2tpv3 = &netdev->u.l2tpv3;
549 
550     if (l2tpv3->has_ipv6 && l2tpv3->ipv6) {
551         s->ipv6 = l2tpv3->ipv6;
552     } else {
553         s->ipv6 = false;
554     }
555 
556     if ((l2tpv3->has_offset) && (l2tpv3->offset > 256)) {
557         error_setg(errp, "offset must be less than 256 bytes");
558         goto outerr;
559     }
560 
561     if (l2tpv3->has_rxcookie || l2tpv3->has_txcookie) {
562         if (l2tpv3->has_rxcookie && l2tpv3->has_txcookie) {
563             s->cookie = true;
564         } else {
565             error_setg(errp,
566                        "require both 'rxcookie' and 'txcookie' or neither");
567             goto outerr;
568         }
569     } else {
570         s->cookie = false;
571     }
572 
573     if (l2tpv3->has_cookie64 || l2tpv3->cookie64) {
574         s->cookie_is_64  = true;
575     } else {
576         s->cookie_is_64  = false;
577     }
578 
579     if (l2tpv3->has_udp && l2tpv3->udp) {
580         s->udp = true;
581         if (!(l2tpv3->has_srcport && l2tpv3->has_dstport)) {
582             error_setg(errp, "need both src and dst port for udp");
583             goto outerr;
584         } else {
585             srcport = l2tpv3->srcport;
586             dstport = l2tpv3->dstport;
587         }
588     } else {
589         s->udp = false;
590         srcport = NULL;
591         dstport = NULL;
592     }
593 
594 
595     s->offset = 4;
596     s->session_offset = 0;
597     s->cookie_offset = 4;
598     s->counter_offset = 4;
599 
600     s->tx_session = l2tpv3->txsession;
601     if (l2tpv3->has_rxsession) {
602         s->rx_session = l2tpv3->rxsession;
603     } else {
604         s->rx_session = s->tx_session;
605     }
606 
607     if (s->cookie) {
608         s->rx_cookie = l2tpv3->rxcookie;
609         s->tx_cookie = l2tpv3->txcookie;
610         if (s->cookie_is_64 == true) {
611             /* 64 bit cookie */
612             s->offset += 8;
613             s->counter_offset += 8;
614         } else {
615             /* 32 bit cookie */
616             s->offset += 4;
617             s->counter_offset += 4;
618         }
619     }
620 
621     memset(&hints, 0, sizeof(hints));
622 
623     if (s->ipv6) {
624         hints.ai_family = AF_INET6;
625     } else {
626         hints.ai_family = AF_INET;
627     }
628     if (s->udp) {
629         hints.ai_socktype = SOCK_DGRAM;
630         hints.ai_protocol = 0;
631         s->offset += 4;
632         s->counter_offset += 4;
633         s->session_offset += 4;
634         s->cookie_offset += 4;
635     } else {
636         hints.ai_socktype = SOCK_RAW;
637         hints.ai_protocol = IPPROTO_L2TP;
638     }
639 
640     gairet = getaddrinfo(l2tpv3->src, srcport, &hints, &result);
641 
642     if ((gairet != 0) || (result == NULL)) {
643         error_setg(errp, "could not resolve src, errno = %s",
644                    gai_strerror(gairet));
645         goto outerr;
646     }
647     fd = socket(result->ai_family, result->ai_socktype, result->ai_protocol);
648     if (fd == -1) {
649         fd = -errno;
650         error_setg(errp, "socket creation failed, errno = %d",
651                    -fd);
652         goto outerr;
653     }
654     if (bind(fd, (struct sockaddr *) result->ai_addr, result->ai_addrlen)) {
655         error_setg(errp, "could not bind socket err=%i", errno);
656         goto outerr;
657     }
658     if (result) {
659         freeaddrinfo(result);
660     }
661 
662     memset(&hints, 0, sizeof(hints));
663 
664     if (s->ipv6) {
665         hints.ai_family = AF_INET6;
666     } else {
667         hints.ai_family = AF_INET;
668     }
669     if (s->udp) {
670         hints.ai_socktype = SOCK_DGRAM;
671         hints.ai_protocol = 0;
672     } else {
673         hints.ai_socktype = SOCK_RAW;
674         hints.ai_protocol = IPPROTO_L2TP;
675     }
676 
677     result = NULL;
678     gairet = getaddrinfo(l2tpv3->dst, dstport, &hints, &result);
679     if ((gairet != 0) || (result == NULL)) {
680         error_setg(errp, "could not resolve dst, error = %s",
681                    gai_strerror(gairet));
682         goto outerr;
683     }
684 
685     s->dgram_dst = g_new0(struct sockaddr_storage, 1);
686     memcpy(s->dgram_dst, result->ai_addr, result->ai_addrlen);
687     s->dst_size = result->ai_addrlen;
688 
689     if (result) {
690         freeaddrinfo(result);
691     }
692 
693     if (l2tpv3->has_counter && l2tpv3->counter) {
694         s->has_counter = true;
695         s->offset += 4;
696     } else {
697         s->has_counter = false;
698     }
699 
700     if (l2tpv3->has_pincounter && l2tpv3->pincounter) {
701         s->has_counter = true;  /* pin counter implies that there is counter */
702         s->pin_counter = true;
703     } else {
704         s->pin_counter = false;
705     }
706 
707     if (l2tpv3->has_offset) {
708         /* extra offset */
709         s->offset += l2tpv3->offset;
710     }
711 
712     if ((s->ipv6) || (s->udp)) {
713         s->header_size = s->offset;
714     } else {
715         s->header_size = s->offset + sizeof(struct iphdr);
716     }
717 
718     s->msgvec = build_l2tpv3_vector(s, MAX_L2TPV3_MSGCNT);
719     s->vec = g_new(struct iovec, MAX_L2TPV3_IOVCNT);
720     s->header_buf = g_malloc(s->header_size);
721 
722     qemu_set_nonblock(fd);
723 
724     s->fd = fd;
725     s->counter = 0;
726 
727     l2tpv3_read_poll(s, true);
728 
729     snprintf(s->nc.info_str, sizeof(s->nc.info_str),
730              "l2tpv3: connected");
731     return 0;
732 outerr:
733     qemu_del_net_client(nc);
734     if (fd >= 0) {
735         close(fd);
736     }
737     if (result) {
738         freeaddrinfo(result);
739     }
740     return -1;
741 }
742 
743