xref: /openbmc/qemu/nbd/server.c (revision 09230cb8676bd4f18f919afe52dab32063adff5f)
1 /*
2  *  Copyright (C) 2005  Anthony Liguori <anthony@codemonkey.ws>
3  *
4  *  Network Block Device Server Side
5  *
6  *  This program is free software; you can redistribute it and/or modify
7  *  it under the terms of the GNU General Public License as published by
8  *  the Free Software Foundation; under version 2 of the License.
9  *
10  *  This program is distributed in the hope that it will be useful,
11  *  but WITHOUT ANY WARRANTY; without even the implied warranty of
12  *  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
13  *  GNU General Public License for more details.
14  *
15  *  You should have received a copy of the GNU General Public License
16  *  along with this program; if not, see <http://www.gnu.org/licenses/>.
17  */
18 
19 #include "qemu/osdep.h"
20 #include "qapi/error.h"
21 #include "nbd-internal.h"
22 
23 static int system_errno_to_nbd_errno(int err)
24 {
25     switch (err) {
26     case 0:
27         return NBD_SUCCESS;
28     case EPERM:
29     case EROFS:
30         return NBD_EPERM;
31     case EIO:
32         return NBD_EIO;
33     case ENOMEM:
34         return NBD_ENOMEM;
35 #ifdef EDQUOT
36     case EDQUOT:
37 #endif
38     case EFBIG:
39     case ENOSPC:
40         return NBD_ENOSPC;
41     case EINVAL:
42     default:
43         return NBD_EINVAL;
44     }
45 }
46 
47 /* Definitions for opaque data types */
48 
49 typedef struct NBDRequest NBDRequest;
50 
51 struct NBDRequest {
52     QSIMPLEQ_ENTRY(NBDRequest) entry;
53     NBDClient *client;
54     uint8_t *data;
55 };
56 
57 struct NBDExport {
58     int refcount;
59     void (*close)(NBDExport *exp);
60 
61     BlockBackend *blk;
62     char *name;
63     off_t dev_offset;
64     off_t size;
65     uint32_t nbdflags;
66     QTAILQ_HEAD(, NBDClient) clients;
67     QTAILQ_ENTRY(NBDExport) next;
68 
69     AioContext *ctx;
70 
71     Notifier eject_notifier;
72 };
73 
74 static QTAILQ_HEAD(, NBDExport) exports = QTAILQ_HEAD_INITIALIZER(exports);
75 
76 struct NBDClient {
77     int refcount;
78     void (*close)(NBDClient *client);
79 
80     NBDExport *exp;
81     QCryptoTLSCreds *tlscreds;
82     char *tlsaclname;
83     QIOChannelSocket *sioc; /* The underlying data channel */
84     QIOChannel *ioc; /* The current I/O channel which may differ (eg TLS) */
85 
86     Coroutine *recv_coroutine;
87 
88     CoMutex send_lock;
89     Coroutine *send_coroutine;
90 
91     bool can_read;
92 
93     QTAILQ_ENTRY(NBDClient) next;
94     int nb_requests;
95     bool closing;
96 };
97 
98 /* That's all folks */
99 
100 static void nbd_set_handlers(NBDClient *client);
101 static void nbd_unset_handlers(NBDClient *client);
102 static void nbd_update_can_read(NBDClient *client);
103 
104 static gboolean nbd_negotiate_continue(QIOChannel *ioc,
105                                        GIOCondition condition,
106                                        void *opaque)
107 {
108     qemu_coroutine_enter(opaque, NULL);
109     return TRUE;
110 }
111 
112 static ssize_t nbd_negotiate_read(QIOChannel *ioc, void *buffer, size_t size)
113 {
114     ssize_t ret;
115     guint watch;
116 
117     assert(qemu_in_coroutine());
118     /* Negotiation are always in main loop. */
119     watch = qio_channel_add_watch(ioc,
120                                   G_IO_IN,
121                                   nbd_negotiate_continue,
122                                   qemu_coroutine_self(),
123                                   NULL);
124     ret = read_sync(ioc, buffer, size);
125     g_source_remove(watch);
126     return ret;
127 
128 }
129 
130 static ssize_t nbd_negotiate_write(QIOChannel *ioc, void *buffer, size_t size)
131 {
132     ssize_t ret;
133     guint watch;
134 
135     assert(qemu_in_coroutine());
136     /* Negotiation are always in main loop. */
137     watch = qio_channel_add_watch(ioc,
138                                   G_IO_OUT,
139                                   nbd_negotiate_continue,
140                                   qemu_coroutine_self(),
141                                   NULL);
142     ret = write_sync(ioc, buffer, size);
143     g_source_remove(watch);
144     return ret;
145 }
146 
147 static ssize_t nbd_negotiate_drop_sync(QIOChannel *ioc, size_t size)
148 {
149     ssize_t ret, dropped = size;
150     uint8_t *buffer = g_malloc(MIN(65536, size));
151 
152     while (size > 0) {
153         ret = nbd_negotiate_read(ioc, buffer, MIN(65536, size));
154         if (ret < 0) {
155             g_free(buffer);
156             return ret;
157         }
158 
159         assert(ret <= size);
160         size -= ret;
161     }
162 
163     g_free(buffer);
164     return dropped;
165 }
166 
167 /* Basic flow for negotiation
168 
169    Server         Client
170    Negotiate
171 
172    or
173 
174    Server         Client
175    Negotiate #1
176                   Option
177    Negotiate #2
178 
179    ----
180 
181    followed by
182 
183    Server         Client
184                   Request
185    Response
186                   Request
187    Response
188                   ...
189    ...
190                   Request (type == 2)
191 
192 */
193 
194 static int nbd_negotiate_send_rep(QIOChannel *ioc, uint32_t type, uint32_t opt)
195 {
196     uint64_t magic;
197     uint32_t len;
198 
199     TRACE("Reply opt=%x type=%x", type, opt);
200 
201     magic = cpu_to_be64(NBD_REP_MAGIC);
202     if (nbd_negotiate_write(ioc, &magic, sizeof(magic)) != sizeof(magic)) {
203         LOG("write failed (rep magic)");
204         return -EINVAL;
205     }
206     opt = cpu_to_be32(opt);
207     if (nbd_negotiate_write(ioc, &opt, sizeof(opt)) != sizeof(opt)) {
208         LOG("write failed (rep opt)");
209         return -EINVAL;
210     }
211     type = cpu_to_be32(type);
212     if (nbd_negotiate_write(ioc, &type, sizeof(type)) != sizeof(type)) {
213         LOG("write failed (rep type)");
214         return -EINVAL;
215     }
216     len = cpu_to_be32(0);
217     if (nbd_negotiate_write(ioc, &len, sizeof(len)) != sizeof(len)) {
218         LOG("write failed (rep data length)");
219         return -EINVAL;
220     }
221     return 0;
222 }
223 
224 static int nbd_negotiate_send_rep_list(QIOChannel *ioc, NBDExport *exp)
225 {
226     uint64_t magic, name_len;
227     uint32_t opt, type, len;
228 
229     TRACE("Advertizing export name '%s'", exp->name ? exp->name : "");
230     name_len = strlen(exp->name);
231     magic = cpu_to_be64(NBD_REP_MAGIC);
232     if (nbd_negotiate_write(ioc, &magic, sizeof(magic)) != sizeof(magic)) {
233         LOG("write failed (magic)");
234         return -EINVAL;
235      }
236     opt = cpu_to_be32(NBD_OPT_LIST);
237     if (nbd_negotiate_write(ioc, &opt, sizeof(opt)) != sizeof(opt)) {
238         LOG("write failed (opt)");
239         return -EINVAL;
240     }
241     type = cpu_to_be32(NBD_REP_SERVER);
242     if (nbd_negotiate_write(ioc, &type, sizeof(type)) != sizeof(type)) {
243         LOG("write failed (reply type)");
244         return -EINVAL;
245     }
246     len = cpu_to_be32(name_len + sizeof(len));
247     if (nbd_negotiate_write(ioc, &len, sizeof(len)) != sizeof(len)) {
248         LOG("write failed (length)");
249         return -EINVAL;
250     }
251     len = cpu_to_be32(name_len);
252     if (nbd_negotiate_write(ioc, &len, sizeof(len)) != sizeof(len)) {
253         LOG("write failed (length)");
254         return -EINVAL;
255     }
256     if (nbd_negotiate_write(ioc, exp->name, name_len) != name_len) {
257         LOG("write failed (buffer)");
258         return -EINVAL;
259     }
260     return 0;
261 }
262 
263 static int nbd_negotiate_handle_list(NBDClient *client, uint32_t length)
264 {
265     NBDExport *exp;
266 
267     if (length) {
268         if (nbd_negotiate_drop_sync(client->ioc, length) != length) {
269             return -EIO;
270         }
271         return nbd_negotiate_send_rep(client->ioc,
272                                       NBD_REP_ERR_INVALID, NBD_OPT_LIST);
273     }
274 
275     /* For each export, send a NBD_REP_SERVER reply. */
276     QTAILQ_FOREACH(exp, &exports, next) {
277         if (nbd_negotiate_send_rep_list(client->ioc, exp)) {
278             return -EINVAL;
279         }
280     }
281     /* Finish with a NBD_REP_ACK. */
282     return nbd_negotiate_send_rep(client->ioc, NBD_REP_ACK, NBD_OPT_LIST);
283 }
284 
285 static int nbd_negotiate_handle_export_name(NBDClient *client, uint32_t length)
286 {
287     int rc = -EINVAL;
288     char name[256];
289 
290     /* Client sends:
291         [20 ..  xx]   export name (length bytes)
292      */
293     TRACE("Checking length");
294     if (length > 255) {
295         LOG("Bad length received");
296         goto fail;
297     }
298     if (nbd_negotiate_read(client->ioc, name, length) != length) {
299         LOG("read failed");
300         goto fail;
301     }
302     name[length] = '\0';
303 
304     TRACE("Client requested export '%s'", name);
305 
306     client->exp = nbd_export_find(name);
307     if (!client->exp) {
308         LOG("export not found");
309         goto fail;
310     }
311 
312     QTAILQ_INSERT_TAIL(&client->exp->clients, client, next);
313     nbd_export_get(client->exp);
314     rc = 0;
315 fail:
316     return rc;
317 }
318 
319 
320 static QIOChannel *nbd_negotiate_handle_starttls(NBDClient *client,
321                                                  uint32_t length)
322 {
323     QIOChannel *ioc;
324     QIOChannelTLS *tioc;
325     struct NBDTLSHandshakeData data = { 0 };
326 
327     TRACE("Setting up TLS");
328     ioc = client->ioc;
329     if (length) {
330         if (nbd_negotiate_drop_sync(ioc, length) != length) {
331             return NULL;
332         }
333         nbd_negotiate_send_rep(ioc, NBD_REP_ERR_INVALID, NBD_OPT_STARTTLS);
334         return NULL;
335     }
336 
337     nbd_negotiate_send_rep(client->ioc, NBD_REP_ACK, NBD_OPT_STARTTLS);
338 
339     tioc = qio_channel_tls_new_server(ioc,
340                                       client->tlscreds,
341                                       client->tlsaclname,
342                                       NULL);
343     if (!tioc) {
344         return NULL;
345     }
346 
347     TRACE("Starting TLS handshake");
348     data.loop = g_main_loop_new(g_main_context_default(), FALSE);
349     qio_channel_tls_handshake(tioc,
350                               nbd_tls_handshake,
351                               &data,
352                               NULL);
353 
354     if (!data.complete) {
355         g_main_loop_run(data.loop);
356     }
357     g_main_loop_unref(data.loop);
358     if (data.error) {
359         object_unref(OBJECT(tioc));
360         error_free(data.error);
361         return NULL;
362     }
363 
364     return QIO_CHANNEL(tioc);
365 }
366 
367 
368 static int nbd_negotiate_options(NBDClient *client)
369 {
370     uint32_t flags;
371     bool fixedNewstyle = false;
372 
373     /* Client sends:
374         [ 0 ..   3]   client flags
375 
376         [ 0 ..   7]   NBD_OPTS_MAGIC
377         [ 8 ..  11]   NBD option
378         [12 ..  15]   Data length
379         ...           Rest of request
380 
381         [ 0 ..   7]   NBD_OPTS_MAGIC
382         [ 8 ..  11]   Second NBD option
383         [12 ..  15]   Data length
384         ...           Rest of request
385     */
386 
387     if (nbd_negotiate_read(client->ioc, &flags, sizeof(flags)) !=
388         sizeof(flags)) {
389         LOG("read failed");
390         return -EIO;
391     }
392     TRACE("Checking client flags");
393     be32_to_cpus(&flags);
394     if (flags & NBD_FLAG_C_FIXED_NEWSTYLE) {
395         TRACE("Support supports fixed newstyle handshake");
396         fixedNewstyle = true;
397         flags &= ~NBD_FLAG_C_FIXED_NEWSTYLE;
398     }
399     if (flags != 0) {
400         TRACE("Unknown client flags 0x%x received", flags);
401         return -EIO;
402     }
403 
404     while (1) {
405         int ret;
406         uint32_t clientflags, length;
407         uint64_t magic;
408 
409         if (nbd_negotiate_read(client->ioc, &magic, sizeof(magic)) !=
410             sizeof(magic)) {
411             LOG("read failed");
412             return -EINVAL;
413         }
414         TRACE("Checking opts magic");
415         if (magic != be64_to_cpu(NBD_OPTS_MAGIC)) {
416             LOG("Bad magic received");
417             return -EINVAL;
418         }
419 
420         if (nbd_negotiate_read(client->ioc, &clientflags,
421                                sizeof(clientflags)) != sizeof(clientflags)) {
422             LOG("read failed");
423             return -EINVAL;
424         }
425         clientflags = be32_to_cpu(clientflags);
426 
427         if (nbd_negotiate_read(client->ioc, &length, sizeof(length)) !=
428             sizeof(length)) {
429             LOG("read failed");
430             return -EINVAL;
431         }
432         length = be32_to_cpu(length);
433 
434         TRACE("Checking option 0x%x", clientflags);
435         if (client->tlscreds &&
436             client->ioc == (QIOChannel *)client->sioc) {
437             QIOChannel *tioc;
438             if (!fixedNewstyle) {
439                 TRACE("Unsupported option 0x%x", clientflags);
440                 return -EINVAL;
441             }
442             switch (clientflags) {
443             case NBD_OPT_STARTTLS:
444                 tioc = nbd_negotiate_handle_starttls(client, length);
445                 if (!tioc) {
446                     return -EIO;
447                 }
448                 object_unref(OBJECT(client->ioc));
449                 client->ioc = QIO_CHANNEL(tioc);
450                 break;
451 
452             default:
453                 TRACE("Option 0x%x not permitted before TLS", clientflags);
454                 nbd_negotiate_send_rep(client->ioc, NBD_REP_ERR_TLS_REQD,
455                                        clientflags);
456                 return -EINVAL;
457             }
458         } else if (fixedNewstyle) {
459             switch (clientflags) {
460             case NBD_OPT_LIST:
461                 ret = nbd_negotiate_handle_list(client, length);
462                 if (ret < 0) {
463                     return ret;
464                 }
465                 break;
466 
467             case NBD_OPT_ABORT:
468                 return -EINVAL;
469 
470             case NBD_OPT_EXPORT_NAME:
471                 return nbd_negotiate_handle_export_name(client, length);
472 
473             case NBD_OPT_STARTTLS:
474                 if (client->tlscreds) {
475                     TRACE("TLS already enabled");
476                     nbd_negotiate_send_rep(client->ioc, NBD_REP_ERR_INVALID,
477                                            clientflags);
478                 } else {
479                     TRACE("TLS not configured");
480                     nbd_negotiate_send_rep(client->ioc, NBD_REP_ERR_POLICY,
481                                            clientflags);
482                 }
483                 return -EINVAL;
484             default:
485                 TRACE("Unsupported option 0x%x", clientflags);
486                 if (nbd_negotiate_drop_sync(client->ioc, length) != length) {
487                     return -EIO;
488                 }
489                 nbd_negotiate_send_rep(client->ioc, NBD_REP_ERR_UNSUP,
490                                        clientflags);
491                 break;
492             }
493         } else {
494             /*
495              * If broken new-style we should drop the connection
496              * for anything except NBD_OPT_EXPORT_NAME
497              */
498             switch (clientflags) {
499             case NBD_OPT_EXPORT_NAME:
500                 return nbd_negotiate_handle_export_name(client, length);
501 
502             default:
503                 TRACE("Unsupported option 0x%x", clientflags);
504                 return -EINVAL;
505             }
506         }
507     }
508 }
509 
510 typedef struct {
511     NBDClient *client;
512     Coroutine *co;
513 } NBDClientNewData;
514 
515 static coroutine_fn int nbd_negotiate(NBDClientNewData *data)
516 {
517     NBDClient *client = data->client;
518     char buf[8 + 8 + 8 + 128];
519     int rc;
520     const int myflags = (NBD_FLAG_HAS_FLAGS | NBD_FLAG_SEND_TRIM |
521                          NBD_FLAG_SEND_FLUSH | NBD_FLAG_SEND_FUA);
522     bool oldStyle;
523 
524     /* Old style negotiation header without options
525         [ 0 ..   7]   passwd       ("NBDMAGIC")
526         [ 8 ..  15]   magic        (NBD_CLIENT_MAGIC)
527         [16 ..  23]   size
528         [24 ..  25]   server flags (0)
529         [26 ..  27]   export flags
530         [28 .. 151]   reserved     (0)
531 
532        New style negotiation header with options
533         [ 0 ..   7]   passwd       ("NBDMAGIC")
534         [ 8 ..  15]   magic        (NBD_OPTS_MAGIC)
535         [16 ..  17]   server flags (0)
536         ....options sent....
537         [18 ..  25]   size
538         [26 ..  27]   export flags
539         [28 .. 151]   reserved     (0)
540      */
541 
542     qio_channel_set_blocking(client->ioc, false, NULL);
543     rc = -EINVAL;
544 
545     TRACE("Beginning negotiation.");
546     memset(buf, 0, sizeof(buf));
547     memcpy(buf, "NBDMAGIC", 8);
548 
549     oldStyle = client->exp != NULL && !client->tlscreds;
550     if (oldStyle) {
551         assert ((client->exp->nbdflags & ~65535) == 0);
552         stq_be_p(buf + 8, NBD_CLIENT_MAGIC);
553         stq_be_p(buf + 16, client->exp->size);
554         stw_be_p(buf + 26, client->exp->nbdflags | myflags);
555     } else {
556         stq_be_p(buf + 8, NBD_OPTS_MAGIC);
557         stw_be_p(buf + 16, NBD_FLAG_FIXED_NEWSTYLE);
558     }
559 
560     if (oldStyle) {
561         if (client->tlscreds) {
562             TRACE("TLS cannot be enabled with oldstyle protocol");
563             goto fail;
564         }
565         if (nbd_negotiate_write(client->ioc, buf, sizeof(buf)) != sizeof(buf)) {
566             LOG("write failed");
567             goto fail;
568         }
569     } else {
570         if (nbd_negotiate_write(client->ioc, buf, 18) != 18) {
571             LOG("write failed");
572             goto fail;
573         }
574         rc = nbd_negotiate_options(client);
575         if (rc != 0) {
576             LOG("option negotiation failed");
577             goto fail;
578         }
579 
580         assert ((client->exp->nbdflags & ~65535) == 0);
581         stq_be_p(buf + 18, client->exp->size);
582         stw_be_p(buf + 26, client->exp->nbdflags | myflags);
583         if (nbd_negotiate_write(client->ioc, buf + 18, sizeof(buf) - 18) !=
584             sizeof(buf) - 18) {
585             LOG("write failed");
586             goto fail;
587         }
588     }
589 
590     TRACE("Negotiation succeeded.");
591     rc = 0;
592 fail:
593     return rc;
594 }
595 
596 #ifdef __linux__
597 
598 int nbd_disconnect(int fd)
599 {
600     ioctl(fd, NBD_CLEAR_QUE);
601     ioctl(fd, NBD_DISCONNECT);
602     ioctl(fd, NBD_CLEAR_SOCK);
603     return 0;
604 }
605 
606 #else
607 
608 int nbd_disconnect(int fd)
609 {
610     return -ENOTSUP;
611 }
612 #endif
613 
614 static ssize_t nbd_receive_request(QIOChannel *ioc, struct nbd_request *request)
615 {
616     uint8_t buf[NBD_REQUEST_SIZE];
617     uint32_t magic;
618     ssize_t ret;
619 
620     ret = read_sync(ioc, buf, sizeof(buf));
621     if (ret < 0) {
622         return ret;
623     }
624 
625     if (ret != sizeof(buf)) {
626         LOG("read failed");
627         return -EINVAL;
628     }
629 
630     /* Request
631        [ 0 ..  3]   magic   (NBD_REQUEST_MAGIC)
632        [ 4 ..  7]   type    (0 == READ, 1 == WRITE)
633        [ 8 .. 15]   handle
634        [16 .. 23]   from
635        [24 .. 27]   len
636      */
637 
638     magic = be32_to_cpup((uint32_t*)buf);
639     request->type  = be32_to_cpup((uint32_t*)(buf + 4));
640     request->handle = be64_to_cpup((uint64_t*)(buf + 8));
641     request->from  = be64_to_cpup((uint64_t*)(buf + 16));
642     request->len   = be32_to_cpup((uint32_t*)(buf + 24));
643 
644     TRACE("Got request: "
645           "{ magic = 0x%x, .type = %d, from = %" PRIu64" , len = %u }",
646           magic, request->type, request->from, request->len);
647 
648     if (magic != NBD_REQUEST_MAGIC) {
649         LOG("invalid magic (got 0x%x)", magic);
650         return -EINVAL;
651     }
652     return 0;
653 }
654 
655 static ssize_t nbd_send_reply(QIOChannel *ioc, struct nbd_reply *reply)
656 {
657     uint8_t buf[NBD_REPLY_SIZE];
658     ssize_t ret;
659 
660     reply->error = system_errno_to_nbd_errno(reply->error);
661 
662     TRACE("Sending response to client: { .error = %d, handle = %" PRIu64 " }",
663           reply->error, reply->handle);
664 
665     /* Reply
666        [ 0 ..  3]    magic   (NBD_REPLY_MAGIC)
667        [ 4 ..  7]    error   (0 == no error)
668        [ 7 .. 15]    handle
669      */
670     stl_be_p(buf, NBD_REPLY_MAGIC);
671     stl_be_p(buf + 4, reply->error);
672     stq_be_p(buf + 8, reply->handle);
673 
674     ret = write_sync(ioc, buf, sizeof(buf));
675     if (ret < 0) {
676         return ret;
677     }
678 
679     if (ret != sizeof(buf)) {
680         LOG("writing to socket failed");
681         return -EINVAL;
682     }
683     return 0;
684 }
685 
686 #define MAX_NBD_REQUESTS 16
687 
688 void nbd_client_get(NBDClient *client)
689 {
690     client->refcount++;
691 }
692 
693 void nbd_client_put(NBDClient *client)
694 {
695     if (--client->refcount == 0) {
696         /* The last reference should be dropped by client->close,
697          * which is called by client_close.
698          */
699         assert(client->closing);
700 
701         nbd_unset_handlers(client);
702         object_unref(OBJECT(client->sioc));
703         object_unref(OBJECT(client->ioc));
704         if (client->tlscreds) {
705             object_unref(OBJECT(client->tlscreds));
706         }
707         g_free(client->tlsaclname);
708         if (client->exp) {
709             QTAILQ_REMOVE(&client->exp->clients, client, next);
710             nbd_export_put(client->exp);
711         }
712         g_free(client);
713     }
714 }
715 
716 static void client_close(NBDClient *client)
717 {
718     if (client->closing) {
719         return;
720     }
721 
722     client->closing = true;
723 
724     /* Force requests to finish.  They will drop their own references,
725      * then we'll close the socket and free the NBDClient.
726      */
727     qio_channel_shutdown(client->ioc, QIO_CHANNEL_SHUTDOWN_BOTH,
728                          NULL);
729 
730     /* Also tell the client, so that they release their reference.  */
731     if (client->close) {
732         client->close(client);
733     }
734 }
735 
736 static NBDRequest *nbd_request_get(NBDClient *client)
737 {
738     NBDRequest *req;
739 
740     assert(client->nb_requests <= MAX_NBD_REQUESTS - 1);
741     client->nb_requests++;
742     nbd_update_can_read(client);
743 
744     req = g_new0(NBDRequest, 1);
745     nbd_client_get(client);
746     req->client = client;
747     return req;
748 }
749 
750 static void nbd_request_put(NBDRequest *req)
751 {
752     NBDClient *client = req->client;
753 
754     if (req->data) {
755         qemu_vfree(req->data);
756     }
757     g_free(req);
758 
759     client->nb_requests--;
760     nbd_update_can_read(client);
761     nbd_client_put(client);
762 }
763 
764 static void blk_aio_attached(AioContext *ctx, void *opaque)
765 {
766     NBDExport *exp = opaque;
767     NBDClient *client;
768 
769     TRACE("Export %s: Attaching clients to AIO context %p\n", exp->name, ctx);
770 
771     exp->ctx = ctx;
772 
773     QTAILQ_FOREACH(client, &exp->clients, next) {
774         nbd_set_handlers(client);
775     }
776 }
777 
778 static void blk_aio_detach(void *opaque)
779 {
780     NBDExport *exp = opaque;
781     NBDClient *client;
782 
783     TRACE("Export %s: Detaching clients from AIO context %p\n", exp->name, exp->ctx);
784 
785     QTAILQ_FOREACH(client, &exp->clients, next) {
786         nbd_unset_handlers(client);
787     }
788 
789     exp->ctx = NULL;
790 }
791 
792 static void nbd_eject_notifier(Notifier *n, void *data)
793 {
794     NBDExport *exp = container_of(n, NBDExport, eject_notifier);
795     nbd_export_close(exp);
796 }
797 
798 NBDExport *nbd_export_new(BlockBackend *blk, off_t dev_offset, off_t size,
799                           uint32_t nbdflags, void (*close)(NBDExport *),
800                           Error **errp)
801 {
802     NBDExport *exp = g_malloc0(sizeof(NBDExport));
803     exp->refcount = 1;
804     QTAILQ_INIT(&exp->clients);
805     exp->blk = blk;
806     exp->dev_offset = dev_offset;
807     exp->nbdflags = nbdflags;
808     exp->size = size < 0 ? blk_getlength(blk) : size;
809     if (exp->size < 0) {
810         error_setg_errno(errp, -exp->size,
811                          "Failed to determine the NBD export's length");
812         goto fail;
813     }
814     exp->size -= exp->size % BDRV_SECTOR_SIZE;
815 
816     exp->close = close;
817     exp->ctx = blk_get_aio_context(blk);
818     blk_ref(blk);
819     blk_add_aio_context_notifier(blk, blk_aio_attached, blk_aio_detach, exp);
820 
821     exp->eject_notifier.notify = nbd_eject_notifier;
822     blk_add_remove_bs_notifier(blk, &exp->eject_notifier);
823 
824     /*
825      * NBD exports are used for non-shared storage migration.  Make sure
826      * that BDRV_O_INACTIVE is cleared and the image is ready for write
827      * access since the export could be available before migration handover.
828      */
829     aio_context_acquire(exp->ctx);
830     blk_invalidate_cache(blk, NULL);
831     aio_context_release(exp->ctx);
832     return exp;
833 
834 fail:
835     g_free(exp);
836     return NULL;
837 }
838 
839 NBDExport *nbd_export_find(const char *name)
840 {
841     NBDExport *exp;
842     QTAILQ_FOREACH(exp, &exports, next) {
843         if (strcmp(name, exp->name) == 0) {
844             return exp;
845         }
846     }
847 
848     return NULL;
849 }
850 
851 void nbd_export_set_name(NBDExport *exp, const char *name)
852 {
853     if (exp->name == name) {
854         return;
855     }
856 
857     nbd_export_get(exp);
858     if (exp->name != NULL) {
859         g_free(exp->name);
860         exp->name = NULL;
861         QTAILQ_REMOVE(&exports, exp, next);
862         nbd_export_put(exp);
863     }
864     if (name != NULL) {
865         nbd_export_get(exp);
866         exp->name = g_strdup(name);
867         QTAILQ_INSERT_TAIL(&exports, exp, next);
868     }
869     nbd_export_put(exp);
870 }
871 
872 void nbd_export_close(NBDExport *exp)
873 {
874     NBDClient *client, *next;
875 
876     nbd_export_get(exp);
877     QTAILQ_FOREACH_SAFE(client, &exp->clients, next, next) {
878         client_close(client);
879     }
880     nbd_export_set_name(exp, NULL);
881     nbd_export_put(exp);
882 }
883 
884 void nbd_export_get(NBDExport *exp)
885 {
886     assert(exp->refcount > 0);
887     exp->refcount++;
888 }
889 
890 void nbd_export_put(NBDExport *exp)
891 {
892     assert(exp->refcount > 0);
893     if (exp->refcount == 1) {
894         nbd_export_close(exp);
895     }
896 
897     if (--exp->refcount == 0) {
898         assert(exp->name == NULL);
899 
900         if (exp->close) {
901             exp->close(exp);
902         }
903 
904         if (exp->blk) {
905             notifier_remove(&exp->eject_notifier);
906             blk_remove_aio_context_notifier(exp->blk, blk_aio_attached,
907                                             blk_aio_detach, exp);
908             blk_unref(exp->blk);
909             exp->blk = NULL;
910         }
911 
912         g_free(exp);
913     }
914 }
915 
916 BlockBackend *nbd_export_get_blockdev(NBDExport *exp)
917 {
918     return exp->blk;
919 }
920 
921 void nbd_export_close_all(void)
922 {
923     NBDExport *exp, *next;
924 
925     QTAILQ_FOREACH_SAFE(exp, &exports, next, next) {
926         nbd_export_close(exp);
927     }
928 }
929 
930 static ssize_t nbd_co_send_reply(NBDRequest *req, struct nbd_reply *reply,
931                                  int len)
932 {
933     NBDClient *client = req->client;
934     ssize_t rc, ret;
935 
936     g_assert(qemu_in_coroutine());
937     qemu_co_mutex_lock(&client->send_lock);
938     client->send_coroutine = qemu_coroutine_self();
939     nbd_set_handlers(client);
940 
941     if (!len) {
942         rc = nbd_send_reply(client->ioc, reply);
943     } else {
944         qio_channel_set_cork(client->ioc, true);
945         rc = nbd_send_reply(client->ioc, reply);
946         if (rc >= 0) {
947             ret = write_sync(client->ioc, req->data, len);
948             if (ret != len) {
949                 rc = -EIO;
950             }
951         }
952         qio_channel_set_cork(client->ioc, false);
953     }
954 
955     client->send_coroutine = NULL;
956     nbd_set_handlers(client);
957     qemu_co_mutex_unlock(&client->send_lock);
958     return rc;
959 }
960 
961 static ssize_t nbd_co_receive_request(NBDRequest *req, struct nbd_request *request)
962 {
963     NBDClient *client = req->client;
964     uint32_t command;
965     ssize_t rc;
966 
967     g_assert(qemu_in_coroutine());
968     client->recv_coroutine = qemu_coroutine_self();
969     nbd_update_can_read(client);
970 
971     rc = nbd_receive_request(client->ioc, request);
972     if (rc < 0) {
973         if (rc != -EAGAIN) {
974             rc = -EIO;
975         }
976         goto out;
977     }
978 
979     if ((request->from + request->len) < request->from) {
980         LOG("integer overflow detected! "
981             "you're probably being attacked");
982         rc = -EINVAL;
983         goto out;
984     }
985 
986     TRACE("Decoding type");
987 
988     command = request->type & NBD_CMD_MASK_COMMAND;
989     if (command == NBD_CMD_READ || command == NBD_CMD_WRITE) {
990         if (request->len > NBD_MAX_BUFFER_SIZE) {
991             LOG("len (%u) is larger than max len (%u)",
992                 request->len, NBD_MAX_BUFFER_SIZE);
993             rc = -EINVAL;
994             goto out;
995         }
996 
997         req->data = blk_try_blockalign(client->exp->blk, request->len);
998         if (req->data == NULL) {
999             rc = -ENOMEM;
1000             goto out;
1001         }
1002     }
1003     if (command == NBD_CMD_WRITE) {
1004         TRACE("Reading %u byte(s)", request->len);
1005 
1006         if (read_sync(client->ioc, req->data, request->len) != request->len) {
1007             LOG("reading from socket failed");
1008             rc = -EIO;
1009             goto out;
1010         }
1011     }
1012     rc = 0;
1013 
1014 out:
1015     client->recv_coroutine = NULL;
1016     nbd_update_can_read(client);
1017 
1018     return rc;
1019 }
1020 
1021 static void nbd_trip(void *opaque)
1022 {
1023     NBDClient *client = opaque;
1024     NBDExport *exp = client->exp;
1025     NBDRequest *req;
1026     struct nbd_request request;
1027     struct nbd_reply reply;
1028     ssize_t ret;
1029     uint32_t command;
1030 
1031     TRACE("Reading request.");
1032     if (client->closing) {
1033         return;
1034     }
1035 
1036     req = nbd_request_get(client);
1037     ret = nbd_co_receive_request(req, &request);
1038     if (ret == -EAGAIN) {
1039         goto done;
1040     }
1041     if (ret == -EIO) {
1042         goto out;
1043     }
1044 
1045     reply.handle = request.handle;
1046     reply.error = 0;
1047 
1048     if (ret < 0) {
1049         reply.error = -ret;
1050         goto error_reply;
1051     }
1052     command = request.type & NBD_CMD_MASK_COMMAND;
1053     if (command != NBD_CMD_DISC && (request.from + request.len) > exp->size) {
1054             LOG("From: %" PRIu64 ", Len: %u, Size: %" PRIu64
1055             ", Offset: %" PRIu64 "\n",
1056                     request.from, request.len,
1057                     (uint64_t)exp->size, (uint64_t)exp->dev_offset);
1058         LOG("requested operation past EOF--bad client?");
1059         goto invalid_request;
1060     }
1061 
1062     if (client->closing) {
1063         /*
1064          * The client may be closed when we are blocked in
1065          * nbd_co_receive_request()
1066          */
1067         goto done;
1068     }
1069 
1070     switch (command) {
1071     case NBD_CMD_READ:
1072         TRACE("Request type is READ");
1073 
1074         if (request.type & NBD_CMD_FLAG_FUA) {
1075             ret = blk_co_flush(exp->blk);
1076             if (ret < 0) {
1077                 LOG("flush failed");
1078                 reply.error = -ret;
1079                 goto error_reply;
1080             }
1081         }
1082 
1083         ret = blk_read(exp->blk,
1084                        (request.from + exp->dev_offset) / BDRV_SECTOR_SIZE,
1085                        req->data, request.len / BDRV_SECTOR_SIZE);
1086         if (ret < 0) {
1087             LOG("reading from file failed");
1088             reply.error = -ret;
1089             goto error_reply;
1090         }
1091 
1092         TRACE("Read %u byte(s)", request.len);
1093         if (nbd_co_send_reply(req, &reply, request.len) < 0)
1094             goto out;
1095         break;
1096     case NBD_CMD_WRITE:
1097         TRACE("Request type is WRITE");
1098 
1099         if (exp->nbdflags & NBD_FLAG_READ_ONLY) {
1100             TRACE("Server is read-only, return error");
1101             reply.error = EROFS;
1102             goto error_reply;
1103         }
1104 
1105         TRACE("Writing to device");
1106 
1107         ret = blk_write(exp->blk,
1108                         (request.from + exp->dev_offset) / BDRV_SECTOR_SIZE,
1109                         req->data, request.len / BDRV_SECTOR_SIZE);
1110         if (ret < 0) {
1111             LOG("writing to file failed");
1112             reply.error = -ret;
1113             goto error_reply;
1114         }
1115 
1116         if (request.type & NBD_CMD_FLAG_FUA) {
1117             ret = blk_co_flush(exp->blk);
1118             if (ret < 0) {
1119                 LOG("flush failed");
1120                 reply.error = -ret;
1121                 goto error_reply;
1122             }
1123         }
1124 
1125         if (nbd_co_send_reply(req, &reply, 0) < 0) {
1126             goto out;
1127         }
1128         break;
1129     case NBD_CMD_DISC:
1130         TRACE("Request type is DISCONNECT");
1131         errno = 0;
1132         goto out;
1133     case NBD_CMD_FLUSH:
1134         TRACE("Request type is FLUSH");
1135 
1136         ret = blk_co_flush(exp->blk);
1137         if (ret < 0) {
1138             LOG("flush failed");
1139             reply.error = -ret;
1140         }
1141         if (nbd_co_send_reply(req, &reply, 0) < 0) {
1142             goto out;
1143         }
1144         break;
1145     case NBD_CMD_TRIM:
1146         TRACE("Request type is TRIM");
1147         ret = blk_co_discard(exp->blk, (request.from + exp->dev_offset)
1148                                        / BDRV_SECTOR_SIZE,
1149                              request.len / BDRV_SECTOR_SIZE);
1150         if (ret < 0) {
1151             LOG("discard failed");
1152             reply.error = -ret;
1153         }
1154         if (nbd_co_send_reply(req, &reply, 0) < 0) {
1155             goto out;
1156         }
1157         break;
1158     default:
1159         LOG("invalid request type (%u) received", request.type);
1160     invalid_request:
1161         reply.error = EINVAL;
1162     error_reply:
1163         if (nbd_co_send_reply(req, &reply, 0) < 0) {
1164             goto out;
1165         }
1166         break;
1167     }
1168 
1169     TRACE("Request/Reply complete");
1170 
1171 done:
1172     nbd_request_put(req);
1173     return;
1174 
1175 out:
1176     nbd_request_put(req);
1177     client_close(client);
1178 }
1179 
1180 static void nbd_read(void *opaque)
1181 {
1182     NBDClient *client = opaque;
1183 
1184     if (client->recv_coroutine) {
1185         qemu_coroutine_enter(client->recv_coroutine, NULL);
1186     } else {
1187         qemu_coroutine_enter(qemu_coroutine_create(nbd_trip), client);
1188     }
1189 }
1190 
1191 static void nbd_restart_write(void *opaque)
1192 {
1193     NBDClient *client = opaque;
1194 
1195     qemu_coroutine_enter(client->send_coroutine, NULL);
1196 }
1197 
1198 static void nbd_set_handlers(NBDClient *client)
1199 {
1200     if (client->exp && client->exp->ctx) {
1201         aio_set_fd_handler(client->exp->ctx, client->sioc->fd,
1202                            true,
1203                            client->can_read ? nbd_read : NULL,
1204                            client->send_coroutine ? nbd_restart_write : NULL,
1205                            client);
1206     }
1207 }
1208 
1209 static void nbd_unset_handlers(NBDClient *client)
1210 {
1211     if (client->exp && client->exp->ctx) {
1212         aio_set_fd_handler(client->exp->ctx, client->sioc->fd,
1213                            true, NULL, NULL, NULL);
1214     }
1215 }
1216 
1217 static void nbd_update_can_read(NBDClient *client)
1218 {
1219     bool can_read = client->recv_coroutine ||
1220                     client->nb_requests < MAX_NBD_REQUESTS;
1221 
1222     if (can_read != client->can_read) {
1223         client->can_read = can_read;
1224         nbd_set_handlers(client);
1225 
1226         /* There is no need to invoke aio_notify(), since aio_set_fd_handler()
1227          * in nbd_set_handlers() will have taken care of that */
1228     }
1229 }
1230 
1231 static coroutine_fn void nbd_co_client_start(void *opaque)
1232 {
1233     NBDClientNewData *data = opaque;
1234     NBDClient *client = data->client;
1235     NBDExport *exp = client->exp;
1236 
1237     if (exp) {
1238         nbd_export_get(exp);
1239     }
1240     if (nbd_negotiate(data)) {
1241         client_close(client);
1242         goto out;
1243     }
1244     qemu_co_mutex_init(&client->send_lock);
1245     nbd_set_handlers(client);
1246 
1247     if (exp) {
1248         QTAILQ_INSERT_TAIL(&exp->clients, client, next);
1249     }
1250 out:
1251     g_free(data);
1252 }
1253 
1254 void nbd_client_new(NBDExport *exp,
1255                     QIOChannelSocket *sioc,
1256                     QCryptoTLSCreds *tlscreds,
1257                     const char *tlsaclname,
1258                     void (*close_fn)(NBDClient *))
1259 {
1260     NBDClient *client;
1261     NBDClientNewData *data = g_new(NBDClientNewData, 1);
1262 
1263     client = g_malloc0(sizeof(NBDClient));
1264     client->refcount = 1;
1265     client->exp = exp;
1266     client->tlscreds = tlscreds;
1267     if (tlscreds) {
1268         object_ref(OBJECT(client->tlscreds));
1269     }
1270     client->tlsaclname = g_strdup(tlsaclname);
1271     client->sioc = sioc;
1272     object_ref(OBJECT(client->sioc));
1273     client->ioc = QIO_CHANNEL(sioc);
1274     object_ref(OBJECT(client->ioc));
1275     client->can_read = true;
1276     client->close = close_fn;
1277 
1278     data->client = client;
1279     data->co = qemu_coroutine_create(nbd_co_client_start);
1280     qemu_coroutine_enter(data->co, data);
1281 }
1282