1 /* 2 * (c) 2017 Stefano Stabellini <stefano@aporeto.com> 3 * 4 * This program is free software; you can redistribute it and/or modify 5 * it under the terms of the GNU General Public License as published by 6 * the Free Software Foundation; either version 2 of the License, or 7 * (at your option) any later version. 8 * 9 * This program is distributed in the hope that it will be useful, 10 * but WITHOUT ANY WARRANTY; without even the implied warranty of 11 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 12 * GNU General Public License for more details. 13 */ 14 15 #include <linux/module.h> 16 #include <linux/net.h> 17 #include <linux/socket.h> 18 19 #include <net/sock.h> 20 21 #include <xen/events.h> 22 #include <xen/grant_table.h> 23 #include <xen/xen.h> 24 #include <xen/xenbus.h> 25 #include <xen/interface/io/pvcalls.h> 26 27 #include "pvcalls-front.h" 28 29 #define PVCALLS_INVALID_ID UINT_MAX 30 #define PVCALLS_RING_ORDER XENBUS_MAX_RING_GRANT_ORDER 31 #define PVCALLS_NR_RSP_PER_RING __CONST_RING_SIZE(xen_pvcalls, XEN_PAGE_SIZE) 32 #define PVCALLS_FRONT_MAX_SPIN 5000 33 34 struct pvcalls_bedata { 35 struct xen_pvcalls_front_ring ring; 36 grant_ref_t ref; 37 int irq; 38 39 struct list_head socket_mappings; 40 spinlock_t socket_lock; 41 42 wait_queue_head_t inflight_req; 43 struct xen_pvcalls_response rsp[PVCALLS_NR_RSP_PER_RING]; 44 }; 45 /* Only one front/back connection supported. */ 46 static struct xenbus_device *pvcalls_front_dev; 47 static atomic_t pvcalls_refcount; 48 49 /* first increment refcount, then proceed */ 50 #define pvcalls_enter() { \ 51 atomic_inc(&pvcalls_refcount); \ 52 } 53 54 /* first complete other operations, then decrement refcount */ 55 #define pvcalls_exit() { \ 56 atomic_dec(&pvcalls_refcount); \ 57 } 58 59 struct sock_mapping { 60 bool active_socket; 61 struct list_head list; 62 struct socket *sock; 63 atomic_t refcount; 64 union { 65 struct { 66 int irq; 67 grant_ref_t ref; 68 struct pvcalls_data_intf *ring; 69 struct pvcalls_data data; 70 struct mutex in_mutex; 71 struct mutex out_mutex; 72 73 wait_queue_head_t inflight_conn_req; 74 } active; 75 struct { 76 /* Socket status */ 77 #define PVCALLS_STATUS_UNINITALIZED 0 78 #define PVCALLS_STATUS_BIND 1 79 #define PVCALLS_STATUS_LISTEN 2 80 uint8_t status; 81 /* 82 * Internal state-machine flags. 83 * Only one accept operation can be inflight for a socket. 84 * Only one poll operation can be inflight for a given socket. 85 */ 86 #define PVCALLS_FLAG_ACCEPT_INFLIGHT 0 87 #define PVCALLS_FLAG_POLL_INFLIGHT 1 88 #define PVCALLS_FLAG_POLL_RET 2 89 uint8_t flags; 90 uint32_t inflight_req_id; 91 struct sock_mapping *accept_map; 92 wait_queue_head_t inflight_accept_req; 93 } passive; 94 }; 95 }; 96 97 static inline struct sock_mapping *pvcalls_enter_sock(struct socket *sock) 98 { 99 struct sock_mapping *map; 100 101 if (!pvcalls_front_dev || 102 dev_get_drvdata(&pvcalls_front_dev->dev) == NULL) 103 return ERR_PTR(-ENOTCONN); 104 105 map = (struct sock_mapping *)sock->sk->sk_send_head; 106 if (map == NULL) 107 return ERR_PTR(-ENOTSOCK); 108 109 pvcalls_enter(); 110 atomic_inc(&map->refcount); 111 return map; 112 } 113 114 static inline void pvcalls_exit_sock(struct socket *sock) 115 { 116 struct sock_mapping *map; 117 118 map = (struct sock_mapping *)sock->sk->sk_send_head; 119 atomic_dec(&map->refcount); 120 pvcalls_exit(); 121 } 122 123 static inline int get_request(struct pvcalls_bedata *bedata, int *req_id) 124 { 125 *req_id = bedata->ring.req_prod_pvt & (RING_SIZE(&bedata->ring) - 1); 126 if (RING_FULL(&bedata->ring) || 127 bedata->rsp[*req_id].req_id != PVCALLS_INVALID_ID) 128 return -EAGAIN; 129 return 0; 130 } 131 132 static bool pvcalls_front_write_todo(struct sock_mapping *map) 133 { 134 struct pvcalls_data_intf *intf = map->active.ring; 135 RING_IDX cons, prod, size = XEN_FLEX_RING_SIZE(PVCALLS_RING_ORDER); 136 int32_t error; 137 138 error = intf->out_error; 139 if (error == -ENOTCONN) 140 return false; 141 if (error != 0) 142 return true; 143 144 cons = intf->out_cons; 145 prod = intf->out_prod; 146 return !!(size - pvcalls_queued(prod, cons, size)); 147 } 148 149 static bool pvcalls_front_read_todo(struct sock_mapping *map) 150 { 151 struct pvcalls_data_intf *intf = map->active.ring; 152 RING_IDX cons, prod; 153 int32_t error; 154 155 cons = intf->in_cons; 156 prod = intf->in_prod; 157 error = intf->in_error; 158 return (error != 0 || 159 pvcalls_queued(prod, cons, 160 XEN_FLEX_RING_SIZE(PVCALLS_RING_ORDER)) != 0); 161 } 162 163 static irqreturn_t pvcalls_front_event_handler(int irq, void *dev_id) 164 { 165 struct xenbus_device *dev = dev_id; 166 struct pvcalls_bedata *bedata; 167 struct xen_pvcalls_response *rsp; 168 uint8_t *src, *dst; 169 int req_id = 0, more = 0, done = 0; 170 171 if (dev == NULL) 172 return IRQ_HANDLED; 173 174 pvcalls_enter(); 175 bedata = dev_get_drvdata(&dev->dev); 176 if (bedata == NULL) { 177 pvcalls_exit(); 178 return IRQ_HANDLED; 179 } 180 181 again: 182 while (RING_HAS_UNCONSUMED_RESPONSES(&bedata->ring)) { 183 rsp = RING_GET_RESPONSE(&bedata->ring, bedata->ring.rsp_cons); 184 185 req_id = rsp->req_id; 186 if (rsp->cmd == PVCALLS_POLL) { 187 struct sock_mapping *map = (struct sock_mapping *)(uintptr_t) 188 rsp->u.poll.id; 189 190 clear_bit(PVCALLS_FLAG_POLL_INFLIGHT, 191 (void *)&map->passive.flags); 192 /* 193 * clear INFLIGHT, then set RET. It pairs with 194 * the checks at the beginning of 195 * pvcalls_front_poll_passive. 196 */ 197 smp_wmb(); 198 set_bit(PVCALLS_FLAG_POLL_RET, 199 (void *)&map->passive.flags); 200 } else { 201 dst = (uint8_t *)&bedata->rsp[req_id] + 202 sizeof(rsp->req_id); 203 src = (uint8_t *)rsp + sizeof(rsp->req_id); 204 memcpy(dst, src, sizeof(*rsp) - sizeof(rsp->req_id)); 205 /* 206 * First copy the rest of the data, then req_id. It is 207 * paired with the barrier when accessing bedata->rsp. 208 */ 209 smp_wmb(); 210 bedata->rsp[req_id].req_id = req_id; 211 } 212 213 done = 1; 214 bedata->ring.rsp_cons++; 215 } 216 217 RING_FINAL_CHECK_FOR_RESPONSES(&bedata->ring, more); 218 if (more) 219 goto again; 220 if (done) 221 wake_up(&bedata->inflight_req); 222 pvcalls_exit(); 223 return IRQ_HANDLED; 224 } 225 226 static void pvcalls_front_free_map(struct pvcalls_bedata *bedata, 227 struct sock_mapping *map) 228 { 229 int i; 230 231 unbind_from_irqhandler(map->active.irq, map); 232 233 spin_lock(&bedata->socket_lock); 234 if (!list_empty(&map->list)) 235 list_del_init(&map->list); 236 spin_unlock(&bedata->socket_lock); 237 238 for (i = 0; i < (1 << PVCALLS_RING_ORDER); i++) 239 gnttab_end_foreign_access(map->active.ring->ref[i], 0, 0); 240 gnttab_end_foreign_access(map->active.ref, 0, 0); 241 free_page((unsigned long)map->active.ring); 242 243 kfree(map); 244 } 245 246 static irqreturn_t pvcalls_front_conn_handler(int irq, void *sock_map) 247 { 248 struct sock_mapping *map = sock_map; 249 250 if (map == NULL) 251 return IRQ_HANDLED; 252 253 wake_up_interruptible(&map->active.inflight_conn_req); 254 255 return IRQ_HANDLED; 256 } 257 258 int pvcalls_front_socket(struct socket *sock) 259 { 260 struct pvcalls_bedata *bedata; 261 struct sock_mapping *map = NULL; 262 struct xen_pvcalls_request *req; 263 int notify, req_id, ret; 264 265 /* 266 * PVCalls only supports domain AF_INET, 267 * type SOCK_STREAM and protocol 0 sockets for now. 268 * 269 * Check socket type here, AF_INET and protocol checks are done 270 * by the caller. 271 */ 272 if (sock->type != SOCK_STREAM) 273 return -EOPNOTSUPP; 274 275 pvcalls_enter(); 276 if (!pvcalls_front_dev) { 277 pvcalls_exit(); 278 return -EACCES; 279 } 280 bedata = dev_get_drvdata(&pvcalls_front_dev->dev); 281 282 map = kzalloc(sizeof(*map), GFP_KERNEL); 283 if (map == NULL) { 284 pvcalls_exit(); 285 return -ENOMEM; 286 } 287 288 spin_lock(&bedata->socket_lock); 289 290 ret = get_request(bedata, &req_id); 291 if (ret < 0) { 292 kfree(map); 293 spin_unlock(&bedata->socket_lock); 294 pvcalls_exit(); 295 return ret; 296 } 297 298 /* 299 * sock->sk->sk_send_head is not used for ip sockets: reuse the 300 * field to store a pointer to the struct sock_mapping 301 * corresponding to the socket. This way, we can easily get the 302 * struct sock_mapping from the struct socket. 303 */ 304 sock->sk->sk_send_head = (void *)map; 305 list_add_tail(&map->list, &bedata->socket_mappings); 306 307 req = RING_GET_REQUEST(&bedata->ring, req_id); 308 req->req_id = req_id; 309 req->cmd = PVCALLS_SOCKET; 310 req->u.socket.id = (uintptr_t) map; 311 req->u.socket.domain = AF_INET; 312 req->u.socket.type = SOCK_STREAM; 313 req->u.socket.protocol = IPPROTO_IP; 314 315 bedata->ring.req_prod_pvt++; 316 RING_PUSH_REQUESTS_AND_CHECK_NOTIFY(&bedata->ring, notify); 317 spin_unlock(&bedata->socket_lock); 318 if (notify) 319 notify_remote_via_irq(bedata->irq); 320 321 wait_event(bedata->inflight_req, 322 READ_ONCE(bedata->rsp[req_id].req_id) == req_id); 323 324 /* read req_id, then the content */ 325 smp_rmb(); 326 ret = bedata->rsp[req_id].ret; 327 bedata->rsp[req_id].req_id = PVCALLS_INVALID_ID; 328 329 pvcalls_exit(); 330 return ret; 331 } 332 333 static int create_active(struct sock_mapping *map, int *evtchn) 334 { 335 void *bytes; 336 int ret = -ENOMEM, irq = -1, i; 337 338 *evtchn = -1; 339 init_waitqueue_head(&map->active.inflight_conn_req); 340 341 map->active.ring = (struct pvcalls_data_intf *) 342 __get_free_page(GFP_KERNEL | __GFP_ZERO); 343 if (map->active.ring == NULL) 344 goto out_error; 345 map->active.ring->ring_order = PVCALLS_RING_ORDER; 346 bytes = (void *)__get_free_pages(GFP_KERNEL | __GFP_ZERO, 347 PVCALLS_RING_ORDER); 348 if (bytes == NULL) 349 goto out_error; 350 for (i = 0; i < (1 << PVCALLS_RING_ORDER); i++) 351 map->active.ring->ref[i] = gnttab_grant_foreign_access( 352 pvcalls_front_dev->otherend_id, 353 pfn_to_gfn(virt_to_pfn(bytes) + i), 0); 354 355 map->active.ref = gnttab_grant_foreign_access( 356 pvcalls_front_dev->otherend_id, 357 pfn_to_gfn(virt_to_pfn((void *)map->active.ring)), 0); 358 359 map->active.data.in = bytes; 360 map->active.data.out = bytes + 361 XEN_FLEX_RING_SIZE(PVCALLS_RING_ORDER); 362 363 ret = xenbus_alloc_evtchn(pvcalls_front_dev, evtchn); 364 if (ret) 365 goto out_error; 366 irq = bind_evtchn_to_irqhandler(*evtchn, pvcalls_front_conn_handler, 367 0, "pvcalls-frontend", map); 368 if (irq < 0) { 369 ret = irq; 370 goto out_error; 371 } 372 373 map->active.irq = irq; 374 map->active_socket = true; 375 mutex_init(&map->active.in_mutex); 376 mutex_init(&map->active.out_mutex); 377 378 return 0; 379 380 out_error: 381 if (*evtchn >= 0) 382 xenbus_free_evtchn(pvcalls_front_dev, *evtchn); 383 kfree(map->active.data.in); 384 kfree(map->active.ring); 385 return ret; 386 } 387 388 int pvcalls_front_connect(struct socket *sock, struct sockaddr *addr, 389 int addr_len, int flags) 390 { 391 struct pvcalls_bedata *bedata; 392 struct sock_mapping *map = NULL; 393 struct xen_pvcalls_request *req; 394 int notify, req_id, ret, evtchn; 395 396 if (addr->sa_family != AF_INET || sock->type != SOCK_STREAM) 397 return -EOPNOTSUPP; 398 399 map = pvcalls_enter_sock(sock); 400 if (IS_ERR(map)) 401 return PTR_ERR(map); 402 403 bedata = dev_get_drvdata(&pvcalls_front_dev->dev); 404 405 spin_lock(&bedata->socket_lock); 406 ret = get_request(bedata, &req_id); 407 if (ret < 0) { 408 spin_unlock(&bedata->socket_lock); 409 pvcalls_exit_sock(sock); 410 return ret; 411 } 412 ret = create_active(map, &evtchn); 413 if (ret < 0) { 414 spin_unlock(&bedata->socket_lock); 415 pvcalls_exit_sock(sock); 416 return ret; 417 } 418 419 req = RING_GET_REQUEST(&bedata->ring, req_id); 420 req->req_id = req_id; 421 req->cmd = PVCALLS_CONNECT; 422 req->u.connect.id = (uintptr_t)map; 423 req->u.connect.len = addr_len; 424 req->u.connect.flags = flags; 425 req->u.connect.ref = map->active.ref; 426 req->u.connect.evtchn = evtchn; 427 memcpy(req->u.connect.addr, addr, sizeof(*addr)); 428 429 map->sock = sock; 430 431 bedata->ring.req_prod_pvt++; 432 RING_PUSH_REQUESTS_AND_CHECK_NOTIFY(&bedata->ring, notify); 433 spin_unlock(&bedata->socket_lock); 434 435 if (notify) 436 notify_remote_via_irq(bedata->irq); 437 438 wait_event(bedata->inflight_req, 439 READ_ONCE(bedata->rsp[req_id].req_id) == req_id); 440 441 /* read req_id, then the content */ 442 smp_rmb(); 443 ret = bedata->rsp[req_id].ret; 444 bedata->rsp[req_id].req_id = PVCALLS_INVALID_ID; 445 pvcalls_exit_sock(sock); 446 return ret; 447 } 448 449 static int __write_ring(struct pvcalls_data_intf *intf, 450 struct pvcalls_data *data, 451 struct iov_iter *msg_iter, 452 int len) 453 { 454 RING_IDX cons, prod, size, masked_prod, masked_cons; 455 RING_IDX array_size = XEN_FLEX_RING_SIZE(PVCALLS_RING_ORDER); 456 int32_t error; 457 458 error = intf->out_error; 459 if (error < 0) 460 return error; 461 cons = intf->out_cons; 462 prod = intf->out_prod; 463 /* read indexes before continuing */ 464 virt_mb(); 465 466 size = pvcalls_queued(prod, cons, array_size); 467 if (size >= array_size) 468 return -EINVAL; 469 if (len > array_size - size) 470 len = array_size - size; 471 472 masked_prod = pvcalls_mask(prod, array_size); 473 masked_cons = pvcalls_mask(cons, array_size); 474 475 if (masked_prod < masked_cons) { 476 len = copy_from_iter(data->out + masked_prod, len, msg_iter); 477 } else { 478 if (len > array_size - masked_prod) { 479 int ret = copy_from_iter(data->out + masked_prod, 480 array_size - masked_prod, msg_iter); 481 if (ret != array_size - masked_prod) { 482 len = ret; 483 goto out; 484 } 485 len = ret + copy_from_iter(data->out, len - ret, msg_iter); 486 } else { 487 len = copy_from_iter(data->out + masked_prod, len, msg_iter); 488 } 489 } 490 out: 491 /* write to ring before updating pointer */ 492 virt_wmb(); 493 intf->out_prod += len; 494 495 return len; 496 } 497 498 int pvcalls_front_sendmsg(struct socket *sock, struct msghdr *msg, 499 size_t len) 500 { 501 struct pvcalls_bedata *bedata; 502 struct sock_mapping *map; 503 int sent, tot_sent = 0; 504 int count = 0, flags; 505 506 flags = msg->msg_flags; 507 if (flags & (MSG_CONFIRM|MSG_DONTROUTE|MSG_EOR|MSG_OOB)) 508 return -EOPNOTSUPP; 509 510 map = pvcalls_enter_sock(sock); 511 if (IS_ERR(map)) 512 return PTR_ERR(map); 513 bedata = dev_get_drvdata(&pvcalls_front_dev->dev); 514 515 mutex_lock(&map->active.out_mutex); 516 if ((flags & MSG_DONTWAIT) && !pvcalls_front_write_todo(map)) { 517 mutex_unlock(&map->active.out_mutex); 518 pvcalls_exit_sock(sock); 519 return -EAGAIN; 520 } 521 if (len > INT_MAX) 522 len = INT_MAX; 523 524 again: 525 count++; 526 sent = __write_ring(map->active.ring, 527 &map->active.data, &msg->msg_iter, 528 len); 529 if (sent > 0) { 530 len -= sent; 531 tot_sent += sent; 532 notify_remote_via_irq(map->active.irq); 533 } 534 if (sent >= 0 && len > 0 && count < PVCALLS_FRONT_MAX_SPIN) 535 goto again; 536 if (sent < 0) 537 tot_sent = sent; 538 539 mutex_unlock(&map->active.out_mutex); 540 pvcalls_exit_sock(sock); 541 return tot_sent; 542 } 543 544 static int __read_ring(struct pvcalls_data_intf *intf, 545 struct pvcalls_data *data, 546 struct iov_iter *msg_iter, 547 size_t len, int flags) 548 { 549 RING_IDX cons, prod, size, masked_prod, masked_cons; 550 RING_IDX array_size = XEN_FLEX_RING_SIZE(PVCALLS_RING_ORDER); 551 int32_t error; 552 553 cons = intf->in_cons; 554 prod = intf->in_prod; 555 error = intf->in_error; 556 /* get pointers before reading from the ring */ 557 virt_rmb(); 558 if (error < 0) 559 return error; 560 561 size = pvcalls_queued(prod, cons, array_size); 562 masked_prod = pvcalls_mask(prod, array_size); 563 masked_cons = pvcalls_mask(cons, array_size); 564 565 if (size == 0) 566 return 0; 567 568 if (len > size) 569 len = size; 570 571 if (masked_prod > masked_cons) { 572 len = copy_to_iter(data->in + masked_cons, len, msg_iter); 573 } else { 574 if (len > (array_size - masked_cons)) { 575 int ret = copy_to_iter(data->in + masked_cons, 576 array_size - masked_cons, msg_iter); 577 if (ret != array_size - masked_cons) { 578 len = ret; 579 goto out; 580 } 581 len = ret + copy_to_iter(data->in, len - ret, msg_iter); 582 } else { 583 len = copy_to_iter(data->in + masked_cons, len, msg_iter); 584 } 585 } 586 out: 587 /* read data from the ring before increasing the index */ 588 virt_mb(); 589 if (!(flags & MSG_PEEK)) 590 intf->in_cons += len; 591 592 return len; 593 } 594 595 int pvcalls_front_recvmsg(struct socket *sock, struct msghdr *msg, size_t len, 596 int flags) 597 { 598 struct pvcalls_bedata *bedata; 599 int ret; 600 struct sock_mapping *map; 601 602 if (flags & (MSG_CMSG_CLOEXEC|MSG_ERRQUEUE|MSG_OOB|MSG_TRUNC)) 603 return -EOPNOTSUPP; 604 605 map = pvcalls_enter_sock(sock); 606 if (IS_ERR(map)) 607 return PTR_ERR(map); 608 bedata = dev_get_drvdata(&pvcalls_front_dev->dev); 609 610 mutex_lock(&map->active.in_mutex); 611 if (len > XEN_FLEX_RING_SIZE(PVCALLS_RING_ORDER)) 612 len = XEN_FLEX_RING_SIZE(PVCALLS_RING_ORDER); 613 614 while (!(flags & MSG_DONTWAIT) && !pvcalls_front_read_todo(map)) { 615 wait_event_interruptible(map->active.inflight_conn_req, 616 pvcalls_front_read_todo(map)); 617 } 618 ret = __read_ring(map->active.ring, &map->active.data, 619 &msg->msg_iter, len, flags); 620 621 if (ret > 0) 622 notify_remote_via_irq(map->active.irq); 623 if (ret == 0) 624 ret = (flags & MSG_DONTWAIT) ? -EAGAIN : 0; 625 if (ret == -ENOTCONN) 626 ret = 0; 627 628 mutex_unlock(&map->active.in_mutex); 629 pvcalls_exit_sock(sock); 630 return ret; 631 } 632 633 int pvcalls_front_bind(struct socket *sock, struct sockaddr *addr, int addr_len) 634 { 635 struct pvcalls_bedata *bedata; 636 struct sock_mapping *map = NULL; 637 struct xen_pvcalls_request *req; 638 int notify, req_id, ret; 639 640 if (addr->sa_family != AF_INET || sock->type != SOCK_STREAM) 641 return -EOPNOTSUPP; 642 643 map = pvcalls_enter_sock(sock); 644 if (IS_ERR(map)) 645 return PTR_ERR(map); 646 bedata = dev_get_drvdata(&pvcalls_front_dev->dev); 647 648 spin_lock(&bedata->socket_lock); 649 ret = get_request(bedata, &req_id); 650 if (ret < 0) { 651 spin_unlock(&bedata->socket_lock); 652 pvcalls_exit_sock(sock); 653 return ret; 654 } 655 req = RING_GET_REQUEST(&bedata->ring, req_id); 656 req->req_id = req_id; 657 map->sock = sock; 658 req->cmd = PVCALLS_BIND; 659 req->u.bind.id = (uintptr_t)map; 660 memcpy(req->u.bind.addr, addr, sizeof(*addr)); 661 req->u.bind.len = addr_len; 662 663 init_waitqueue_head(&map->passive.inflight_accept_req); 664 665 map->active_socket = false; 666 667 bedata->ring.req_prod_pvt++; 668 RING_PUSH_REQUESTS_AND_CHECK_NOTIFY(&bedata->ring, notify); 669 spin_unlock(&bedata->socket_lock); 670 if (notify) 671 notify_remote_via_irq(bedata->irq); 672 673 wait_event(bedata->inflight_req, 674 READ_ONCE(bedata->rsp[req_id].req_id) == req_id); 675 676 /* read req_id, then the content */ 677 smp_rmb(); 678 ret = bedata->rsp[req_id].ret; 679 bedata->rsp[req_id].req_id = PVCALLS_INVALID_ID; 680 681 map->passive.status = PVCALLS_STATUS_BIND; 682 pvcalls_exit_sock(sock); 683 return 0; 684 } 685 686 int pvcalls_front_listen(struct socket *sock, int backlog) 687 { 688 struct pvcalls_bedata *bedata; 689 struct sock_mapping *map; 690 struct xen_pvcalls_request *req; 691 int notify, req_id, ret; 692 693 map = pvcalls_enter_sock(sock); 694 if (IS_ERR(map)) 695 return PTR_ERR(map); 696 bedata = dev_get_drvdata(&pvcalls_front_dev->dev); 697 698 if (map->passive.status != PVCALLS_STATUS_BIND) { 699 pvcalls_exit_sock(sock); 700 return -EOPNOTSUPP; 701 } 702 703 spin_lock(&bedata->socket_lock); 704 ret = get_request(bedata, &req_id); 705 if (ret < 0) { 706 spin_unlock(&bedata->socket_lock); 707 pvcalls_exit_sock(sock); 708 return ret; 709 } 710 req = RING_GET_REQUEST(&bedata->ring, req_id); 711 req->req_id = req_id; 712 req->cmd = PVCALLS_LISTEN; 713 req->u.listen.id = (uintptr_t) map; 714 req->u.listen.backlog = backlog; 715 716 bedata->ring.req_prod_pvt++; 717 RING_PUSH_REQUESTS_AND_CHECK_NOTIFY(&bedata->ring, notify); 718 spin_unlock(&bedata->socket_lock); 719 if (notify) 720 notify_remote_via_irq(bedata->irq); 721 722 wait_event(bedata->inflight_req, 723 READ_ONCE(bedata->rsp[req_id].req_id) == req_id); 724 725 /* read req_id, then the content */ 726 smp_rmb(); 727 ret = bedata->rsp[req_id].ret; 728 bedata->rsp[req_id].req_id = PVCALLS_INVALID_ID; 729 730 map->passive.status = PVCALLS_STATUS_LISTEN; 731 pvcalls_exit_sock(sock); 732 return ret; 733 } 734 735 int pvcalls_front_accept(struct socket *sock, struct socket *newsock, int flags) 736 { 737 struct pvcalls_bedata *bedata; 738 struct sock_mapping *map; 739 struct sock_mapping *map2 = NULL; 740 struct xen_pvcalls_request *req; 741 int notify, req_id, ret, evtchn, nonblock; 742 743 map = pvcalls_enter_sock(sock); 744 if (IS_ERR(map)) 745 return PTR_ERR(map); 746 bedata = dev_get_drvdata(&pvcalls_front_dev->dev); 747 748 if (map->passive.status != PVCALLS_STATUS_LISTEN) { 749 pvcalls_exit_sock(sock); 750 return -EINVAL; 751 } 752 753 nonblock = flags & SOCK_NONBLOCK; 754 /* 755 * Backend only supports 1 inflight accept request, will return 756 * errors for the others 757 */ 758 if (test_and_set_bit(PVCALLS_FLAG_ACCEPT_INFLIGHT, 759 (void *)&map->passive.flags)) { 760 req_id = READ_ONCE(map->passive.inflight_req_id); 761 if (req_id != PVCALLS_INVALID_ID && 762 READ_ONCE(bedata->rsp[req_id].req_id) == req_id) { 763 map2 = map->passive.accept_map; 764 goto received; 765 } 766 if (nonblock) { 767 pvcalls_exit_sock(sock); 768 return -EAGAIN; 769 } 770 if (wait_event_interruptible(map->passive.inflight_accept_req, 771 !test_and_set_bit(PVCALLS_FLAG_ACCEPT_INFLIGHT, 772 (void *)&map->passive.flags))) { 773 pvcalls_exit_sock(sock); 774 return -EINTR; 775 } 776 } 777 778 spin_lock(&bedata->socket_lock); 779 ret = get_request(bedata, &req_id); 780 if (ret < 0) { 781 clear_bit(PVCALLS_FLAG_ACCEPT_INFLIGHT, 782 (void *)&map->passive.flags); 783 spin_unlock(&bedata->socket_lock); 784 pvcalls_exit_sock(sock); 785 return ret; 786 } 787 map2 = kzalloc(sizeof(*map2), GFP_ATOMIC); 788 if (map2 == NULL) { 789 clear_bit(PVCALLS_FLAG_ACCEPT_INFLIGHT, 790 (void *)&map->passive.flags); 791 spin_unlock(&bedata->socket_lock); 792 pvcalls_exit_sock(sock); 793 return -ENOMEM; 794 } 795 ret = create_active(map2, &evtchn); 796 if (ret < 0) { 797 kfree(map2); 798 clear_bit(PVCALLS_FLAG_ACCEPT_INFLIGHT, 799 (void *)&map->passive.flags); 800 spin_unlock(&bedata->socket_lock); 801 pvcalls_exit_sock(sock); 802 return ret; 803 } 804 list_add_tail(&map2->list, &bedata->socket_mappings); 805 806 req = RING_GET_REQUEST(&bedata->ring, req_id); 807 req->req_id = req_id; 808 req->cmd = PVCALLS_ACCEPT; 809 req->u.accept.id = (uintptr_t) map; 810 req->u.accept.ref = map2->active.ref; 811 req->u.accept.id_new = (uintptr_t) map2; 812 req->u.accept.evtchn = evtchn; 813 map->passive.accept_map = map2; 814 815 bedata->ring.req_prod_pvt++; 816 RING_PUSH_REQUESTS_AND_CHECK_NOTIFY(&bedata->ring, notify); 817 spin_unlock(&bedata->socket_lock); 818 if (notify) 819 notify_remote_via_irq(bedata->irq); 820 /* We could check if we have received a response before returning. */ 821 if (nonblock) { 822 WRITE_ONCE(map->passive.inflight_req_id, req_id); 823 pvcalls_exit_sock(sock); 824 return -EAGAIN; 825 } 826 827 if (wait_event_interruptible(bedata->inflight_req, 828 READ_ONCE(bedata->rsp[req_id].req_id) == req_id)) { 829 pvcalls_exit_sock(sock); 830 return -EINTR; 831 } 832 /* read req_id, then the content */ 833 smp_rmb(); 834 835 received: 836 map2->sock = newsock; 837 newsock->sk = kzalloc(sizeof(*newsock->sk), GFP_KERNEL); 838 if (!newsock->sk) { 839 bedata->rsp[req_id].req_id = PVCALLS_INVALID_ID; 840 map->passive.inflight_req_id = PVCALLS_INVALID_ID; 841 clear_bit(PVCALLS_FLAG_ACCEPT_INFLIGHT, 842 (void *)&map->passive.flags); 843 pvcalls_front_free_map(bedata, map2); 844 pvcalls_exit_sock(sock); 845 return -ENOMEM; 846 } 847 newsock->sk->sk_send_head = (void *)map2; 848 849 ret = bedata->rsp[req_id].ret; 850 bedata->rsp[req_id].req_id = PVCALLS_INVALID_ID; 851 map->passive.inflight_req_id = PVCALLS_INVALID_ID; 852 853 clear_bit(PVCALLS_FLAG_ACCEPT_INFLIGHT, (void *)&map->passive.flags); 854 wake_up(&map->passive.inflight_accept_req); 855 856 pvcalls_exit_sock(sock); 857 return ret; 858 } 859 860 static __poll_t pvcalls_front_poll_passive(struct file *file, 861 struct pvcalls_bedata *bedata, 862 struct sock_mapping *map, 863 poll_table *wait) 864 { 865 int notify, req_id, ret; 866 struct xen_pvcalls_request *req; 867 868 if (test_bit(PVCALLS_FLAG_ACCEPT_INFLIGHT, 869 (void *)&map->passive.flags)) { 870 uint32_t req_id = READ_ONCE(map->passive.inflight_req_id); 871 872 if (req_id != PVCALLS_INVALID_ID && 873 READ_ONCE(bedata->rsp[req_id].req_id) == req_id) 874 return EPOLLIN | EPOLLRDNORM; 875 876 poll_wait(file, &map->passive.inflight_accept_req, wait); 877 return 0; 878 } 879 880 if (test_and_clear_bit(PVCALLS_FLAG_POLL_RET, 881 (void *)&map->passive.flags)) 882 return EPOLLIN | EPOLLRDNORM; 883 884 /* 885 * First check RET, then INFLIGHT. No barriers necessary to 886 * ensure execution ordering because of the conditional 887 * instructions creating control dependencies. 888 */ 889 890 if (test_and_set_bit(PVCALLS_FLAG_POLL_INFLIGHT, 891 (void *)&map->passive.flags)) { 892 poll_wait(file, &bedata->inflight_req, wait); 893 return 0; 894 } 895 896 spin_lock(&bedata->socket_lock); 897 ret = get_request(bedata, &req_id); 898 if (ret < 0) { 899 spin_unlock(&bedata->socket_lock); 900 return ret; 901 } 902 req = RING_GET_REQUEST(&bedata->ring, req_id); 903 req->req_id = req_id; 904 req->cmd = PVCALLS_POLL; 905 req->u.poll.id = (uintptr_t) map; 906 907 bedata->ring.req_prod_pvt++; 908 RING_PUSH_REQUESTS_AND_CHECK_NOTIFY(&bedata->ring, notify); 909 spin_unlock(&bedata->socket_lock); 910 if (notify) 911 notify_remote_via_irq(bedata->irq); 912 913 poll_wait(file, &bedata->inflight_req, wait); 914 return 0; 915 } 916 917 static __poll_t pvcalls_front_poll_active(struct file *file, 918 struct pvcalls_bedata *bedata, 919 struct sock_mapping *map, 920 poll_table *wait) 921 { 922 __poll_t mask = 0; 923 int32_t in_error, out_error; 924 struct pvcalls_data_intf *intf = map->active.ring; 925 926 out_error = intf->out_error; 927 in_error = intf->in_error; 928 929 poll_wait(file, &map->active.inflight_conn_req, wait); 930 if (pvcalls_front_write_todo(map)) 931 mask |= EPOLLOUT | EPOLLWRNORM; 932 if (pvcalls_front_read_todo(map)) 933 mask |= EPOLLIN | EPOLLRDNORM; 934 if (in_error != 0 || out_error != 0) 935 mask |= EPOLLERR; 936 937 return mask; 938 } 939 940 __poll_t pvcalls_front_poll(struct file *file, struct socket *sock, 941 poll_table *wait) 942 { 943 struct pvcalls_bedata *bedata; 944 struct sock_mapping *map; 945 __poll_t ret; 946 947 map = pvcalls_enter_sock(sock); 948 if (IS_ERR(map)) 949 return EPOLLNVAL; 950 bedata = dev_get_drvdata(&pvcalls_front_dev->dev); 951 952 if (map->active_socket) 953 ret = pvcalls_front_poll_active(file, bedata, map, wait); 954 else 955 ret = pvcalls_front_poll_passive(file, bedata, map, wait); 956 pvcalls_exit_sock(sock); 957 return ret; 958 } 959 960 int pvcalls_front_release(struct socket *sock) 961 { 962 struct pvcalls_bedata *bedata; 963 struct sock_mapping *map; 964 int req_id, notify, ret; 965 struct xen_pvcalls_request *req; 966 967 if (sock->sk == NULL) 968 return 0; 969 970 map = pvcalls_enter_sock(sock); 971 if (IS_ERR(map)) { 972 if (PTR_ERR(map) == -ENOTCONN) 973 return -EIO; 974 else 975 return 0; 976 } 977 bedata = dev_get_drvdata(&pvcalls_front_dev->dev); 978 979 spin_lock(&bedata->socket_lock); 980 ret = get_request(bedata, &req_id); 981 if (ret < 0) { 982 spin_unlock(&bedata->socket_lock); 983 pvcalls_exit_sock(sock); 984 return ret; 985 } 986 sock->sk->sk_send_head = NULL; 987 988 req = RING_GET_REQUEST(&bedata->ring, req_id); 989 req->req_id = req_id; 990 req->cmd = PVCALLS_RELEASE; 991 req->u.release.id = (uintptr_t)map; 992 993 bedata->ring.req_prod_pvt++; 994 RING_PUSH_REQUESTS_AND_CHECK_NOTIFY(&bedata->ring, notify); 995 spin_unlock(&bedata->socket_lock); 996 if (notify) 997 notify_remote_via_irq(bedata->irq); 998 999 wait_event(bedata->inflight_req, 1000 READ_ONCE(bedata->rsp[req_id].req_id) == req_id); 1001 1002 if (map->active_socket) { 1003 /* 1004 * Set in_error and wake up inflight_conn_req to force 1005 * recvmsg waiters to exit. 1006 */ 1007 map->active.ring->in_error = -EBADF; 1008 wake_up_interruptible(&map->active.inflight_conn_req); 1009 1010 /* 1011 * We need to make sure that sendmsg/recvmsg on this socket have 1012 * not started before we've cleared sk_send_head here. The 1013 * easiest way to guarantee this is to see that no pvcalls 1014 * (other than us) is in progress on this socket. 1015 */ 1016 while (atomic_read(&map->refcount) > 1) 1017 cpu_relax(); 1018 1019 pvcalls_front_free_map(bedata, map); 1020 } else { 1021 wake_up(&bedata->inflight_req); 1022 wake_up(&map->passive.inflight_accept_req); 1023 1024 while (atomic_read(&map->refcount) > 1) 1025 cpu_relax(); 1026 1027 spin_lock(&bedata->socket_lock); 1028 list_del(&map->list); 1029 spin_unlock(&bedata->socket_lock); 1030 if (READ_ONCE(map->passive.inflight_req_id) != 1031 PVCALLS_INVALID_ID) { 1032 pvcalls_front_free_map(bedata, 1033 map->passive.accept_map); 1034 } 1035 kfree(map); 1036 } 1037 WRITE_ONCE(bedata->rsp[req_id].req_id, PVCALLS_INVALID_ID); 1038 1039 pvcalls_exit(); 1040 return 0; 1041 } 1042 1043 static const struct xenbus_device_id pvcalls_front_ids[] = { 1044 { "pvcalls" }, 1045 { "" } 1046 }; 1047 1048 static int pvcalls_front_remove(struct xenbus_device *dev) 1049 { 1050 struct pvcalls_bedata *bedata; 1051 struct sock_mapping *map = NULL, *n; 1052 1053 bedata = dev_get_drvdata(&pvcalls_front_dev->dev); 1054 dev_set_drvdata(&dev->dev, NULL); 1055 pvcalls_front_dev = NULL; 1056 if (bedata->irq >= 0) 1057 unbind_from_irqhandler(bedata->irq, dev); 1058 1059 list_for_each_entry_safe(map, n, &bedata->socket_mappings, list) { 1060 map->sock->sk->sk_send_head = NULL; 1061 if (map->active_socket) { 1062 map->active.ring->in_error = -EBADF; 1063 wake_up_interruptible(&map->active.inflight_conn_req); 1064 } 1065 } 1066 1067 smp_mb(); 1068 while (atomic_read(&pvcalls_refcount) > 0) 1069 cpu_relax(); 1070 list_for_each_entry_safe(map, n, &bedata->socket_mappings, list) { 1071 if (map->active_socket) { 1072 /* No need to lock, refcount is 0 */ 1073 pvcalls_front_free_map(bedata, map); 1074 } else { 1075 list_del(&map->list); 1076 kfree(map); 1077 } 1078 } 1079 if (bedata->ref != -1) 1080 gnttab_end_foreign_access(bedata->ref, 0, 0); 1081 kfree(bedata->ring.sring); 1082 kfree(bedata); 1083 xenbus_switch_state(dev, XenbusStateClosed); 1084 return 0; 1085 } 1086 1087 static int pvcalls_front_probe(struct xenbus_device *dev, 1088 const struct xenbus_device_id *id) 1089 { 1090 int ret = -ENOMEM, evtchn, i; 1091 unsigned int max_page_order, function_calls, len; 1092 char *versions; 1093 grant_ref_t gref_head = 0; 1094 struct xenbus_transaction xbt; 1095 struct pvcalls_bedata *bedata = NULL; 1096 struct xen_pvcalls_sring *sring; 1097 1098 if (pvcalls_front_dev != NULL) { 1099 dev_err(&dev->dev, "only one PV Calls connection supported\n"); 1100 return -EINVAL; 1101 } 1102 1103 versions = xenbus_read(XBT_NIL, dev->otherend, "versions", &len); 1104 if (IS_ERR(versions)) 1105 return PTR_ERR(versions); 1106 if (!len) 1107 return -EINVAL; 1108 if (strcmp(versions, "1")) { 1109 kfree(versions); 1110 return -EINVAL; 1111 } 1112 kfree(versions); 1113 max_page_order = xenbus_read_unsigned(dev->otherend, 1114 "max-page-order", 0); 1115 if (max_page_order < PVCALLS_RING_ORDER) 1116 return -ENODEV; 1117 function_calls = xenbus_read_unsigned(dev->otherend, 1118 "function-calls", 0); 1119 /* See XENBUS_FUNCTIONS_CALLS in pvcalls.h */ 1120 if (function_calls != 1) 1121 return -ENODEV; 1122 pr_info("%s max-page-order is %u\n", __func__, max_page_order); 1123 1124 bedata = kzalloc(sizeof(struct pvcalls_bedata), GFP_KERNEL); 1125 if (!bedata) 1126 return -ENOMEM; 1127 1128 dev_set_drvdata(&dev->dev, bedata); 1129 pvcalls_front_dev = dev; 1130 init_waitqueue_head(&bedata->inflight_req); 1131 INIT_LIST_HEAD(&bedata->socket_mappings); 1132 spin_lock_init(&bedata->socket_lock); 1133 bedata->irq = -1; 1134 bedata->ref = -1; 1135 1136 for (i = 0; i < PVCALLS_NR_RSP_PER_RING; i++) 1137 bedata->rsp[i].req_id = PVCALLS_INVALID_ID; 1138 1139 sring = (struct xen_pvcalls_sring *) __get_free_page(GFP_KERNEL | 1140 __GFP_ZERO); 1141 if (!sring) 1142 goto error; 1143 SHARED_RING_INIT(sring); 1144 FRONT_RING_INIT(&bedata->ring, sring, XEN_PAGE_SIZE); 1145 1146 ret = xenbus_alloc_evtchn(dev, &evtchn); 1147 if (ret) 1148 goto error; 1149 1150 bedata->irq = bind_evtchn_to_irqhandler(evtchn, 1151 pvcalls_front_event_handler, 1152 0, "pvcalls-frontend", dev); 1153 if (bedata->irq < 0) { 1154 ret = bedata->irq; 1155 goto error; 1156 } 1157 1158 ret = gnttab_alloc_grant_references(1, &gref_head); 1159 if (ret < 0) 1160 goto error; 1161 ret = gnttab_claim_grant_reference(&gref_head); 1162 if (ret < 0) 1163 goto error; 1164 bedata->ref = ret; 1165 gnttab_grant_foreign_access_ref(bedata->ref, dev->otherend_id, 1166 virt_to_gfn((void *)sring), 0); 1167 1168 again: 1169 ret = xenbus_transaction_start(&xbt); 1170 if (ret) { 1171 xenbus_dev_fatal(dev, ret, "starting transaction"); 1172 goto error; 1173 } 1174 ret = xenbus_printf(xbt, dev->nodename, "version", "%u", 1); 1175 if (ret) 1176 goto error_xenbus; 1177 ret = xenbus_printf(xbt, dev->nodename, "ring-ref", "%d", bedata->ref); 1178 if (ret) 1179 goto error_xenbus; 1180 ret = xenbus_printf(xbt, dev->nodename, "port", "%u", 1181 evtchn); 1182 if (ret) 1183 goto error_xenbus; 1184 ret = xenbus_transaction_end(xbt, 0); 1185 if (ret) { 1186 if (ret == -EAGAIN) 1187 goto again; 1188 xenbus_dev_fatal(dev, ret, "completing transaction"); 1189 goto error; 1190 } 1191 xenbus_switch_state(dev, XenbusStateInitialised); 1192 1193 return 0; 1194 1195 error_xenbus: 1196 xenbus_transaction_end(xbt, 1); 1197 xenbus_dev_fatal(dev, ret, "writing xenstore"); 1198 error: 1199 pvcalls_front_remove(dev); 1200 return ret; 1201 } 1202 1203 static void pvcalls_front_changed(struct xenbus_device *dev, 1204 enum xenbus_state backend_state) 1205 { 1206 switch (backend_state) { 1207 case XenbusStateReconfiguring: 1208 case XenbusStateReconfigured: 1209 case XenbusStateInitialising: 1210 case XenbusStateInitialised: 1211 case XenbusStateUnknown: 1212 break; 1213 1214 case XenbusStateInitWait: 1215 break; 1216 1217 case XenbusStateConnected: 1218 xenbus_switch_state(dev, XenbusStateConnected); 1219 break; 1220 1221 case XenbusStateClosed: 1222 if (dev->state == XenbusStateClosed) 1223 break; 1224 /* Missed the backend's CLOSING state */ 1225 /* fall through */ 1226 case XenbusStateClosing: 1227 xenbus_frontend_closed(dev); 1228 break; 1229 } 1230 } 1231 1232 static struct xenbus_driver pvcalls_front_driver = { 1233 .ids = pvcalls_front_ids, 1234 .probe = pvcalls_front_probe, 1235 .remove = pvcalls_front_remove, 1236 .otherend_changed = pvcalls_front_changed, 1237 }; 1238 1239 static int __init pvcalls_frontend_init(void) 1240 { 1241 if (!xen_domain()) 1242 return -ENODEV; 1243 1244 pr_info("Initialising Xen pvcalls frontend driver\n"); 1245 1246 return xenbus_register_frontend(&pvcalls_front_driver); 1247 } 1248 1249 module_init(pvcalls_frontend_init); 1250 1251 MODULE_DESCRIPTION("Xen PV Calls frontend driver"); 1252 MODULE_AUTHOR("Stefano Stabellini <sstabellini@kernel.org>"); 1253 MODULE_LICENSE("GPL"); 1254