1 // SPDX-License-Identifier: GPL-2.0-or-later 2 /* 3 * (c) 2017 Stefano Stabellini <stefano@aporeto.com> 4 */ 5 6 #include <linux/module.h> 7 #include <linux/net.h> 8 #include <linux/socket.h> 9 10 #include <net/sock.h> 11 12 #include <xen/events.h> 13 #include <xen/grant_table.h> 14 #include <xen/xen.h> 15 #include <xen/xenbus.h> 16 #include <xen/interface/io/pvcalls.h> 17 18 #include "pvcalls-front.h" 19 20 #define PVCALLS_INVALID_ID UINT_MAX 21 #define PVCALLS_RING_ORDER XENBUS_MAX_RING_GRANT_ORDER 22 #define PVCALLS_NR_RSP_PER_RING __CONST_RING_SIZE(xen_pvcalls, XEN_PAGE_SIZE) 23 #define PVCALLS_FRONT_MAX_SPIN 5000 24 25 static struct proto pvcalls_proto = { 26 .name = "PVCalls", 27 .owner = THIS_MODULE, 28 .obj_size = sizeof(struct sock), 29 }; 30 31 struct pvcalls_bedata { 32 struct xen_pvcalls_front_ring ring; 33 grant_ref_t ref; 34 int irq; 35 36 struct list_head socket_mappings; 37 spinlock_t socket_lock; 38 39 wait_queue_head_t inflight_req; 40 struct xen_pvcalls_response rsp[PVCALLS_NR_RSP_PER_RING]; 41 }; 42 /* Only one front/back connection supported. */ 43 static struct xenbus_device *pvcalls_front_dev; 44 static atomic_t pvcalls_refcount; 45 46 /* first increment refcount, then proceed */ 47 #define pvcalls_enter() { \ 48 atomic_inc(&pvcalls_refcount); \ 49 } 50 51 /* first complete other operations, then decrement refcount */ 52 #define pvcalls_exit() { \ 53 atomic_dec(&pvcalls_refcount); \ 54 } 55 56 struct sock_mapping { 57 bool active_socket; 58 struct list_head list; 59 struct socket *sock; 60 atomic_t refcount; 61 union { 62 struct { 63 int irq; 64 grant_ref_t ref; 65 struct pvcalls_data_intf *ring; 66 struct pvcalls_data data; 67 struct mutex in_mutex; 68 struct mutex out_mutex; 69 70 wait_queue_head_t inflight_conn_req; 71 } active; 72 struct { 73 /* 74 * Socket status, needs to be 64-bit aligned due to the 75 * test_and_* functions which have this requirement on arm64. 76 */ 77 #define PVCALLS_STATUS_UNINITALIZED 0 78 #define PVCALLS_STATUS_BIND 1 79 #define PVCALLS_STATUS_LISTEN 2 80 uint8_t status __attribute__((aligned(8))); 81 /* 82 * Internal state-machine flags. 83 * Only one accept operation can be inflight for a socket. 84 * Only one poll operation can be inflight for a given socket. 85 * flags needs to be 64-bit aligned due to the test_and_* 86 * functions which have this requirement on arm64. 87 */ 88 #define PVCALLS_FLAG_ACCEPT_INFLIGHT 0 89 #define PVCALLS_FLAG_POLL_INFLIGHT 1 90 #define PVCALLS_FLAG_POLL_RET 2 91 uint8_t flags __attribute__((aligned(8))); 92 uint32_t inflight_req_id; 93 struct sock_mapping *accept_map; 94 wait_queue_head_t inflight_accept_req; 95 } passive; 96 }; 97 }; 98 99 static inline struct sock_mapping *pvcalls_enter_sock(struct socket *sock) 100 { 101 struct sock_mapping *map; 102 103 if (!pvcalls_front_dev || 104 dev_get_drvdata(&pvcalls_front_dev->dev) == NULL) 105 return ERR_PTR(-ENOTCONN); 106 107 map = (struct sock_mapping *)sock->sk->sk_send_head; 108 if (map == NULL) 109 return ERR_PTR(-ENOTSOCK); 110 111 pvcalls_enter(); 112 atomic_inc(&map->refcount); 113 return map; 114 } 115 116 static inline void pvcalls_exit_sock(struct socket *sock) 117 { 118 struct sock_mapping *map; 119 120 map = (struct sock_mapping *)sock->sk->sk_send_head; 121 atomic_dec(&map->refcount); 122 pvcalls_exit(); 123 } 124 125 static inline int get_request(struct pvcalls_bedata *bedata, int *req_id) 126 { 127 *req_id = bedata->ring.req_prod_pvt & (RING_SIZE(&bedata->ring) - 1); 128 if (RING_FULL(&bedata->ring) || 129 bedata->rsp[*req_id].req_id != PVCALLS_INVALID_ID) 130 return -EAGAIN; 131 return 0; 132 } 133 134 static bool pvcalls_front_write_todo(struct sock_mapping *map) 135 { 136 struct pvcalls_data_intf *intf = map->active.ring; 137 RING_IDX cons, prod, size = XEN_FLEX_RING_SIZE(PVCALLS_RING_ORDER); 138 int32_t error; 139 140 error = intf->out_error; 141 if (error == -ENOTCONN) 142 return false; 143 if (error != 0) 144 return true; 145 146 cons = intf->out_cons; 147 prod = intf->out_prod; 148 return !!(size - pvcalls_queued(prod, cons, size)); 149 } 150 151 static bool pvcalls_front_read_todo(struct sock_mapping *map) 152 { 153 struct pvcalls_data_intf *intf = map->active.ring; 154 RING_IDX cons, prod; 155 int32_t error; 156 157 cons = intf->in_cons; 158 prod = intf->in_prod; 159 error = intf->in_error; 160 return (error != 0 || 161 pvcalls_queued(prod, cons, 162 XEN_FLEX_RING_SIZE(PVCALLS_RING_ORDER)) != 0); 163 } 164 165 static irqreturn_t pvcalls_front_event_handler(int irq, void *dev_id) 166 { 167 struct xenbus_device *dev = dev_id; 168 struct pvcalls_bedata *bedata; 169 struct xen_pvcalls_response *rsp; 170 uint8_t *src, *dst; 171 int req_id = 0, more = 0, done = 0; 172 173 if (dev == NULL) 174 return IRQ_HANDLED; 175 176 pvcalls_enter(); 177 bedata = dev_get_drvdata(&dev->dev); 178 if (bedata == NULL) { 179 pvcalls_exit(); 180 return IRQ_HANDLED; 181 } 182 183 again: 184 while (RING_HAS_UNCONSUMED_RESPONSES(&bedata->ring)) { 185 rsp = RING_GET_RESPONSE(&bedata->ring, bedata->ring.rsp_cons); 186 187 req_id = rsp->req_id; 188 if (rsp->cmd == PVCALLS_POLL) { 189 struct sock_mapping *map = (struct sock_mapping *)(uintptr_t) 190 rsp->u.poll.id; 191 192 clear_bit(PVCALLS_FLAG_POLL_INFLIGHT, 193 (void *)&map->passive.flags); 194 /* 195 * clear INFLIGHT, then set RET. It pairs with 196 * the checks at the beginning of 197 * pvcalls_front_poll_passive. 198 */ 199 smp_wmb(); 200 set_bit(PVCALLS_FLAG_POLL_RET, 201 (void *)&map->passive.flags); 202 } else { 203 dst = (uint8_t *)&bedata->rsp[req_id] + 204 sizeof(rsp->req_id); 205 src = (uint8_t *)rsp + sizeof(rsp->req_id); 206 memcpy(dst, src, sizeof(*rsp) - sizeof(rsp->req_id)); 207 /* 208 * First copy the rest of the data, then req_id. It is 209 * paired with the barrier when accessing bedata->rsp. 210 */ 211 smp_wmb(); 212 bedata->rsp[req_id].req_id = req_id; 213 } 214 215 done = 1; 216 bedata->ring.rsp_cons++; 217 } 218 219 RING_FINAL_CHECK_FOR_RESPONSES(&bedata->ring, more); 220 if (more) 221 goto again; 222 if (done) 223 wake_up(&bedata->inflight_req); 224 pvcalls_exit(); 225 return IRQ_HANDLED; 226 } 227 228 static void free_active_ring(struct sock_mapping *map); 229 230 static void pvcalls_front_destroy_active(struct pvcalls_bedata *bedata, 231 struct sock_mapping *map) 232 { 233 int i; 234 235 unbind_from_irqhandler(map->active.irq, map); 236 237 if (bedata) { 238 spin_lock(&bedata->socket_lock); 239 if (!list_empty(&map->list)) 240 list_del_init(&map->list); 241 spin_unlock(&bedata->socket_lock); 242 } 243 244 for (i = 0; i < (1 << PVCALLS_RING_ORDER); i++) 245 gnttab_end_foreign_access(map->active.ring->ref[i], NULL); 246 gnttab_end_foreign_access(map->active.ref, NULL); 247 free_active_ring(map); 248 } 249 250 static void pvcalls_front_free_map(struct pvcalls_bedata *bedata, 251 struct sock_mapping *map) 252 { 253 pvcalls_front_destroy_active(bedata, map); 254 255 kfree(map); 256 } 257 258 static irqreturn_t pvcalls_front_conn_handler(int irq, void *sock_map) 259 { 260 struct sock_mapping *map = sock_map; 261 262 if (map == NULL) 263 return IRQ_HANDLED; 264 265 wake_up_interruptible(&map->active.inflight_conn_req); 266 267 return IRQ_HANDLED; 268 } 269 270 int pvcalls_front_socket(struct socket *sock) 271 { 272 struct pvcalls_bedata *bedata; 273 struct sock_mapping *map = NULL; 274 struct xen_pvcalls_request *req; 275 int notify, req_id, ret; 276 277 /* 278 * PVCalls only supports domain AF_INET, 279 * type SOCK_STREAM and protocol 0 sockets for now. 280 * 281 * Check socket type here, AF_INET and protocol checks are done 282 * by the caller. 283 */ 284 if (sock->type != SOCK_STREAM) 285 return -EOPNOTSUPP; 286 287 pvcalls_enter(); 288 if (!pvcalls_front_dev) { 289 pvcalls_exit(); 290 return -EACCES; 291 } 292 bedata = dev_get_drvdata(&pvcalls_front_dev->dev); 293 294 map = kzalloc(sizeof(*map), GFP_KERNEL); 295 if (map == NULL) { 296 pvcalls_exit(); 297 return -ENOMEM; 298 } 299 300 spin_lock(&bedata->socket_lock); 301 302 ret = get_request(bedata, &req_id); 303 if (ret < 0) { 304 kfree(map); 305 spin_unlock(&bedata->socket_lock); 306 pvcalls_exit(); 307 return ret; 308 } 309 310 /* 311 * sock->sk->sk_send_head is not used for ip sockets: reuse the 312 * field to store a pointer to the struct sock_mapping 313 * corresponding to the socket. This way, we can easily get the 314 * struct sock_mapping from the struct socket. 315 */ 316 sock->sk->sk_send_head = (void *)map; 317 list_add_tail(&map->list, &bedata->socket_mappings); 318 319 req = RING_GET_REQUEST(&bedata->ring, req_id); 320 req->req_id = req_id; 321 req->cmd = PVCALLS_SOCKET; 322 req->u.socket.id = (uintptr_t) map; 323 req->u.socket.domain = AF_INET; 324 req->u.socket.type = SOCK_STREAM; 325 req->u.socket.protocol = IPPROTO_IP; 326 327 bedata->ring.req_prod_pvt++; 328 RING_PUSH_REQUESTS_AND_CHECK_NOTIFY(&bedata->ring, notify); 329 spin_unlock(&bedata->socket_lock); 330 if (notify) 331 notify_remote_via_irq(bedata->irq); 332 333 wait_event(bedata->inflight_req, 334 READ_ONCE(bedata->rsp[req_id].req_id) == req_id); 335 336 /* read req_id, then the content */ 337 smp_rmb(); 338 ret = bedata->rsp[req_id].ret; 339 bedata->rsp[req_id].req_id = PVCALLS_INVALID_ID; 340 341 pvcalls_exit(); 342 return ret; 343 } 344 345 static void free_active_ring(struct sock_mapping *map) 346 { 347 if (!map->active.ring) 348 return; 349 350 free_pages_exact(map->active.data.in, 351 PAGE_SIZE << map->active.ring->ring_order); 352 free_page((unsigned long)map->active.ring); 353 } 354 355 static int alloc_active_ring(struct sock_mapping *map) 356 { 357 void *bytes; 358 359 map->active.ring = (struct pvcalls_data_intf *) 360 get_zeroed_page(GFP_KERNEL); 361 if (!map->active.ring) 362 goto out; 363 364 map->active.ring->ring_order = PVCALLS_RING_ORDER; 365 bytes = alloc_pages_exact(PAGE_SIZE << PVCALLS_RING_ORDER, 366 GFP_KERNEL | __GFP_ZERO); 367 if (!bytes) 368 goto out; 369 370 map->active.data.in = bytes; 371 map->active.data.out = bytes + 372 XEN_FLEX_RING_SIZE(PVCALLS_RING_ORDER); 373 374 return 0; 375 376 out: 377 free_active_ring(map); 378 return -ENOMEM; 379 } 380 381 static int create_active(struct sock_mapping *map, evtchn_port_t *evtchn) 382 { 383 void *bytes; 384 int ret, irq = -1, i; 385 386 *evtchn = 0; 387 init_waitqueue_head(&map->active.inflight_conn_req); 388 389 bytes = map->active.data.in; 390 for (i = 0; i < (1 << PVCALLS_RING_ORDER); i++) 391 map->active.ring->ref[i] = gnttab_grant_foreign_access( 392 pvcalls_front_dev->otherend_id, 393 pfn_to_gfn(virt_to_pfn(bytes) + i), 0); 394 395 map->active.ref = gnttab_grant_foreign_access( 396 pvcalls_front_dev->otherend_id, 397 pfn_to_gfn(virt_to_pfn((void *)map->active.ring)), 0); 398 399 ret = xenbus_alloc_evtchn(pvcalls_front_dev, evtchn); 400 if (ret) 401 goto out_error; 402 irq = bind_evtchn_to_irqhandler(*evtchn, pvcalls_front_conn_handler, 403 0, "pvcalls-frontend", map); 404 if (irq < 0) { 405 ret = irq; 406 goto out_error; 407 } 408 409 map->active.irq = irq; 410 map->active_socket = true; 411 mutex_init(&map->active.in_mutex); 412 mutex_init(&map->active.out_mutex); 413 414 return 0; 415 416 out_error: 417 if (*evtchn > 0) 418 xenbus_free_evtchn(pvcalls_front_dev, *evtchn); 419 return ret; 420 } 421 422 int pvcalls_front_connect(struct socket *sock, struct sockaddr *addr, 423 int addr_len, int flags) 424 { 425 struct pvcalls_bedata *bedata; 426 struct sock_mapping *map = NULL; 427 struct xen_pvcalls_request *req; 428 int notify, req_id, ret; 429 evtchn_port_t evtchn; 430 431 if (addr->sa_family != AF_INET || sock->type != SOCK_STREAM) 432 return -EOPNOTSUPP; 433 434 map = pvcalls_enter_sock(sock); 435 if (IS_ERR(map)) 436 return PTR_ERR(map); 437 438 bedata = dev_get_drvdata(&pvcalls_front_dev->dev); 439 ret = alloc_active_ring(map); 440 if (ret < 0) { 441 pvcalls_exit_sock(sock); 442 return ret; 443 } 444 ret = create_active(map, &evtchn); 445 if (ret < 0) { 446 free_active_ring(map); 447 pvcalls_exit_sock(sock); 448 return ret; 449 } 450 451 spin_lock(&bedata->socket_lock); 452 ret = get_request(bedata, &req_id); 453 if (ret < 0) { 454 spin_unlock(&bedata->socket_lock); 455 pvcalls_front_destroy_active(NULL, map); 456 pvcalls_exit_sock(sock); 457 return ret; 458 } 459 460 req = RING_GET_REQUEST(&bedata->ring, req_id); 461 req->req_id = req_id; 462 req->cmd = PVCALLS_CONNECT; 463 req->u.connect.id = (uintptr_t)map; 464 req->u.connect.len = addr_len; 465 req->u.connect.flags = flags; 466 req->u.connect.ref = map->active.ref; 467 req->u.connect.evtchn = evtchn; 468 memcpy(req->u.connect.addr, addr, sizeof(*addr)); 469 470 map->sock = sock; 471 472 bedata->ring.req_prod_pvt++; 473 RING_PUSH_REQUESTS_AND_CHECK_NOTIFY(&bedata->ring, notify); 474 spin_unlock(&bedata->socket_lock); 475 476 if (notify) 477 notify_remote_via_irq(bedata->irq); 478 479 wait_event(bedata->inflight_req, 480 READ_ONCE(bedata->rsp[req_id].req_id) == req_id); 481 482 /* read req_id, then the content */ 483 smp_rmb(); 484 ret = bedata->rsp[req_id].ret; 485 bedata->rsp[req_id].req_id = PVCALLS_INVALID_ID; 486 pvcalls_exit_sock(sock); 487 return ret; 488 } 489 490 static int __write_ring(struct pvcalls_data_intf *intf, 491 struct pvcalls_data *data, 492 struct iov_iter *msg_iter, 493 int len) 494 { 495 RING_IDX cons, prod, size, masked_prod, masked_cons; 496 RING_IDX array_size = XEN_FLEX_RING_SIZE(PVCALLS_RING_ORDER); 497 int32_t error; 498 499 error = intf->out_error; 500 if (error < 0) 501 return error; 502 cons = intf->out_cons; 503 prod = intf->out_prod; 504 /* read indexes before continuing */ 505 virt_mb(); 506 507 size = pvcalls_queued(prod, cons, array_size); 508 if (size > array_size) 509 return -EINVAL; 510 if (size == array_size) 511 return 0; 512 if (len > array_size - size) 513 len = array_size - size; 514 515 masked_prod = pvcalls_mask(prod, array_size); 516 masked_cons = pvcalls_mask(cons, array_size); 517 518 if (masked_prod < masked_cons) { 519 len = copy_from_iter(data->out + masked_prod, len, msg_iter); 520 } else { 521 if (len > array_size - masked_prod) { 522 int ret = copy_from_iter(data->out + masked_prod, 523 array_size - masked_prod, msg_iter); 524 if (ret != array_size - masked_prod) { 525 len = ret; 526 goto out; 527 } 528 len = ret + copy_from_iter(data->out, len - ret, msg_iter); 529 } else { 530 len = copy_from_iter(data->out + masked_prod, len, msg_iter); 531 } 532 } 533 out: 534 /* write to ring before updating pointer */ 535 virt_wmb(); 536 intf->out_prod += len; 537 538 return len; 539 } 540 541 int pvcalls_front_sendmsg(struct socket *sock, struct msghdr *msg, 542 size_t len) 543 { 544 struct sock_mapping *map; 545 int sent, tot_sent = 0; 546 int count = 0, flags; 547 548 flags = msg->msg_flags; 549 if (flags & (MSG_CONFIRM|MSG_DONTROUTE|MSG_EOR|MSG_OOB)) 550 return -EOPNOTSUPP; 551 552 map = pvcalls_enter_sock(sock); 553 if (IS_ERR(map)) 554 return PTR_ERR(map); 555 556 mutex_lock(&map->active.out_mutex); 557 if ((flags & MSG_DONTWAIT) && !pvcalls_front_write_todo(map)) { 558 mutex_unlock(&map->active.out_mutex); 559 pvcalls_exit_sock(sock); 560 return -EAGAIN; 561 } 562 if (len > INT_MAX) 563 len = INT_MAX; 564 565 again: 566 count++; 567 sent = __write_ring(map->active.ring, 568 &map->active.data, &msg->msg_iter, 569 len); 570 if (sent > 0) { 571 len -= sent; 572 tot_sent += sent; 573 notify_remote_via_irq(map->active.irq); 574 } 575 if (sent >= 0 && len > 0 && count < PVCALLS_FRONT_MAX_SPIN) 576 goto again; 577 if (sent < 0) 578 tot_sent = sent; 579 580 mutex_unlock(&map->active.out_mutex); 581 pvcalls_exit_sock(sock); 582 return tot_sent; 583 } 584 585 static int __read_ring(struct pvcalls_data_intf *intf, 586 struct pvcalls_data *data, 587 struct iov_iter *msg_iter, 588 size_t len, int flags) 589 { 590 RING_IDX cons, prod, size, masked_prod, masked_cons; 591 RING_IDX array_size = XEN_FLEX_RING_SIZE(PVCALLS_RING_ORDER); 592 int32_t error; 593 594 cons = intf->in_cons; 595 prod = intf->in_prod; 596 error = intf->in_error; 597 /* get pointers before reading from the ring */ 598 virt_rmb(); 599 600 size = pvcalls_queued(prod, cons, array_size); 601 masked_prod = pvcalls_mask(prod, array_size); 602 masked_cons = pvcalls_mask(cons, array_size); 603 604 if (size == 0) 605 return error ?: size; 606 607 if (len > size) 608 len = size; 609 610 if (masked_prod > masked_cons) { 611 len = copy_to_iter(data->in + masked_cons, len, msg_iter); 612 } else { 613 if (len > (array_size - masked_cons)) { 614 int ret = copy_to_iter(data->in + masked_cons, 615 array_size - masked_cons, msg_iter); 616 if (ret != array_size - masked_cons) { 617 len = ret; 618 goto out; 619 } 620 len = ret + copy_to_iter(data->in, len - ret, msg_iter); 621 } else { 622 len = copy_to_iter(data->in + masked_cons, len, msg_iter); 623 } 624 } 625 out: 626 /* read data from the ring before increasing the index */ 627 virt_mb(); 628 if (!(flags & MSG_PEEK)) 629 intf->in_cons += len; 630 631 return len; 632 } 633 634 int pvcalls_front_recvmsg(struct socket *sock, struct msghdr *msg, size_t len, 635 int flags) 636 { 637 int ret; 638 struct sock_mapping *map; 639 640 if (flags & (MSG_CMSG_CLOEXEC|MSG_ERRQUEUE|MSG_OOB|MSG_TRUNC)) 641 return -EOPNOTSUPP; 642 643 map = pvcalls_enter_sock(sock); 644 if (IS_ERR(map)) 645 return PTR_ERR(map); 646 647 mutex_lock(&map->active.in_mutex); 648 if (len > XEN_FLEX_RING_SIZE(PVCALLS_RING_ORDER)) 649 len = XEN_FLEX_RING_SIZE(PVCALLS_RING_ORDER); 650 651 while (!(flags & MSG_DONTWAIT) && !pvcalls_front_read_todo(map)) { 652 wait_event_interruptible(map->active.inflight_conn_req, 653 pvcalls_front_read_todo(map)); 654 } 655 ret = __read_ring(map->active.ring, &map->active.data, 656 &msg->msg_iter, len, flags); 657 658 if (ret > 0) 659 notify_remote_via_irq(map->active.irq); 660 if (ret == 0) 661 ret = (flags & MSG_DONTWAIT) ? -EAGAIN : 0; 662 if (ret == -ENOTCONN) 663 ret = 0; 664 665 mutex_unlock(&map->active.in_mutex); 666 pvcalls_exit_sock(sock); 667 return ret; 668 } 669 670 int pvcalls_front_bind(struct socket *sock, struct sockaddr *addr, int addr_len) 671 { 672 struct pvcalls_bedata *bedata; 673 struct sock_mapping *map = NULL; 674 struct xen_pvcalls_request *req; 675 int notify, req_id, ret; 676 677 if (addr->sa_family != AF_INET || sock->type != SOCK_STREAM) 678 return -EOPNOTSUPP; 679 680 map = pvcalls_enter_sock(sock); 681 if (IS_ERR(map)) 682 return PTR_ERR(map); 683 bedata = dev_get_drvdata(&pvcalls_front_dev->dev); 684 685 spin_lock(&bedata->socket_lock); 686 ret = get_request(bedata, &req_id); 687 if (ret < 0) { 688 spin_unlock(&bedata->socket_lock); 689 pvcalls_exit_sock(sock); 690 return ret; 691 } 692 req = RING_GET_REQUEST(&bedata->ring, req_id); 693 req->req_id = req_id; 694 map->sock = sock; 695 req->cmd = PVCALLS_BIND; 696 req->u.bind.id = (uintptr_t)map; 697 memcpy(req->u.bind.addr, addr, sizeof(*addr)); 698 req->u.bind.len = addr_len; 699 700 init_waitqueue_head(&map->passive.inflight_accept_req); 701 702 map->active_socket = false; 703 704 bedata->ring.req_prod_pvt++; 705 RING_PUSH_REQUESTS_AND_CHECK_NOTIFY(&bedata->ring, notify); 706 spin_unlock(&bedata->socket_lock); 707 if (notify) 708 notify_remote_via_irq(bedata->irq); 709 710 wait_event(bedata->inflight_req, 711 READ_ONCE(bedata->rsp[req_id].req_id) == req_id); 712 713 /* read req_id, then the content */ 714 smp_rmb(); 715 ret = bedata->rsp[req_id].ret; 716 bedata->rsp[req_id].req_id = PVCALLS_INVALID_ID; 717 718 map->passive.status = PVCALLS_STATUS_BIND; 719 pvcalls_exit_sock(sock); 720 return 0; 721 } 722 723 int pvcalls_front_listen(struct socket *sock, int backlog) 724 { 725 struct pvcalls_bedata *bedata; 726 struct sock_mapping *map; 727 struct xen_pvcalls_request *req; 728 int notify, req_id, ret; 729 730 map = pvcalls_enter_sock(sock); 731 if (IS_ERR(map)) 732 return PTR_ERR(map); 733 bedata = dev_get_drvdata(&pvcalls_front_dev->dev); 734 735 if (map->passive.status != PVCALLS_STATUS_BIND) { 736 pvcalls_exit_sock(sock); 737 return -EOPNOTSUPP; 738 } 739 740 spin_lock(&bedata->socket_lock); 741 ret = get_request(bedata, &req_id); 742 if (ret < 0) { 743 spin_unlock(&bedata->socket_lock); 744 pvcalls_exit_sock(sock); 745 return ret; 746 } 747 req = RING_GET_REQUEST(&bedata->ring, req_id); 748 req->req_id = req_id; 749 req->cmd = PVCALLS_LISTEN; 750 req->u.listen.id = (uintptr_t) map; 751 req->u.listen.backlog = backlog; 752 753 bedata->ring.req_prod_pvt++; 754 RING_PUSH_REQUESTS_AND_CHECK_NOTIFY(&bedata->ring, notify); 755 spin_unlock(&bedata->socket_lock); 756 if (notify) 757 notify_remote_via_irq(bedata->irq); 758 759 wait_event(bedata->inflight_req, 760 READ_ONCE(bedata->rsp[req_id].req_id) == req_id); 761 762 /* read req_id, then the content */ 763 smp_rmb(); 764 ret = bedata->rsp[req_id].ret; 765 bedata->rsp[req_id].req_id = PVCALLS_INVALID_ID; 766 767 map->passive.status = PVCALLS_STATUS_LISTEN; 768 pvcalls_exit_sock(sock); 769 return ret; 770 } 771 772 int pvcalls_front_accept(struct socket *sock, struct socket *newsock, int flags) 773 { 774 struct pvcalls_bedata *bedata; 775 struct sock_mapping *map; 776 struct sock_mapping *map2 = NULL; 777 struct xen_pvcalls_request *req; 778 int notify, req_id, ret, nonblock; 779 evtchn_port_t evtchn; 780 781 map = pvcalls_enter_sock(sock); 782 if (IS_ERR(map)) 783 return PTR_ERR(map); 784 bedata = dev_get_drvdata(&pvcalls_front_dev->dev); 785 786 if (map->passive.status != PVCALLS_STATUS_LISTEN) { 787 pvcalls_exit_sock(sock); 788 return -EINVAL; 789 } 790 791 nonblock = flags & SOCK_NONBLOCK; 792 /* 793 * Backend only supports 1 inflight accept request, will return 794 * errors for the others 795 */ 796 if (test_and_set_bit(PVCALLS_FLAG_ACCEPT_INFLIGHT, 797 (void *)&map->passive.flags)) { 798 req_id = READ_ONCE(map->passive.inflight_req_id); 799 if (req_id != PVCALLS_INVALID_ID && 800 READ_ONCE(bedata->rsp[req_id].req_id) == req_id) { 801 map2 = map->passive.accept_map; 802 goto received; 803 } 804 if (nonblock) { 805 pvcalls_exit_sock(sock); 806 return -EAGAIN; 807 } 808 if (wait_event_interruptible(map->passive.inflight_accept_req, 809 !test_and_set_bit(PVCALLS_FLAG_ACCEPT_INFLIGHT, 810 (void *)&map->passive.flags))) { 811 pvcalls_exit_sock(sock); 812 return -EINTR; 813 } 814 } 815 816 map2 = kzalloc(sizeof(*map2), GFP_KERNEL); 817 if (map2 == NULL) { 818 clear_bit(PVCALLS_FLAG_ACCEPT_INFLIGHT, 819 (void *)&map->passive.flags); 820 pvcalls_exit_sock(sock); 821 return -ENOMEM; 822 } 823 ret = alloc_active_ring(map2); 824 if (ret < 0) { 825 clear_bit(PVCALLS_FLAG_ACCEPT_INFLIGHT, 826 (void *)&map->passive.flags); 827 kfree(map2); 828 pvcalls_exit_sock(sock); 829 return ret; 830 } 831 ret = create_active(map2, &evtchn); 832 if (ret < 0) { 833 free_active_ring(map2); 834 kfree(map2); 835 clear_bit(PVCALLS_FLAG_ACCEPT_INFLIGHT, 836 (void *)&map->passive.flags); 837 pvcalls_exit_sock(sock); 838 return ret; 839 } 840 841 spin_lock(&bedata->socket_lock); 842 ret = get_request(bedata, &req_id); 843 if (ret < 0) { 844 clear_bit(PVCALLS_FLAG_ACCEPT_INFLIGHT, 845 (void *)&map->passive.flags); 846 spin_unlock(&bedata->socket_lock); 847 pvcalls_front_free_map(bedata, map2); 848 pvcalls_exit_sock(sock); 849 return ret; 850 } 851 852 list_add_tail(&map2->list, &bedata->socket_mappings); 853 854 req = RING_GET_REQUEST(&bedata->ring, req_id); 855 req->req_id = req_id; 856 req->cmd = PVCALLS_ACCEPT; 857 req->u.accept.id = (uintptr_t) map; 858 req->u.accept.ref = map2->active.ref; 859 req->u.accept.id_new = (uintptr_t) map2; 860 req->u.accept.evtchn = evtchn; 861 map->passive.accept_map = map2; 862 863 bedata->ring.req_prod_pvt++; 864 RING_PUSH_REQUESTS_AND_CHECK_NOTIFY(&bedata->ring, notify); 865 spin_unlock(&bedata->socket_lock); 866 if (notify) 867 notify_remote_via_irq(bedata->irq); 868 /* We could check if we have received a response before returning. */ 869 if (nonblock) { 870 WRITE_ONCE(map->passive.inflight_req_id, req_id); 871 pvcalls_exit_sock(sock); 872 return -EAGAIN; 873 } 874 875 if (wait_event_interruptible(bedata->inflight_req, 876 READ_ONCE(bedata->rsp[req_id].req_id) == req_id)) { 877 pvcalls_exit_sock(sock); 878 return -EINTR; 879 } 880 /* read req_id, then the content */ 881 smp_rmb(); 882 883 received: 884 map2->sock = newsock; 885 newsock->sk = sk_alloc(sock_net(sock->sk), PF_INET, GFP_KERNEL, &pvcalls_proto, false); 886 if (!newsock->sk) { 887 bedata->rsp[req_id].req_id = PVCALLS_INVALID_ID; 888 map->passive.inflight_req_id = PVCALLS_INVALID_ID; 889 clear_bit(PVCALLS_FLAG_ACCEPT_INFLIGHT, 890 (void *)&map->passive.flags); 891 pvcalls_front_free_map(bedata, map2); 892 pvcalls_exit_sock(sock); 893 return -ENOMEM; 894 } 895 newsock->sk->sk_send_head = (void *)map2; 896 897 ret = bedata->rsp[req_id].ret; 898 bedata->rsp[req_id].req_id = PVCALLS_INVALID_ID; 899 map->passive.inflight_req_id = PVCALLS_INVALID_ID; 900 901 clear_bit(PVCALLS_FLAG_ACCEPT_INFLIGHT, (void *)&map->passive.flags); 902 wake_up(&map->passive.inflight_accept_req); 903 904 pvcalls_exit_sock(sock); 905 return ret; 906 } 907 908 static __poll_t pvcalls_front_poll_passive(struct file *file, 909 struct pvcalls_bedata *bedata, 910 struct sock_mapping *map, 911 poll_table *wait) 912 { 913 int notify, req_id, ret; 914 struct xen_pvcalls_request *req; 915 916 if (test_bit(PVCALLS_FLAG_ACCEPT_INFLIGHT, 917 (void *)&map->passive.flags)) { 918 uint32_t req_id = READ_ONCE(map->passive.inflight_req_id); 919 920 if (req_id != PVCALLS_INVALID_ID && 921 READ_ONCE(bedata->rsp[req_id].req_id) == req_id) 922 return EPOLLIN | EPOLLRDNORM; 923 924 poll_wait(file, &map->passive.inflight_accept_req, wait); 925 return 0; 926 } 927 928 if (test_and_clear_bit(PVCALLS_FLAG_POLL_RET, 929 (void *)&map->passive.flags)) 930 return EPOLLIN | EPOLLRDNORM; 931 932 /* 933 * First check RET, then INFLIGHT. No barriers necessary to 934 * ensure execution ordering because of the conditional 935 * instructions creating control dependencies. 936 */ 937 938 if (test_and_set_bit(PVCALLS_FLAG_POLL_INFLIGHT, 939 (void *)&map->passive.flags)) { 940 poll_wait(file, &bedata->inflight_req, wait); 941 return 0; 942 } 943 944 spin_lock(&bedata->socket_lock); 945 ret = get_request(bedata, &req_id); 946 if (ret < 0) { 947 spin_unlock(&bedata->socket_lock); 948 return ret; 949 } 950 req = RING_GET_REQUEST(&bedata->ring, req_id); 951 req->req_id = req_id; 952 req->cmd = PVCALLS_POLL; 953 req->u.poll.id = (uintptr_t) map; 954 955 bedata->ring.req_prod_pvt++; 956 RING_PUSH_REQUESTS_AND_CHECK_NOTIFY(&bedata->ring, notify); 957 spin_unlock(&bedata->socket_lock); 958 if (notify) 959 notify_remote_via_irq(bedata->irq); 960 961 poll_wait(file, &bedata->inflight_req, wait); 962 return 0; 963 } 964 965 static __poll_t pvcalls_front_poll_active(struct file *file, 966 struct pvcalls_bedata *bedata, 967 struct sock_mapping *map, 968 poll_table *wait) 969 { 970 __poll_t mask = 0; 971 int32_t in_error, out_error; 972 struct pvcalls_data_intf *intf = map->active.ring; 973 974 out_error = intf->out_error; 975 in_error = intf->in_error; 976 977 poll_wait(file, &map->active.inflight_conn_req, wait); 978 if (pvcalls_front_write_todo(map)) 979 mask |= EPOLLOUT | EPOLLWRNORM; 980 if (pvcalls_front_read_todo(map)) 981 mask |= EPOLLIN | EPOLLRDNORM; 982 if (in_error != 0 || out_error != 0) 983 mask |= EPOLLERR; 984 985 return mask; 986 } 987 988 __poll_t pvcalls_front_poll(struct file *file, struct socket *sock, 989 poll_table *wait) 990 { 991 struct pvcalls_bedata *bedata; 992 struct sock_mapping *map; 993 __poll_t ret; 994 995 map = pvcalls_enter_sock(sock); 996 if (IS_ERR(map)) 997 return EPOLLNVAL; 998 bedata = dev_get_drvdata(&pvcalls_front_dev->dev); 999 1000 if (map->active_socket) 1001 ret = pvcalls_front_poll_active(file, bedata, map, wait); 1002 else 1003 ret = pvcalls_front_poll_passive(file, bedata, map, wait); 1004 pvcalls_exit_sock(sock); 1005 return ret; 1006 } 1007 1008 int pvcalls_front_release(struct socket *sock) 1009 { 1010 struct pvcalls_bedata *bedata; 1011 struct sock_mapping *map; 1012 int req_id, notify, ret; 1013 struct xen_pvcalls_request *req; 1014 1015 if (sock->sk == NULL) 1016 return 0; 1017 1018 map = pvcalls_enter_sock(sock); 1019 if (IS_ERR(map)) { 1020 if (PTR_ERR(map) == -ENOTCONN) 1021 return -EIO; 1022 else 1023 return 0; 1024 } 1025 bedata = dev_get_drvdata(&pvcalls_front_dev->dev); 1026 1027 spin_lock(&bedata->socket_lock); 1028 ret = get_request(bedata, &req_id); 1029 if (ret < 0) { 1030 spin_unlock(&bedata->socket_lock); 1031 pvcalls_exit_sock(sock); 1032 return ret; 1033 } 1034 sock->sk->sk_send_head = NULL; 1035 1036 req = RING_GET_REQUEST(&bedata->ring, req_id); 1037 req->req_id = req_id; 1038 req->cmd = PVCALLS_RELEASE; 1039 req->u.release.id = (uintptr_t)map; 1040 1041 bedata->ring.req_prod_pvt++; 1042 RING_PUSH_REQUESTS_AND_CHECK_NOTIFY(&bedata->ring, notify); 1043 spin_unlock(&bedata->socket_lock); 1044 if (notify) 1045 notify_remote_via_irq(bedata->irq); 1046 1047 wait_event(bedata->inflight_req, 1048 READ_ONCE(bedata->rsp[req_id].req_id) == req_id); 1049 1050 if (map->active_socket) { 1051 /* 1052 * Set in_error and wake up inflight_conn_req to force 1053 * recvmsg waiters to exit. 1054 */ 1055 map->active.ring->in_error = -EBADF; 1056 wake_up_interruptible(&map->active.inflight_conn_req); 1057 1058 /* 1059 * We need to make sure that sendmsg/recvmsg on this socket have 1060 * not started before we've cleared sk_send_head here. The 1061 * easiest way to guarantee this is to see that no pvcalls 1062 * (other than us) is in progress on this socket. 1063 */ 1064 while (atomic_read(&map->refcount) > 1) 1065 cpu_relax(); 1066 1067 pvcalls_front_free_map(bedata, map); 1068 } else { 1069 wake_up(&bedata->inflight_req); 1070 wake_up(&map->passive.inflight_accept_req); 1071 1072 while (atomic_read(&map->refcount) > 1) 1073 cpu_relax(); 1074 1075 spin_lock(&bedata->socket_lock); 1076 list_del(&map->list); 1077 spin_unlock(&bedata->socket_lock); 1078 if (READ_ONCE(map->passive.inflight_req_id) != PVCALLS_INVALID_ID && 1079 READ_ONCE(map->passive.inflight_req_id) != 0) { 1080 pvcalls_front_free_map(bedata, 1081 map->passive.accept_map); 1082 } 1083 kfree(map); 1084 } 1085 WRITE_ONCE(bedata->rsp[req_id].req_id, PVCALLS_INVALID_ID); 1086 1087 pvcalls_exit(); 1088 return 0; 1089 } 1090 1091 static const struct xenbus_device_id pvcalls_front_ids[] = { 1092 { "pvcalls" }, 1093 { "" } 1094 }; 1095 1096 static void pvcalls_front_remove(struct xenbus_device *dev) 1097 { 1098 struct pvcalls_bedata *bedata; 1099 struct sock_mapping *map = NULL, *n; 1100 1101 bedata = dev_get_drvdata(&pvcalls_front_dev->dev); 1102 dev_set_drvdata(&dev->dev, NULL); 1103 pvcalls_front_dev = NULL; 1104 if (bedata->irq >= 0) 1105 unbind_from_irqhandler(bedata->irq, dev); 1106 1107 list_for_each_entry_safe(map, n, &bedata->socket_mappings, list) { 1108 map->sock->sk->sk_send_head = NULL; 1109 if (map->active_socket) { 1110 map->active.ring->in_error = -EBADF; 1111 wake_up_interruptible(&map->active.inflight_conn_req); 1112 } 1113 } 1114 1115 smp_mb(); 1116 while (atomic_read(&pvcalls_refcount) > 0) 1117 cpu_relax(); 1118 list_for_each_entry_safe(map, n, &bedata->socket_mappings, list) { 1119 if (map->active_socket) { 1120 /* No need to lock, refcount is 0 */ 1121 pvcalls_front_free_map(bedata, map); 1122 } else { 1123 list_del(&map->list); 1124 kfree(map); 1125 } 1126 } 1127 if (bedata->ref != -1) 1128 gnttab_end_foreign_access(bedata->ref, NULL); 1129 kfree(bedata->ring.sring); 1130 kfree(bedata); 1131 xenbus_switch_state(dev, XenbusStateClosed); 1132 } 1133 1134 static int pvcalls_front_probe(struct xenbus_device *dev, 1135 const struct xenbus_device_id *id) 1136 { 1137 int ret = -ENOMEM, i; 1138 evtchn_port_t evtchn; 1139 unsigned int max_page_order, function_calls, len; 1140 char *versions; 1141 grant_ref_t gref_head = 0; 1142 struct xenbus_transaction xbt; 1143 struct pvcalls_bedata *bedata = NULL; 1144 struct xen_pvcalls_sring *sring; 1145 1146 if (pvcalls_front_dev != NULL) { 1147 dev_err(&dev->dev, "only one PV Calls connection supported\n"); 1148 return -EINVAL; 1149 } 1150 1151 versions = xenbus_read(XBT_NIL, dev->otherend, "versions", &len); 1152 if (IS_ERR(versions)) 1153 return PTR_ERR(versions); 1154 if (!len) 1155 return -EINVAL; 1156 if (strcmp(versions, "1")) { 1157 kfree(versions); 1158 return -EINVAL; 1159 } 1160 kfree(versions); 1161 max_page_order = xenbus_read_unsigned(dev->otherend, 1162 "max-page-order", 0); 1163 if (max_page_order < PVCALLS_RING_ORDER) 1164 return -ENODEV; 1165 function_calls = xenbus_read_unsigned(dev->otherend, 1166 "function-calls", 0); 1167 /* See XENBUS_FUNCTIONS_CALLS in pvcalls.h */ 1168 if (function_calls != 1) 1169 return -ENODEV; 1170 pr_info("%s max-page-order is %u\n", __func__, max_page_order); 1171 1172 bedata = kzalloc(sizeof(struct pvcalls_bedata), GFP_KERNEL); 1173 if (!bedata) 1174 return -ENOMEM; 1175 1176 dev_set_drvdata(&dev->dev, bedata); 1177 pvcalls_front_dev = dev; 1178 init_waitqueue_head(&bedata->inflight_req); 1179 INIT_LIST_HEAD(&bedata->socket_mappings); 1180 spin_lock_init(&bedata->socket_lock); 1181 bedata->irq = -1; 1182 bedata->ref = -1; 1183 1184 for (i = 0; i < PVCALLS_NR_RSP_PER_RING; i++) 1185 bedata->rsp[i].req_id = PVCALLS_INVALID_ID; 1186 1187 sring = (struct xen_pvcalls_sring *) __get_free_page(GFP_KERNEL | 1188 __GFP_ZERO); 1189 if (!sring) 1190 goto error; 1191 SHARED_RING_INIT(sring); 1192 FRONT_RING_INIT(&bedata->ring, sring, XEN_PAGE_SIZE); 1193 1194 ret = xenbus_alloc_evtchn(dev, &evtchn); 1195 if (ret) 1196 goto error; 1197 1198 bedata->irq = bind_evtchn_to_irqhandler(evtchn, 1199 pvcalls_front_event_handler, 1200 0, "pvcalls-frontend", dev); 1201 if (bedata->irq < 0) { 1202 ret = bedata->irq; 1203 goto error; 1204 } 1205 1206 ret = gnttab_alloc_grant_references(1, &gref_head); 1207 if (ret < 0) 1208 goto error; 1209 ret = gnttab_claim_grant_reference(&gref_head); 1210 if (ret < 0) 1211 goto error; 1212 bedata->ref = ret; 1213 gnttab_grant_foreign_access_ref(bedata->ref, dev->otherend_id, 1214 virt_to_gfn((void *)sring), 0); 1215 1216 again: 1217 ret = xenbus_transaction_start(&xbt); 1218 if (ret) { 1219 xenbus_dev_fatal(dev, ret, "starting transaction"); 1220 goto error; 1221 } 1222 ret = xenbus_printf(xbt, dev->nodename, "version", "%u", 1); 1223 if (ret) 1224 goto error_xenbus; 1225 ret = xenbus_printf(xbt, dev->nodename, "ring-ref", "%d", bedata->ref); 1226 if (ret) 1227 goto error_xenbus; 1228 ret = xenbus_printf(xbt, dev->nodename, "port", "%u", 1229 evtchn); 1230 if (ret) 1231 goto error_xenbus; 1232 ret = xenbus_transaction_end(xbt, 0); 1233 if (ret) { 1234 if (ret == -EAGAIN) 1235 goto again; 1236 xenbus_dev_fatal(dev, ret, "completing transaction"); 1237 goto error; 1238 } 1239 xenbus_switch_state(dev, XenbusStateInitialised); 1240 1241 return 0; 1242 1243 error_xenbus: 1244 xenbus_transaction_end(xbt, 1); 1245 xenbus_dev_fatal(dev, ret, "writing xenstore"); 1246 error: 1247 pvcalls_front_remove(dev); 1248 return ret; 1249 } 1250 1251 static void pvcalls_front_changed(struct xenbus_device *dev, 1252 enum xenbus_state backend_state) 1253 { 1254 switch (backend_state) { 1255 case XenbusStateReconfiguring: 1256 case XenbusStateReconfigured: 1257 case XenbusStateInitialising: 1258 case XenbusStateInitialised: 1259 case XenbusStateUnknown: 1260 break; 1261 1262 case XenbusStateInitWait: 1263 break; 1264 1265 case XenbusStateConnected: 1266 xenbus_switch_state(dev, XenbusStateConnected); 1267 break; 1268 1269 case XenbusStateClosed: 1270 if (dev->state == XenbusStateClosed) 1271 break; 1272 /* Missed the backend's CLOSING state */ 1273 fallthrough; 1274 case XenbusStateClosing: 1275 xenbus_frontend_closed(dev); 1276 break; 1277 } 1278 } 1279 1280 static struct xenbus_driver pvcalls_front_driver = { 1281 .ids = pvcalls_front_ids, 1282 .probe = pvcalls_front_probe, 1283 .remove = pvcalls_front_remove, 1284 .otherend_changed = pvcalls_front_changed, 1285 .not_essential = true, 1286 }; 1287 1288 static int __init pvcalls_frontend_init(void) 1289 { 1290 if (!xen_domain()) 1291 return -ENODEV; 1292 1293 pr_info("Initialising Xen pvcalls frontend driver\n"); 1294 1295 return xenbus_register_frontend(&pvcalls_front_driver); 1296 } 1297 1298 module_init(pvcalls_frontend_init); 1299 1300 MODULE_DESCRIPTION("Xen PV Calls frontend driver"); 1301 MODULE_AUTHOR("Stefano Stabellini <sstabellini@kernel.org>"); 1302 MODULE_LICENSE("GPL"); 1303