1 /* 2 * (c) 2017 Stefano Stabellini <stefano@aporeto.com> 3 * 4 * This program is free software; you can redistribute it and/or modify 5 * it under the terms of the GNU General Public License as published by 6 * the Free Software Foundation; either version 2 of the License, or 7 * (at your option) any later version. 8 * 9 * This program is distributed in the hope that it will be useful, 10 * but WITHOUT ANY WARRANTY; without even the implied warranty of 11 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 12 * GNU General Public License for more details. 13 */ 14 15 #include <linux/module.h> 16 #include <linux/net.h> 17 #include <linux/socket.h> 18 19 #include <net/sock.h> 20 21 #include <xen/events.h> 22 #include <xen/grant_table.h> 23 #include <xen/xen.h> 24 #include <xen/xenbus.h> 25 #include <xen/interface/io/pvcalls.h> 26 27 #include "pvcalls-front.h" 28 29 #define PVCALLS_INVALID_ID UINT_MAX 30 #define PVCALLS_RING_ORDER XENBUS_MAX_RING_GRANT_ORDER 31 #define PVCALLS_NR_RSP_PER_RING __CONST_RING_SIZE(xen_pvcalls, XEN_PAGE_SIZE) 32 #define PVCALLS_FRONT_MAX_SPIN 5000 33 34 struct pvcalls_bedata { 35 struct xen_pvcalls_front_ring ring; 36 grant_ref_t ref; 37 int irq; 38 39 struct list_head socket_mappings; 40 spinlock_t socket_lock; 41 42 wait_queue_head_t inflight_req; 43 struct xen_pvcalls_response rsp[PVCALLS_NR_RSP_PER_RING]; 44 }; 45 /* Only one front/back connection supported. */ 46 static struct xenbus_device *pvcalls_front_dev; 47 static atomic_t pvcalls_refcount; 48 49 /* first increment refcount, then proceed */ 50 #define pvcalls_enter() { \ 51 atomic_inc(&pvcalls_refcount); \ 52 } 53 54 /* first complete other operations, then decrement refcount */ 55 #define pvcalls_exit() { \ 56 atomic_dec(&pvcalls_refcount); \ 57 } 58 59 struct sock_mapping { 60 bool active_socket; 61 struct list_head list; 62 struct socket *sock; 63 union { 64 struct { 65 int irq; 66 grant_ref_t ref; 67 struct pvcalls_data_intf *ring; 68 struct pvcalls_data data; 69 struct mutex in_mutex; 70 struct mutex out_mutex; 71 72 wait_queue_head_t inflight_conn_req; 73 } active; 74 struct { 75 /* Socket status */ 76 #define PVCALLS_STATUS_UNINITALIZED 0 77 #define PVCALLS_STATUS_BIND 1 78 #define PVCALLS_STATUS_LISTEN 2 79 uint8_t status; 80 /* 81 * Internal state-machine flags. 82 * Only one accept operation can be inflight for a socket. 83 * Only one poll operation can be inflight for a given socket. 84 */ 85 #define PVCALLS_FLAG_ACCEPT_INFLIGHT 0 86 #define PVCALLS_FLAG_POLL_INFLIGHT 1 87 #define PVCALLS_FLAG_POLL_RET 2 88 uint8_t flags; 89 uint32_t inflight_req_id; 90 struct sock_mapping *accept_map; 91 wait_queue_head_t inflight_accept_req; 92 } passive; 93 }; 94 }; 95 96 static inline int get_request(struct pvcalls_bedata *bedata, int *req_id) 97 { 98 *req_id = bedata->ring.req_prod_pvt & (RING_SIZE(&bedata->ring) - 1); 99 if (RING_FULL(&bedata->ring) || 100 bedata->rsp[*req_id].req_id != PVCALLS_INVALID_ID) 101 return -EAGAIN; 102 return 0; 103 } 104 105 static bool pvcalls_front_write_todo(struct sock_mapping *map) 106 { 107 struct pvcalls_data_intf *intf = map->active.ring; 108 RING_IDX cons, prod, size = XEN_FLEX_RING_SIZE(PVCALLS_RING_ORDER); 109 int32_t error; 110 111 error = intf->out_error; 112 if (error == -ENOTCONN) 113 return false; 114 if (error != 0) 115 return true; 116 117 cons = intf->out_cons; 118 prod = intf->out_prod; 119 return !!(size - pvcalls_queued(prod, cons, size)); 120 } 121 122 static bool pvcalls_front_read_todo(struct sock_mapping *map) 123 { 124 struct pvcalls_data_intf *intf = map->active.ring; 125 RING_IDX cons, prod; 126 int32_t error; 127 128 cons = intf->in_cons; 129 prod = intf->in_prod; 130 error = intf->in_error; 131 return (error != 0 || 132 pvcalls_queued(prod, cons, 133 XEN_FLEX_RING_SIZE(PVCALLS_RING_ORDER)) != 0); 134 } 135 136 static irqreturn_t pvcalls_front_event_handler(int irq, void *dev_id) 137 { 138 struct xenbus_device *dev = dev_id; 139 struct pvcalls_bedata *bedata; 140 struct xen_pvcalls_response *rsp; 141 uint8_t *src, *dst; 142 int req_id = 0, more = 0, done = 0; 143 144 if (dev == NULL) 145 return IRQ_HANDLED; 146 147 pvcalls_enter(); 148 bedata = dev_get_drvdata(&dev->dev); 149 if (bedata == NULL) { 150 pvcalls_exit(); 151 return IRQ_HANDLED; 152 } 153 154 again: 155 while (RING_HAS_UNCONSUMED_RESPONSES(&bedata->ring)) { 156 rsp = RING_GET_RESPONSE(&bedata->ring, bedata->ring.rsp_cons); 157 158 req_id = rsp->req_id; 159 if (rsp->cmd == PVCALLS_POLL) { 160 struct sock_mapping *map = (struct sock_mapping *)(uintptr_t) 161 rsp->u.poll.id; 162 163 clear_bit(PVCALLS_FLAG_POLL_INFLIGHT, 164 (void *)&map->passive.flags); 165 /* 166 * clear INFLIGHT, then set RET. It pairs with 167 * the checks at the beginning of 168 * pvcalls_front_poll_passive. 169 */ 170 smp_wmb(); 171 set_bit(PVCALLS_FLAG_POLL_RET, 172 (void *)&map->passive.flags); 173 } else { 174 dst = (uint8_t *)&bedata->rsp[req_id] + 175 sizeof(rsp->req_id); 176 src = (uint8_t *)rsp + sizeof(rsp->req_id); 177 memcpy(dst, src, sizeof(*rsp) - sizeof(rsp->req_id)); 178 /* 179 * First copy the rest of the data, then req_id. It is 180 * paired with the barrier when accessing bedata->rsp. 181 */ 182 smp_wmb(); 183 bedata->rsp[req_id].req_id = req_id; 184 } 185 186 done = 1; 187 bedata->ring.rsp_cons++; 188 } 189 190 RING_FINAL_CHECK_FOR_RESPONSES(&bedata->ring, more); 191 if (more) 192 goto again; 193 if (done) 194 wake_up(&bedata->inflight_req); 195 pvcalls_exit(); 196 return IRQ_HANDLED; 197 } 198 199 static void pvcalls_front_free_map(struct pvcalls_bedata *bedata, 200 struct sock_mapping *map) 201 { 202 int i; 203 204 unbind_from_irqhandler(map->active.irq, map); 205 206 spin_lock(&bedata->socket_lock); 207 if (!list_empty(&map->list)) 208 list_del_init(&map->list); 209 spin_unlock(&bedata->socket_lock); 210 211 for (i = 0; i < (1 << PVCALLS_RING_ORDER); i++) 212 gnttab_end_foreign_access(map->active.ring->ref[i], 0, 0); 213 gnttab_end_foreign_access(map->active.ref, 0, 0); 214 free_page((unsigned long)map->active.ring); 215 216 kfree(map); 217 } 218 219 static irqreturn_t pvcalls_front_conn_handler(int irq, void *sock_map) 220 { 221 struct sock_mapping *map = sock_map; 222 223 if (map == NULL) 224 return IRQ_HANDLED; 225 226 wake_up_interruptible(&map->active.inflight_conn_req); 227 228 return IRQ_HANDLED; 229 } 230 231 int pvcalls_front_socket(struct socket *sock) 232 { 233 struct pvcalls_bedata *bedata; 234 struct sock_mapping *map = NULL; 235 struct xen_pvcalls_request *req; 236 int notify, req_id, ret; 237 238 /* 239 * PVCalls only supports domain AF_INET, 240 * type SOCK_STREAM and protocol 0 sockets for now. 241 * 242 * Check socket type here, AF_INET and protocol checks are done 243 * by the caller. 244 */ 245 if (sock->type != SOCK_STREAM) 246 return -EOPNOTSUPP; 247 248 pvcalls_enter(); 249 if (!pvcalls_front_dev) { 250 pvcalls_exit(); 251 return -EACCES; 252 } 253 bedata = dev_get_drvdata(&pvcalls_front_dev->dev); 254 255 map = kzalloc(sizeof(*map), GFP_KERNEL); 256 if (map == NULL) { 257 pvcalls_exit(); 258 return -ENOMEM; 259 } 260 261 spin_lock(&bedata->socket_lock); 262 263 ret = get_request(bedata, &req_id); 264 if (ret < 0) { 265 kfree(map); 266 spin_unlock(&bedata->socket_lock); 267 pvcalls_exit(); 268 return ret; 269 } 270 271 /* 272 * sock->sk->sk_send_head is not used for ip sockets: reuse the 273 * field to store a pointer to the struct sock_mapping 274 * corresponding to the socket. This way, we can easily get the 275 * struct sock_mapping from the struct socket. 276 */ 277 sock->sk->sk_send_head = (void *)map; 278 list_add_tail(&map->list, &bedata->socket_mappings); 279 280 req = RING_GET_REQUEST(&bedata->ring, req_id); 281 req->req_id = req_id; 282 req->cmd = PVCALLS_SOCKET; 283 req->u.socket.id = (uintptr_t) map; 284 req->u.socket.domain = AF_INET; 285 req->u.socket.type = SOCK_STREAM; 286 req->u.socket.protocol = IPPROTO_IP; 287 288 bedata->ring.req_prod_pvt++; 289 RING_PUSH_REQUESTS_AND_CHECK_NOTIFY(&bedata->ring, notify); 290 spin_unlock(&bedata->socket_lock); 291 if (notify) 292 notify_remote_via_irq(bedata->irq); 293 294 wait_event(bedata->inflight_req, 295 READ_ONCE(bedata->rsp[req_id].req_id) == req_id); 296 297 /* read req_id, then the content */ 298 smp_rmb(); 299 ret = bedata->rsp[req_id].ret; 300 bedata->rsp[req_id].req_id = PVCALLS_INVALID_ID; 301 302 pvcalls_exit(); 303 return ret; 304 } 305 306 static int create_active(struct sock_mapping *map, int *evtchn) 307 { 308 void *bytes; 309 int ret = -ENOMEM, irq = -1, i; 310 311 *evtchn = -1; 312 init_waitqueue_head(&map->active.inflight_conn_req); 313 314 map->active.ring = (struct pvcalls_data_intf *) 315 __get_free_page(GFP_KERNEL | __GFP_ZERO); 316 if (map->active.ring == NULL) 317 goto out_error; 318 map->active.ring->ring_order = PVCALLS_RING_ORDER; 319 bytes = (void *)__get_free_pages(GFP_KERNEL | __GFP_ZERO, 320 PVCALLS_RING_ORDER); 321 if (bytes == NULL) 322 goto out_error; 323 for (i = 0; i < (1 << PVCALLS_RING_ORDER); i++) 324 map->active.ring->ref[i] = gnttab_grant_foreign_access( 325 pvcalls_front_dev->otherend_id, 326 pfn_to_gfn(virt_to_pfn(bytes) + i), 0); 327 328 map->active.ref = gnttab_grant_foreign_access( 329 pvcalls_front_dev->otherend_id, 330 pfn_to_gfn(virt_to_pfn((void *)map->active.ring)), 0); 331 332 map->active.data.in = bytes; 333 map->active.data.out = bytes + 334 XEN_FLEX_RING_SIZE(PVCALLS_RING_ORDER); 335 336 ret = xenbus_alloc_evtchn(pvcalls_front_dev, evtchn); 337 if (ret) 338 goto out_error; 339 irq = bind_evtchn_to_irqhandler(*evtchn, pvcalls_front_conn_handler, 340 0, "pvcalls-frontend", map); 341 if (irq < 0) { 342 ret = irq; 343 goto out_error; 344 } 345 346 map->active.irq = irq; 347 map->active_socket = true; 348 mutex_init(&map->active.in_mutex); 349 mutex_init(&map->active.out_mutex); 350 351 return 0; 352 353 out_error: 354 if (*evtchn >= 0) 355 xenbus_free_evtchn(pvcalls_front_dev, *evtchn); 356 kfree(map->active.data.in); 357 kfree(map->active.ring); 358 return ret; 359 } 360 361 int pvcalls_front_connect(struct socket *sock, struct sockaddr *addr, 362 int addr_len, int flags) 363 { 364 struct pvcalls_bedata *bedata; 365 struct sock_mapping *map = NULL; 366 struct xen_pvcalls_request *req; 367 int notify, req_id, ret, evtchn; 368 369 if (addr->sa_family != AF_INET || sock->type != SOCK_STREAM) 370 return -EOPNOTSUPP; 371 372 pvcalls_enter(); 373 if (!pvcalls_front_dev) { 374 pvcalls_exit(); 375 return -ENOTCONN; 376 } 377 378 bedata = dev_get_drvdata(&pvcalls_front_dev->dev); 379 380 map = (struct sock_mapping *)sock->sk->sk_send_head; 381 if (!map) { 382 pvcalls_exit(); 383 return -ENOTSOCK; 384 } 385 386 spin_lock(&bedata->socket_lock); 387 ret = get_request(bedata, &req_id); 388 if (ret < 0) { 389 spin_unlock(&bedata->socket_lock); 390 pvcalls_exit(); 391 return ret; 392 } 393 ret = create_active(map, &evtchn); 394 if (ret < 0) { 395 spin_unlock(&bedata->socket_lock); 396 pvcalls_exit(); 397 return ret; 398 } 399 400 req = RING_GET_REQUEST(&bedata->ring, req_id); 401 req->req_id = req_id; 402 req->cmd = PVCALLS_CONNECT; 403 req->u.connect.id = (uintptr_t)map; 404 req->u.connect.len = addr_len; 405 req->u.connect.flags = flags; 406 req->u.connect.ref = map->active.ref; 407 req->u.connect.evtchn = evtchn; 408 memcpy(req->u.connect.addr, addr, sizeof(*addr)); 409 410 map->sock = sock; 411 412 bedata->ring.req_prod_pvt++; 413 RING_PUSH_REQUESTS_AND_CHECK_NOTIFY(&bedata->ring, notify); 414 spin_unlock(&bedata->socket_lock); 415 416 if (notify) 417 notify_remote_via_irq(bedata->irq); 418 419 wait_event(bedata->inflight_req, 420 READ_ONCE(bedata->rsp[req_id].req_id) == req_id); 421 422 /* read req_id, then the content */ 423 smp_rmb(); 424 ret = bedata->rsp[req_id].ret; 425 bedata->rsp[req_id].req_id = PVCALLS_INVALID_ID; 426 pvcalls_exit(); 427 return ret; 428 } 429 430 static int __write_ring(struct pvcalls_data_intf *intf, 431 struct pvcalls_data *data, 432 struct iov_iter *msg_iter, 433 int len) 434 { 435 RING_IDX cons, prod, size, masked_prod, masked_cons; 436 RING_IDX array_size = XEN_FLEX_RING_SIZE(PVCALLS_RING_ORDER); 437 int32_t error; 438 439 error = intf->out_error; 440 if (error < 0) 441 return error; 442 cons = intf->out_cons; 443 prod = intf->out_prod; 444 /* read indexes before continuing */ 445 virt_mb(); 446 447 size = pvcalls_queued(prod, cons, array_size); 448 if (size >= array_size) 449 return -EINVAL; 450 if (len > array_size - size) 451 len = array_size - size; 452 453 masked_prod = pvcalls_mask(prod, array_size); 454 masked_cons = pvcalls_mask(cons, array_size); 455 456 if (masked_prod < masked_cons) { 457 len = copy_from_iter(data->out + masked_prod, len, msg_iter); 458 } else { 459 if (len > array_size - masked_prod) { 460 int ret = copy_from_iter(data->out + masked_prod, 461 array_size - masked_prod, msg_iter); 462 if (ret != array_size - masked_prod) { 463 len = ret; 464 goto out; 465 } 466 len = ret + copy_from_iter(data->out, len - ret, msg_iter); 467 } else { 468 len = copy_from_iter(data->out + masked_prod, len, msg_iter); 469 } 470 } 471 out: 472 /* write to ring before updating pointer */ 473 virt_wmb(); 474 intf->out_prod += len; 475 476 return len; 477 } 478 479 int pvcalls_front_sendmsg(struct socket *sock, struct msghdr *msg, 480 size_t len) 481 { 482 struct pvcalls_bedata *bedata; 483 struct sock_mapping *map; 484 int sent, tot_sent = 0; 485 int count = 0, flags; 486 487 flags = msg->msg_flags; 488 if (flags & (MSG_CONFIRM|MSG_DONTROUTE|MSG_EOR|MSG_OOB)) 489 return -EOPNOTSUPP; 490 491 pvcalls_enter(); 492 if (!pvcalls_front_dev) { 493 pvcalls_exit(); 494 return -ENOTCONN; 495 } 496 bedata = dev_get_drvdata(&pvcalls_front_dev->dev); 497 498 map = (struct sock_mapping *) sock->sk->sk_send_head; 499 if (!map) { 500 pvcalls_exit(); 501 return -ENOTSOCK; 502 } 503 504 mutex_lock(&map->active.out_mutex); 505 if ((flags & MSG_DONTWAIT) && !pvcalls_front_write_todo(map)) { 506 mutex_unlock(&map->active.out_mutex); 507 pvcalls_exit(); 508 return -EAGAIN; 509 } 510 if (len > INT_MAX) 511 len = INT_MAX; 512 513 again: 514 count++; 515 sent = __write_ring(map->active.ring, 516 &map->active.data, &msg->msg_iter, 517 len); 518 if (sent > 0) { 519 len -= sent; 520 tot_sent += sent; 521 notify_remote_via_irq(map->active.irq); 522 } 523 if (sent >= 0 && len > 0 && count < PVCALLS_FRONT_MAX_SPIN) 524 goto again; 525 if (sent < 0) 526 tot_sent = sent; 527 528 mutex_unlock(&map->active.out_mutex); 529 pvcalls_exit(); 530 return tot_sent; 531 } 532 533 static int __read_ring(struct pvcalls_data_intf *intf, 534 struct pvcalls_data *data, 535 struct iov_iter *msg_iter, 536 size_t len, int flags) 537 { 538 RING_IDX cons, prod, size, masked_prod, masked_cons; 539 RING_IDX array_size = XEN_FLEX_RING_SIZE(PVCALLS_RING_ORDER); 540 int32_t error; 541 542 cons = intf->in_cons; 543 prod = intf->in_prod; 544 error = intf->in_error; 545 /* get pointers before reading from the ring */ 546 virt_rmb(); 547 if (error < 0) 548 return error; 549 550 size = pvcalls_queued(prod, cons, array_size); 551 masked_prod = pvcalls_mask(prod, array_size); 552 masked_cons = pvcalls_mask(cons, array_size); 553 554 if (size == 0) 555 return 0; 556 557 if (len > size) 558 len = size; 559 560 if (masked_prod > masked_cons) { 561 len = copy_to_iter(data->in + masked_cons, len, msg_iter); 562 } else { 563 if (len > (array_size - masked_cons)) { 564 int ret = copy_to_iter(data->in + masked_cons, 565 array_size - masked_cons, msg_iter); 566 if (ret != array_size - masked_cons) { 567 len = ret; 568 goto out; 569 } 570 len = ret + copy_to_iter(data->in, len - ret, msg_iter); 571 } else { 572 len = copy_to_iter(data->in + masked_cons, len, msg_iter); 573 } 574 } 575 out: 576 /* read data from the ring before increasing the index */ 577 virt_mb(); 578 if (!(flags & MSG_PEEK)) 579 intf->in_cons += len; 580 581 return len; 582 } 583 584 int pvcalls_front_recvmsg(struct socket *sock, struct msghdr *msg, size_t len, 585 int flags) 586 { 587 struct pvcalls_bedata *bedata; 588 int ret; 589 struct sock_mapping *map; 590 591 if (flags & (MSG_CMSG_CLOEXEC|MSG_ERRQUEUE|MSG_OOB|MSG_TRUNC)) 592 return -EOPNOTSUPP; 593 594 pvcalls_enter(); 595 if (!pvcalls_front_dev) { 596 pvcalls_exit(); 597 return -ENOTCONN; 598 } 599 bedata = dev_get_drvdata(&pvcalls_front_dev->dev); 600 601 map = (struct sock_mapping *) sock->sk->sk_send_head; 602 if (!map) { 603 pvcalls_exit(); 604 return -ENOTSOCK; 605 } 606 607 mutex_lock(&map->active.in_mutex); 608 if (len > XEN_FLEX_RING_SIZE(PVCALLS_RING_ORDER)) 609 len = XEN_FLEX_RING_SIZE(PVCALLS_RING_ORDER); 610 611 while (!(flags & MSG_DONTWAIT) && !pvcalls_front_read_todo(map)) { 612 wait_event_interruptible(map->active.inflight_conn_req, 613 pvcalls_front_read_todo(map)); 614 } 615 ret = __read_ring(map->active.ring, &map->active.data, 616 &msg->msg_iter, len, flags); 617 618 if (ret > 0) 619 notify_remote_via_irq(map->active.irq); 620 if (ret == 0) 621 ret = (flags & MSG_DONTWAIT) ? -EAGAIN : 0; 622 if (ret == -ENOTCONN) 623 ret = 0; 624 625 mutex_unlock(&map->active.in_mutex); 626 pvcalls_exit(); 627 return ret; 628 } 629 630 int pvcalls_front_bind(struct socket *sock, struct sockaddr *addr, int addr_len) 631 { 632 struct pvcalls_bedata *bedata; 633 struct sock_mapping *map = NULL; 634 struct xen_pvcalls_request *req; 635 int notify, req_id, ret; 636 637 if (addr->sa_family != AF_INET || sock->type != SOCK_STREAM) 638 return -EOPNOTSUPP; 639 640 pvcalls_enter(); 641 if (!pvcalls_front_dev) { 642 pvcalls_exit(); 643 return -ENOTCONN; 644 } 645 bedata = dev_get_drvdata(&pvcalls_front_dev->dev); 646 647 map = (struct sock_mapping *) sock->sk->sk_send_head; 648 if (map == NULL) { 649 pvcalls_exit(); 650 return -ENOTSOCK; 651 } 652 653 spin_lock(&bedata->socket_lock); 654 ret = get_request(bedata, &req_id); 655 if (ret < 0) { 656 spin_unlock(&bedata->socket_lock); 657 pvcalls_exit(); 658 return ret; 659 } 660 req = RING_GET_REQUEST(&bedata->ring, req_id); 661 req->req_id = req_id; 662 map->sock = sock; 663 req->cmd = PVCALLS_BIND; 664 req->u.bind.id = (uintptr_t)map; 665 memcpy(req->u.bind.addr, addr, sizeof(*addr)); 666 req->u.bind.len = addr_len; 667 668 init_waitqueue_head(&map->passive.inflight_accept_req); 669 670 map->active_socket = false; 671 672 bedata->ring.req_prod_pvt++; 673 RING_PUSH_REQUESTS_AND_CHECK_NOTIFY(&bedata->ring, notify); 674 spin_unlock(&bedata->socket_lock); 675 if (notify) 676 notify_remote_via_irq(bedata->irq); 677 678 wait_event(bedata->inflight_req, 679 READ_ONCE(bedata->rsp[req_id].req_id) == req_id); 680 681 /* read req_id, then the content */ 682 smp_rmb(); 683 ret = bedata->rsp[req_id].ret; 684 bedata->rsp[req_id].req_id = PVCALLS_INVALID_ID; 685 686 map->passive.status = PVCALLS_STATUS_BIND; 687 pvcalls_exit(); 688 return 0; 689 } 690 691 int pvcalls_front_listen(struct socket *sock, int backlog) 692 { 693 struct pvcalls_bedata *bedata; 694 struct sock_mapping *map; 695 struct xen_pvcalls_request *req; 696 int notify, req_id, ret; 697 698 pvcalls_enter(); 699 if (!pvcalls_front_dev) { 700 pvcalls_exit(); 701 return -ENOTCONN; 702 } 703 bedata = dev_get_drvdata(&pvcalls_front_dev->dev); 704 705 map = (struct sock_mapping *) sock->sk->sk_send_head; 706 if (!map) { 707 pvcalls_exit(); 708 return -ENOTSOCK; 709 } 710 711 if (map->passive.status != PVCALLS_STATUS_BIND) { 712 pvcalls_exit(); 713 return -EOPNOTSUPP; 714 } 715 716 spin_lock(&bedata->socket_lock); 717 ret = get_request(bedata, &req_id); 718 if (ret < 0) { 719 spin_unlock(&bedata->socket_lock); 720 pvcalls_exit(); 721 return ret; 722 } 723 req = RING_GET_REQUEST(&bedata->ring, req_id); 724 req->req_id = req_id; 725 req->cmd = PVCALLS_LISTEN; 726 req->u.listen.id = (uintptr_t) map; 727 req->u.listen.backlog = backlog; 728 729 bedata->ring.req_prod_pvt++; 730 RING_PUSH_REQUESTS_AND_CHECK_NOTIFY(&bedata->ring, notify); 731 spin_unlock(&bedata->socket_lock); 732 if (notify) 733 notify_remote_via_irq(bedata->irq); 734 735 wait_event(bedata->inflight_req, 736 READ_ONCE(bedata->rsp[req_id].req_id) == req_id); 737 738 /* read req_id, then the content */ 739 smp_rmb(); 740 ret = bedata->rsp[req_id].ret; 741 bedata->rsp[req_id].req_id = PVCALLS_INVALID_ID; 742 743 map->passive.status = PVCALLS_STATUS_LISTEN; 744 pvcalls_exit(); 745 return ret; 746 } 747 748 int pvcalls_front_accept(struct socket *sock, struct socket *newsock, int flags) 749 { 750 struct pvcalls_bedata *bedata; 751 struct sock_mapping *map; 752 struct sock_mapping *map2 = NULL; 753 struct xen_pvcalls_request *req; 754 int notify, req_id, ret, evtchn, nonblock; 755 756 pvcalls_enter(); 757 if (!pvcalls_front_dev) { 758 pvcalls_exit(); 759 return -ENOTCONN; 760 } 761 bedata = dev_get_drvdata(&pvcalls_front_dev->dev); 762 763 map = (struct sock_mapping *) sock->sk->sk_send_head; 764 if (!map) { 765 pvcalls_exit(); 766 return -ENOTSOCK; 767 } 768 769 if (map->passive.status != PVCALLS_STATUS_LISTEN) { 770 pvcalls_exit(); 771 return -EINVAL; 772 } 773 774 nonblock = flags & SOCK_NONBLOCK; 775 /* 776 * Backend only supports 1 inflight accept request, will return 777 * errors for the others 778 */ 779 if (test_and_set_bit(PVCALLS_FLAG_ACCEPT_INFLIGHT, 780 (void *)&map->passive.flags)) { 781 req_id = READ_ONCE(map->passive.inflight_req_id); 782 if (req_id != PVCALLS_INVALID_ID && 783 READ_ONCE(bedata->rsp[req_id].req_id) == req_id) { 784 map2 = map->passive.accept_map; 785 goto received; 786 } 787 if (nonblock) { 788 pvcalls_exit(); 789 return -EAGAIN; 790 } 791 if (wait_event_interruptible(map->passive.inflight_accept_req, 792 !test_and_set_bit(PVCALLS_FLAG_ACCEPT_INFLIGHT, 793 (void *)&map->passive.flags))) { 794 pvcalls_exit(); 795 return -EINTR; 796 } 797 } 798 799 spin_lock(&bedata->socket_lock); 800 ret = get_request(bedata, &req_id); 801 if (ret < 0) { 802 clear_bit(PVCALLS_FLAG_ACCEPT_INFLIGHT, 803 (void *)&map->passive.flags); 804 spin_unlock(&bedata->socket_lock); 805 pvcalls_exit(); 806 return ret; 807 } 808 map2 = kzalloc(sizeof(*map2), GFP_ATOMIC); 809 if (map2 == NULL) { 810 clear_bit(PVCALLS_FLAG_ACCEPT_INFLIGHT, 811 (void *)&map->passive.flags); 812 spin_unlock(&bedata->socket_lock); 813 pvcalls_exit(); 814 return -ENOMEM; 815 } 816 ret = create_active(map2, &evtchn); 817 if (ret < 0) { 818 kfree(map2); 819 clear_bit(PVCALLS_FLAG_ACCEPT_INFLIGHT, 820 (void *)&map->passive.flags); 821 spin_unlock(&bedata->socket_lock); 822 pvcalls_exit(); 823 return ret; 824 } 825 list_add_tail(&map2->list, &bedata->socket_mappings); 826 827 req = RING_GET_REQUEST(&bedata->ring, req_id); 828 req->req_id = req_id; 829 req->cmd = PVCALLS_ACCEPT; 830 req->u.accept.id = (uintptr_t) map; 831 req->u.accept.ref = map2->active.ref; 832 req->u.accept.id_new = (uintptr_t) map2; 833 req->u.accept.evtchn = evtchn; 834 map->passive.accept_map = map2; 835 836 bedata->ring.req_prod_pvt++; 837 RING_PUSH_REQUESTS_AND_CHECK_NOTIFY(&bedata->ring, notify); 838 spin_unlock(&bedata->socket_lock); 839 if (notify) 840 notify_remote_via_irq(bedata->irq); 841 /* We could check if we have received a response before returning. */ 842 if (nonblock) { 843 WRITE_ONCE(map->passive.inflight_req_id, req_id); 844 pvcalls_exit(); 845 return -EAGAIN; 846 } 847 848 if (wait_event_interruptible(bedata->inflight_req, 849 READ_ONCE(bedata->rsp[req_id].req_id) == req_id)) { 850 pvcalls_exit(); 851 return -EINTR; 852 } 853 /* read req_id, then the content */ 854 smp_rmb(); 855 856 received: 857 map2->sock = newsock; 858 newsock->sk = kzalloc(sizeof(*newsock->sk), GFP_KERNEL); 859 if (!newsock->sk) { 860 bedata->rsp[req_id].req_id = PVCALLS_INVALID_ID; 861 map->passive.inflight_req_id = PVCALLS_INVALID_ID; 862 clear_bit(PVCALLS_FLAG_ACCEPT_INFLIGHT, 863 (void *)&map->passive.flags); 864 pvcalls_front_free_map(bedata, map2); 865 pvcalls_exit(); 866 return -ENOMEM; 867 } 868 newsock->sk->sk_send_head = (void *)map2; 869 870 ret = bedata->rsp[req_id].ret; 871 bedata->rsp[req_id].req_id = PVCALLS_INVALID_ID; 872 map->passive.inflight_req_id = PVCALLS_INVALID_ID; 873 874 clear_bit(PVCALLS_FLAG_ACCEPT_INFLIGHT, (void *)&map->passive.flags); 875 wake_up(&map->passive.inflight_accept_req); 876 877 pvcalls_exit(); 878 return ret; 879 } 880 881 static __poll_t pvcalls_front_poll_passive(struct file *file, 882 struct pvcalls_bedata *bedata, 883 struct sock_mapping *map, 884 poll_table *wait) 885 { 886 int notify, req_id, ret; 887 struct xen_pvcalls_request *req; 888 889 if (test_bit(PVCALLS_FLAG_ACCEPT_INFLIGHT, 890 (void *)&map->passive.flags)) { 891 uint32_t req_id = READ_ONCE(map->passive.inflight_req_id); 892 893 if (req_id != PVCALLS_INVALID_ID && 894 READ_ONCE(bedata->rsp[req_id].req_id) == req_id) 895 return EPOLLIN | EPOLLRDNORM; 896 897 poll_wait(file, &map->passive.inflight_accept_req, wait); 898 return 0; 899 } 900 901 if (test_and_clear_bit(PVCALLS_FLAG_POLL_RET, 902 (void *)&map->passive.flags)) 903 return EPOLLIN | EPOLLRDNORM; 904 905 /* 906 * First check RET, then INFLIGHT. No barriers necessary to 907 * ensure execution ordering because of the conditional 908 * instructions creating control dependencies. 909 */ 910 911 if (test_and_set_bit(PVCALLS_FLAG_POLL_INFLIGHT, 912 (void *)&map->passive.flags)) { 913 poll_wait(file, &bedata->inflight_req, wait); 914 return 0; 915 } 916 917 spin_lock(&bedata->socket_lock); 918 ret = get_request(bedata, &req_id); 919 if (ret < 0) { 920 spin_unlock(&bedata->socket_lock); 921 return ret; 922 } 923 req = RING_GET_REQUEST(&bedata->ring, req_id); 924 req->req_id = req_id; 925 req->cmd = PVCALLS_POLL; 926 req->u.poll.id = (uintptr_t) map; 927 928 bedata->ring.req_prod_pvt++; 929 RING_PUSH_REQUESTS_AND_CHECK_NOTIFY(&bedata->ring, notify); 930 spin_unlock(&bedata->socket_lock); 931 if (notify) 932 notify_remote_via_irq(bedata->irq); 933 934 poll_wait(file, &bedata->inflight_req, wait); 935 return 0; 936 } 937 938 static __poll_t pvcalls_front_poll_active(struct file *file, 939 struct pvcalls_bedata *bedata, 940 struct sock_mapping *map, 941 poll_table *wait) 942 { 943 __poll_t mask = 0; 944 int32_t in_error, out_error; 945 struct pvcalls_data_intf *intf = map->active.ring; 946 947 out_error = intf->out_error; 948 in_error = intf->in_error; 949 950 poll_wait(file, &map->active.inflight_conn_req, wait); 951 if (pvcalls_front_write_todo(map)) 952 mask |= EPOLLOUT | EPOLLWRNORM; 953 if (pvcalls_front_read_todo(map)) 954 mask |= EPOLLIN | EPOLLRDNORM; 955 if (in_error != 0 || out_error != 0) 956 mask |= EPOLLERR; 957 958 return mask; 959 } 960 961 __poll_t pvcalls_front_poll(struct file *file, struct socket *sock, 962 poll_table *wait) 963 { 964 struct pvcalls_bedata *bedata; 965 struct sock_mapping *map; 966 __poll_t ret; 967 968 pvcalls_enter(); 969 if (!pvcalls_front_dev) { 970 pvcalls_exit(); 971 return EPOLLNVAL; 972 } 973 bedata = dev_get_drvdata(&pvcalls_front_dev->dev); 974 975 map = (struct sock_mapping *) sock->sk->sk_send_head; 976 if (!map) { 977 pvcalls_exit(); 978 return EPOLLNVAL; 979 } 980 if (map->active_socket) 981 ret = pvcalls_front_poll_active(file, bedata, map, wait); 982 else 983 ret = pvcalls_front_poll_passive(file, bedata, map, wait); 984 pvcalls_exit(); 985 return ret; 986 } 987 988 int pvcalls_front_release(struct socket *sock) 989 { 990 struct pvcalls_bedata *bedata; 991 struct sock_mapping *map; 992 int req_id, notify, ret; 993 struct xen_pvcalls_request *req; 994 995 if (sock->sk == NULL) 996 return 0; 997 998 pvcalls_enter(); 999 if (!pvcalls_front_dev) { 1000 pvcalls_exit(); 1001 return -EIO; 1002 } 1003 1004 bedata = dev_get_drvdata(&pvcalls_front_dev->dev); 1005 1006 map = (struct sock_mapping *) sock->sk->sk_send_head; 1007 if (map == NULL) { 1008 pvcalls_exit(); 1009 return 0; 1010 } 1011 1012 spin_lock(&bedata->socket_lock); 1013 ret = get_request(bedata, &req_id); 1014 if (ret < 0) { 1015 spin_unlock(&bedata->socket_lock); 1016 pvcalls_exit(); 1017 return ret; 1018 } 1019 sock->sk->sk_send_head = NULL; 1020 1021 req = RING_GET_REQUEST(&bedata->ring, req_id); 1022 req->req_id = req_id; 1023 req->cmd = PVCALLS_RELEASE; 1024 req->u.release.id = (uintptr_t)map; 1025 1026 bedata->ring.req_prod_pvt++; 1027 RING_PUSH_REQUESTS_AND_CHECK_NOTIFY(&bedata->ring, notify); 1028 spin_unlock(&bedata->socket_lock); 1029 if (notify) 1030 notify_remote_via_irq(bedata->irq); 1031 1032 wait_event(bedata->inflight_req, 1033 READ_ONCE(bedata->rsp[req_id].req_id) == req_id); 1034 1035 if (map->active_socket) { 1036 /* 1037 * Set in_error and wake up inflight_conn_req to force 1038 * recvmsg waiters to exit. 1039 */ 1040 map->active.ring->in_error = -EBADF; 1041 wake_up_interruptible(&map->active.inflight_conn_req); 1042 1043 /* 1044 * We need to make sure that sendmsg/recvmsg on this socket have 1045 * not started before we've cleared sk_send_head here. The 1046 * easiest (though not optimal) way to guarantee this is to see 1047 * that no pvcall (other than us) is in progress. 1048 */ 1049 while (atomic_read(&pvcalls_refcount) > 1) 1050 cpu_relax(); 1051 1052 pvcalls_front_free_map(bedata, map); 1053 } else { 1054 spin_lock(&bedata->socket_lock); 1055 list_del(&map->list); 1056 spin_unlock(&bedata->socket_lock); 1057 if (READ_ONCE(map->passive.inflight_req_id) != 1058 PVCALLS_INVALID_ID) { 1059 pvcalls_front_free_map(bedata, 1060 map->passive.accept_map); 1061 } 1062 kfree(map); 1063 } 1064 WRITE_ONCE(bedata->rsp[req_id].req_id, PVCALLS_INVALID_ID); 1065 1066 pvcalls_exit(); 1067 return 0; 1068 } 1069 1070 static const struct xenbus_device_id pvcalls_front_ids[] = { 1071 { "pvcalls" }, 1072 { "" } 1073 }; 1074 1075 static int pvcalls_front_remove(struct xenbus_device *dev) 1076 { 1077 struct pvcalls_bedata *bedata; 1078 struct sock_mapping *map = NULL, *n; 1079 1080 bedata = dev_get_drvdata(&pvcalls_front_dev->dev); 1081 dev_set_drvdata(&dev->dev, NULL); 1082 pvcalls_front_dev = NULL; 1083 if (bedata->irq >= 0) 1084 unbind_from_irqhandler(bedata->irq, dev); 1085 1086 list_for_each_entry_safe(map, n, &bedata->socket_mappings, list) { 1087 map->sock->sk->sk_send_head = NULL; 1088 if (map->active_socket) { 1089 map->active.ring->in_error = -EBADF; 1090 wake_up_interruptible(&map->active.inflight_conn_req); 1091 } 1092 } 1093 1094 smp_mb(); 1095 while (atomic_read(&pvcalls_refcount) > 0) 1096 cpu_relax(); 1097 list_for_each_entry_safe(map, n, &bedata->socket_mappings, list) { 1098 if (map->active_socket) { 1099 /* No need to lock, refcount is 0 */ 1100 pvcalls_front_free_map(bedata, map); 1101 } else { 1102 list_del(&map->list); 1103 kfree(map); 1104 } 1105 } 1106 if (bedata->ref != -1) 1107 gnttab_end_foreign_access(bedata->ref, 0, 0); 1108 kfree(bedata->ring.sring); 1109 kfree(bedata); 1110 xenbus_switch_state(dev, XenbusStateClosed); 1111 return 0; 1112 } 1113 1114 static int pvcalls_front_probe(struct xenbus_device *dev, 1115 const struct xenbus_device_id *id) 1116 { 1117 int ret = -ENOMEM, evtchn, i; 1118 unsigned int max_page_order, function_calls, len; 1119 char *versions; 1120 grant_ref_t gref_head = 0; 1121 struct xenbus_transaction xbt; 1122 struct pvcalls_bedata *bedata = NULL; 1123 struct xen_pvcalls_sring *sring; 1124 1125 if (pvcalls_front_dev != NULL) { 1126 dev_err(&dev->dev, "only one PV Calls connection supported\n"); 1127 return -EINVAL; 1128 } 1129 1130 versions = xenbus_read(XBT_NIL, dev->otherend, "versions", &len); 1131 if (IS_ERR(versions)) 1132 return PTR_ERR(versions); 1133 if (!len) 1134 return -EINVAL; 1135 if (strcmp(versions, "1")) { 1136 kfree(versions); 1137 return -EINVAL; 1138 } 1139 kfree(versions); 1140 max_page_order = xenbus_read_unsigned(dev->otherend, 1141 "max-page-order", 0); 1142 if (max_page_order < PVCALLS_RING_ORDER) 1143 return -ENODEV; 1144 function_calls = xenbus_read_unsigned(dev->otherend, 1145 "function-calls", 0); 1146 /* See XENBUS_FUNCTIONS_CALLS in pvcalls.h */ 1147 if (function_calls != 1) 1148 return -ENODEV; 1149 pr_info("%s max-page-order is %u\n", __func__, max_page_order); 1150 1151 bedata = kzalloc(sizeof(struct pvcalls_bedata), GFP_KERNEL); 1152 if (!bedata) 1153 return -ENOMEM; 1154 1155 dev_set_drvdata(&dev->dev, bedata); 1156 pvcalls_front_dev = dev; 1157 init_waitqueue_head(&bedata->inflight_req); 1158 INIT_LIST_HEAD(&bedata->socket_mappings); 1159 spin_lock_init(&bedata->socket_lock); 1160 bedata->irq = -1; 1161 bedata->ref = -1; 1162 1163 for (i = 0; i < PVCALLS_NR_RSP_PER_RING; i++) 1164 bedata->rsp[i].req_id = PVCALLS_INVALID_ID; 1165 1166 sring = (struct xen_pvcalls_sring *) __get_free_page(GFP_KERNEL | 1167 __GFP_ZERO); 1168 if (!sring) 1169 goto error; 1170 SHARED_RING_INIT(sring); 1171 FRONT_RING_INIT(&bedata->ring, sring, XEN_PAGE_SIZE); 1172 1173 ret = xenbus_alloc_evtchn(dev, &evtchn); 1174 if (ret) 1175 goto error; 1176 1177 bedata->irq = bind_evtchn_to_irqhandler(evtchn, 1178 pvcalls_front_event_handler, 1179 0, "pvcalls-frontend", dev); 1180 if (bedata->irq < 0) { 1181 ret = bedata->irq; 1182 goto error; 1183 } 1184 1185 ret = gnttab_alloc_grant_references(1, &gref_head); 1186 if (ret < 0) 1187 goto error; 1188 ret = gnttab_claim_grant_reference(&gref_head); 1189 if (ret < 0) 1190 goto error; 1191 bedata->ref = ret; 1192 gnttab_grant_foreign_access_ref(bedata->ref, dev->otherend_id, 1193 virt_to_gfn((void *)sring), 0); 1194 1195 again: 1196 ret = xenbus_transaction_start(&xbt); 1197 if (ret) { 1198 xenbus_dev_fatal(dev, ret, "starting transaction"); 1199 goto error; 1200 } 1201 ret = xenbus_printf(xbt, dev->nodename, "version", "%u", 1); 1202 if (ret) 1203 goto error_xenbus; 1204 ret = xenbus_printf(xbt, dev->nodename, "ring-ref", "%d", bedata->ref); 1205 if (ret) 1206 goto error_xenbus; 1207 ret = xenbus_printf(xbt, dev->nodename, "port", "%u", 1208 evtchn); 1209 if (ret) 1210 goto error_xenbus; 1211 ret = xenbus_transaction_end(xbt, 0); 1212 if (ret) { 1213 if (ret == -EAGAIN) 1214 goto again; 1215 xenbus_dev_fatal(dev, ret, "completing transaction"); 1216 goto error; 1217 } 1218 xenbus_switch_state(dev, XenbusStateInitialised); 1219 1220 return 0; 1221 1222 error_xenbus: 1223 xenbus_transaction_end(xbt, 1); 1224 xenbus_dev_fatal(dev, ret, "writing xenstore"); 1225 error: 1226 pvcalls_front_remove(dev); 1227 return ret; 1228 } 1229 1230 static void pvcalls_front_changed(struct xenbus_device *dev, 1231 enum xenbus_state backend_state) 1232 { 1233 switch (backend_state) { 1234 case XenbusStateReconfiguring: 1235 case XenbusStateReconfigured: 1236 case XenbusStateInitialising: 1237 case XenbusStateInitialised: 1238 case XenbusStateUnknown: 1239 break; 1240 1241 case XenbusStateInitWait: 1242 break; 1243 1244 case XenbusStateConnected: 1245 xenbus_switch_state(dev, XenbusStateConnected); 1246 break; 1247 1248 case XenbusStateClosed: 1249 if (dev->state == XenbusStateClosed) 1250 break; 1251 /* Missed the backend's CLOSING state */ 1252 /* fall through */ 1253 case XenbusStateClosing: 1254 xenbus_frontend_closed(dev); 1255 break; 1256 } 1257 } 1258 1259 static struct xenbus_driver pvcalls_front_driver = { 1260 .ids = pvcalls_front_ids, 1261 .probe = pvcalls_front_probe, 1262 .remove = pvcalls_front_remove, 1263 .otherend_changed = pvcalls_front_changed, 1264 }; 1265 1266 static int __init pvcalls_frontend_init(void) 1267 { 1268 if (!xen_domain()) 1269 return -ENODEV; 1270 1271 pr_info("Initialising Xen pvcalls frontend driver\n"); 1272 1273 return xenbus_register_frontend(&pvcalls_front_driver); 1274 } 1275 1276 module_init(pvcalls_frontend_init); 1277 1278 MODULE_DESCRIPTION("Xen PV Calls frontend driver"); 1279 MODULE_AUTHOR("Stefano Stabellini <sstabellini@kernel.org>"); 1280 MODULE_LICENSE("GPL"); 1281