1 // SPDX-License-Identifier: GPL-2.0-or-later 2 /* 3 * (c) 2017 Stefano Stabellini <stefano@aporeto.com> 4 */ 5 6 #include <linux/module.h> 7 #include <linux/net.h> 8 #include <linux/socket.h> 9 10 #include <net/sock.h> 11 12 #include <xen/events.h> 13 #include <xen/grant_table.h> 14 #include <xen/xen.h> 15 #include <xen/xenbus.h> 16 #include <xen/interface/io/pvcalls.h> 17 18 #include "pvcalls-front.h" 19 20 #define PVCALLS_INVALID_ID UINT_MAX 21 #define PVCALLS_RING_ORDER XENBUS_MAX_RING_GRANT_ORDER 22 #define PVCALLS_NR_RSP_PER_RING __CONST_RING_SIZE(xen_pvcalls, XEN_PAGE_SIZE) 23 #define PVCALLS_FRONT_MAX_SPIN 5000 24 25 static struct proto pvcalls_proto = { 26 .name = "PVCalls", 27 .owner = THIS_MODULE, 28 .obj_size = sizeof(struct sock), 29 }; 30 31 struct pvcalls_bedata { 32 struct xen_pvcalls_front_ring ring; 33 grant_ref_t ref; 34 int irq; 35 36 struct list_head socket_mappings; 37 spinlock_t socket_lock; 38 39 wait_queue_head_t inflight_req; 40 struct xen_pvcalls_response rsp[PVCALLS_NR_RSP_PER_RING]; 41 }; 42 /* Only one front/back connection supported. */ 43 static struct xenbus_device *pvcalls_front_dev; 44 static atomic_t pvcalls_refcount; 45 46 /* first increment refcount, then proceed */ 47 #define pvcalls_enter() { \ 48 atomic_inc(&pvcalls_refcount); \ 49 } 50 51 /* first complete other operations, then decrement refcount */ 52 #define pvcalls_exit() { \ 53 atomic_dec(&pvcalls_refcount); \ 54 } 55 56 struct sock_mapping { 57 bool active_socket; 58 struct list_head list; 59 struct socket *sock; 60 atomic_t refcount; 61 union { 62 struct { 63 int irq; 64 grant_ref_t ref; 65 struct pvcalls_data_intf *ring; 66 struct pvcalls_data data; 67 struct mutex in_mutex; 68 struct mutex out_mutex; 69 70 wait_queue_head_t inflight_conn_req; 71 } active; 72 struct { 73 /* 74 * Socket status, needs to be 64-bit aligned due to the 75 * test_and_* functions which have this requirement on arm64. 76 */ 77 #define PVCALLS_STATUS_UNINITALIZED 0 78 #define PVCALLS_STATUS_BIND 1 79 #define PVCALLS_STATUS_LISTEN 2 80 uint8_t status __attribute__((aligned(8))); 81 /* 82 * Internal state-machine flags. 83 * Only one accept operation can be inflight for a socket. 84 * Only one poll operation can be inflight for a given socket. 85 * flags needs to be 64-bit aligned due to the test_and_* 86 * functions which have this requirement on arm64. 87 */ 88 #define PVCALLS_FLAG_ACCEPT_INFLIGHT 0 89 #define PVCALLS_FLAG_POLL_INFLIGHT 1 90 #define PVCALLS_FLAG_POLL_RET 2 91 uint8_t flags __attribute__((aligned(8))); 92 uint32_t inflight_req_id; 93 struct sock_mapping *accept_map; 94 wait_queue_head_t inflight_accept_req; 95 } passive; 96 }; 97 }; 98 99 static inline struct sock_mapping *pvcalls_enter_sock(struct socket *sock) 100 { 101 struct sock_mapping *map; 102 103 if (!pvcalls_front_dev || 104 dev_get_drvdata(&pvcalls_front_dev->dev) == NULL) 105 return ERR_PTR(-ENOTCONN); 106 107 map = (struct sock_mapping *)sock->sk->sk_send_head; 108 if (map == NULL) 109 return ERR_PTR(-ENOTSOCK); 110 111 pvcalls_enter(); 112 atomic_inc(&map->refcount); 113 return map; 114 } 115 116 static inline void pvcalls_exit_sock(struct socket *sock) 117 { 118 struct sock_mapping *map; 119 120 map = (struct sock_mapping *)sock->sk->sk_send_head; 121 atomic_dec(&map->refcount); 122 pvcalls_exit(); 123 } 124 125 static inline int get_request(struct pvcalls_bedata *bedata, int *req_id) 126 { 127 *req_id = bedata->ring.req_prod_pvt & (RING_SIZE(&bedata->ring) - 1); 128 if (RING_FULL(&bedata->ring) || 129 bedata->rsp[*req_id].req_id != PVCALLS_INVALID_ID) 130 return -EAGAIN; 131 return 0; 132 } 133 134 static bool pvcalls_front_write_todo(struct sock_mapping *map) 135 { 136 struct pvcalls_data_intf *intf = map->active.ring; 137 RING_IDX cons, prod, size = XEN_FLEX_RING_SIZE(PVCALLS_RING_ORDER); 138 int32_t error; 139 140 error = intf->out_error; 141 if (error == -ENOTCONN) 142 return false; 143 if (error != 0) 144 return true; 145 146 cons = intf->out_cons; 147 prod = intf->out_prod; 148 return !!(size - pvcalls_queued(prod, cons, size)); 149 } 150 151 static bool pvcalls_front_read_todo(struct sock_mapping *map) 152 { 153 struct pvcalls_data_intf *intf = map->active.ring; 154 RING_IDX cons, prod; 155 int32_t error; 156 157 cons = intf->in_cons; 158 prod = intf->in_prod; 159 error = intf->in_error; 160 return (error != 0 || 161 pvcalls_queued(prod, cons, 162 XEN_FLEX_RING_SIZE(PVCALLS_RING_ORDER)) != 0); 163 } 164 165 static irqreturn_t pvcalls_front_event_handler(int irq, void *dev_id) 166 { 167 struct xenbus_device *dev = dev_id; 168 struct pvcalls_bedata *bedata; 169 struct xen_pvcalls_response *rsp; 170 uint8_t *src, *dst; 171 int req_id = 0, more = 0, done = 0; 172 173 if (dev == NULL) 174 return IRQ_HANDLED; 175 176 pvcalls_enter(); 177 bedata = dev_get_drvdata(&dev->dev); 178 if (bedata == NULL) { 179 pvcalls_exit(); 180 return IRQ_HANDLED; 181 } 182 183 again: 184 while (RING_HAS_UNCONSUMED_RESPONSES(&bedata->ring)) { 185 rsp = RING_GET_RESPONSE(&bedata->ring, bedata->ring.rsp_cons); 186 187 req_id = rsp->req_id; 188 if (rsp->cmd == PVCALLS_POLL) { 189 struct sock_mapping *map = (struct sock_mapping *)(uintptr_t) 190 rsp->u.poll.id; 191 192 clear_bit(PVCALLS_FLAG_POLL_INFLIGHT, 193 (void *)&map->passive.flags); 194 /* 195 * clear INFLIGHT, then set RET. It pairs with 196 * the checks at the beginning of 197 * pvcalls_front_poll_passive. 198 */ 199 smp_wmb(); 200 set_bit(PVCALLS_FLAG_POLL_RET, 201 (void *)&map->passive.flags); 202 } else { 203 dst = (uint8_t *)&bedata->rsp[req_id] + 204 sizeof(rsp->req_id); 205 src = (uint8_t *)rsp + sizeof(rsp->req_id); 206 memcpy(dst, src, sizeof(*rsp) - sizeof(rsp->req_id)); 207 /* 208 * First copy the rest of the data, then req_id. It is 209 * paired with the barrier when accessing bedata->rsp. 210 */ 211 smp_wmb(); 212 bedata->rsp[req_id].req_id = req_id; 213 } 214 215 done = 1; 216 bedata->ring.rsp_cons++; 217 } 218 219 RING_FINAL_CHECK_FOR_RESPONSES(&bedata->ring, more); 220 if (more) 221 goto again; 222 if (done) 223 wake_up(&bedata->inflight_req); 224 pvcalls_exit(); 225 return IRQ_HANDLED; 226 } 227 228 static void pvcalls_front_free_map(struct pvcalls_bedata *bedata, 229 struct sock_mapping *map) 230 { 231 int i; 232 233 unbind_from_irqhandler(map->active.irq, map); 234 235 spin_lock(&bedata->socket_lock); 236 if (!list_empty(&map->list)) 237 list_del_init(&map->list); 238 spin_unlock(&bedata->socket_lock); 239 240 for (i = 0; i < (1 << PVCALLS_RING_ORDER); i++) 241 gnttab_end_foreign_access(map->active.ring->ref[i], 0, 0); 242 gnttab_end_foreign_access(map->active.ref, 0, 0); 243 free_page((unsigned long)map->active.ring); 244 245 kfree(map); 246 } 247 248 static irqreturn_t pvcalls_front_conn_handler(int irq, void *sock_map) 249 { 250 struct sock_mapping *map = sock_map; 251 252 if (map == NULL) 253 return IRQ_HANDLED; 254 255 wake_up_interruptible(&map->active.inflight_conn_req); 256 257 return IRQ_HANDLED; 258 } 259 260 int pvcalls_front_socket(struct socket *sock) 261 { 262 struct pvcalls_bedata *bedata; 263 struct sock_mapping *map = NULL; 264 struct xen_pvcalls_request *req; 265 int notify, req_id, ret; 266 267 /* 268 * PVCalls only supports domain AF_INET, 269 * type SOCK_STREAM and protocol 0 sockets for now. 270 * 271 * Check socket type here, AF_INET and protocol checks are done 272 * by the caller. 273 */ 274 if (sock->type != SOCK_STREAM) 275 return -EOPNOTSUPP; 276 277 pvcalls_enter(); 278 if (!pvcalls_front_dev) { 279 pvcalls_exit(); 280 return -EACCES; 281 } 282 bedata = dev_get_drvdata(&pvcalls_front_dev->dev); 283 284 map = kzalloc(sizeof(*map), GFP_KERNEL); 285 if (map == NULL) { 286 pvcalls_exit(); 287 return -ENOMEM; 288 } 289 290 spin_lock(&bedata->socket_lock); 291 292 ret = get_request(bedata, &req_id); 293 if (ret < 0) { 294 kfree(map); 295 spin_unlock(&bedata->socket_lock); 296 pvcalls_exit(); 297 return ret; 298 } 299 300 /* 301 * sock->sk->sk_send_head is not used for ip sockets: reuse the 302 * field to store a pointer to the struct sock_mapping 303 * corresponding to the socket. This way, we can easily get the 304 * struct sock_mapping from the struct socket. 305 */ 306 sock->sk->sk_send_head = (void *)map; 307 list_add_tail(&map->list, &bedata->socket_mappings); 308 309 req = RING_GET_REQUEST(&bedata->ring, req_id); 310 req->req_id = req_id; 311 req->cmd = PVCALLS_SOCKET; 312 req->u.socket.id = (uintptr_t) map; 313 req->u.socket.domain = AF_INET; 314 req->u.socket.type = SOCK_STREAM; 315 req->u.socket.protocol = IPPROTO_IP; 316 317 bedata->ring.req_prod_pvt++; 318 RING_PUSH_REQUESTS_AND_CHECK_NOTIFY(&bedata->ring, notify); 319 spin_unlock(&bedata->socket_lock); 320 if (notify) 321 notify_remote_via_irq(bedata->irq); 322 323 wait_event(bedata->inflight_req, 324 READ_ONCE(bedata->rsp[req_id].req_id) == req_id); 325 326 /* read req_id, then the content */ 327 smp_rmb(); 328 ret = bedata->rsp[req_id].ret; 329 bedata->rsp[req_id].req_id = PVCALLS_INVALID_ID; 330 331 pvcalls_exit(); 332 return ret; 333 } 334 335 static void free_active_ring(struct sock_mapping *map) 336 { 337 if (!map->active.ring) 338 return; 339 340 free_pages((unsigned long)map->active.data.in, 341 map->active.ring->ring_order); 342 free_page((unsigned long)map->active.ring); 343 } 344 345 static int alloc_active_ring(struct sock_mapping *map) 346 { 347 void *bytes; 348 349 map->active.ring = (struct pvcalls_data_intf *) 350 get_zeroed_page(GFP_KERNEL); 351 if (!map->active.ring) 352 goto out; 353 354 map->active.ring->ring_order = PVCALLS_RING_ORDER; 355 bytes = (void *)__get_free_pages(GFP_KERNEL | __GFP_ZERO, 356 PVCALLS_RING_ORDER); 357 if (!bytes) 358 goto out; 359 360 map->active.data.in = bytes; 361 map->active.data.out = bytes + 362 XEN_FLEX_RING_SIZE(PVCALLS_RING_ORDER); 363 364 return 0; 365 366 out: 367 free_active_ring(map); 368 return -ENOMEM; 369 } 370 371 static int create_active(struct sock_mapping *map, int *evtchn) 372 { 373 void *bytes; 374 int ret = -ENOMEM, irq = -1, i; 375 376 *evtchn = -1; 377 init_waitqueue_head(&map->active.inflight_conn_req); 378 379 bytes = map->active.data.in; 380 for (i = 0; i < (1 << PVCALLS_RING_ORDER); i++) 381 map->active.ring->ref[i] = gnttab_grant_foreign_access( 382 pvcalls_front_dev->otherend_id, 383 pfn_to_gfn(virt_to_pfn(bytes) + i), 0); 384 385 map->active.ref = gnttab_grant_foreign_access( 386 pvcalls_front_dev->otherend_id, 387 pfn_to_gfn(virt_to_pfn((void *)map->active.ring)), 0); 388 389 ret = xenbus_alloc_evtchn(pvcalls_front_dev, evtchn); 390 if (ret) 391 goto out_error; 392 irq = bind_evtchn_to_irqhandler(*evtchn, pvcalls_front_conn_handler, 393 0, "pvcalls-frontend", map); 394 if (irq < 0) { 395 ret = irq; 396 goto out_error; 397 } 398 399 map->active.irq = irq; 400 map->active_socket = true; 401 mutex_init(&map->active.in_mutex); 402 mutex_init(&map->active.out_mutex); 403 404 return 0; 405 406 out_error: 407 if (*evtchn >= 0) 408 xenbus_free_evtchn(pvcalls_front_dev, *evtchn); 409 return ret; 410 } 411 412 int pvcalls_front_connect(struct socket *sock, struct sockaddr *addr, 413 int addr_len, int flags) 414 { 415 struct pvcalls_bedata *bedata; 416 struct sock_mapping *map = NULL; 417 struct xen_pvcalls_request *req; 418 int notify, req_id, ret, evtchn; 419 420 if (addr->sa_family != AF_INET || sock->type != SOCK_STREAM) 421 return -EOPNOTSUPP; 422 423 map = pvcalls_enter_sock(sock); 424 if (IS_ERR(map)) 425 return PTR_ERR(map); 426 427 bedata = dev_get_drvdata(&pvcalls_front_dev->dev); 428 ret = alloc_active_ring(map); 429 if (ret < 0) { 430 pvcalls_exit_sock(sock); 431 return ret; 432 } 433 434 spin_lock(&bedata->socket_lock); 435 ret = get_request(bedata, &req_id); 436 if (ret < 0) { 437 spin_unlock(&bedata->socket_lock); 438 free_active_ring(map); 439 pvcalls_exit_sock(sock); 440 return ret; 441 } 442 ret = create_active(map, &evtchn); 443 if (ret < 0) { 444 spin_unlock(&bedata->socket_lock); 445 free_active_ring(map); 446 pvcalls_exit_sock(sock); 447 return ret; 448 } 449 450 req = RING_GET_REQUEST(&bedata->ring, req_id); 451 req->req_id = req_id; 452 req->cmd = PVCALLS_CONNECT; 453 req->u.connect.id = (uintptr_t)map; 454 req->u.connect.len = addr_len; 455 req->u.connect.flags = flags; 456 req->u.connect.ref = map->active.ref; 457 req->u.connect.evtchn = evtchn; 458 memcpy(req->u.connect.addr, addr, sizeof(*addr)); 459 460 map->sock = sock; 461 462 bedata->ring.req_prod_pvt++; 463 RING_PUSH_REQUESTS_AND_CHECK_NOTIFY(&bedata->ring, notify); 464 spin_unlock(&bedata->socket_lock); 465 466 if (notify) 467 notify_remote_via_irq(bedata->irq); 468 469 wait_event(bedata->inflight_req, 470 READ_ONCE(bedata->rsp[req_id].req_id) == req_id); 471 472 /* read req_id, then the content */ 473 smp_rmb(); 474 ret = bedata->rsp[req_id].ret; 475 bedata->rsp[req_id].req_id = PVCALLS_INVALID_ID; 476 pvcalls_exit_sock(sock); 477 return ret; 478 } 479 480 static int __write_ring(struct pvcalls_data_intf *intf, 481 struct pvcalls_data *data, 482 struct iov_iter *msg_iter, 483 int len) 484 { 485 RING_IDX cons, prod, size, masked_prod, masked_cons; 486 RING_IDX array_size = XEN_FLEX_RING_SIZE(PVCALLS_RING_ORDER); 487 int32_t error; 488 489 error = intf->out_error; 490 if (error < 0) 491 return error; 492 cons = intf->out_cons; 493 prod = intf->out_prod; 494 /* read indexes before continuing */ 495 virt_mb(); 496 497 size = pvcalls_queued(prod, cons, array_size); 498 if (size > array_size) 499 return -EINVAL; 500 if (size == array_size) 501 return 0; 502 if (len > array_size - size) 503 len = array_size - size; 504 505 masked_prod = pvcalls_mask(prod, array_size); 506 masked_cons = pvcalls_mask(cons, array_size); 507 508 if (masked_prod < masked_cons) { 509 len = copy_from_iter(data->out + masked_prod, len, msg_iter); 510 } else { 511 if (len > array_size - masked_prod) { 512 int ret = copy_from_iter(data->out + masked_prod, 513 array_size - masked_prod, msg_iter); 514 if (ret != array_size - masked_prod) { 515 len = ret; 516 goto out; 517 } 518 len = ret + copy_from_iter(data->out, len - ret, msg_iter); 519 } else { 520 len = copy_from_iter(data->out + masked_prod, len, msg_iter); 521 } 522 } 523 out: 524 /* write to ring before updating pointer */ 525 virt_wmb(); 526 intf->out_prod += len; 527 528 return len; 529 } 530 531 int pvcalls_front_sendmsg(struct socket *sock, struct msghdr *msg, 532 size_t len) 533 { 534 struct sock_mapping *map; 535 int sent, tot_sent = 0; 536 int count = 0, flags; 537 538 flags = msg->msg_flags; 539 if (flags & (MSG_CONFIRM|MSG_DONTROUTE|MSG_EOR|MSG_OOB)) 540 return -EOPNOTSUPP; 541 542 map = pvcalls_enter_sock(sock); 543 if (IS_ERR(map)) 544 return PTR_ERR(map); 545 546 mutex_lock(&map->active.out_mutex); 547 if ((flags & MSG_DONTWAIT) && !pvcalls_front_write_todo(map)) { 548 mutex_unlock(&map->active.out_mutex); 549 pvcalls_exit_sock(sock); 550 return -EAGAIN; 551 } 552 if (len > INT_MAX) 553 len = INT_MAX; 554 555 again: 556 count++; 557 sent = __write_ring(map->active.ring, 558 &map->active.data, &msg->msg_iter, 559 len); 560 if (sent > 0) { 561 len -= sent; 562 tot_sent += sent; 563 notify_remote_via_irq(map->active.irq); 564 } 565 if (sent >= 0 && len > 0 && count < PVCALLS_FRONT_MAX_SPIN) 566 goto again; 567 if (sent < 0) 568 tot_sent = sent; 569 570 mutex_unlock(&map->active.out_mutex); 571 pvcalls_exit_sock(sock); 572 return tot_sent; 573 } 574 575 static int __read_ring(struct pvcalls_data_intf *intf, 576 struct pvcalls_data *data, 577 struct iov_iter *msg_iter, 578 size_t len, int flags) 579 { 580 RING_IDX cons, prod, size, masked_prod, masked_cons; 581 RING_IDX array_size = XEN_FLEX_RING_SIZE(PVCALLS_RING_ORDER); 582 int32_t error; 583 584 cons = intf->in_cons; 585 prod = intf->in_prod; 586 error = intf->in_error; 587 /* get pointers before reading from the ring */ 588 virt_rmb(); 589 590 size = pvcalls_queued(prod, cons, array_size); 591 masked_prod = pvcalls_mask(prod, array_size); 592 masked_cons = pvcalls_mask(cons, array_size); 593 594 if (size == 0) 595 return error ?: size; 596 597 if (len > size) 598 len = size; 599 600 if (masked_prod > masked_cons) { 601 len = copy_to_iter(data->in + masked_cons, len, msg_iter); 602 } else { 603 if (len > (array_size - masked_cons)) { 604 int ret = copy_to_iter(data->in + masked_cons, 605 array_size - masked_cons, msg_iter); 606 if (ret != array_size - masked_cons) { 607 len = ret; 608 goto out; 609 } 610 len = ret + copy_to_iter(data->in, len - ret, msg_iter); 611 } else { 612 len = copy_to_iter(data->in + masked_cons, len, msg_iter); 613 } 614 } 615 out: 616 /* read data from the ring before increasing the index */ 617 virt_mb(); 618 if (!(flags & MSG_PEEK)) 619 intf->in_cons += len; 620 621 return len; 622 } 623 624 int pvcalls_front_recvmsg(struct socket *sock, struct msghdr *msg, size_t len, 625 int flags) 626 { 627 int ret; 628 struct sock_mapping *map; 629 630 if (flags & (MSG_CMSG_CLOEXEC|MSG_ERRQUEUE|MSG_OOB|MSG_TRUNC)) 631 return -EOPNOTSUPP; 632 633 map = pvcalls_enter_sock(sock); 634 if (IS_ERR(map)) 635 return PTR_ERR(map); 636 637 mutex_lock(&map->active.in_mutex); 638 if (len > XEN_FLEX_RING_SIZE(PVCALLS_RING_ORDER)) 639 len = XEN_FLEX_RING_SIZE(PVCALLS_RING_ORDER); 640 641 while (!(flags & MSG_DONTWAIT) && !pvcalls_front_read_todo(map)) { 642 wait_event_interruptible(map->active.inflight_conn_req, 643 pvcalls_front_read_todo(map)); 644 } 645 ret = __read_ring(map->active.ring, &map->active.data, 646 &msg->msg_iter, len, flags); 647 648 if (ret > 0) 649 notify_remote_via_irq(map->active.irq); 650 if (ret == 0) 651 ret = (flags & MSG_DONTWAIT) ? -EAGAIN : 0; 652 if (ret == -ENOTCONN) 653 ret = 0; 654 655 mutex_unlock(&map->active.in_mutex); 656 pvcalls_exit_sock(sock); 657 return ret; 658 } 659 660 int pvcalls_front_bind(struct socket *sock, struct sockaddr *addr, int addr_len) 661 { 662 struct pvcalls_bedata *bedata; 663 struct sock_mapping *map = NULL; 664 struct xen_pvcalls_request *req; 665 int notify, req_id, ret; 666 667 if (addr->sa_family != AF_INET || sock->type != SOCK_STREAM) 668 return -EOPNOTSUPP; 669 670 map = pvcalls_enter_sock(sock); 671 if (IS_ERR(map)) 672 return PTR_ERR(map); 673 bedata = dev_get_drvdata(&pvcalls_front_dev->dev); 674 675 spin_lock(&bedata->socket_lock); 676 ret = get_request(bedata, &req_id); 677 if (ret < 0) { 678 spin_unlock(&bedata->socket_lock); 679 pvcalls_exit_sock(sock); 680 return ret; 681 } 682 req = RING_GET_REQUEST(&bedata->ring, req_id); 683 req->req_id = req_id; 684 map->sock = sock; 685 req->cmd = PVCALLS_BIND; 686 req->u.bind.id = (uintptr_t)map; 687 memcpy(req->u.bind.addr, addr, sizeof(*addr)); 688 req->u.bind.len = addr_len; 689 690 init_waitqueue_head(&map->passive.inflight_accept_req); 691 692 map->active_socket = false; 693 694 bedata->ring.req_prod_pvt++; 695 RING_PUSH_REQUESTS_AND_CHECK_NOTIFY(&bedata->ring, notify); 696 spin_unlock(&bedata->socket_lock); 697 if (notify) 698 notify_remote_via_irq(bedata->irq); 699 700 wait_event(bedata->inflight_req, 701 READ_ONCE(bedata->rsp[req_id].req_id) == req_id); 702 703 /* read req_id, then the content */ 704 smp_rmb(); 705 ret = bedata->rsp[req_id].ret; 706 bedata->rsp[req_id].req_id = PVCALLS_INVALID_ID; 707 708 map->passive.status = PVCALLS_STATUS_BIND; 709 pvcalls_exit_sock(sock); 710 return 0; 711 } 712 713 int pvcalls_front_listen(struct socket *sock, int backlog) 714 { 715 struct pvcalls_bedata *bedata; 716 struct sock_mapping *map; 717 struct xen_pvcalls_request *req; 718 int notify, req_id, ret; 719 720 map = pvcalls_enter_sock(sock); 721 if (IS_ERR(map)) 722 return PTR_ERR(map); 723 bedata = dev_get_drvdata(&pvcalls_front_dev->dev); 724 725 if (map->passive.status != PVCALLS_STATUS_BIND) { 726 pvcalls_exit_sock(sock); 727 return -EOPNOTSUPP; 728 } 729 730 spin_lock(&bedata->socket_lock); 731 ret = get_request(bedata, &req_id); 732 if (ret < 0) { 733 spin_unlock(&bedata->socket_lock); 734 pvcalls_exit_sock(sock); 735 return ret; 736 } 737 req = RING_GET_REQUEST(&bedata->ring, req_id); 738 req->req_id = req_id; 739 req->cmd = PVCALLS_LISTEN; 740 req->u.listen.id = (uintptr_t) map; 741 req->u.listen.backlog = backlog; 742 743 bedata->ring.req_prod_pvt++; 744 RING_PUSH_REQUESTS_AND_CHECK_NOTIFY(&bedata->ring, notify); 745 spin_unlock(&bedata->socket_lock); 746 if (notify) 747 notify_remote_via_irq(bedata->irq); 748 749 wait_event(bedata->inflight_req, 750 READ_ONCE(bedata->rsp[req_id].req_id) == req_id); 751 752 /* read req_id, then the content */ 753 smp_rmb(); 754 ret = bedata->rsp[req_id].ret; 755 bedata->rsp[req_id].req_id = PVCALLS_INVALID_ID; 756 757 map->passive.status = PVCALLS_STATUS_LISTEN; 758 pvcalls_exit_sock(sock); 759 return ret; 760 } 761 762 int pvcalls_front_accept(struct socket *sock, struct socket *newsock, int flags) 763 { 764 struct pvcalls_bedata *bedata; 765 struct sock_mapping *map; 766 struct sock_mapping *map2 = NULL; 767 struct xen_pvcalls_request *req; 768 int notify, req_id, ret, evtchn, nonblock; 769 770 map = pvcalls_enter_sock(sock); 771 if (IS_ERR(map)) 772 return PTR_ERR(map); 773 bedata = dev_get_drvdata(&pvcalls_front_dev->dev); 774 775 if (map->passive.status != PVCALLS_STATUS_LISTEN) { 776 pvcalls_exit_sock(sock); 777 return -EINVAL; 778 } 779 780 nonblock = flags & SOCK_NONBLOCK; 781 /* 782 * Backend only supports 1 inflight accept request, will return 783 * errors for the others 784 */ 785 if (test_and_set_bit(PVCALLS_FLAG_ACCEPT_INFLIGHT, 786 (void *)&map->passive.flags)) { 787 req_id = READ_ONCE(map->passive.inflight_req_id); 788 if (req_id != PVCALLS_INVALID_ID && 789 READ_ONCE(bedata->rsp[req_id].req_id) == req_id) { 790 map2 = map->passive.accept_map; 791 goto received; 792 } 793 if (nonblock) { 794 pvcalls_exit_sock(sock); 795 return -EAGAIN; 796 } 797 if (wait_event_interruptible(map->passive.inflight_accept_req, 798 !test_and_set_bit(PVCALLS_FLAG_ACCEPT_INFLIGHT, 799 (void *)&map->passive.flags))) { 800 pvcalls_exit_sock(sock); 801 return -EINTR; 802 } 803 } 804 805 map2 = kzalloc(sizeof(*map2), GFP_KERNEL); 806 if (map2 == NULL) { 807 clear_bit(PVCALLS_FLAG_ACCEPT_INFLIGHT, 808 (void *)&map->passive.flags); 809 pvcalls_exit_sock(sock); 810 return -ENOMEM; 811 } 812 ret = alloc_active_ring(map2); 813 if (ret < 0) { 814 clear_bit(PVCALLS_FLAG_ACCEPT_INFLIGHT, 815 (void *)&map->passive.flags); 816 kfree(map2); 817 pvcalls_exit_sock(sock); 818 return ret; 819 } 820 spin_lock(&bedata->socket_lock); 821 ret = get_request(bedata, &req_id); 822 if (ret < 0) { 823 clear_bit(PVCALLS_FLAG_ACCEPT_INFLIGHT, 824 (void *)&map->passive.flags); 825 spin_unlock(&bedata->socket_lock); 826 free_active_ring(map2); 827 kfree(map2); 828 pvcalls_exit_sock(sock); 829 return ret; 830 } 831 832 ret = create_active(map2, &evtchn); 833 if (ret < 0) { 834 free_active_ring(map2); 835 kfree(map2); 836 clear_bit(PVCALLS_FLAG_ACCEPT_INFLIGHT, 837 (void *)&map->passive.flags); 838 spin_unlock(&bedata->socket_lock); 839 pvcalls_exit_sock(sock); 840 return ret; 841 } 842 list_add_tail(&map2->list, &bedata->socket_mappings); 843 844 req = RING_GET_REQUEST(&bedata->ring, req_id); 845 req->req_id = req_id; 846 req->cmd = PVCALLS_ACCEPT; 847 req->u.accept.id = (uintptr_t) map; 848 req->u.accept.ref = map2->active.ref; 849 req->u.accept.id_new = (uintptr_t) map2; 850 req->u.accept.evtchn = evtchn; 851 map->passive.accept_map = map2; 852 853 bedata->ring.req_prod_pvt++; 854 RING_PUSH_REQUESTS_AND_CHECK_NOTIFY(&bedata->ring, notify); 855 spin_unlock(&bedata->socket_lock); 856 if (notify) 857 notify_remote_via_irq(bedata->irq); 858 /* We could check if we have received a response before returning. */ 859 if (nonblock) { 860 WRITE_ONCE(map->passive.inflight_req_id, req_id); 861 pvcalls_exit_sock(sock); 862 return -EAGAIN; 863 } 864 865 if (wait_event_interruptible(bedata->inflight_req, 866 READ_ONCE(bedata->rsp[req_id].req_id) == req_id)) { 867 pvcalls_exit_sock(sock); 868 return -EINTR; 869 } 870 /* read req_id, then the content */ 871 smp_rmb(); 872 873 received: 874 map2->sock = newsock; 875 newsock->sk = sk_alloc(sock_net(sock->sk), PF_INET, GFP_KERNEL, &pvcalls_proto, false); 876 if (!newsock->sk) { 877 bedata->rsp[req_id].req_id = PVCALLS_INVALID_ID; 878 map->passive.inflight_req_id = PVCALLS_INVALID_ID; 879 clear_bit(PVCALLS_FLAG_ACCEPT_INFLIGHT, 880 (void *)&map->passive.flags); 881 pvcalls_front_free_map(bedata, map2); 882 pvcalls_exit_sock(sock); 883 return -ENOMEM; 884 } 885 newsock->sk->sk_send_head = (void *)map2; 886 887 ret = bedata->rsp[req_id].ret; 888 bedata->rsp[req_id].req_id = PVCALLS_INVALID_ID; 889 map->passive.inflight_req_id = PVCALLS_INVALID_ID; 890 891 clear_bit(PVCALLS_FLAG_ACCEPT_INFLIGHT, (void *)&map->passive.flags); 892 wake_up(&map->passive.inflight_accept_req); 893 894 pvcalls_exit_sock(sock); 895 return ret; 896 } 897 898 static __poll_t pvcalls_front_poll_passive(struct file *file, 899 struct pvcalls_bedata *bedata, 900 struct sock_mapping *map, 901 poll_table *wait) 902 { 903 int notify, req_id, ret; 904 struct xen_pvcalls_request *req; 905 906 if (test_bit(PVCALLS_FLAG_ACCEPT_INFLIGHT, 907 (void *)&map->passive.flags)) { 908 uint32_t req_id = READ_ONCE(map->passive.inflight_req_id); 909 910 if (req_id != PVCALLS_INVALID_ID && 911 READ_ONCE(bedata->rsp[req_id].req_id) == req_id) 912 return EPOLLIN | EPOLLRDNORM; 913 914 poll_wait(file, &map->passive.inflight_accept_req, wait); 915 return 0; 916 } 917 918 if (test_and_clear_bit(PVCALLS_FLAG_POLL_RET, 919 (void *)&map->passive.flags)) 920 return EPOLLIN | EPOLLRDNORM; 921 922 /* 923 * First check RET, then INFLIGHT. No barriers necessary to 924 * ensure execution ordering because of the conditional 925 * instructions creating control dependencies. 926 */ 927 928 if (test_and_set_bit(PVCALLS_FLAG_POLL_INFLIGHT, 929 (void *)&map->passive.flags)) { 930 poll_wait(file, &bedata->inflight_req, wait); 931 return 0; 932 } 933 934 spin_lock(&bedata->socket_lock); 935 ret = get_request(bedata, &req_id); 936 if (ret < 0) { 937 spin_unlock(&bedata->socket_lock); 938 return ret; 939 } 940 req = RING_GET_REQUEST(&bedata->ring, req_id); 941 req->req_id = req_id; 942 req->cmd = PVCALLS_POLL; 943 req->u.poll.id = (uintptr_t) map; 944 945 bedata->ring.req_prod_pvt++; 946 RING_PUSH_REQUESTS_AND_CHECK_NOTIFY(&bedata->ring, notify); 947 spin_unlock(&bedata->socket_lock); 948 if (notify) 949 notify_remote_via_irq(bedata->irq); 950 951 poll_wait(file, &bedata->inflight_req, wait); 952 return 0; 953 } 954 955 static __poll_t pvcalls_front_poll_active(struct file *file, 956 struct pvcalls_bedata *bedata, 957 struct sock_mapping *map, 958 poll_table *wait) 959 { 960 __poll_t mask = 0; 961 int32_t in_error, out_error; 962 struct pvcalls_data_intf *intf = map->active.ring; 963 964 out_error = intf->out_error; 965 in_error = intf->in_error; 966 967 poll_wait(file, &map->active.inflight_conn_req, wait); 968 if (pvcalls_front_write_todo(map)) 969 mask |= EPOLLOUT | EPOLLWRNORM; 970 if (pvcalls_front_read_todo(map)) 971 mask |= EPOLLIN | EPOLLRDNORM; 972 if (in_error != 0 || out_error != 0) 973 mask |= EPOLLERR; 974 975 return mask; 976 } 977 978 __poll_t pvcalls_front_poll(struct file *file, struct socket *sock, 979 poll_table *wait) 980 { 981 struct pvcalls_bedata *bedata; 982 struct sock_mapping *map; 983 __poll_t ret; 984 985 map = pvcalls_enter_sock(sock); 986 if (IS_ERR(map)) 987 return EPOLLNVAL; 988 bedata = dev_get_drvdata(&pvcalls_front_dev->dev); 989 990 if (map->active_socket) 991 ret = pvcalls_front_poll_active(file, bedata, map, wait); 992 else 993 ret = pvcalls_front_poll_passive(file, bedata, map, wait); 994 pvcalls_exit_sock(sock); 995 return ret; 996 } 997 998 int pvcalls_front_release(struct socket *sock) 999 { 1000 struct pvcalls_bedata *bedata; 1001 struct sock_mapping *map; 1002 int req_id, notify, ret; 1003 struct xen_pvcalls_request *req; 1004 1005 if (sock->sk == NULL) 1006 return 0; 1007 1008 map = pvcalls_enter_sock(sock); 1009 if (IS_ERR(map)) { 1010 if (PTR_ERR(map) == -ENOTCONN) 1011 return -EIO; 1012 else 1013 return 0; 1014 } 1015 bedata = dev_get_drvdata(&pvcalls_front_dev->dev); 1016 1017 spin_lock(&bedata->socket_lock); 1018 ret = get_request(bedata, &req_id); 1019 if (ret < 0) { 1020 spin_unlock(&bedata->socket_lock); 1021 pvcalls_exit_sock(sock); 1022 return ret; 1023 } 1024 sock->sk->sk_send_head = NULL; 1025 1026 req = RING_GET_REQUEST(&bedata->ring, req_id); 1027 req->req_id = req_id; 1028 req->cmd = PVCALLS_RELEASE; 1029 req->u.release.id = (uintptr_t)map; 1030 1031 bedata->ring.req_prod_pvt++; 1032 RING_PUSH_REQUESTS_AND_CHECK_NOTIFY(&bedata->ring, notify); 1033 spin_unlock(&bedata->socket_lock); 1034 if (notify) 1035 notify_remote_via_irq(bedata->irq); 1036 1037 wait_event(bedata->inflight_req, 1038 READ_ONCE(bedata->rsp[req_id].req_id) == req_id); 1039 1040 if (map->active_socket) { 1041 /* 1042 * Set in_error and wake up inflight_conn_req to force 1043 * recvmsg waiters to exit. 1044 */ 1045 map->active.ring->in_error = -EBADF; 1046 wake_up_interruptible(&map->active.inflight_conn_req); 1047 1048 /* 1049 * We need to make sure that sendmsg/recvmsg on this socket have 1050 * not started before we've cleared sk_send_head here. The 1051 * easiest way to guarantee this is to see that no pvcalls 1052 * (other than us) is in progress on this socket. 1053 */ 1054 while (atomic_read(&map->refcount) > 1) 1055 cpu_relax(); 1056 1057 pvcalls_front_free_map(bedata, map); 1058 } else { 1059 wake_up(&bedata->inflight_req); 1060 wake_up(&map->passive.inflight_accept_req); 1061 1062 while (atomic_read(&map->refcount) > 1) 1063 cpu_relax(); 1064 1065 spin_lock(&bedata->socket_lock); 1066 list_del(&map->list); 1067 spin_unlock(&bedata->socket_lock); 1068 if (READ_ONCE(map->passive.inflight_req_id) != PVCALLS_INVALID_ID && 1069 READ_ONCE(map->passive.inflight_req_id) != 0) { 1070 pvcalls_front_free_map(bedata, 1071 map->passive.accept_map); 1072 } 1073 kfree(map); 1074 } 1075 WRITE_ONCE(bedata->rsp[req_id].req_id, PVCALLS_INVALID_ID); 1076 1077 pvcalls_exit(); 1078 return 0; 1079 } 1080 1081 static const struct xenbus_device_id pvcalls_front_ids[] = { 1082 { "pvcalls" }, 1083 { "" } 1084 }; 1085 1086 static int pvcalls_front_remove(struct xenbus_device *dev) 1087 { 1088 struct pvcalls_bedata *bedata; 1089 struct sock_mapping *map = NULL, *n; 1090 1091 bedata = dev_get_drvdata(&pvcalls_front_dev->dev); 1092 dev_set_drvdata(&dev->dev, NULL); 1093 pvcalls_front_dev = NULL; 1094 if (bedata->irq >= 0) 1095 unbind_from_irqhandler(bedata->irq, dev); 1096 1097 list_for_each_entry_safe(map, n, &bedata->socket_mappings, list) { 1098 map->sock->sk->sk_send_head = NULL; 1099 if (map->active_socket) { 1100 map->active.ring->in_error = -EBADF; 1101 wake_up_interruptible(&map->active.inflight_conn_req); 1102 } 1103 } 1104 1105 smp_mb(); 1106 while (atomic_read(&pvcalls_refcount) > 0) 1107 cpu_relax(); 1108 list_for_each_entry_safe(map, n, &bedata->socket_mappings, list) { 1109 if (map->active_socket) { 1110 /* No need to lock, refcount is 0 */ 1111 pvcalls_front_free_map(bedata, map); 1112 } else { 1113 list_del(&map->list); 1114 kfree(map); 1115 } 1116 } 1117 if (bedata->ref != -1) 1118 gnttab_end_foreign_access(bedata->ref, 0, 0); 1119 kfree(bedata->ring.sring); 1120 kfree(bedata); 1121 xenbus_switch_state(dev, XenbusStateClosed); 1122 return 0; 1123 } 1124 1125 static int pvcalls_front_probe(struct xenbus_device *dev, 1126 const struct xenbus_device_id *id) 1127 { 1128 int ret = -ENOMEM, evtchn, i; 1129 unsigned int max_page_order, function_calls, len; 1130 char *versions; 1131 grant_ref_t gref_head = 0; 1132 struct xenbus_transaction xbt; 1133 struct pvcalls_bedata *bedata = NULL; 1134 struct xen_pvcalls_sring *sring; 1135 1136 if (pvcalls_front_dev != NULL) { 1137 dev_err(&dev->dev, "only one PV Calls connection supported\n"); 1138 return -EINVAL; 1139 } 1140 1141 versions = xenbus_read(XBT_NIL, dev->otherend, "versions", &len); 1142 if (IS_ERR(versions)) 1143 return PTR_ERR(versions); 1144 if (!len) 1145 return -EINVAL; 1146 if (strcmp(versions, "1")) { 1147 kfree(versions); 1148 return -EINVAL; 1149 } 1150 kfree(versions); 1151 max_page_order = xenbus_read_unsigned(dev->otherend, 1152 "max-page-order", 0); 1153 if (max_page_order < PVCALLS_RING_ORDER) 1154 return -ENODEV; 1155 function_calls = xenbus_read_unsigned(dev->otherend, 1156 "function-calls", 0); 1157 /* See XENBUS_FUNCTIONS_CALLS in pvcalls.h */ 1158 if (function_calls != 1) 1159 return -ENODEV; 1160 pr_info("%s max-page-order is %u\n", __func__, max_page_order); 1161 1162 bedata = kzalloc(sizeof(struct pvcalls_bedata), GFP_KERNEL); 1163 if (!bedata) 1164 return -ENOMEM; 1165 1166 dev_set_drvdata(&dev->dev, bedata); 1167 pvcalls_front_dev = dev; 1168 init_waitqueue_head(&bedata->inflight_req); 1169 INIT_LIST_HEAD(&bedata->socket_mappings); 1170 spin_lock_init(&bedata->socket_lock); 1171 bedata->irq = -1; 1172 bedata->ref = -1; 1173 1174 for (i = 0; i < PVCALLS_NR_RSP_PER_RING; i++) 1175 bedata->rsp[i].req_id = PVCALLS_INVALID_ID; 1176 1177 sring = (struct xen_pvcalls_sring *) __get_free_page(GFP_KERNEL | 1178 __GFP_ZERO); 1179 if (!sring) 1180 goto error; 1181 SHARED_RING_INIT(sring); 1182 FRONT_RING_INIT(&bedata->ring, sring, XEN_PAGE_SIZE); 1183 1184 ret = xenbus_alloc_evtchn(dev, &evtchn); 1185 if (ret) 1186 goto error; 1187 1188 bedata->irq = bind_evtchn_to_irqhandler(evtchn, 1189 pvcalls_front_event_handler, 1190 0, "pvcalls-frontend", dev); 1191 if (bedata->irq < 0) { 1192 ret = bedata->irq; 1193 goto error; 1194 } 1195 1196 ret = gnttab_alloc_grant_references(1, &gref_head); 1197 if (ret < 0) 1198 goto error; 1199 ret = gnttab_claim_grant_reference(&gref_head); 1200 if (ret < 0) 1201 goto error; 1202 bedata->ref = ret; 1203 gnttab_grant_foreign_access_ref(bedata->ref, dev->otherend_id, 1204 virt_to_gfn((void *)sring), 0); 1205 1206 again: 1207 ret = xenbus_transaction_start(&xbt); 1208 if (ret) { 1209 xenbus_dev_fatal(dev, ret, "starting transaction"); 1210 goto error; 1211 } 1212 ret = xenbus_printf(xbt, dev->nodename, "version", "%u", 1); 1213 if (ret) 1214 goto error_xenbus; 1215 ret = xenbus_printf(xbt, dev->nodename, "ring-ref", "%d", bedata->ref); 1216 if (ret) 1217 goto error_xenbus; 1218 ret = xenbus_printf(xbt, dev->nodename, "port", "%u", 1219 evtchn); 1220 if (ret) 1221 goto error_xenbus; 1222 ret = xenbus_transaction_end(xbt, 0); 1223 if (ret) { 1224 if (ret == -EAGAIN) 1225 goto again; 1226 xenbus_dev_fatal(dev, ret, "completing transaction"); 1227 goto error; 1228 } 1229 xenbus_switch_state(dev, XenbusStateInitialised); 1230 1231 return 0; 1232 1233 error_xenbus: 1234 xenbus_transaction_end(xbt, 1); 1235 xenbus_dev_fatal(dev, ret, "writing xenstore"); 1236 error: 1237 pvcalls_front_remove(dev); 1238 return ret; 1239 } 1240 1241 static void pvcalls_front_changed(struct xenbus_device *dev, 1242 enum xenbus_state backend_state) 1243 { 1244 switch (backend_state) { 1245 case XenbusStateReconfiguring: 1246 case XenbusStateReconfigured: 1247 case XenbusStateInitialising: 1248 case XenbusStateInitialised: 1249 case XenbusStateUnknown: 1250 break; 1251 1252 case XenbusStateInitWait: 1253 break; 1254 1255 case XenbusStateConnected: 1256 xenbus_switch_state(dev, XenbusStateConnected); 1257 break; 1258 1259 case XenbusStateClosed: 1260 if (dev->state == XenbusStateClosed) 1261 break; 1262 /* Missed the backend's CLOSING state */ 1263 /* fall through */ 1264 case XenbusStateClosing: 1265 xenbus_frontend_closed(dev); 1266 break; 1267 } 1268 } 1269 1270 static struct xenbus_driver pvcalls_front_driver = { 1271 .ids = pvcalls_front_ids, 1272 .probe = pvcalls_front_probe, 1273 .remove = pvcalls_front_remove, 1274 .otherend_changed = pvcalls_front_changed, 1275 }; 1276 1277 static int __init pvcalls_frontend_init(void) 1278 { 1279 if (!xen_domain()) 1280 return -ENODEV; 1281 1282 pr_info("Initialising Xen pvcalls frontend driver\n"); 1283 1284 return xenbus_register_frontend(&pvcalls_front_driver); 1285 } 1286 1287 module_init(pvcalls_frontend_init); 1288 1289 MODULE_DESCRIPTION("Xen PV Calls frontend driver"); 1290 MODULE_AUTHOR("Stefano Stabellini <sstabellini@kernel.org>"); 1291 MODULE_LICENSE("GPL"); 1292