1 /* 2 * common code for virtio vsock 3 * 4 * Copyright (C) 2013-2015 Red Hat, Inc. 5 * Author: Asias He <asias@redhat.com> 6 * Stefan Hajnoczi <stefanha@redhat.com> 7 * 8 * This work is licensed under the terms of the GNU GPL, version 2. 9 */ 10 #include <linux/spinlock.h> 11 #include <linux/module.h> 12 #include <linux/ctype.h> 13 #include <linux/list.h> 14 #include <linux/virtio.h> 15 #include <linux/virtio_ids.h> 16 #include <linux/virtio_config.h> 17 #include <linux/virtio_vsock.h> 18 19 #include <net/sock.h> 20 #include <net/af_vsock.h> 21 22 #define CREATE_TRACE_POINTS 23 #include <trace/events/vsock_virtio_transport_common.h> 24 25 /* How long to wait for graceful shutdown of a connection */ 26 #define VSOCK_CLOSE_TIMEOUT (8 * HZ) 27 28 static const struct virtio_transport *virtio_transport_get_ops(void) 29 { 30 const struct vsock_transport *t = vsock_core_get_transport(); 31 32 return container_of(t, struct virtio_transport, transport); 33 } 34 35 struct virtio_vsock_pkt * 36 virtio_transport_alloc_pkt(struct virtio_vsock_pkt_info *info, 37 size_t len, 38 u32 src_cid, 39 u32 src_port, 40 u32 dst_cid, 41 u32 dst_port) 42 { 43 struct virtio_vsock_pkt *pkt; 44 int err; 45 46 pkt = kzalloc(sizeof(*pkt), GFP_KERNEL); 47 if (!pkt) 48 return NULL; 49 50 pkt->hdr.type = cpu_to_le16(info->type); 51 pkt->hdr.op = cpu_to_le16(info->op); 52 pkt->hdr.src_cid = cpu_to_le64(src_cid); 53 pkt->hdr.dst_cid = cpu_to_le64(dst_cid); 54 pkt->hdr.src_port = cpu_to_le32(src_port); 55 pkt->hdr.dst_port = cpu_to_le32(dst_port); 56 pkt->hdr.flags = cpu_to_le32(info->flags); 57 pkt->len = len; 58 pkt->hdr.len = cpu_to_le32(len); 59 pkt->reply = info->reply; 60 61 if (info->msg && len > 0) { 62 pkt->buf = kmalloc(len, GFP_KERNEL); 63 if (!pkt->buf) 64 goto out_pkt; 65 err = memcpy_from_msg(pkt->buf, info->msg, len); 66 if (err) 67 goto out; 68 } 69 70 trace_virtio_transport_alloc_pkt(src_cid, src_port, 71 dst_cid, dst_port, 72 len, 73 info->type, 74 info->op, 75 info->flags); 76 77 return pkt; 78 79 out: 80 kfree(pkt->buf); 81 out_pkt: 82 kfree(pkt); 83 return NULL; 84 } 85 EXPORT_SYMBOL_GPL(virtio_transport_alloc_pkt); 86 87 static int virtio_transport_send_pkt_info(struct vsock_sock *vsk, 88 struct virtio_vsock_pkt_info *info) 89 { 90 u32 src_cid, src_port, dst_cid, dst_port; 91 struct virtio_vsock_sock *vvs; 92 struct virtio_vsock_pkt *pkt; 93 u32 pkt_len = info->pkt_len; 94 95 src_cid = vm_sockets_get_local_cid(); 96 src_port = vsk->local_addr.svm_port; 97 if (!info->remote_cid) { 98 dst_cid = vsk->remote_addr.svm_cid; 99 dst_port = vsk->remote_addr.svm_port; 100 } else { 101 dst_cid = info->remote_cid; 102 dst_port = info->remote_port; 103 } 104 105 vvs = vsk->trans; 106 107 /* we can send less than pkt_len bytes */ 108 if (pkt_len > VIRTIO_VSOCK_DEFAULT_RX_BUF_SIZE) 109 pkt_len = VIRTIO_VSOCK_DEFAULT_RX_BUF_SIZE; 110 111 /* virtio_transport_get_credit might return less than pkt_len credit */ 112 pkt_len = virtio_transport_get_credit(vvs, pkt_len); 113 114 /* Do not send zero length OP_RW pkt */ 115 if (pkt_len == 0 && info->op == VIRTIO_VSOCK_OP_RW) 116 return pkt_len; 117 118 pkt = virtio_transport_alloc_pkt(info, pkt_len, 119 src_cid, src_port, 120 dst_cid, dst_port); 121 if (!pkt) { 122 virtio_transport_put_credit(vvs, pkt_len); 123 return -ENOMEM; 124 } 125 126 virtio_transport_inc_tx_pkt(vvs, pkt); 127 128 return virtio_transport_get_ops()->send_pkt(pkt); 129 } 130 131 static void virtio_transport_inc_rx_pkt(struct virtio_vsock_sock *vvs, 132 struct virtio_vsock_pkt *pkt) 133 { 134 vvs->rx_bytes += pkt->len; 135 } 136 137 static void virtio_transport_dec_rx_pkt(struct virtio_vsock_sock *vvs, 138 struct virtio_vsock_pkt *pkt) 139 { 140 vvs->rx_bytes -= pkt->len; 141 vvs->fwd_cnt += pkt->len; 142 } 143 144 void virtio_transport_inc_tx_pkt(struct virtio_vsock_sock *vvs, struct virtio_vsock_pkt *pkt) 145 { 146 spin_lock_bh(&vvs->tx_lock); 147 pkt->hdr.fwd_cnt = cpu_to_le32(vvs->fwd_cnt); 148 pkt->hdr.buf_alloc = cpu_to_le32(vvs->buf_alloc); 149 spin_unlock_bh(&vvs->tx_lock); 150 } 151 EXPORT_SYMBOL_GPL(virtio_transport_inc_tx_pkt); 152 153 u32 virtio_transport_get_credit(struct virtio_vsock_sock *vvs, u32 credit) 154 { 155 u32 ret; 156 157 spin_lock_bh(&vvs->tx_lock); 158 ret = vvs->peer_buf_alloc - (vvs->tx_cnt - vvs->peer_fwd_cnt); 159 if (ret > credit) 160 ret = credit; 161 vvs->tx_cnt += ret; 162 spin_unlock_bh(&vvs->tx_lock); 163 164 return ret; 165 } 166 EXPORT_SYMBOL_GPL(virtio_transport_get_credit); 167 168 void virtio_transport_put_credit(struct virtio_vsock_sock *vvs, u32 credit) 169 { 170 spin_lock_bh(&vvs->tx_lock); 171 vvs->tx_cnt -= credit; 172 spin_unlock_bh(&vvs->tx_lock); 173 } 174 EXPORT_SYMBOL_GPL(virtio_transport_put_credit); 175 176 static int virtio_transport_send_credit_update(struct vsock_sock *vsk, 177 int type, 178 struct virtio_vsock_hdr *hdr) 179 { 180 struct virtio_vsock_pkt_info info = { 181 .op = VIRTIO_VSOCK_OP_CREDIT_UPDATE, 182 .type = type, 183 }; 184 185 return virtio_transport_send_pkt_info(vsk, &info); 186 } 187 188 static ssize_t 189 virtio_transport_stream_do_dequeue(struct vsock_sock *vsk, 190 struct msghdr *msg, 191 size_t len) 192 { 193 struct virtio_vsock_sock *vvs = vsk->trans; 194 struct virtio_vsock_pkt *pkt; 195 size_t bytes, total = 0; 196 int err = -EFAULT; 197 198 spin_lock_bh(&vvs->rx_lock); 199 while (total < len && !list_empty(&vvs->rx_queue)) { 200 pkt = list_first_entry(&vvs->rx_queue, 201 struct virtio_vsock_pkt, list); 202 203 bytes = len - total; 204 if (bytes > pkt->len - pkt->off) 205 bytes = pkt->len - pkt->off; 206 207 /* sk_lock is held by caller so no one else can dequeue. 208 * Unlock rx_lock since memcpy_to_msg() may sleep. 209 */ 210 spin_unlock_bh(&vvs->rx_lock); 211 212 err = memcpy_to_msg(msg, pkt->buf + pkt->off, bytes); 213 if (err) 214 goto out; 215 216 spin_lock_bh(&vvs->rx_lock); 217 218 total += bytes; 219 pkt->off += bytes; 220 if (pkt->off == pkt->len) { 221 virtio_transport_dec_rx_pkt(vvs, pkt); 222 list_del(&pkt->list); 223 virtio_transport_free_pkt(pkt); 224 } 225 } 226 spin_unlock_bh(&vvs->rx_lock); 227 228 /* Send a credit pkt to peer */ 229 virtio_transport_send_credit_update(vsk, VIRTIO_VSOCK_TYPE_STREAM, 230 NULL); 231 232 return total; 233 234 out: 235 if (total) 236 err = total; 237 return err; 238 } 239 240 ssize_t 241 virtio_transport_stream_dequeue(struct vsock_sock *vsk, 242 struct msghdr *msg, 243 size_t len, int flags) 244 { 245 if (flags & MSG_PEEK) 246 return -EOPNOTSUPP; 247 248 return virtio_transport_stream_do_dequeue(vsk, msg, len); 249 } 250 EXPORT_SYMBOL_GPL(virtio_transport_stream_dequeue); 251 252 int 253 virtio_transport_dgram_dequeue(struct vsock_sock *vsk, 254 struct msghdr *msg, 255 size_t len, int flags) 256 { 257 return -EOPNOTSUPP; 258 } 259 EXPORT_SYMBOL_GPL(virtio_transport_dgram_dequeue); 260 261 s64 virtio_transport_stream_has_data(struct vsock_sock *vsk) 262 { 263 struct virtio_vsock_sock *vvs = vsk->trans; 264 s64 bytes; 265 266 spin_lock_bh(&vvs->rx_lock); 267 bytes = vvs->rx_bytes; 268 spin_unlock_bh(&vvs->rx_lock); 269 270 return bytes; 271 } 272 EXPORT_SYMBOL_GPL(virtio_transport_stream_has_data); 273 274 static s64 virtio_transport_has_space(struct vsock_sock *vsk) 275 { 276 struct virtio_vsock_sock *vvs = vsk->trans; 277 s64 bytes; 278 279 bytes = vvs->peer_buf_alloc - (vvs->tx_cnt - vvs->peer_fwd_cnt); 280 if (bytes < 0) 281 bytes = 0; 282 283 return bytes; 284 } 285 286 s64 virtio_transport_stream_has_space(struct vsock_sock *vsk) 287 { 288 struct virtio_vsock_sock *vvs = vsk->trans; 289 s64 bytes; 290 291 spin_lock_bh(&vvs->tx_lock); 292 bytes = virtio_transport_has_space(vsk); 293 spin_unlock_bh(&vvs->tx_lock); 294 295 return bytes; 296 } 297 EXPORT_SYMBOL_GPL(virtio_transport_stream_has_space); 298 299 int virtio_transport_do_socket_init(struct vsock_sock *vsk, 300 struct vsock_sock *psk) 301 { 302 struct virtio_vsock_sock *vvs; 303 304 vvs = kzalloc(sizeof(*vvs), GFP_KERNEL); 305 if (!vvs) 306 return -ENOMEM; 307 308 vsk->trans = vvs; 309 vvs->vsk = vsk; 310 if (psk) { 311 struct virtio_vsock_sock *ptrans = psk->trans; 312 313 vvs->buf_size = ptrans->buf_size; 314 vvs->buf_size_min = ptrans->buf_size_min; 315 vvs->buf_size_max = ptrans->buf_size_max; 316 vvs->peer_buf_alloc = ptrans->peer_buf_alloc; 317 } else { 318 vvs->buf_size = VIRTIO_VSOCK_DEFAULT_BUF_SIZE; 319 vvs->buf_size_min = VIRTIO_VSOCK_DEFAULT_MIN_BUF_SIZE; 320 vvs->buf_size_max = VIRTIO_VSOCK_DEFAULT_MAX_BUF_SIZE; 321 } 322 323 vvs->buf_alloc = vvs->buf_size; 324 325 spin_lock_init(&vvs->rx_lock); 326 spin_lock_init(&vvs->tx_lock); 327 INIT_LIST_HEAD(&vvs->rx_queue); 328 329 return 0; 330 } 331 EXPORT_SYMBOL_GPL(virtio_transport_do_socket_init); 332 333 u64 virtio_transport_get_buffer_size(struct vsock_sock *vsk) 334 { 335 struct virtio_vsock_sock *vvs = vsk->trans; 336 337 return vvs->buf_size; 338 } 339 EXPORT_SYMBOL_GPL(virtio_transport_get_buffer_size); 340 341 u64 virtio_transport_get_min_buffer_size(struct vsock_sock *vsk) 342 { 343 struct virtio_vsock_sock *vvs = vsk->trans; 344 345 return vvs->buf_size_min; 346 } 347 EXPORT_SYMBOL_GPL(virtio_transport_get_min_buffer_size); 348 349 u64 virtio_transport_get_max_buffer_size(struct vsock_sock *vsk) 350 { 351 struct virtio_vsock_sock *vvs = vsk->trans; 352 353 return vvs->buf_size_max; 354 } 355 EXPORT_SYMBOL_GPL(virtio_transport_get_max_buffer_size); 356 357 void virtio_transport_set_buffer_size(struct vsock_sock *vsk, u64 val) 358 { 359 struct virtio_vsock_sock *vvs = vsk->trans; 360 361 if (val > VIRTIO_VSOCK_MAX_BUF_SIZE) 362 val = VIRTIO_VSOCK_MAX_BUF_SIZE; 363 if (val < vvs->buf_size_min) 364 vvs->buf_size_min = val; 365 if (val > vvs->buf_size_max) 366 vvs->buf_size_max = val; 367 vvs->buf_size = val; 368 vvs->buf_alloc = val; 369 } 370 EXPORT_SYMBOL_GPL(virtio_transport_set_buffer_size); 371 372 void virtio_transport_set_min_buffer_size(struct vsock_sock *vsk, u64 val) 373 { 374 struct virtio_vsock_sock *vvs = vsk->trans; 375 376 if (val > VIRTIO_VSOCK_MAX_BUF_SIZE) 377 val = VIRTIO_VSOCK_MAX_BUF_SIZE; 378 if (val > vvs->buf_size) 379 vvs->buf_size = val; 380 vvs->buf_size_min = val; 381 } 382 EXPORT_SYMBOL_GPL(virtio_transport_set_min_buffer_size); 383 384 void virtio_transport_set_max_buffer_size(struct vsock_sock *vsk, u64 val) 385 { 386 struct virtio_vsock_sock *vvs = vsk->trans; 387 388 if (val > VIRTIO_VSOCK_MAX_BUF_SIZE) 389 val = VIRTIO_VSOCK_MAX_BUF_SIZE; 390 if (val < vvs->buf_size) 391 vvs->buf_size = val; 392 vvs->buf_size_max = val; 393 } 394 EXPORT_SYMBOL_GPL(virtio_transport_set_max_buffer_size); 395 396 int 397 virtio_transport_notify_poll_in(struct vsock_sock *vsk, 398 size_t target, 399 bool *data_ready_now) 400 { 401 if (vsock_stream_has_data(vsk)) 402 *data_ready_now = true; 403 else 404 *data_ready_now = false; 405 406 return 0; 407 } 408 EXPORT_SYMBOL_GPL(virtio_transport_notify_poll_in); 409 410 int 411 virtio_transport_notify_poll_out(struct vsock_sock *vsk, 412 size_t target, 413 bool *space_avail_now) 414 { 415 s64 free_space; 416 417 free_space = vsock_stream_has_space(vsk); 418 if (free_space > 0) 419 *space_avail_now = true; 420 else if (free_space == 0) 421 *space_avail_now = false; 422 423 return 0; 424 } 425 EXPORT_SYMBOL_GPL(virtio_transport_notify_poll_out); 426 427 int virtio_transport_notify_recv_init(struct vsock_sock *vsk, 428 size_t target, struct vsock_transport_recv_notify_data *data) 429 { 430 return 0; 431 } 432 EXPORT_SYMBOL_GPL(virtio_transport_notify_recv_init); 433 434 int virtio_transport_notify_recv_pre_block(struct vsock_sock *vsk, 435 size_t target, struct vsock_transport_recv_notify_data *data) 436 { 437 return 0; 438 } 439 EXPORT_SYMBOL_GPL(virtio_transport_notify_recv_pre_block); 440 441 int virtio_transport_notify_recv_pre_dequeue(struct vsock_sock *vsk, 442 size_t target, struct vsock_transport_recv_notify_data *data) 443 { 444 return 0; 445 } 446 EXPORT_SYMBOL_GPL(virtio_transport_notify_recv_pre_dequeue); 447 448 int virtio_transport_notify_recv_post_dequeue(struct vsock_sock *vsk, 449 size_t target, ssize_t copied, bool data_read, 450 struct vsock_transport_recv_notify_data *data) 451 { 452 return 0; 453 } 454 EXPORT_SYMBOL_GPL(virtio_transport_notify_recv_post_dequeue); 455 456 int virtio_transport_notify_send_init(struct vsock_sock *vsk, 457 struct vsock_transport_send_notify_data *data) 458 { 459 return 0; 460 } 461 EXPORT_SYMBOL_GPL(virtio_transport_notify_send_init); 462 463 int virtio_transport_notify_send_pre_block(struct vsock_sock *vsk, 464 struct vsock_transport_send_notify_data *data) 465 { 466 return 0; 467 } 468 EXPORT_SYMBOL_GPL(virtio_transport_notify_send_pre_block); 469 470 int virtio_transport_notify_send_pre_enqueue(struct vsock_sock *vsk, 471 struct vsock_transport_send_notify_data *data) 472 { 473 return 0; 474 } 475 EXPORT_SYMBOL_GPL(virtio_transport_notify_send_pre_enqueue); 476 477 int virtio_transport_notify_send_post_enqueue(struct vsock_sock *vsk, 478 ssize_t written, struct vsock_transport_send_notify_data *data) 479 { 480 return 0; 481 } 482 EXPORT_SYMBOL_GPL(virtio_transport_notify_send_post_enqueue); 483 484 u64 virtio_transport_stream_rcvhiwat(struct vsock_sock *vsk) 485 { 486 struct virtio_vsock_sock *vvs = vsk->trans; 487 488 return vvs->buf_size; 489 } 490 EXPORT_SYMBOL_GPL(virtio_transport_stream_rcvhiwat); 491 492 bool virtio_transport_stream_is_active(struct vsock_sock *vsk) 493 { 494 return true; 495 } 496 EXPORT_SYMBOL_GPL(virtio_transport_stream_is_active); 497 498 bool virtio_transport_stream_allow(u32 cid, u32 port) 499 { 500 return true; 501 } 502 EXPORT_SYMBOL_GPL(virtio_transport_stream_allow); 503 504 int virtio_transport_dgram_bind(struct vsock_sock *vsk, 505 struct sockaddr_vm *addr) 506 { 507 return -EOPNOTSUPP; 508 } 509 EXPORT_SYMBOL_GPL(virtio_transport_dgram_bind); 510 511 bool virtio_transport_dgram_allow(u32 cid, u32 port) 512 { 513 return false; 514 } 515 EXPORT_SYMBOL_GPL(virtio_transport_dgram_allow); 516 517 int virtio_transport_connect(struct vsock_sock *vsk) 518 { 519 struct virtio_vsock_pkt_info info = { 520 .op = VIRTIO_VSOCK_OP_REQUEST, 521 .type = VIRTIO_VSOCK_TYPE_STREAM, 522 }; 523 524 return virtio_transport_send_pkt_info(vsk, &info); 525 } 526 EXPORT_SYMBOL_GPL(virtio_transport_connect); 527 528 int virtio_transport_shutdown(struct vsock_sock *vsk, int mode) 529 { 530 struct virtio_vsock_pkt_info info = { 531 .op = VIRTIO_VSOCK_OP_SHUTDOWN, 532 .type = VIRTIO_VSOCK_TYPE_STREAM, 533 .flags = (mode & RCV_SHUTDOWN ? 534 VIRTIO_VSOCK_SHUTDOWN_RCV : 0) | 535 (mode & SEND_SHUTDOWN ? 536 VIRTIO_VSOCK_SHUTDOWN_SEND : 0), 537 }; 538 539 return virtio_transport_send_pkt_info(vsk, &info); 540 } 541 EXPORT_SYMBOL_GPL(virtio_transport_shutdown); 542 543 int 544 virtio_transport_dgram_enqueue(struct vsock_sock *vsk, 545 struct sockaddr_vm *remote_addr, 546 struct msghdr *msg, 547 size_t dgram_len) 548 { 549 return -EOPNOTSUPP; 550 } 551 EXPORT_SYMBOL_GPL(virtio_transport_dgram_enqueue); 552 553 ssize_t 554 virtio_transport_stream_enqueue(struct vsock_sock *vsk, 555 struct msghdr *msg, 556 size_t len) 557 { 558 struct virtio_vsock_pkt_info info = { 559 .op = VIRTIO_VSOCK_OP_RW, 560 .type = VIRTIO_VSOCK_TYPE_STREAM, 561 .msg = msg, 562 .pkt_len = len, 563 }; 564 565 return virtio_transport_send_pkt_info(vsk, &info); 566 } 567 EXPORT_SYMBOL_GPL(virtio_transport_stream_enqueue); 568 569 void virtio_transport_destruct(struct vsock_sock *vsk) 570 { 571 struct virtio_vsock_sock *vvs = vsk->trans; 572 573 kfree(vvs); 574 } 575 EXPORT_SYMBOL_GPL(virtio_transport_destruct); 576 577 static int virtio_transport_reset(struct vsock_sock *vsk, 578 struct virtio_vsock_pkt *pkt) 579 { 580 struct virtio_vsock_pkt_info info = { 581 .op = VIRTIO_VSOCK_OP_RST, 582 .type = VIRTIO_VSOCK_TYPE_STREAM, 583 .reply = !!pkt, 584 }; 585 586 /* Send RST only if the original pkt is not a RST pkt */ 587 if (pkt && le16_to_cpu(pkt->hdr.op) == VIRTIO_VSOCK_OP_RST) 588 return 0; 589 590 return virtio_transport_send_pkt_info(vsk, &info); 591 } 592 593 /* Normally packets are associated with a socket. There may be no socket if an 594 * attempt was made to connect to a socket that does not exist. 595 */ 596 static int virtio_transport_reset_no_sock(struct virtio_vsock_pkt *pkt) 597 { 598 struct virtio_vsock_pkt_info info = { 599 .op = VIRTIO_VSOCK_OP_RST, 600 .type = le16_to_cpu(pkt->hdr.type), 601 .reply = true, 602 }; 603 604 /* Send RST only if the original pkt is not a RST pkt */ 605 if (le16_to_cpu(pkt->hdr.op) == VIRTIO_VSOCK_OP_RST) 606 return 0; 607 608 pkt = virtio_transport_alloc_pkt(&info, 0, 609 le32_to_cpu(pkt->hdr.dst_cid), 610 le32_to_cpu(pkt->hdr.dst_port), 611 le32_to_cpu(pkt->hdr.src_cid), 612 le32_to_cpu(pkt->hdr.src_port)); 613 if (!pkt) 614 return -ENOMEM; 615 616 return virtio_transport_get_ops()->send_pkt(pkt); 617 } 618 619 static void virtio_transport_wait_close(struct sock *sk, long timeout) 620 { 621 if (timeout) { 622 DEFINE_WAIT(wait); 623 624 do { 625 prepare_to_wait(sk_sleep(sk), &wait, 626 TASK_INTERRUPTIBLE); 627 if (sk_wait_event(sk, &timeout, 628 sock_flag(sk, SOCK_DONE))) 629 break; 630 } while (!signal_pending(current) && timeout); 631 632 finish_wait(sk_sleep(sk), &wait); 633 } 634 } 635 636 static void virtio_transport_do_close(struct vsock_sock *vsk, 637 bool cancel_timeout) 638 { 639 struct sock *sk = sk_vsock(vsk); 640 641 sock_set_flag(sk, SOCK_DONE); 642 vsk->peer_shutdown = SHUTDOWN_MASK; 643 if (vsock_stream_has_data(vsk) <= 0) 644 sk->sk_state = SS_DISCONNECTING; 645 sk->sk_state_change(sk); 646 647 if (vsk->close_work_scheduled && 648 (!cancel_timeout || cancel_delayed_work(&vsk->close_work))) { 649 vsk->close_work_scheduled = false; 650 651 vsock_remove_sock(vsk); 652 653 /* Release refcnt obtained when we scheduled the timeout */ 654 sock_put(sk); 655 } 656 } 657 658 static void virtio_transport_close_timeout(struct work_struct *work) 659 { 660 struct vsock_sock *vsk = 661 container_of(work, struct vsock_sock, close_work.work); 662 struct sock *sk = sk_vsock(vsk); 663 664 sock_hold(sk); 665 lock_sock(sk); 666 667 if (!sock_flag(sk, SOCK_DONE)) { 668 (void)virtio_transport_reset(vsk, NULL); 669 670 virtio_transport_do_close(vsk, false); 671 } 672 673 vsk->close_work_scheduled = false; 674 675 release_sock(sk); 676 sock_put(sk); 677 } 678 679 /* User context, vsk->sk is locked */ 680 static bool virtio_transport_close(struct vsock_sock *vsk) 681 { 682 struct sock *sk = &vsk->sk; 683 684 if (!(sk->sk_state == SS_CONNECTED || 685 sk->sk_state == SS_DISCONNECTING)) 686 return true; 687 688 /* Already received SHUTDOWN from peer, reply with RST */ 689 if ((vsk->peer_shutdown & SHUTDOWN_MASK) == SHUTDOWN_MASK) { 690 (void)virtio_transport_reset(vsk, NULL); 691 return true; 692 } 693 694 if ((sk->sk_shutdown & SHUTDOWN_MASK) != SHUTDOWN_MASK) 695 (void)virtio_transport_shutdown(vsk, SHUTDOWN_MASK); 696 697 if (sock_flag(sk, SOCK_LINGER) && !(current->flags & PF_EXITING)) 698 virtio_transport_wait_close(sk, sk->sk_lingertime); 699 700 if (sock_flag(sk, SOCK_DONE)) { 701 return true; 702 } 703 704 sock_hold(sk); 705 INIT_DELAYED_WORK(&vsk->close_work, 706 virtio_transport_close_timeout); 707 vsk->close_work_scheduled = true; 708 schedule_delayed_work(&vsk->close_work, VSOCK_CLOSE_TIMEOUT); 709 return false; 710 } 711 712 void virtio_transport_release(struct vsock_sock *vsk) 713 { 714 struct sock *sk = &vsk->sk; 715 bool remove_sock = true; 716 717 lock_sock(sk); 718 if (sk->sk_type == SOCK_STREAM) 719 remove_sock = virtio_transport_close(vsk); 720 release_sock(sk); 721 722 if (remove_sock) 723 vsock_remove_sock(vsk); 724 } 725 EXPORT_SYMBOL_GPL(virtio_transport_release); 726 727 static int 728 virtio_transport_recv_connecting(struct sock *sk, 729 struct virtio_vsock_pkt *pkt) 730 { 731 struct vsock_sock *vsk = vsock_sk(sk); 732 int err; 733 int skerr; 734 735 switch (le16_to_cpu(pkt->hdr.op)) { 736 case VIRTIO_VSOCK_OP_RESPONSE: 737 sk->sk_state = SS_CONNECTED; 738 sk->sk_socket->state = SS_CONNECTED; 739 vsock_insert_connected(vsk); 740 sk->sk_state_change(sk); 741 break; 742 case VIRTIO_VSOCK_OP_INVALID: 743 break; 744 case VIRTIO_VSOCK_OP_RST: 745 skerr = ECONNRESET; 746 err = 0; 747 goto destroy; 748 default: 749 skerr = EPROTO; 750 err = -EINVAL; 751 goto destroy; 752 } 753 return 0; 754 755 destroy: 756 virtio_transport_reset(vsk, pkt); 757 sk->sk_state = SS_UNCONNECTED; 758 sk->sk_err = skerr; 759 sk->sk_error_report(sk); 760 return err; 761 } 762 763 static int 764 virtio_transport_recv_connected(struct sock *sk, 765 struct virtio_vsock_pkt *pkt) 766 { 767 struct vsock_sock *vsk = vsock_sk(sk); 768 struct virtio_vsock_sock *vvs = vsk->trans; 769 int err = 0; 770 771 switch (le16_to_cpu(pkt->hdr.op)) { 772 case VIRTIO_VSOCK_OP_RW: 773 pkt->len = le32_to_cpu(pkt->hdr.len); 774 pkt->off = 0; 775 776 spin_lock_bh(&vvs->rx_lock); 777 virtio_transport_inc_rx_pkt(vvs, pkt); 778 list_add_tail(&pkt->list, &vvs->rx_queue); 779 spin_unlock_bh(&vvs->rx_lock); 780 781 sk->sk_data_ready(sk); 782 return err; 783 case VIRTIO_VSOCK_OP_CREDIT_UPDATE: 784 sk->sk_write_space(sk); 785 break; 786 case VIRTIO_VSOCK_OP_SHUTDOWN: 787 if (le32_to_cpu(pkt->hdr.flags) & VIRTIO_VSOCK_SHUTDOWN_RCV) 788 vsk->peer_shutdown |= RCV_SHUTDOWN; 789 if (le32_to_cpu(pkt->hdr.flags) & VIRTIO_VSOCK_SHUTDOWN_SEND) 790 vsk->peer_shutdown |= SEND_SHUTDOWN; 791 if (vsk->peer_shutdown == SHUTDOWN_MASK && 792 vsock_stream_has_data(vsk) <= 0) 793 sk->sk_state = SS_DISCONNECTING; 794 if (le32_to_cpu(pkt->hdr.flags)) 795 sk->sk_state_change(sk); 796 break; 797 case VIRTIO_VSOCK_OP_RST: 798 virtio_transport_do_close(vsk, true); 799 break; 800 default: 801 err = -EINVAL; 802 break; 803 } 804 805 virtio_transport_free_pkt(pkt); 806 return err; 807 } 808 809 static void 810 virtio_transport_recv_disconnecting(struct sock *sk, 811 struct virtio_vsock_pkt *pkt) 812 { 813 struct vsock_sock *vsk = vsock_sk(sk); 814 815 if (le16_to_cpu(pkt->hdr.op) == VIRTIO_VSOCK_OP_RST) 816 virtio_transport_do_close(vsk, true); 817 } 818 819 static int 820 virtio_transport_send_response(struct vsock_sock *vsk, 821 struct virtio_vsock_pkt *pkt) 822 { 823 struct virtio_vsock_pkt_info info = { 824 .op = VIRTIO_VSOCK_OP_RESPONSE, 825 .type = VIRTIO_VSOCK_TYPE_STREAM, 826 .remote_cid = le32_to_cpu(pkt->hdr.src_cid), 827 .remote_port = le32_to_cpu(pkt->hdr.src_port), 828 .reply = true, 829 }; 830 831 return virtio_transport_send_pkt_info(vsk, &info); 832 } 833 834 /* Handle server socket */ 835 static int 836 virtio_transport_recv_listen(struct sock *sk, struct virtio_vsock_pkt *pkt) 837 { 838 struct vsock_sock *vsk = vsock_sk(sk); 839 struct vsock_sock *vchild; 840 struct sock *child; 841 842 if (le16_to_cpu(pkt->hdr.op) != VIRTIO_VSOCK_OP_REQUEST) { 843 virtio_transport_reset(vsk, pkt); 844 return -EINVAL; 845 } 846 847 if (sk_acceptq_is_full(sk)) { 848 virtio_transport_reset(vsk, pkt); 849 return -ENOMEM; 850 } 851 852 child = __vsock_create(sock_net(sk), NULL, sk, GFP_KERNEL, 853 sk->sk_type, 0); 854 if (!child) { 855 virtio_transport_reset(vsk, pkt); 856 return -ENOMEM; 857 } 858 859 sk->sk_ack_backlog++; 860 861 lock_sock_nested(child, SINGLE_DEPTH_NESTING); 862 863 child->sk_state = SS_CONNECTED; 864 865 vchild = vsock_sk(child); 866 vsock_addr_init(&vchild->local_addr, le32_to_cpu(pkt->hdr.dst_cid), 867 le32_to_cpu(pkt->hdr.dst_port)); 868 vsock_addr_init(&vchild->remote_addr, le32_to_cpu(pkt->hdr.src_cid), 869 le32_to_cpu(pkt->hdr.src_port)); 870 871 vsock_insert_connected(vchild); 872 vsock_enqueue_accept(sk, child); 873 virtio_transport_send_response(vchild, pkt); 874 875 release_sock(child); 876 877 sk->sk_data_ready(sk); 878 return 0; 879 } 880 881 static bool virtio_transport_space_update(struct sock *sk, 882 struct virtio_vsock_pkt *pkt) 883 { 884 struct vsock_sock *vsk = vsock_sk(sk); 885 struct virtio_vsock_sock *vvs = vsk->trans; 886 bool space_available; 887 888 /* buf_alloc and fwd_cnt is always included in the hdr */ 889 spin_lock_bh(&vvs->tx_lock); 890 vvs->peer_buf_alloc = le32_to_cpu(pkt->hdr.buf_alloc); 891 vvs->peer_fwd_cnt = le32_to_cpu(pkt->hdr.fwd_cnt); 892 space_available = virtio_transport_has_space(vsk); 893 spin_unlock_bh(&vvs->tx_lock); 894 return space_available; 895 } 896 897 /* We are under the virtio-vsock's vsock->rx_lock or vhost-vsock's vq->mutex 898 * lock. 899 */ 900 void virtio_transport_recv_pkt(struct virtio_vsock_pkt *pkt) 901 { 902 struct sockaddr_vm src, dst; 903 struct vsock_sock *vsk; 904 struct sock *sk; 905 bool space_available; 906 907 vsock_addr_init(&src, le32_to_cpu(pkt->hdr.src_cid), 908 le32_to_cpu(pkt->hdr.src_port)); 909 vsock_addr_init(&dst, le32_to_cpu(pkt->hdr.dst_cid), 910 le32_to_cpu(pkt->hdr.dst_port)); 911 912 trace_virtio_transport_recv_pkt(src.svm_cid, src.svm_port, 913 dst.svm_cid, dst.svm_port, 914 le32_to_cpu(pkt->hdr.len), 915 le16_to_cpu(pkt->hdr.type), 916 le16_to_cpu(pkt->hdr.op), 917 le32_to_cpu(pkt->hdr.flags), 918 le32_to_cpu(pkt->hdr.buf_alloc), 919 le32_to_cpu(pkt->hdr.fwd_cnt)); 920 921 if (le16_to_cpu(pkt->hdr.type) != VIRTIO_VSOCK_TYPE_STREAM) { 922 (void)virtio_transport_reset_no_sock(pkt); 923 goto free_pkt; 924 } 925 926 /* The socket must be in connected or bound table 927 * otherwise send reset back 928 */ 929 sk = vsock_find_connected_socket(&src, &dst); 930 if (!sk) { 931 sk = vsock_find_bound_socket(&dst); 932 if (!sk) { 933 (void)virtio_transport_reset_no_sock(pkt); 934 goto free_pkt; 935 } 936 } 937 938 vsk = vsock_sk(sk); 939 940 space_available = virtio_transport_space_update(sk, pkt); 941 942 lock_sock(sk); 943 944 /* Update CID in case it has changed after a transport reset event */ 945 vsk->local_addr.svm_cid = dst.svm_cid; 946 947 if (space_available) 948 sk->sk_write_space(sk); 949 950 switch (sk->sk_state) { 951 case VSOCK_SS_LISTEN: 952 virtio_transport_recv_listen(sk, pkt); 953 virtio_transport_free_pkt(pkt); 954 break; 955 case SS_CONNECTING: 956 virtio_transport_recv_connecting(sk, pkt); 957 virtio_transport_free_pkt(pkt); 958 break; 959 case SS_CONNECTED: 960 virtio_transport_recv_connected(sk, pkt); 961 break; 962 case SS_DISCONNECTING: 963 virtio_transport_recv_disconnecting(sk, pkt); 964 virtio_transport_free_pkt(pkt); 965 break; 966 default: 967 virtio_transport_free_pkt(pkt); 968 break; 969 } 970 release_sock(sk); 971 972 /* Release refcnt obtained when we fetched this socket out of the 973 * bound or connected list. 974 */ 975 sock_put(sk); 976 return; 977 978 free_pkt: 979 virtio_transport_free_pkt(pkt); 980 } 981 EXPORT_SYMBOL_GPL(virtio_transport_recv_pkt); 982 983 void virtio_transport_free_pkt(struct virtio_vsock_pkt *pkt) 984 { 985 kfree(pkt->buf); 986 kfree(pkt); 987 } 988 EXPORT_SYMBOL_GPL(virtio_transport_free_pkt); 989 990 MODULE_LICENSE("GPL v2"); 991 MODULE_AUTHOR("Asias He"); 992 MODULE_DESCRIPTION("common code for virtio vsock"); 993