1 /* 2 * (c) 2017 Stefano Stabellini <stefano@aporeto.com> 3 * 4 * This program is free software; you can redistribute it and/or modify 5 * it under the terms of the GNU General Public License as published by 6 * the Free Software Foundation; either version 2 of the License, or 7 * (at your option) any later version. 8 * 9 * This program is distributed in the hope that it will be useful, 10 * but WITHOUT ANY WARRANTY; without even the implied warranty of 11 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 12 * GNU General Public License for more details. 13 */ 14 15 #include <linux/module.h> 16 #include <linux/net.h> 17 #include <linux/socket.h> 18 19 #include <net/sock.h> 20 21 #include <xen/events.h> 22 #include <xen/grant_table.h> 23 #include <xen/xen.h> 24 #include <xen/xenbus.h> 25 #include <xen/interface/io/pvcalls.h> 26 27 #include "pvcalls-front.h" 28 29 #define PVCALLS_INVALID_ID UINT_MAX 30 #define PVCALLS_RING_ORDER XENBUS_MAX_RING_GRANT_ORDER 31 #define PVCALLS_NR_RSP_PER_RING __CONST_RING_SIZE(xen_pvcalls, XEN_PAGE_SIZE) 32 #define PVCALLS_FRONT_MAX_SPIN 5000 33 34 struct pvcalls_bedata { 35 struct xen_pvcalls_front_ring ring; 36 grant_ref_t ref; 37 int irq; 38 39 struct list_head socket_mappings; 40 spinlock_t socket_lock; 41 42 wait_queue_head_t inflight_req; 43 struct xen_pvcalls_response rsp[PVCALLS_NR_RSP_PER_RING]; 44 }; 45 /* Only one front/back connection supported. */ 46 static struct xenbus_device *pvcalls_front_dev; 47 static atomic_t pvcalls_refcount; 48 49 /* first increment refcount, then proceed */ 50 #define pvcalls_enter() { \ 51 atomic_inc(&pvcalls_refcount); \ 52 } 53 54 /* first complete other operations, then decrement refcount */ 55 #define pvcalls_exit() { \ 56 atomic_dec(&pvcalls_refcount); \ 57 } 58 59 struct sock_mapping { 60 bool active_socket; 61 struct list_head list; 62 struct socket *sock; 63 atomic_t refcount; 64 union { 65 struct { 66 int irq; 67 grant_ref_t ref; 68 struct pvcalls_data_intf *ring; 69 struct pvcalls_data data; 70 struct mutex in_mutex; 71 struct mutex out_mutex; 72 73 wait_queue_head_t inflight_conn_req; 74 } active; 75 struct { 76 /* 77 * Socket status, needs to be 64-bit aligned due to the 78 * test_and_* functions which have this requirement on arm64. 79 */ 80 #define PVCALLS_STATUS_UNINITALIZED 0 81 #define PVCALLS_STATUS_BIND 1 82 #define PVCALLS_STATUS_LISTEN 2 83 uint8_t status __attribute__((aligned(8))); 84 /* 85 * Internal state-machine flags. 86 * Only one accept operation can be inflight for a socket. 87 * Only one poll operation can be inflight for a given socket. 88 * flags needs to be 64-bit aligned due to the test_and_* 89 * functions which have this requirement on arm64. 90 */ 91 #define PVCALLS_FLAG_ACCEPT_INFLIGHT 0 92 #define PVCALLS_FLAG_POLL_INFLIGHT 1 93 #define PVCALLS_FLAG_POLL_RET 2 94 uint8_t flags __attribute__((aligned(8))); 95 uint32_t inflight_req_id; 96 struct sock_mapping *accept_map; 97 wait_queue_head_t inflight_accept_req; 98 } passive; 99 }; 100 }; 101 102 static inline struct sock_mapping *pvcalls_enter_sock(struct socket *sock) 103 { 104 struct sock_mapping *map; 105 106 if (!pvcalls_front_dev || 107 dev_get_drvdata(&pvcalls_front_dev->dev) == NULL) 108 return ERR_PTR(-ENOTCONN); 109 110 map = (struct sock_mapping *)sock->sk->sk_send_head; 111 if (map == NULL) 112 return ERR_PTR(-ENOTSOCK); 113 114 pvcalls_enter(); 115 atomic_inc(&map->refcount); 116 return map; 117 } 118 119 static inline void pvcalls_exit_sock(struct socket *sock) 120 { 121 struct sock_mapping *map; 122 123 map = (struct sock_mapping *)sock->sk->sk_send_head; 124 atomic_dec(&map->refcount); 125 pvcalls_exit(); 126 } 127 128 static inline int get_request(struct pvcalls_bedata *bedata, int *req_id) 129 { 130 *req_id = bedata->ring.req_prod_pvt & (RING_SIZE(&bedata->ring) - 1); 131 if (RING_FULL(&bedata->ring) || 132 bedata->rsp[*req_id].req_id != PVCALLS_INVALID_ID) 133 return -EAGAIN; 134 return 0; 135 } 136 137 static bool pvcalls_front_write_todo(struct sock_mapping *map) 138 { 139 struct pvcalls_data_intf *intf = map->active.ring; 140 RING_IDX cons, prod, size = XEN_FLEX_RING_SIZE(PVCALLS_RING_ORDER); 141 int32_t error; 142 143 error = intf->out_error; 144 if (error == -ENOTCONN) 145 return false; 146 if (error != 0) 147 return true; 148 149 cons = intf->out_cons; 150 prod = intf->out_prod; 151 return !!(size - pvcalls_queued(prod, cons, size)); 152 } 153 154 static bool pvcalls_front_read_todo(struct sock_mapping *map) 155 { 156 struct pvcalls_data_intf *intf = map->active.ring; 157 RING_IDX cons, prod; 158 int32_t error; 159 160 cons = intf->in_cons; 161 prod = intf->in_prod; 162 error = intf->in_error; 163 return (error != 0 || 164 pvcalls_queued(prod, cons, 165 XEN_FLEX_RING_SIZE(PVCALLS_RING_ORDER)) != 0); 166 } 167 168 static irqreturn_t pvcalls_front_event_handler(int irq, void *dev_id) 169 { 170 struct xenbus_device *dev = dev_id; 171 struct pvcalls_bedata *bedata; 172 struct xen_pvcalls_response *rsp; 173 uint8_t *src, *dst; 174 int req_id = 0, more = 0, done = 0; 175 176 if (dev == NULL) 177 return IRQ_HANDLED; 178 179 pvcalls_enter(); 180 bedata = dev_get_drvdata(&dev->dev); 181 if (bedata == NULL) { 182 pvcalls_exit(); 183 return IRQ_HANDLED; 184 } 185 186 again: 187 while (RING_HAS_UNCONSUMED_RESPONSES(&bedata->ring)) { 188 rsp = RING_GET_RESPONSE(&bedata->ring, bedata->ring.rsp_cons); 189 190 req_id = rsp->req_id; 191 if (rsp->cmd == PVCALLS_POLL) { 192 struct sock_mapping *map = (struct sock_mapping *)(uintptr_t) 193 rsp->u.poll.id; 194 195 clear_bit(PVCALLS_FLAG_POLL_INFLIGHT, 196 (void *)&map->passive.flags); 197 /* 198 * clear INFLIGHT, then set RET. It pairs with 199 * the checks at the beginning of 200 * pvcalls_front_poll_passive. 201 */ 202 smp_wmb(); 203 set_bit(PVCALLS_FLAG_POLL_RET, 204 (void *)&map->passive.flags); 205 } else { 206 dst = (uint8_t *)&bedata->rsp[req_id] + 207 sizeof(rsp->req_id); 208 src = (uint8_t *)rsp + sizeof(rsp->req_id); 209 memcpy(dst, src, sizeof(*rsp) - sizeof(rsp->req_id)); 210 /* 211 * First copy the rest of the data, then req_id. It is 212 * paired with the barrier when accessing bedata->rsp. 213 */ 214 smp_wmb(); 215 bedata->rsp[req_id].req_id = req_id; 216 } 217 218 done = 1; 219 bedata->ring.rsp_cons++; 220 } 221 222 RING_FINAL_CHECK_FOR_RESPONSES(&bedata->ring, more); 223 if (more) 224 goto again; 225 if (done) 226 wake_up(&bedata->inflight_req); 227 pvcalls_exit(); 228 return IRQ_HANDLED; 229 } 230 231 static void pvcalls_front_free_map(struct pvcalls_bedata *bedata, 232 struct sock_mapping *map) 233 { 234 int i; 235 236 unbind_from_irqhandler(map->active.irq, map); 237 238 spin_lock(&bedata->socket_lock); 239 if (!list_empty(&map->list)) 240 list_del_init(&map->list); 241 spin_unlock(&bedata->socket_lock); 242 243 for (i = 0; i < (1 << PVCALLS_RING_ORDER); i++) 244 gnttab_end_foreign_access(map->active.ring->ref[i], 0, 0); 245 gnttab_end_foreign_access(map->active.ref, 0, 0); 246 free_page((unsigned long)map->active.ring); 247 248 kfree(map); 249 } 250 251 static irqreturn_t pvcalls_front_conn_handler(int irq, void *sock_map) 252 { 253 struct sock_mapping *map = sock_map; 254 255 if (map == NULL) 256 return IRQ_HANDLED; 257 258 wake_up_interruptible(&map->active.inflight_conn_req); 259 260 return IRQ_HANDLED; 261 } 262 263 int pvcalls_front_socket(struct socket *sock) 264 { 265 struct pvcalls_bedata *bedata; 266 struct sock_mapping *map = NULL; 267 struct xen_pvcalls_request *req; 268 int notify, req_id, ret; 269 270 /* 271 * PVCalls only supports domain AF_INET, 272 * type SOCK_STREAM and protocol 0 sockets for now. 273 * 274 * Check socket type here, AF_INET and protocol checks are done 275 * by the caller. 276 */ 277 if (sock->type != SOCK_STREAM) 278 return -EOPNOTSUPP; 279 280 pvcalls_enter(); 281 if (!pvcalls_front_dev) { 282 pvcalls_exit(); 283 return -EACCES; 284 } 285 bedata = dev_get_drvdata(&pvcalls_front_dev->dev); 286 287 map = kzalloc(sizeof(*map), GFP_KERNEL); 288 if (map == NULL) { 289 pvcalls_exit(); 290 return -ENOMEM; 291 } 292 293 spin_lock(&bedata->socket_lock); 294 295 ret = get_request(bedata, &req_id); 296 if (ret < 0) { 297 kfree(map); 298 spin_unlock(&bedata->socket_lock); 299 pvcalls_exit(); 300 return ret; 301 } 302 303 /* 304 * sock->sk->sk_send_head is not used for ip sockets: reuse the 305 * field to store a pointer to the struct sock_mapping 306 * corresponding to the socket. This way, we can easily get the 307 * struct sock_mapping from the struct socket. 308 */ 309 sock->sk->sk_send_head = (void *)map; 310 list_add_tail(&map->list, &bedata->socket_mappings); 311 312 req = RING_GET_REQUEST(&bedata->ring, req_id); 313 req->req_id = req_id; 314 req->cmd = PVCALLS_SOCKET; 315 req->u.socket.id = (uintptr_t) map; 316 req->u.socket.domain = AF_INET; 317 req->u.socket.type = SOCK_STREAM; 318 req->u.socket.protocol = IPPROTO_IP; 319 320 bedata->ring.req_prod_pvt++; 321 RING_PUSH_REQUESTS_AND_CHECK_NOTIFY(&bedata->ring, notify); 322 spin_unlock(&bedata->socket_lock); 323 if (notify) 324 notify_remote_via_irq(bedata->irq); 325 326 wait_event(bedata->inflight_req, 327 READ_ONCE(bedata->rsp[req_id].req_id) == req_id); 328 329 /* read req_id, then the content */ 330 smp_rmb(); 331 ret = bedata->rsp[req_id].ret; 332 bedata->rsp[req_id].req_id = PVCALLS_INVALID_ID; 333 334 pvcalls_exit(); 335 return ret; 336 } 337 338 static int create_active(struct sock_mapping *map, int *evtchn) 339 { 340 void *bytes; 341 int ret = -ENOMEM, irq = -1, i; 342 343 *evtchn = -1; 344 init_waitqueue_head(&map->active.inflight_conn_req); 345 346 map->active.ring = (struct pvcalls_data_intf *) 347 __get_free_page(GFP_KERNEL | __GFP_ZERO); 348 if (map->active.ring == NULL) 349 goto out_error; 350 map->active.ring->ring_order = PVCALLS_RING_ORDER; 351 bytes = (void *)__get_free_pages(GFP_KERNEL | __GFP_ZERO, 352 PVCALLS_RING_ORDER); 353 if (bytes == NULL) 354 goto out_error; 355 for (i = 0; i < (1 << PVCALLS_RING_ORDER); i++) 356 map->active.ring->ref[i] = gnttab_grant_foreign_access( 357 pvcalls_front_dev->otherend_id, 358 pfn_to_gfn(virt_to_pfn(bytes) + i), 0); 359 360 map->active.ref = gnttab_grant_foreign_access( 361 pvcalls_front_dev->otherend_id, 362 pfn_to_gfn(virt_to_pfn((void *)map->active.ring)), 0); 363 364 map->active.data.in = bytes; 365 map->active.data.out = bytes + 366 XEN_FLEX_RING_SIZE(PVCALLS_RING_ORDER); 367 368 ret = xenbus_alloc_evtchn(pvcalls_front_dev, evtchn); 369 if (ret) 370 goto out_error; 371 irq = bind_evtchn_to_irqhandler(*evtchn, pvcalls_front_conn_handler, 372 0, "pvcalls-frontend", map); 373 if (irq < 0) { 374 ret = irq; 375 goto out_error; 376 } 377 378 map->active.irq = irq; 379 map->active_socket = true; 380 mutex_init(&map->active.in_mutex); 381 mutex_init(&map->active.out_mutex); 382 383 return 0; 384 385 out_error: 386 if (*evtchn >= 0) 387 xenbus_free_evtchn(pvcalls_front_dev, *evtchn); 388 free_pages((unsigned long)map->active.data.in, PVCALLS_RING_ORDER); 389 free_page((unsigned long)map->active.ring); 390 return ret; 391 } 392 393 int pvcalls_front_connect(struct socket *sock, struct sockaddr *addr, 394 int addr_len, int flags) 395 { 396 struct pvcalls_bedata *bedata; 397 struct sock_mapping *map = NULL; 398 struct xen_pvcalls_request *req; 399 int notify, req_id, ret, evtchn; 400 401 if (addr->sa_family != AF_INET || sock->type != SOCK_STREAM) 402 return -EOPNOTSUPP; 403 404 map = pvcalls_enter_sock(sock); 405 if (IS_ERR(map)) 406 return PTR_ERR(map); 407 408 bedata = dev_get_drvdata(&pvcalls_front_dev->dev); 409 410 spin_lock(&bedata->socket_lock); 411 ret = get_request(bedata, &req_id); 412 if (ret < 0) { 413 spin_unlock(&bedata->socket_lock); 414 pvcalls_exit_sock(sock); 415 return ret; 416 } 417 ret = create_active(map, &evtchn); 418 if (ret < 0) { 419 spin_unlock(&bedata->socket_lock); 420 pvcalls_exit_sock(sock); 421 return ret; 422 } 423 424 req = RING_GET_REQUEST(&bedata->ring, req_id); 425 req->req_id = req_id; 426 req->cmd = PVCALLS_CONNECT; 427 req->u.connect.id = (uintptr_t)map; 428 req->u.connect.len = addr_len; 429 req->u.connect.flags = flags; 430 req->u.connect.ref = map->active.ref; 431 req->u.connect.evtchn = evtchn; 432 memcpy(req->u.connect.addr, addr, sizeof(*addr)); 433 434 map->sock = sock; 435 436 bedata->ring.req_prod_pvt++; 437 RING_PUSH_REQUESTS_AND_CHECK_NOTIFY(&bedata->ring, notify); 438 spin_unlock(&bedata->socket_lock); 439 440 if (notify) 441 notify_remote_via_irq(bedata->irq); 442 443 wait_event(bedata->inflight_req, 444 READ_ONCE(bedata->rsp[req_id].req_id) == req_id); 445 446 /* read req_id, then the content */ 447 smp_rmb(); 448 ret = bedata->rsp[req_id].ret; 449 bedata->rsp[req_id].req_id = PVCALLS_INVALID_ID; 450 pvcalls_exit_sock(sock); 451 return ret; 452 } 453 454 static int __write_ring(struct pvcalls_data_intf *intf, 455 struct pvcalls_data *data, 456 struct iov_iter *msg_iter, 457 int len) 458 { 459 RING_IDX cons, prod, size, masked_prod, masked_cons; 460 RING_IDX array_size = XEN_FLEX_RING_SIZE(PVCALLS_RING_ORDER); 461 int32_t error; 462 463 error = intf->out_error; 464 if (error < 0) 465 return error; 466 cons = intf->out_cons; 467 prod = intf->out_prod; 468 /* read indexes before continuing */ 469 virt_mb(); 470 471 size = pvcalls_queued(prod, cons, array_size); 472 if (size >= array_size) 473 return -EINVAL; 474 if (len > array_size - size) 475 len = array_size - size; 476 477 masked_prod = pvcalls_mask(prod, array_size); 478 masked_cons = pvcalls_mask(cons, array_size); 479 480 if (masked_prod < masked_cons) { 481 len = copy_from_iter(data->out + masked_prod, len, msg_iter); 482 } else { 483 if (len > array_size - masked_prod) { 484 int ret = copy_from_iter(data->out + masked_prod, 485 array_size - masked_prod, msg_iter); 486 if (ret != array_size - masked_prod) { 487 len = ret; 488 goto out; 489 } 490 len = ret + copy_from_iter(data->out, len - ret, msg_iter); 491 } else { 492 len = copy_from_iter(data->out + masked_prod, len, msg_iter); 493 } 494 } 495 out: 496 /* write to ring before updating pointer */ 497 virt_wmb(); 498 intf->out_prod += len; 499 500 return len; 501 } 502 503 int pvcalls_front_sendmsg(struct socket *sock, struct msghdr *msg, 504 size_t len) 505 { 506 struct pvcalls_bedata *bedata; 507 struct sock_mapping *map; 508 int sent, tot_sent = 0; 509 int count = 0, flags; 510 511 flags = msg->msg_flags; 512 if (flags & (MSG_CONFIRM|MSG_DONTROUTE|MSG_EOR|MSG_OOB)) 513 return -EOPNOTSUPP; 514 515 map = pvcalls_enter_sock(sock); 516 if (IS_ERR(map)) 517 return PTR_ERR(map); 518 bedata = dev_get_drvdata(&pvcalls_front_dev->dev); 519 520 mutex_lock(&map->active.out_mutex); 521 if ((flags & MSG_DONTWAIT) && !pvcalls_front_write_todo(map)) { 522 mutex_unlock(&map->active.out_mutex); 523 pvcalls_exit_sock(sock); 524 return -EAGAIN; 525 } 526 if (len > INT_MAX) 527 len = INT_MAX; 528 529 again: 530 count++; 531 sent = __write_ring(map->active.ring, 532 &map->active.data, &msg->msg_iter, 533 len); 534 if (sent > 0) { 535 len -= sent; 536 tot_sent += sent; 537 notify_remote_via_irq(map->active.irq); 538 } 539 if (sent >= 0 && len > 0 && count < PVCALLS_FRONT_MAX_SPIN) 540 goto again; 541 if (sent < 0) 542 tot_sent = sent; 543 544 mutex_unlock(&map->active.out_mutex); 545 pvcalls_exit_sock(sock); 546 return tot_sent; 547 } 548 549 static int __read_ring(struct pvcalls_data_intf *intf, 550 struct pvcalls_data *data, 551 struct iov_iter *msg_iter, 552 size_t len, int flags) 553 { 554 RING_IDX cons, prod, size, masked_prod, masked_cons; 555 RING_IDX array_size = XEN_FLEX_RING_SIZE(PVCALLS_RING_ORDER); 556 int32_t error; 557 558 cons = intf->in_cons; 559 prod = intf->in_prod; 560 error = intf->in_error; 561 /* get pointers before reading from the ring */ 562 virt_rmb(); 563 if (error < 0) 564 return error; 565 566 size = pvcalls_queued(prod, cons, array_size); 567 masked_prod = pvcalls_mask(prod, array_size); 568 masked_cons = pvcalls_mask(cons, array_size); 569 570 if (size == 0) 571 return 0; 572 573 if (len > size) 574 len = size; 575 576 if (masked_prod > masked_cons) { 577 len = copy_to_iter(data->in + masked_cons, len, msg_iter); 578 } else { 579 if (len > (array_size - masked_cons)) { 580 int ret = copy_to_iter(data->in + masked_cons, 581 array_size - masked_cons, msg_iter); 582 if (ret != array_size - masked_cons) { 583 len = ret; 584 goto out; 585 } 586 len = ret + copy_to_iter(data->in, len - ret, msg_iter); 587 } else { 588 len = copy_to_iter(data->in + masked_cons, len, msg_iter); 589 } 590 } 591 out: 592 /* read data from the ring before increasing the index */ 593 virt_mb(); 594 if (!(flags & MSG_PEEK)) 595 intf->in_cons += len; 596 597 return len; 598 } 599 600 int pvcalls_front_recvmsg(struct socket *sock, struct msghdr *msg, size_t len, 601 int flags) 602 { 603 struct pvcalls_bedata *bedata; 604 int ret; 605 struct sock_mapping *map; 606 607 if (flags & (MSG_CMSG_CLOEXEC|MSG_ERRQUEUE|MSG_OOB|MSG_TRUNC)) 608 return -EOPNOTSUPP; 609 610 map = pvcalls_enter_sock(sock); 611 if (IS_ERR(map)) 612 return PTR_ERR(map); 613 bedata = dev_get_drvdata(&pvcalls_front_dev->dev); 614 615 mutex_lock(&map->active.in_mutex); 616 if (len > XEN_FLEX_RING_SIZE(PVCALLS_RING_ORDER)) 617 len = XEN_FLEX_RING_SIZE(PVCALLS_RING_ORDER); 618 619 while (!(flags & MSG_DONTWAIT) && !pvcalls_front_read_todo(map)) { 620 wait_event_interruptible(map->active.inflight_conn_req, 621 pvcalls_front_read_todo(map)); 622 } 623 ret = __read_ring(map->active.ring, &map->active.data, 624 &msg->msg_iter, len, flags); 625 626 if (ret > 0) 627 notify_remote_via_irq(map->active.irq); 628 if (ret == 0) 629 ret = (flags & MSG_DONTWAIT) ? -EAGAIN : 0; 630 if (ret == -ENOTCONN) 631 ret = 0; 632 633 mutex_unlock(&map->active.in_mutex); 634 pvcalls_exit_sock(sock); 635 return ret; 636 } 637 638 int pvcalls_front_bind(struct socket *sock, struct sockaddr *addr, int addr_len) 639 { 640 struct pvcalls_bedata *bedata; 641 struct sock_mapping *map = NULL; 642 struct xen_pvcalls_request *req; 643 int notify, req_id, ret; 644 645 if (addr->sa_family != AF_INET || sock->type != SOCK_STREAM) 646 return -EOPNOTSUPP; 647 648 map = pvcalls_enter_sock(sock); 649 if (IS_ERR(map)) 650 return PTR_ERR(map); 651 bedata = dev_get_drvdata(&pvcalls_front_dev->dev); 652 653 spin_lock(&bedata->socket_lock); 654 ret = get_request(bedata, &req_id); 655 if (ret < 0) { 656 spin_unlock(&bedata->socket_lock); 657 pvcalls_exit_sock(sock); 658 return ret; 659 } 660 req = RING_GET_REQUEST(&bedata->ring, req_id); 661 req->req_id = req_id; 662 map->sock = sock; 663 req->cmd = PVCALLS_BIND; 664 req->u.bind.id = (uintptr_t)map; 665 memcpy(req->u.bind.addr, addr, sizeof(*addr)); 666 req->u.bind.len = addr_len; 667 668 init_waitqueue_head(&map->passive.inflight_accept_req); 669 670 map->active_socket = false; 671 672 bedata->ring.req_prod_pvt++; 673 RING_PUSH_REQUESTS_AND_CHECK_NOTIFY(&bedata->ring, notify); 674 spin_unlock(&bedata->socket_lock); 675 if (notify) 676 notify_remote_via_irq(bedata->irq); 677 678 wait_event(bedata->inflight_req, 679 READ_ONCE(bedata->rsp[req_id].req_id) == req_id); 680 681 /* read req_id, then the content */ 682 smp_rmb(); 683 ret = bedata->rsp[req_id].ret; 684 bedata->rsp[req_id].req_id = PVCALLS_INVALID_ID; 685 686 map->passive.status = PVCALLS_STATUS_BIND; 687 pvcalls_exit_sock(sock); 688 return 0; 689 } 690 691 int pvcalls_front_listen(struct socket *sock, int backlog) 692 { 693 struct pvcalls_bedata *bedata; 694 struct sock_mapping *map; 695 struct xen_pvcalls_request *req; 696 int notify, req_id, ret; 697 698 map = pvcalls_enter_sock(sock); 699 if (IS_ERR(map)) 700 return PTR_ERR(map); 701 bedata = dev_get_drvdata(&pvcalls_front_dev->dev); 702 703 if (map->passive.status != PVCALLS_STATUS_BIND) { 704 pvcalls_exit_sock(sock); 705 return -EOPNOTSUPP; 706 } 707 708 spin_lock(&bedata->socket_lock); 709 ret = get_request(bedata, &req_id); 710 if (ret < 0) { 711 spin_unlock(&bedata->socket_lock); 712 pvcalls_exit_sock(sock); 713 return ret; 714 } 715 req = RING_GET_REQUEST(&bedata->ring, req_id); 716 req->req_id = req_id; 717 req->cmd = PVCALLS_LISTEN; 718 req->u.listen.id = (uintptr_t) map; 719 req->u.listen.backlog = backlog; 720 721 bedata->ring.req_prod_pvt++; 722 RING_PUSH_REQUESTS_AND_CHECK_NOTIFY(&bedata->ring, notify); 723 spin_unlock(&bedata->socket_lock); 724 if (notify) 725 notify_remote_via_irq(bedata->irq); 726 727 wait_event(bedata->inflight_req, 728 READ_ONCE(bedata->rsp[req_id].req_id) == req_id); 729 730 /* read req_id, then the content */ 731 smp_rmb(); 732 ret = bedata->rsp[req_id].ret; 733 bedata->rsp[req_id].req_id = PVCALLS_INVALID_ID; 734 735 map->passive.status = PVCALLS_STATUS_LISTEN; 736 pvcalls_exit_sock(sock); 737 return ret; 738 } 739 740 int pvcalls_front_accept(struct socket *sock, struct socket *newsock, int flags) 741 { 742 struct pvcalls_bedata *bedata; 743 struct sock_mapping *map; 744 struct sock_mapping *map2 = NULL; 745 struct xen_pvcalls_request *req; 746 int notify, req_id, ret, evtchn, nonblock; 747 748 map = pvcalls_enter_sock(sock); 749 if (IS_ERR(map)) 750 return PTR_ERR(map); 751 bedata = dev_get_drvdata(&pvcalls_front_dev->dev); 752 753 if (map->passive.status != PVCALLS_STATUS_LISTEN) { 754 pvcalls_exit_sock(sock); 755 return -EINVAL; 756 } 757 758 nonblock = flags & SOCK_NONBLOCK; 759 /* 760 * Backend only supports 1 inflight accept request, will return 761 * errors for the others 762 */ 763 if (test_and_set_bit(PVCALLS_FLAG_ACCEPT_INFLIGHT, 764 (void *)&map->passive.flags)) { 765 req_id = READ_ONCE(map->passive.inflight_req_id); 766 if (req_id != PVCALLS_INVALID_ID && 767 READ_ONCE(bedata->rsp[req_id].req_id) == req_id) { 768 map2 = map->passive.accept_map; 769 goto received; 770 } 771 if (nonblock) { 772 pvcalls_exit_sock(sock); 773 return -EAGAIN; 774 } 775 if (wait_event_interruptible(map->passive.inflight_accept_req, 776 !test_and_set_bit(PVCALLS_FLAG_ACCEPT_INFLIGHT, 777 (void *)&map->passive.flags))) { 778 pvcalls_exit_sock(sock); 779 return -EINTR; 780 } 781 } 782 783 spin_lock(&bedata->socket_lock); 784 ret = get_request(bedata, &req_id); 785 if (ret < 0) { 786 clear_bit(PVCALLS_FLAG_ACCEPT_INFLIGHT, 787 (void *)&map->passive.flags); 788 spin_unlock(&bedata->socket_lock); 789 pvcalls_exit_sock(sock); 790 return ret; 791 } 792 map2 = kzalloc(sizeof(*map2), GFP_ATOMIC); 793 if (map2 == NULL) { 794 clear_bit(PVCALLS_FLAG_ACCEPT_INFLIGHT, 795 (void *)&map->passive.flags); 796 spin_unlock(&bedata->socket_lock); 797 pvcalls_exit_sock(sock); 798 return -ENOMEM; 799 } 800 ret = create_active(map2, &evtchn); 801 if (ret < 0) { 802 kfree(map2); 803 clear_bit(PVCALLS_FLAG_ACCEPT_INFLIGHT, 804 (void *)&map->passive.flags); 805 spin_unlock(&bedata->socket_lock); 806 pvcalls_exit_sock(sock); 807 return ret; 808 } 809 list_add_tail(&map2->list, &bedata->socket_mappings); 810 811 req = RING_GET_REQUEST(&bedata->ring, req_id); 812 req->req_id = req_id; 813 req->cmd = PVCALLS_ACCEPT; 814 req->u.accept.id = (uintptr_t) map; 815 req->u.accept.ref = map2->active.ref; 816 req->u.accept.id_new = (uintptr_t) map2; 817 req->u.accept.evtchn = evtchn; 818 map->passive.accept_map = map2; 819 820 bedata->ring.req_prod_pvt++; 821 RING_PUSH_REQUESTS_AND_CHECK_NOTIFY(&bedata->ring, notify); 822 spin_unlock(&bedata->socket_lock); 823 if (notify) 824 notify_remote_via_irq(bedata->irq); 825 /* We could check if we have received a response before returning. */ 826 if (nonblock) { 827 WRITE_ONCE(map->passive.inflight_req_id, req_id); 828 pvcalls_exit_sock(sock); 829 return -EAGAIN; 830 } 831 832 if (wait_event_interruptible(bedata->inflight_req, 833 READ_ONCE(bedata->rsp[req_id].req_id) == req_id)) { 834 pvcalls_exit_sock(sock); 835 return -EINTR; 836 } 837 /* read req_id, then the content */ 838 smp_rmb(); 839 840 received: 841 map2->sock = newsock; 842 newsock->sk = kzalloc(sizeof(*newsock->sk), GFP_KERNEL); 843 if (!newsock->sk) { 844 bedata->rsp[req_id].req_id = PVCALLS_INVALID_ID; 845 map->passive.inflight_req_id = PVCALLS_INVALID_ID; 846 clear_bit(PVCALLS_FLAG_ACCEPT_INFLIGHT, 847 (void *)&map->passive.flags); 848 pvcalls_front_free_map(bedata, map2); 849 pvcalls_exit_sock(sock); 850 return -ENOMEM; 851 } 852 newsock->sk->sk_send_head = (void *)map2; 853 854 ret = bedata->rsp[req_id].ret; 855 bedata->rsp[req_id].req_id = PVCALLS_INVALID_ID; 856 map->passive.inflight_req_id = PVCALLS_INVALID_ID; 857 858 clear_bit(PVCALLS_FLAG_ACCEPT_INFLIGHT, (void *)&map->passive.flags); 859 wake_up(&map->passive.inflight_accept_req); 860 861 pvcalls_exit_sock(sock); 862 return ret; 863 } 864 865 static __poll_t pvcalls_front_poll_passive(struct file *file, 866 struct pvcalls_bedata *bedata, 867 struct sock_mapping *map, 868 poll_table *wait) 869 { 870 int notify, req_id, ret; 871 struct xen_pvcalls_request *req; 872 873 if (test_bit(PVCALLS_FLAG_ACCEPT_INFLIGHT, 874 (void *)&map->passive.flags)) { 875 uint32_t req_id = READ_ONCE(map->passive.inflight_req_id); 876 877 if (req_id != PVCALLS_INVALID_ID && 878 READ_ONCE(bedata->rsp[req_id].req_id) == req_id) 879 return EPOLLIN | EPOLLRDNORM; 880 881 poll_wait(file, &map->passive.inflight_accept_req, wait); 882 return 0; 883 } 884 885 if (test_and_clear_bit(PVCALLS_FLAG_POLL_RET, 886 (void *)&map->passive.flags)) 887 return EPOLLIN | EPOLLRDNORM; 888 889 /* 890 * First check RET, then INFLIGHT. No barriers necessary to 891 * ensure execution ordering because of the conditional 892 * instructions creating control dependencies. 893 */ 894 895 if (test_and_set_bit(PVCALLS_FLAG_POLL_INFLIGHT, 896 (void *)&map->passive.flags)) { 897 poll_wait(file, &bedata->inflight_req, wait); 898 return 0; 899 } 900 901 spin_lock(&bedata->socket_lock); 902 ret = get_request(bedata, &req_id); 903 if (ret < 0) { 904 spin_unlock(&bedata->socket_lock); 905 return ret; 906 } 907 req = RING_GET_REQUEST(&bedata->ring, req_id); 908 req->req_id = req_id; 909 req->cmd = PVCALLS_POLL; 910 req->u.poll.id = (uintptr_t) map; 911 912 bedata->ring.req_prod_pvt++; 913 RING_PUSH_REQUESTS_AND_CHECK_NOTIFY(&bedata->ring, notify); 914 spin_unlock(&bedata->socket_lock); 915 if (notify) 916 notify_remote_via_irq(bedata->irq); 917 918 poll_wait(file, &bedata->inflight_req, wait); 919 return 0; 920 } 921 922 static __poll_t pvcalls_front_poll_active(struct file *file, 923 struct pvcalls_bedata *bedata, 924 struct sock_mapping *map, 925 poll_table *wait) 926 { 927 __poll_t mask = 0; 928 int32_t in_error, out_error; 929 struct pvcalls_data_intf *intf = map->active.ring; 930 931 out_error = intf->out_error; 932 in_error = intf->in_error; 933 934 poll_wait(file, &map->active.inflight_conn_req, wait); 935 if (pvcalls_front_write_todo(map)) 936 mask |= EPOLLOUT | EPOLLWRNORM; 937 if (pvcalls_front_read_todo(map)) 938 mask |= EPOLLIN | EPOLLRDNORM; 939 if (in_error != 0 || out_error != 0) 940 mask |= EPOLLERR; 941 942 return mask; 943 } 944 945 __poll_t pvcalls_front_poll(struct file *file, struct socket *sock, 946 poll_table *wait) 947 { 948 struct pvcalls_bedata *bedata; 949 struct sock_mapping *map; 950 __poll_t ret; 951 952 map = pvcalls_enter_sock(sock); 953 if (IS_ERR(map)) 954 return EPOLLNVAL; 955 bedata = dev_get_drvdata(&pvcalls_front_dev->dev); 956 957 if (map->active_socket) 958 ret = pvcalls_front_poll_active(file, bedata, map, wait); 959 else 960 ret = pvcalls_front_poll_passive(file, bedata, map, wait); 961 pvcalls_exit_sock(sock); 962 return ret; 963 } 964 965 int pvcalls_front_release(struct socket *sock) 966 { 967 struct pvcalls_bedata *bedata; 968 struct sock_mapping *map; 969 int req_id, notify, ret; 970 struct xen_pvcalls_request *req; 971 972 if (sock->sk == NULL) 973 return 0; 974 975 map = pvcalls_enter_sock(sock); 976 if (IS_ERR(map)) { 977 if (PTR_ERR(map) == -ENOTCONN) 978 return -EIO; 979 else 980 return 0; 981 } 982 bedata = dev_get_drvdata(&pvcalls_front_dev->dev); 983 984 spin_lock(&bedata->socket_lock); 985 ret = get_request(bedata, &req_id); 986 if (ret < 0) { 987 spin_unlock(&bedata->socket_lock); 988 pvcalls_exit_sock(sock); 989 return ret; 990 } 991 sock->sk->sk_send_head = NULL; 992 993 req = RING_GET_REQUEST(&bedata->ring, req_id); 994 req->req_id = req_id; 995 req->cmd = PVCALLS_RELEASE; 996 req->u.release.id = (uintptr_t)map; 997 998 bedata->ring.req_prod_pvt++; 999 RING_PUSH_REQUESTS_AND_CHECK_NOTIFY(&bedata->ring, notify); 1000 spin_unlock(&bedata->socket_lock); 1001 if (notify) 1002 notify_remote_via_irq(bedata->irq); 1003 1004 wait_event(bedata->inflight_req, 1005 READ_ONCE(bedata->rsp[req_id].req_id) == req_id); 1006 1007 if (map->active_socket) { 1008 /* 1009 * Set in_error and wake up inflight_conn_req to force 1010 * recvmsg waiters to exit. 1011 */ 1012 map->active.ring->in_error = -EBADF; 1013 wake_up_interruptible(&map->active.inflight_conn_req); 1014 1015 /* 1016 * We need to make sure that sendmsg/recvmsg on this socket have 1017 * not started before we've cleared sk_send_head here. The 1018 * easiest way to guarantee this is to see that no pvcalls 1019 * (other than us) is in progress on this socket. 1020 */ 1021 while (atomic_read(&map->refcount) > 1) 1022 cpu_relax(); 1023 1024 pvcalls_front_free_map(bedata, map); 1025 } else { 1026 wake_up(&bedata->inflight_req); 1027 wake_up(&map->passive.inflight_accept_req); 1028 1029 while (atomic_read(&map->refcount) > 1) 1030 cpu_relax(); 1031 1032 spin_lock(&bedata->socket_lock); 1033 list_del(&map->list); 1034 spin_unlock(&bedata->socket_lock); 1035 if (READ_ONCE(map->passive.inflight_req_id) != 1036 PVCALLS_INVALID_ID) { 1037 pvcalls_front_free_map(bedata, 1038 map->passive.accept_map); 1039 } 1040 kfree(map); 1041 } 1042 WRITE_ONCE(bedata->rsp[req_id].req_id, PVCALLS_INVALID_ID); 1043 1044 pvcalls_exit(); 1045 return 0; 1046 } 1047 1048 static const struct xenbus_device_id pvcalls_front_ids[] = { 1049 { "pvcalls" }, 1050 { "" } 1051 }; 1052 1053 static int pvcalls_front_remove(struct xenbus_device *dev) 1054 { 1055 struct pvcalls_bedata *bedata; 1056 struct sock_mapping *map = NULL, *n; 1057 1058 bedata = dev_get_drvdata(&pvcalls_front_dev->dev); 1059 dev_set_drvdata(&dev->dev, NULL); 1060 pvcalls_front_dev = NULL; 1061 if (bedata->irq >= 0) 1062 unbind_from_irqhandler(bedata->irq, dev); 1063 1064 list_for_each_entry_safe(map, n, &bedata->socket_mappings, list) { 1065 map->sock->sk->sk_send_head = NULL; 1066 if (map->active_socket) { 1067 map->active.ring->in_error = -EBADF; 1068 wake_up_interruptible(&map->active.inflight_conn_req); 1069 } 1070 } 1071 1072 smp_mb(); 1073 while (atomic_read(&pvcalls_refcount) > 0) 1074 cpu_relax(); 1075 list_for_each_entry_safe(map, n, &bedata->socket_mappings, list) { 1076 if (map->active_socket) { 1077 /* No need to lock, refcount is 0 */ 1078 pvcalls_front_free_map(bedata, map); 1079 } else { 1080 list_del(&map->list); 1081 kfree(map); 1082 } 1083 } 1084 if (bedata->ref != -1) 1085 gnttab_end_foreign_access(bedata->ref, 0, 0); 1086 kfree(bedata->ring.sring); 1087 kfree(bedata); 1088 xenbus_switch_state(dev, XenbusStateClosed); 1089 return 0; 1090 } 1091 1092 static int pvcalls_front_probe(struct xenbus_device *dev, 1093 const struct xenbus_device_id *id) 1094 { 1095 int ret = -ENOMEM, evtchn, i; 1096 unsigned int max_page_order, function_calls, len; 1097 char *versions; 1098 grant_ref_t gref_head = 0; 1099 struct xenbus_transaction xbt; 1100 struct pvcalls_bedata *bedata = NULL; 1101 struct xen_pvcalls_sring *sring; 1102 1103 if (pvcalls_front_dev != NULL) { 1104 dev_err(&dev->dev, "only one PV Calls connection supported\n"); 1105 return -EINVAL; 1106 } 1107 1108 versions = xenbus_read(XBT_NIL, dev->otherend, "versions", &len); 1109 if (IS_ERR(versions)) 1110 return PTR_ERR(versions); 1111 if (!len) 1112 return -EINVAL; 1113 if (strcmp(versions, "1")) { 1114 kfree(versions); 1115 return -EINVAL; 1116 } 1117 kfree(versions); 1118 max_page_order = xenbus_read_unsigned(dev->otherend, 1119 "max-page-order", 0); 1120 if (max_page_order < PVCALLS_RING_ORDER) 1121 return -ENODEV; 1122 function_calls = xenbus_read_unsigned(dev->otherend, 1123 "function-calls", 0); 1124 /* See XENBUS_FUNCTIONS_CALLS in pvcalls.h */ 1125 if (function_calls != 1) 1126 return -ENODEV; 1127 pr_info("%s max-page-order is %u\n", __func__, max_page_order); 1128 1129 bedata = kzalloc(sizeof(struct pvcalls_bedata), GFP_KERNEL); 1130 if (!bedata) 1131 return -ENOMEM; 1132 1133 dev_set_drvdata(&dev->dev, bedata); 1134 pvcalls_front_dev = dev; 1135 init_waitqueue_head(&bedata->inflight_req); 1136 INIT_LIST_HEAD(&bedata->socket_mappings); 1137 spin_lock_init(&bedata->socket_lock); 1138 bedata->irq = -1; 1139 bedata->ref = -1; 1140 1141 for (i = 0; i < PVCALLS_NR_RSP_PER_RING; i++) 1142 bedata->rsp[i].req_id = PVCALLS_INVALID_ID; 1143 1144 sring = (struct xen_pvcalls_sring *) __get_free_page(GFP_KERNEL | 1145 __GFP_ZERO); 1146 if (!sring) 1147 goto error; 1148 SHARED_RING_INIT(sring); 1149 FRONT_RING_INIT(&bedata->ring, sring, XEN_PAGE_SIZE); 1150 1151 ret = xenbus_alloc_evtchn(dev, &evtchn); 1152 if (ret) 1153 goto error; 1154 1155 bedata->irq = bind_evtchn_to_irqhandler(evtchn, 1156 pvcalls_front_event_handler, 1157 0, "pvcalls-frontend", dev); 1158 if (bedata->irq < 0) { 1159 ret = bedata->irq; 1160 goto error; 1161 } 1162 1163 ret = gnttab_alloc_grant_references(1, &gref_head); 1164 if (ret < 0) 1165 goto error; 1166 ret = gnttab_claim_grant_reference(&gref_head); 1167 if (ret < 0) 1168 goto error; 1169 bedata->ref = ret; 1170 gnttab_grant_foreign_access_ref(bedata->ref, dev->otherend_id, 1171 virt_to_gfn((void *)sring), 0); 1172 1173 again: 1174 ret = xenbus_transaction_start(&xbt); 1175 if (ret) { 1176 xenbus_dev_fatal(dev, ret, "starting transaction"); 1177 goto error; 1178 } 1179 ret = xenbus_printf(xbt, dev->nodename, "version", "%u", 1); 1180 if (ret) 1181 goto error_xenbus; 1182 ret = xenbus_printf(xbt, dev->nodename, "ring-ref", "%d", bedata->ref); 1183 if (ret) 1184 goto error_xenbus; 1185 ret = xenbus_printf(xbt, dev->nodename, "port", "%u", 1186 evtchn); 1187 if (ret) 1188 goto error_xenbus; 1189 ret = xenbus_transaction_end(xbt, 0); 1190 if (ret) { 1191 if (ret == -EAGAIN) 1192 goto again; 1193 xenbus_dev_fatal(dev, ret, "completing transaction"); 1194 goto error; 1195 } 1196 xenbus_switch_state(dev, XenbusStateInitialised); 1197 1198 return 0; 1199 1200 error_xenbus: 1201 xenbus_transaction_end(xbt, 1); 1202 xenbus_dev_fatal(dev, ret, "writing xenstore"); 1203 error: 1204 pvcalls_front_remove(dev); 1205 return ret; 1206 } 1207 1208 static void pvcalls_front_changed(struct xenbus_device *dev, 1209 enum xenbus_state backend_state) 1210 { 1211 switch (backend_state) { 1212 case XenbusStateReconfiguring: 1213 case XenbusStateReconfigured: 1214 case XenbusStateInitialising: 1215 case XenbusStateInitialised: 1216 case XenbusStateUnknown: 1217 break; 1218 1219 case XenbusStateInitWait: 1220 break; 1221 1222 case XenbusStateConnected: 1223 xenbus_switch_state(dev, XenbusStateConnected); 1224 break; 1225 1226 case XenbusStateClosed: 1227 if (dev->state == XenbusStateClosed) 1228 break; 1229 /* Missed the backend's CLOSING state */ 1230 /* fall through */ 1231 case XenbusStateClosing: 1232 xenbus_frontend_closed(dev); 1233 break; 1234 } 1235 } 1236 1237 static struct xenbus_driver pvcalls_front_driver = { 1238 .ids = pvcalls_front_ids, 1239 .probe = pvcalls_front_probe, 1240 .remove = pvcalls_front_remove, 1241 .otherend_changed = pvcalls_front_changed, 1242 }; 1243 1244 static int __init pvcalls_frontend_init(void) 1245 { 1246 if (!xen_domain()) 1247 return -ENODEV; 1248 1249 pr_info("Initialising Xen pvcalls frontend driver\n"); 1250 1251 return xenbus_register_frontend(&pvcalls_front_driver); 1252 } 1253 1254 module_init(pvcalls_frontend_init); 1255 1256 MODULE_DESCRIPTION("Xen PV Calls frontend driver"); 1257 MODULE_AUTHOR("Stefano Stabellini <sstabellini@kernel.org>"); 1258 MODULE_LICENSE("GPL"); 1259