1 /* 2 * common code for virtio vsock 3 * 4 * Copyright (C) 2013-2015 Red Hat, Inc. 5 * Author: Asias He <asias@redhat.com> 6 * Stefan Hajnoczi <stefanha@redhat.com> 7 * 8 * This work is licensed under the terms of the GNU GPL, version 2. 9 */ 10 #include <linux/spinlock.h> 11 #include <linux/module.h> 12 #include <linux/ctype.h> 13 #include <linux/list.h> 14 #include <linux/virtio.h> 15 #include <linux/virtio_ids.h> 16 #include <linux/virtio_config.h> 17 #include <linux/virtio_vsock.h> 18 19 #include <net/sock.h> 20 #include <net/af_vsock.h> 21 22 #define CREATE_TRACE_POINTS 23 #include <trace/events/vsock_virtio_transport_common.h> 24 25 /* How long to wait for graceful shutdown of a connection */ 26 #define VSOCK_CLOSE_TIMEOUT (8 * HZ) 27 28 static const struct virtio_transport *virtio_transport_get_ops(void) 29 { 30 const struct vsock_transport *t = vsock_core_get_transport(); 31 32 return container_of(t, struct virtio_transport, transport); 33 } 34 35 static struct virtio_vsock_pkt * 36 virtio_transport_alloc_pkt(struct virtio_vsock_pkt_info *info, 37 size_t len, 38 u32 src_cid, 39 u32 src_port, 40 u32 dst_cid, 41 u32 dst_port) 42 { 43 struct virtio_vsock_pkt *pkt; 44 int err; 45 46 pkt = kzalloc(sizeof(*pkt), GFP_KERNEL); 47 if (!pkt) 48 return NULL; 49 50 pkt->hdr.type = cpu_to_le16(info->type); 51 pkt->hdr.op = cpu_to_le16(info->op); 52 pkt->hdr.src_cid = cpu_to_le64(src_cid); 53 pkt->hdr.dst_cid = cpu_to_le64(dst_cid); 54 pkt->hdr.src_port = cpu_to_le32(src_port); 55 pkt->hdr.dst_port = cpu_to_le32(dst_port); 56 pkt->hdr.flags = cpu_to_le32(info->flags); 57 pkt->len = len; 58 pkt->hdr.len = cpu_to_le32(len); 59 pkt->reply = info->reply; 60 61 if (info->msg && len > 0) { 62 pkt->buf = kmalloc(len, GFP_KERNEL); 63 if (!pkt->buf) 64 goto out_pkt; 65 err = memcpy_from_msg(pkt->buf, info->msg, len); 66 if (err) 67 goto out; 68 } 69 70 trace_virtio_transport_alloc_pkt(src_cid, src_port, 71 dst_cid, dst_port, 72 len, 73 info->type, 74 info->op, 75 info->flags); 76 77 return pkt; 78 79 out: 80 kfree(pkt->buf); 81 out_pkt: 82 kfree(pkt); 83 return NULL; 84 } 85 86 static int virtio_transport_send_pkt_info(struct vsock_sock *vsk, 87 struct virtio_vsock_pkt_info *info) 88 { 89 u32 src_cid, src_port, dst_cid, dst_port; 90 struct virtio_vsock_sock *vvs; 91 struct virtio_vsock_pkt *pkt; 92 u32 pkt_len = info->pkt_len; 93 94 src_cid = vm_sockets_get_local_cid(); 95 src_port = vsk->local_addr.svm_port; 96 if (!info->remote_cid) { 97 dst_cid = vsk->remote_addr.svm_cid; 98 dst_port = vsk->remote_addr.svm_port; 99 } else { 100 dst_cid = info->remote_cid; 101 dst_port = info->remote_port; 102 } 103 104 vvs = vsk->trans; 105 106 /* we can send less than pkt_len bytes */ 107 if (pkt_len > VIRTIO_VSOCK_DEFAULT_RX_BUF_SIZE) 108 pkt_len = VIRTIO_VSOCK_DEFAULT_RX_BUF_SIZE; 109 110 /* virtio_transport_get_credit might return less than pkt_len credit */ 111 pkt_len = virtio_transport_get_credit(vvs, pkt_len); 112 113 /* Do not send zero length OP_RW pkt */ 114 if (pkt_len == 0 && info->op == VIRTIO_VSOCK_OP_RW) 115 return pkt_len; 116 117 pkt = virtio_transport_alloc_pkt(info, pkt_len, 118 src_cid, src_port, 119 dst_cid, dst_port); 120 if (!pkt) { 121 virtio_transport_put_credit(vvs, pkt_len); 122 return -ENOMEM; 123 } 124 125 virtio_transport_inc_tx_pkt(vvs, pkt); 126 127 return virtio_transport_get_ops()->send_pkt(pkt); 128 } 129 130 static void virtio_transport_inc_rx_pkt(struct virtio_vsock_sock *vvs, 131 struct virtio_vsock_pkt *pkt) 132 { 133 vvs->rx_bytes += pkt->len; 134 } 135 136 static void virtio_transport_dec_rx_pkt(struct virtio_vsock_sock *vvs, 137 struct virtio_vsock_pkt *pkt) 138 { 139 vvs->rx_bytes -= pkt->len; 140 vvs->fwd_cnt += pkt->len; 141 } 142 143 void virtio_transport_inc_tx_pkt(struct virtio_vsock_sock *vvs, struct virtio_vsock_pkt *pkt) 144 { 145 spin_lock_bh(&vvs->tx_lock); 146 pkt->hdr.fwd_cnt = cpu_to_le32(vvs->fwd_cnt); 147 pkt->hdr.buf_alloc = cpu_to_le32(vvs->buf_alloc); 148 spin_unlock_bh(&vvs->tx_lock); 149 } 150 EXPORT_SYMBOL_GPL(virtio_transport_inc_tx_pkt); 151 152 u32 virtio_transport_get_credit(struct virtio_vsock_sock *vvs, u32 credit) 153 { 154 u32 ret; 155 156 spin_lock_bh(&vvs->tx_lock); 157 ret = vvs->peer_buf_alloc - (vvs->tx_cnt - vvs->peer_fwd_cnt); 158 if (ret > credit) 159 ret = credit; 160 vvs->tx_cnt += ret; 161 spin_unlock_bh(&vvs->tx_lock); 162 163 return ret; 164 } 165 EXPORT_SYMBOL_GPL(virtio_transport_get_credit); 166 167 void virtio_transport_put_credit(struct virtio_vsock_sock *vvs, u32 credit) 168 { 169 spin_lock_bh(&vvs->tx_lock); 170 vvs->tx_cnt -= credit; 171 spin_unlock_bh(&vvs->tx_lock); 172 } 173 EXPORT_SYMBOL_GPL(virtio_transport_put_credit); 174 175 static int virtio_transport_send_credit_update(struct vsock_sock *vsk, 176 int type, 177 struct virtio_vsock_hdr *hdr) 178 { 179 struct virtio_vsock_pkt_info info = { 180 .op = VIRTIO_VSOCK_OP_CREDIT_UPDATE, 181 .type = type, 182 }; 183 184 return virtio_transport_send_pkt_info(vsk, &info); 185 } 186 187 static ssize_t 188 virtio_transport_stream_do_dequeue(struct vsock_sock *vsk, 189 struct msghdr *msg, 190 size_t len) 191 { 192 struct virtio_vsock_sock *vvs = vsk->trans; 193 struct virtio_vsock_pkt *pkt; 194 size_t bytes, total = 0; 195 int err = -EFAULT; 196 197 spin_lock_bh(&vvs->rx_lock); 198 while (total < len && !list_empty(&vvs->rx_queue)) { 199 pkt = list_first_entry(&vvs->rx_queue, 200 struct virtio_vsock_pkt, list); 201 202 bytes = len - total; 203 if (bytes > pkt->len - pkt->off) 204 bytes = pkt->len - pkt->off; 205 206 /* sk_lock is held by caller so no one else can dequeue. 207 * Unlock rx_lock since memcpy_to_msg() may sleep. 208 */ 209 spin_unlock_bh(&vvs->rx_lock); 210 211 err = memcpy_to_msg(msg, pkt->buf + pkt->off, bytes); 212 if (err) 213 goto out; 214 215 spin_lock_bh(&vvs->rx_lock); 216 217 total += bytes; 218 pkt->off += bytes; 219 if (pkt->off == pkt->len) { 220 virtio_transport_dec_rx_pkt(vvs, pkt); 221 list_del(&pkt->list); 222 virtio_transport_free_pkt(pkt); 223 } 224 } 225 spin_unlock_bh(&vvs->rx_lock); 226 227 /* Send a credit pkt to peer */ 228 virtio_transport_send_credit_update(vsk, VIRTIO_VSOCK_TYPE_STREAM, 229 NULL); 230 231 return total; 232 233 out: 234 if (total) 235 err = total; 236 return err; 237 } 238 239 ssize_t 240 virtio_transport_stream_dequeue(struct vsock_sock *vsk, 241 struct msghdr *msg, 242 size_t len, int flags) 243 { 244 if (flags & MSG_PEEK) 245 return -EOPNOTSUPP; 246 247 return virtio_transport_stream_do_dequeue(vsk, msg, len); 248 } 249 EXPORT_SYMBOL_GPL(virtio_transport_stream_dequeue); 250 251 int 252 virtio_transport_dgram_dequeue(struct vsock_sock *vsk, 253 struct msghdr *msg, 254 size_t len, int flags) 255 { 256 return -EOPNOTSUPP; 257 } 258 EXPORT_SYMBOL_GPL(virtio_transport_dgram_dequeue); 259 260 s64 virtio_transport_stream_has_data(struct vsock_sock *vsk) 261 { 262 struct virtio_vsock_sock *vvs = vsk->trans; 263 s64 bytes; 264 265 spin_lock_bh(&vvs->rx_lock); 266 bytes = vvs->rx_bytes; 267 spin_unlock_bh(&vvs->rx_lock); 268 269 return bytes; 270 } 271 EXPORT_SYMBOL_GPL(virtio_transport_stream_has_data); 272 273 static s64 virtio_transport_has_space(struct vsock_sock *vsk) 274 { 275 struct virtio_vsock_sock *vvs = vsk->trans; 276 s64 bytes; 277 278 bytes = vvs->peer_buf_alloc - (vvs->tx_cnt - vvs->peer_fwd_cnt); 279 if (bytes < 0) 280 bytes = 0; 281 282 return bytes; 283 } 284 285 s64 virtio_transport_stream_has_space(struct vsock_sock *vsk) 286 { 287 struct virtio_vsock_sock *vvs = vsk->trans; 288 s64 bytes; 289 290 spin_lock_bh(&vvs->tx_lock); 291 bytes = virtio_transport_has_space(vsk); 292 spin_unlock_bh(&vvs->tx_lock); 293 294 return bytes; 295 } 296 EXPORT_SYMBOL_GPL(virtio_transport_stream_has_space); 297 298 int virtio_transport_do_socket_init(struct vsock_sock *vsk, 299 struct vsock_sock *psk) 300 { 301 struct virtio_vsock_sock *vvs; 302 303 vvs = kzalloc(sizeof(*vvs), GFP_KERNEL); 304 if (!vvs) 305 return -ENOMEM; 306 307 vsk->trans = vvs; 308 vvs->vsk = vsk; 309 if (psk) { 310 struct virtio_vsock_sock *ptrans = psk->trans; 311 312 vvs->buf_size = ptrans->buf_size; 313 vvs->buf_size_min = ptrans->buf_size_min; 314 vvs->buf_size_max = ptrans->buf_size_max; 315 vvs->peer_buf_alloc = ptrans->peer_buf_alloc; 316 } else { 317 vvs->buf_size = VIRTIO_VSOCK_DEFAULT_BUF_SIZE; 318 vvs->buf_size_min = VIRTIO_VSOCK_DEFAULT_MIN_BUF_SIZE; 319 vvs->buf_size_max = VIRTIO_VSOCK_DEFAULT_MAX_BUF_SIZE; 320 } 321 322 vvs->buf_alloc = vvs->buf_size; 323 324 spin_lock_init(&vvs->rx_lock); 325 spin_lock_init(&vvs->tx_lock); 326 INIT_LIST_HEAD(&vvs->rx_queue); 327 328 return 0; 329 } 330 EXPORT_SYMBOL_GPL(virtio_transport_do_socket_init); 331 332 u64 virtio_transport_get_buffer_size(struct vsock_sock *vsk) 333 { 334 struct virtio_vsock_sock *vvs = vsk->trans; 335 336 return vvs->buf_size; 337 } 338 EXPORT_SYMBOL_GPL(virtio_transport_get_buffer_size); 339 340 u64 virtio_transport_get_min_buffer_size(struct vsock_sock *vsk) 341 { 342 struct virtio_vsock_sock *vvs = vsk->trans; 343 344 return vvs->buf_size_min; 345 } 346 EXPORT_SYMBOL_GPL(virtio_transport_get_min_buffer_size); 347 348 u64 virtio_transport_get_max_buffer_size(struct vsock_sock *vsk) 349 { 350 struct virtio_vsock_sock *vvs = vsk->trans; 351 352 return vvs->buf_size_max; 353 } 354 EXPORT_SYMBOL_GPL(virtio_transport_get_max_buffer_size); 355 356 void virtio_transport_set_buffer_size(struct vsock_sock *vsk, u64 val) 357 { 358 struct virtio_vsock_sock *vvs = vsk->trans; 359 360 if (val > VIRTIO_VSOCK_MAX_BUF_SIZE) 361 val = VIRTIO_VSOCK_MAX_BUF_SIZE; 362 if (val < vvs->buf_size_min) 363 vvs->buf_size_min = val; 364 if (val > vvs->buf_size_max) 365 vvs->buf_size_max = val; 366 vvs->buf_size = val; 367 vvs->buf_alloc = val; 368 } 369 EXPORT_SYMBOL_GPL(virtio_transport_set_buffer_size); 370 371 void virtio_transport_set_min_buffer_size(struct vsock_sock *vsk, u64 val) 372 { 373 struct virtio_vsock_sock *vvs = vsk->trans; 374 375 if (val > VIRTIO_VSOCK_MAX_BUF_SIZE) 376 val = VIRTIO_VSOCK_MAX_BUF_SIZE; 377 if (val > vvs->buf_size) 378 vvs->buf_size = val; 379 vvs->buf_size_min = val; 380 } 381 EXPORT_SYMBOL_GPL(virtio_transport_set_min_buffer_size); 382 383 void virtio_transport_set_max_buffer_size(struct vsock_sock *vsk, u64 val) 384 { 385 struct virtio_vsock_sock *vvs = vsk->trans; 386 387 if (val > VIRTIO_VSOCK_MAX_BUF_SIZE) 388 val = VIRTIO_VSOCK_MAX_BUF_SIZE; 389 if (val < vvs->buf_size) 390 vvs->buf_size = val; 391 vvs->buf_size_max = val; 392 } 393 EXPORT_SYMBOL_GPL(virtio_transport_set_max_buffer_size); 394 395 int 396 virtio_transport_notify_poll_in(struct vsock_sock *vsk, 397 size_t target, 398 bool *data_ready_now) 399 { 400 if (vsock_stream_has_data(vsk)) 401 *data_ready_now = true; 402 else 403 *data_ready_now = false; 404 405 return 0; 406 } 407 EXPORT_SYMBOL_GPL(virtio_transport_notify_poll_in); 408 409 int 410 virtio_transport_notify_poll_out(struct vsock_sock *vsk, 411 size_t target, 412 bool *space_avail_now) 413 { 414 s64 free_space; 415 416 free_space = vsock_stream_has_space(vsk); 417 if (free_space > 0) 418 *space_avail_now = true; 419 else if (free_space == 0) 420 *space_avail_now = false; 421 422 return 0; 423 } 424 EXPORT_SYMBOL_GPL(virtio_transport_notify_poll_out); 425 426 int virtio_transport_notify_recv_init(struct vsock_sock *vsk, 427 size_t target, struct vsock_transport_recv_notify_data *data) 428 { 429 return 0; 430 } 431 EXPORT_SYMBOL_GPL(virtio_transport_notify_recv_init); 432 433 int virtio_transport_notify_recv_pre_block(struct vsock_sock *vsk, 434 size_t target, struct vsock_transport_recv_notify_data *data) 435 { 436 return 0; 437 } 438 EXPORT_SYMBOL_GPL(virtio_transport_notify_recv_pre_block); 439 440 int virtio_transport_notify_recv_pre_dequeue(struct vsock_sock *vsk, 441 size_t target, struct vsock_transport_recv_notify_data *data) 442 { 443 return 0; 444 } 445 EXPORT_SYMBOL_GPL(virtio_transport_notify_recv_pre_dequeue); 446 447 int virtio_transport_notify_recv_post_dequeue(struct vsock_sock *vsk, 448 size_t target, ssize_t copied, bool data_read, 449 struct vsock_transport_recv_notify_data *data) 450 { 451 return 0; 452 } 453 EXPORT_SYMBOL_GPL(virtio_transport_notify_recv_post_dequeue); 454 455 int virtio_transport_notify_send_init(struct vsock_sock *vsk, 456 struct vsock_transport_send_notify_data *data) 457 { 458 return 0; 459 } 460 EXPORT_SYMBOL_GPL(virtio_transport_notify_send_init); 461 462 int virtio_transport_notify_send_pre_block(struct vsock_sock *vsk, 463 struct vsock_transport_send_notify_data *data) 464 { 465 return 0; 466 } 467 EXPORT_SYMBOL_GPL(virtio_transport_notify_send_pre_block); 468 469 int virtio_transport_notify_send_pre_enqueue(struct vsock_sock *vsk, 470 struct vsock_transport_send_notify_data *data) 471 { 472 return 0; 473 } 474 EXPORT_SYMBOL_GPL(virtio_transport_notify_send_pre_enqueue); 475 476 int virtio_transport_notify_send_post_enqueue(struct vsock_sock *vsk, 477 ssize_t written, struct vsock_transport_send_notify_data *data) 478 { 479 return 0; 480 } 481 EXPORT_SYMBOL_GPL(virtio_transport_notify_send_post_enqueue); 482 483 u64 virtio_transport_stream_rcvhiwat(struct vsock_sock *vsk) 484 { 485 struct virtio_vsock_sock *vvs = vsk->trans; 486 487 return vvs->buf_size; 488 } 489 EXPORT_SYMBOL_GPL(virtio_transport_stream_rcvhiwat); 490 491 bool virtio_transport_stream_is_active(struct vsock_sock *vsk) 492 { 493 return true; 494 } 495 EXPORT_SYMBOL_GPL(virtio_transport_stream_is_active); 496 497 bool virtio_transport_stream_allow(u32 cid, u32 port) 498 { 499 return true; 500 } 501 EXPORT_SYMBOL_GPL(virtio_transport_stream_allow); 502 503 int virtio_transport_dgram_bind(struct vsock_sock *vsk, 504 struct sockaddr_vm *addr) 505 { 506 return -EOPNOTSUPP; 507 } 508 EXPORT_SYMBOL_GPL(virtio_transport_dgram_bind); 509 510 bool virtio_transport_dgram_allow(u32 cid, u32 port) 511 { 512 return false; 513 } 514 EXPORT_SYMBOL_GPL(virtio_transport_dgram_allow); 515 516 int virtio_transport_connect(struct vsock_sock *vsk) 517 { 518 struct virtio_vsock_pkt_info info = { 519 .op = VIRTIO_VSOCK_OP_REQUEST, 520 .type = VIRTIO_VSOCK_TYPE_STREAM, 521 }; 522 523 return virtio_transport_send_pkt_info(vsk, &info); 524 } 525 EXPORT_SYMBOL_GPL(virtio_transport_connect); 526 527 int virtio_transport_shutdown(struct vsock_sock *vsk, int mode) 528 { 529 struct virtio_vsock_pkt_info info = { 530 .op = VIRTIO_VSOCK_OP_SHUTDOWN, 531 .type = VIRTIO_VSOCK_TYPE_STREAM, 532 .flags = (mode & RCV_SHUTDOWN ? 533 VIRTIO_VSOCK_SHUTDOWN_RCV : 0) | 534 (mode & SEND_SHUTDOWN ? 535 VIRTIO_VSOCK_SHUTDOWN_SEND : 0), 536 }; 537 538 return virtio_transport_send_pkt_info(vsk, &info); 539 } 540 EXPORT_SYMBOL_GPL(virtio_transport_shutdown); 541 542 int 543 virtio_transport_dgram_enqueue(struct vsock_sock *vsk, 544 struct sockaddr_vm *remote_addr, 545 struct msghdr *msg, 546 size_t dgram_len) 547 { 548 return -EOPNOTSUPP; 549 } 550 EXPORT_SYMBOL_GPL(virtio_transport_dgram_enqueue); 551 552 ssize_t 553 virtio_transport_stream_enqueue(struct vsock_sock *vsk, 554 struct msghdr *msg, 555 size_t len) 556 { 557 struct virtio_vsock_pkt_info info = { 558 .op = VIRTIO_VSOCK_OP_RW, 559 .type = VIRTIO_VSOCK_TYPE_STREAM, 560 .msg = msg, 561 .pkt_len = len, 562 }; 563 564 return virtio_transport_send_pkt_info(vsk, &info); 565 } 566 EXPORT_SYMBOL_GPL(virtio_transport_stream_enqueue); 567 568 void virtio_transport_destruct(struct vsock_sock *vsk) 569 { 570 struct virtio_vsock_sock *vvs = vsk->trans; 571 572 kfree(vvs); 573 } 574 EXPORT_SYMBOL_GPL(virtio_transport_destruct); 575 576 static int virtio_transport_reset(struct vsock_sock *vsk, 577 struct virtio_vsock_pkt *pkt) 578 { 579 struct virtio_vsock_pkt_info info = { 580 .op = VIRTIO_VSOCK_OP_RST, 581 .type = VIRTIO_VSOCK_TYPE_STREAM, 582 .reply = !!pkt, 583 }; 584 585 /* Send RST only if the original pkt is not a RST pkt */ 586 if (pkt && le16_to_cpu(pkt->hdr.op) == VIRTIO_VSOCK_OP_RST) 587 return 0; 588 589 return virtio_transport_send_pkt_info(vsk, &info); 590 } 591 592 /* Normally packets are associated with a socket. There may be no socket if an 593 * attempt was made to connect to a socket that does not exist. 594 */ 595 static int virtio_transport_reset_no_sock(struct virtio_vsock_pkt *pkt) 596 { 597 struct virtio_vsock_pkt_info info = { 598 .op = VIRTIO_VSOCK_OP_RST, 599 .type = le16_to_cpu(pkt->hdr.type), 600 .reply = true, 601 }; 602 603 /* Send RST only if the original pkt is not a RST pkt */ 604 if (le16_to_cpu(pkt->hdr.op) == VIRTIO_VSOCK_OP_RST) 605 return 0; 606 607 pkt = virtio_transport_alloc_pkt(&info, 0, 608 le64_to_cpu(pkt->hdr.dst_cid), 609 le32_to_cpu(pkt->hdr.dst_port), 610 le64_to_cpu(pkt->hdr.src_cid), 611 le32_to_cpu(pkt->hdr.src_port)); 612 if (!pkt) 613 return -ENOMEM; 614 615 return virtio_transport_get_ops()->send_pkt(pkt); 616 } 617 618 static void virtio_transport_wait_close(struct sock *sk, long timeout) 619 { 620 if (timeout) { 621 DEFINE_WAIT_FUNC(wait, woken_wake_function); 622 623 add_wait_queue(sk_sleep(sk), &wait); 624 625 do { 626 if (sk_wait_event(sk, &timeout, 627 sock_flag(sk, SOCK_DONE), &wait)) 628 break; 629 } while (!signal_pending(current) && timeout); 630 631 remove_wait_queue(sk_sleep(sk), &wait); 632 } 633 } 634 635 static void virtio_transport_do_close(struct vsock_sock *vsk, 636 bool cancel_timeout) 637 { 638 struct sock *sk = sk_vsock(vsk); 639 640 sock_set_flag(sk, SOCK_DONE); 641 vsk->peer_shutdown = SHUTDOWN_MASK; 642 if (vsock_stream_has_data(vsk) <= 0) 643 sk->sk_state = SS_DISCONNECTING; 644 sk->sk_state_change(sk); 645 646 if (vsk->close_work_scheduled && 647 (!cancel_timeout || cancel_delayed_work(&vsk->close_work))) { 648 vsk->close_work_scheduled = false; 649 650 vsock_remove_sock(vsk); 651 652 /* Release refcnt obtained when we scheduled the timeout */ 653 sock_put(sk); 654 } 655 } 656 657 static void virtio_transport_close_timeout(struct work_struct *work) 658 { 659 struct vsock_sock *vsk = 660 container_of(work, struct vsock_sock, close_work.work); 661 struct sock *sk = sk_vsock(vsk); 662 663 sock_hold(sk); 664 lock_sock(sk); 665 666 if (!sock_flag(sk, SOCK_DONE)) { 667 (void)virtio_transport_reset(vsk, NULL); 668 669 virtio_transport_do_close(vsk, false); 670 } 671 672 vsk->close_work_scheduled = false; 673 674 release_sock(sk); 675 sock_put(sk); 676 } 677 678 /* User context, vsk->sk is locked */ 679 static bool virtio_transport_close(struct vsock_sock *vsk) 680 { 681 struct sock *sk = &vsk->sk; 682 683 if (!(sk->sk_state == SS_CONNECTED || 684 sk->sk_state == SS_DISCONNECTING)) 685 return true; 686 687 /* Already received SHUTDOWN from peer, reply with RST */ 688 if ((vsk->peer_shutdown & SHUTDOWN_MASK) == SHUTDOWN_MASK) { 689 (void)virtio_transport_reset(vsk, NULL); 690 return true; 691 } 692 693 if ((sk->sk_shutdown & SHUTDOWN_MASK) != SHUTDOWN_MASK) 694 (void)virtio_transport_shutdown(vsk, SHUTDOWN_MASK); 695 696 if (sock_flag(sk, SOCK_LINGER) && !(current->flags & PF_EXITING)) 697 virtio_transport_wait_close(sk, sk->sk_lingertime); 698 699 if (sock_flag(sk, SOCK_DONE)) { 700 return true; 701 } 702 703 sock_hold(sk); 704 INIT_DELAYED_WORK(&vsk->close_work, 705 virtio_transport_close_timeout); 706 vsk->close_work_scheduled = true; 707 schedule_delayed_work(&vsk->close_work, VSOCK_CLOSE_TIMEOUT); 708 return false; 709 } 710 711 void virtio_transport_release(struct vsock_sock *vsk) 712 { 713 struct sock *sk = &vsk->sk; 714 bool remove_sock = true; 715 716 lock_sock(sk); 717 if (sk->sk_type == SOCK_STREAM) 718 remove_sock = virtio_transport_close(vsk); 719 release_sock(sk); 720 721 if (remove_sock) 722 vsock_remove_sock(vsk); 723 } 724 EXPORT_SYMBOL_GPL(virtio_transport_release); 725 726 static int 727 virtio_transport_recv_connecting(struct sock *sk, 728 struct virtio_vsock_pkt *pkt) 729 { 730 struct vsock_sock *vsk = vsock_sk(sk); 731 int err; 732 int skerr; 733 734 switch (le16_to_cpu(pkt->hdr.op)) { 735 case VIRTIO_VSOCK_OP_RESPONSE: 736 sk->sk_state = SS_CONNECTED; 737 sk->sk_socket->state = SS_CONNECTED; 738 vsock_insert_connected(vsk); 739 sk->sk_state_change(sk); 740 break; 741 case VIRTIO_VSOCK_OP_INVALID: 742 break; 743 case VIRTIO_VSOCK_OP_RST: 744 skerr = ECONNRESET; 745 err = 0; 746 goto destroy; 747 default: 748 skerr = EPROTO; 749 err = -EINVAL; 750 goto destroy; 751 } 752 return 0; 753 754 destroy: 755 virtio_transport_reset(vsk, pkt); 756 sk->sk_state = SS_UNCONNECTED; 757 sk->sk_err = skerr; 758 sk->sk_error_report(sk); 759 return err; 760 } 761 762 static int 763 virtio_transport_recv_connected(struct sock *sk, 764 struct virtio_vsock_pkt *pkt) 765 { 766 struct vsock_sock *vsk = vsock_sk(sk); 767 struct virtio_vsock_sock *vvs = vsk->trans; 768 int err = 0; 769 770 switch (le16_to_cpu(pkt->hdr.op)) { 771 case VIRTIO_VSOCK_OP_RW: 772 pkt->len = le32_to_cpu(pkt->hdr.len); 773 pkt->off = 0; 774 775 spin_lock_bh(&vvs->rx_lock); 776 virtio_transport_inc_rx_pkt(vvs, pkt); 777 list_add_tail(&pkt->list, &vvs->rx_queue); 778 spin_unlock_bh(&vvs->rx_lock); 779 780 sk->sk_data_ready(sk); 781 return err; 782 case VIRTIO_VSOCK_OP_CREDIT_UPDATE: 783 sk->sk_write_space(sk); 784 break; 785 case VIRTIO_VSOCK_OP_SHUTDOWN: 786 if (le32_to_cpu(pkt->hdr.flags) & VIRTIO_VSOCK_SHUTDOWN_RCV) 787 vsk->peer_shutdown |= RCV_SHUTDOWN; 788 if (le32_to_cpu(pkt->hdr.flags) & VIRTIO_VSOCK_SHUTDOWN_SEND) 789 vsk->peer_shutdown |= SEND_SHUTDOWN; 790 if (vsk->peer_shutdown == SHUTDOWN_MASK && 791 vsock_stream_has_data(vsk) <= 0) 792 sk->sk_state = SS_DISCONNECTING; 793 if (le32_to_cpu(pkt->hdr.flags)) 794 sk->sk_state_change(sk); 795 break; 796 case VIRTIO_VSOCK_OP_RST: 797 virtio_transport_do_close(vsk, true); 798 break; 799 default: 800 err = -EINVAL; 801 break; 802 } 803 804 virtio_transport_free_pkt(pkt); 805 return err; 806 } 807 808 static void 809 virtio_transport_recv_disconnecting(struct sock *sk, 810 struct virtio_vsock_pkt *pkt) 811 { 812 struct vsock_sock *vsk = vsock_sk(sk); 813 814 if (le16_to_cpu(pkt->hdr.op) == VIRTIO_VSOCK_OP_RST) 815 virtio_transport_do_close(vsk, true); 816 } 817 818 static int 819 virtio_transport_send_response(struct vsock_sock *vsk, 820 struct virtio_vsock_pkt *pkt) 821 { 822 struct virtio_vsock_pkt_info info = { 823 .op = VIRTIO_VSOCK_OP_RESPONSE, 824 .type = VIRTIO_VSOCK_TYPE_STREAM, 825 .remote_cid = le64_to_cpu(pkt->hdr.src_cid), 826 .remote_port = le32_to_cpu(pkt->hdr.src_port), 827 .reply = true, 828 }; 829 830 return virtio_transport_send_pkt_info(vsk, &info); 831 } 832 833 /* Handle server socket */ 834 static int 835 virtio_transport_recv_listen(struct sock *sk, struct virtio_vsock_pkt *pkt) 836 { 837 struct vsock_sock *vsk = vsock_sk(sk); 838 struct vsock_sock *vchild; 839 struct sock *child; 840 841 if (le16_to_cpu(pkt->hdr.op) != VIRTIO_VSOCK_OP_REQUEST) { 842 virtio_transport_reset(vsk, pkt); 843 return -EINVAL; 844 } 845 846 if (sk_acceptq_is_full(sk)) { 847 virtio_transport_reset(vsk, pkt); 848 return -ENOMEM; 849 } 850 851 child = __vsock_create(sock_net(sk), NULL, sk, GFP_KERNEL, 852 sk->sk_type, 0); 853 if (!child) { 854 virtio_transport_reset(vsk, pkt); 855 return -ENOMEM; 856 } 857 858 sk->sk_ack_backlog++; 859 860 lock_sock_nested(child, SINGLE_DEPTH_NESTING); 861 862 child->sk_state = SS_CONNECTED; 863 864 vchild = vsock_sk(child); 865 vsock_addr_init(&vchild->local_addr, le64_to_cpu(pkt->hdr.dst_cid), 866 le32_to_cpu(pkt->hdr.dst_port)); 867 vsock_addr_init(&vchild->remote_addr, le64_to_cpu(pkt->hdr.src_cid), 868 le32_to_cpu(pkt->hdr.src_port)); 869 870 vsock_insert_connected(vchild); 871 vsock_enqueue_accept(sk, child); 872 virtio_transport_send_response(vchild, pkt); 873 874 release_sock(child); 875 876 sk->sk_data_ready(sk); 877 return 0; 878 } 879 880 static bool virtio_transport_space_update(struct sock *sk, 881 struct virtio_vsock_pkt *pkt) 882 { 883 struct vsock_sock *vsk = vsock_sk(sk); 884 struct virtio_vsock_sock *vvs = vsk->trans; 885 bool space_available; 886 887 /* buf_alloc and fwd_cnt is always included in the hdr */ 888 spin_lock_bh(&vvs->tx_lock); 889 vvs->peer_buf_alloc = le32_to_cpu(pkt->hdr.buf_alloc); 890 vvs->peer_fwd_cnt = le32_to_cpu(pkt->hdr.fwd_cnt); 891 space_available = virtio_transport_has_space(vsk); 892 spin_unlock_bh(&vvs->tx_lock); 893 return space_available; 894 } 895 896 /* We are under the virtio-vsock's vsock->rx_lock or vhost-vsock's vq->mutex 897 * lock. 898 */ 899 void virtio_transport_recv_pkt(struct virtio_vsock_pkt *pkt) 900 { 901 struct sockaddr_vm src, dst; 902 struct vsock_sock *vsk; 903 struct sock *sk; 904 bool space_available; 905 906 vsock_addr_init(&src, le64_to_cpu(pkt->hdr.src_cid), 907 le32_to_cpu(pkt->hdr.src_port)); 908 vsock_addr_init(&dst, le64_to_cpu(pkt->hdr.dst_cid), 909 le32_to_cpu(pkt->hdr.dst_port)); 910 911 trace_virtio_transport_recv_pkt(src.svm_cid, src.svm_port, 912 dst.svm_cid, dst.svm_port, 913 le32_to_cpu(pkt->hdr.len), 914 le16_to_cpu(pkt->hdr.type), 915 le16_to_cpu(pkt->hdr.op), 916 le32_to_cpu(pkt->hdr.flags), 917 le32_to_cpu(pkt->hdr.buf_alloc), 918 le32_to_cpu(pkt->hdr.fwd_cnt)); 919 920 if (le16_to_cpu(pkt->hdr.type) != VIRTIO_VSOCK_TYPE_STREAM) { 921 (void)virtio_transport_reset_no_sock(pkt); 922 goto free_pkt; 923 } 924 925 /* The socket must be in connected or bound table 926 * otherwise send reset back 927 */ 928 sk = vsock_find_connected_socket(&src, &dst); 929 if (!sk) { 930 sk = vsock_find_bound_socket(&dst); 931 if (!sk) { 932 (void)virtio_transport_reset_no_sock(pkt); 933 goto free_pkt; 934 } 935 } 936 937 vsk = vsock_sk(sk); 938 939 space_available = virtio_transport_space_update(sk, pkt); 940 941 lock_sock(sk); 942 943 /* Update CID in case it has changed after a transport reset event */ 944 vsk->local_addr.svm_cid = dst.svm_cid; 945 946 if (space_available) 947 sk->sk_write_space(sk); 948 949 switch (sk->sk_state) { 950 case VSOCK_SS_LISTEN: 951 virtio_transport_recv_listen(sk, pkt); 952 virtio_transport_free_pkt(pkt); 953 break; 954 case SS_CONNECTING: 955 virtio_transport_recv_connecting(sk, pkt); 956 virtio_transport_free_pkt(pkt); 957 break; 958 case SS_CONNECTED: 959 virtio_transport_recv_connected(sk, pkt); 960 break; 961 case SS_DISCONNECTING: 962 virtio_transport_recv_disconnecting(sk, pkt); 963 virtio_transport_free_pkt(pkt); 964 break; 965 default: 966 virtio_transport_free_pkt(pkt); 967 break; 968 } 969 release_sock(sk); 970 971 /* Release refcnt obtained when we fetched this socket out of the 972 * bound or connected list. 973 */ 974 sock_put(sk); 975 return; 976 977 free_pkt: 978 virtio_transport_free_pkt(pkt); 979 } 980 EXPORT_SYMBOL_GPL(virtio_transport_recv_pkt); 981 982 void virtio_transport_free_pkt(struct virtio_vsock_pkt *pkt) 983 { 984 kfree(pkt->buf); 985 kfree(pkt); 986 } 987 EXPORT_SYMBOL_GPL(virtio_transport_free_pkt); 988 989 MODULE_LICENSE("GPL v2"); 990 MODULE_AUTHOR("Asias He"); 991 MODULE_DESCRIPTION("common code for virtio vsock"); 992