1 /* 2 * common code for virtio vsock 3 * 4 * Copyright (C) 2013-2015 Red Hat, Inc. 5 * Author: Asias He <asias@redhat.com> 6 * Stefan Hajnoczi <stefanha@redhat.com> 7 * 8 * This work is licensed under the terms of the GNU GPL, version 2. 9 */ 10 #include <linux/spinlock.h> 11 #include <linux/module.h> 12 #include <linux/sched/signal.h> 13 #include <linux/ctype.h> 14 #include <linux/list.h> 15 #include <linux/virtio.h> 16 #include <linux/virtio_ids.h> 17 #include <linux/virtio_config.h> 18 #include <linux/virtio_vsock.h> 19 20 #include <net/sock.h> 21 #include <net/af_vsock.h> 22 23 #define CREATE_TRACE_POINTS 24 #include <trace/events/vsock_virtio_transport_common.h> 25 26 /* How long to wait for graceful shutdown of a connection */ 27 #define VSOCK_CLOSE_TIMEOUT (8 * HZ) 28 29 static const struct virtio_transport *virtio_transport_get_ops(void) 30 { 31 const struct vsock_transport *t = vsock_core_get_transport(); 32 33 return container_of(t, struct virtio_transport, transport); 34 } 35 36 static struct virtio_vsock_pkt * 37 virtio_transport_alloc_pkt(struct virtio_vsock_pkt_info *info, 38 size_t len, 39 u32 src_cid, 40 u32 src_port, 41 u32 dst_cid, 42 u32 dst_port) 43 { 44 struct virtio_vsock_pkt *pkt; 45 int err; 46 47 pkt = kzalloc(sizeof(*pkt), GFP_KERNEL); 48 if (!pkt) 49 return NULL; 50 51 pkt->hdr.type = cpu_to_le16(info->type); 52 pkt->hdr.op = cpu_to_le16(info->op); 53 pkt->hdr.src_cid = cpu_to_le64(src_cid); 54 pkt->hdr.dst_cid = cpu_to_le64(dst_cid); 55 pkt->hdr.src_port = cpu_to_le32(src_port); 56 pkt->hdr.dst_port = cpu_to_le32(dst_port); 57 pkt->hdr.flags = cpu_to_le32(info->flags); 58 pkt->len = len; 59 pkt->hdr.len = cpu_to_le32(len); 60 pkt->reply = info->reply; 61 pkt->vsk = info->vsk; 62 63 if (info->msg && len > 0) { 64 pkt->buf = kmalloc(len, GFP_KERNEL); 65 if (!pkt->buf) 66 goto out_pkt; 67 err = memcpy_from_msg(pkt->buf, info->msg, len); 68 if (err) 69 goto out; 70 } 71 72 trace_virtio_transport_alloc_pkt(src_cid, src_port, 73 dst_cid, dst_port, 74 len, 75 info->type, 76 info->op, 77 info->flags); 78 79 return pkt; 80 81 out: 82 kfree(pkt->buf); 83 out_pkt: 84 kfree(pkt); 85 return NULL; 86 } 87 88 static int virtio_transport_send_pkt_info(struct vsock_sock *vsk, 89 struct virtio_vsock_pkt_info *info) 90 { 91 u32 src_cid, src_port, dst_cid, dst_port; 92 struct virtio_vsock_sock *vvs; 93 struct virtio_vsock_pkt *pkt; 94 u32 pkt_len = info->pkt_len; 95 96 src_cid = vm_sockets_get_local_cid(); 97 src_port = vsk->local_addr.svm_port; 98 if (!info->remote_cid) { 99 dst_cid = vsk->remote_addr.svm_cid; 100 dst_port = vsk->remote_addr.svm_port; 101 } else { 102 dst_cid = info->remote_cid; 103 dst_port = info->remote_port; 104 } 105 106 vvs = vsk->trans; 107 108 /* we can send less than pkt_len bytes */ 109 if (pkt_len > VIRTIO_VSOCK_DEFAULT_RX_BUF_SIZE) 110 pkt_len = VIRTIO_VSOCK_DEFAULT_RX_BUF_SIZE; 111 112 /* virtio_transport_get_credit might return less than pkt_len credit */ 113 pkt_len = virtio_transport_get_credit(vvs, pkt_len); 114 115 /* Do not send zero length OP_RW pkt */ 116 if (pkt_len == 0 && info->op == VIRTIO_VSOCK_OP_RW) 117 return pkt_len; 118 119 pkt = virtio_transport_alloc_pkt(info, pkt_len, 120 src_cid, src_port, 121 dst_cid, dst_port); 122 if (!pkt) { 123 virtio_transport_put_credit(vvs, pkt_len); 124 return -ENOMEM; 125 } 126 127 virtio_transport_inc_tx_pkt(vvs, pkt); 128 129 return virtio_transport_get_ops()->send_pkt(pkt); 130 } 131 132 static void virtio_transport_inc_rx_pkt(struct virtio_vsock_sock *vvs, 133 struct virtio_vsock_pkt *pkt) 134 { 135 vvs->rx_bytes += pkt->len; 136 } 137 138 static void virtio_transport_dec_rx_pkt(struct virtio_vsock_sock *vvs, 139 struct virtio_vsock_pkt *pkt) 140 { 141 vvs->rx_bytes -= pkt->len; 142 vvs->fwd_cnt += pkt->len; 143 } 144 145 void virtio_transport_inc_tx_pkt(struct virtio_vsock_sock *vvs, struct virtio_vsock_pkt *pkt) 146 { 147 spin_lock_bh(&vvs->tx_lock); 148 pkt->hdr.fwd_cnt = cpu_to_le32(vvs->fwd_cnt); 149 pkt->hdr.buf_alloc = cpu_to_le32(vvs->buf_alloc); 150 spin_unlock_bh(&vvs->tx_lock); 151 } 152 EXPORT_SYMBOL_GPL(virtio_transport_inc_tx_pkt); 153 154 u32 virtio_transport_get_credit(struct virtio_vsock_sock *vvs, u32 credit) 155 { 156 u32 ret; 157 158 spin_lock_bh(&vvs->tx_lock); 159 ret = vvs->peer_buf_alloc - (vvs->tx_cnt - vvs->peer_fwd_cnt); 160 if (ret > credit) 161 ret = credit; 162 vvs->tx_cnt += ret; 163 spin_unlock_bh(&vvs->tx_lock); 164 165 return ret; 166 } 167 EXPORT_SYMBOL_GPL(virtio_transport_get_credit); 168 169 void virtio_transport_put_credit(struct virtio_vsock_sock *vvs, u32 credit) 170 { 171 spin_lock_bh(&vvs->tx_lock); 172 vvs->tx_cnt -= credit; 173 spin_unlock_bh(&vvs->tx_lock); 174 } 175 EXPORT_SYMBOL_GPL(virtio_transport_put_credit); 176 177 static int virtio_transport_send_credit_update(struct vsock_sock *vsk, 178 int type, 179 struct virtio_vsock_hdr *hdr) 180 { 181 struct virtio_vsock_pkt_info info = { 182 .op = VIRTIO_VSOCK_OP_CREDIT_UPDATE, 183 .type = type, 184 .vsk = vsk, 185 }; 186 187 return virtio_transport_send_pkt_info(vsk, &info); 188 } 189 190 static ssize_t 191 virtio_transport_stream_do_dequeue(struct vsock_sock *vsk, 192 struct msghdr *msg, 193 size_t len) 194 { 195 struct virtio_vsock_sock *vvs = vsk->trans; 196 struct virtio_vsock_pkt *pkt; 197 size_t bytes, total = 0; 198 int err = -EFAULT; 199 200 spin_lock_bh(&vvs->rx_lock); 201 while (total < len && !list_empty(&vvs->rx_queue)) { 202 pkt = list_first_entry(&vvs->rx_queue, 203 struct virtio_vsock_pkt, list); 204 205 bytes = len - total; 206 if (bytes > pkt->len - pkt->off) 207 bytes = pkt->len - pkt->off; 208 209 /* sk_lock is held by caller so no one else can dequeue. 210 * Unlock rx_lock since memcpy_to_msg() may sleep. 211 */ 212 spin_unlock_bh(&vvs->rx_lock); 213 214 err = memcpy_to_msg(msg, pkt->buf + pkt->off, bytes); 215 if (err) 216 goto out; 217 218 spin_lock_bh(&vvs->rx_lock); 219 220 total += bytes; 221 pkt->off += bytes; 222 if (pkt->off == pkt->len) { 223 virtio_transport_dec_rx_pkt(vvs, pkt); 224 list_del(&pkt->list); 225 virtio_transport_free_pkt(pkt); 226 } 227 } 228 spin_unlock_bh(&vvs->rx_lock); 229 230 /* Send a credit pkt to peer */ 231 virtio_transport_send_credit_update(vsk, VIRTIO_VSOCK_TYPE_STREAM, 232 NULL); 233 234 return total; 235 236 out: 237 if (total) 238 err = total; 239 return err; 240 } 241 242 ssize_t 243 virtio_transport_stream_dequeue(struct vsock_sock *vsk, 244 struct msghdr *msg, 245 size_t len, int flags) 246 { 247 if (flags & MSG_PEEK) 248 return -EOPNOTSUPP; 249 250 return virtio_transport_stream_do_dequeue(vsk, msg, len); 251 } 252 EXPORT_SYMBOL_GPL(virtio_transport_stream_dequeue); 253 254 int 255 virtio_transport_dgram_dequeue(struct vsock_sock *vsk, 256 struct msghdr *msg, 257 size_t len, int flags) 258 { 259 return -EOPNOTSUPP; 260 } 261 EXPORT_SYMBOL_GPL(virtio_transport_dgram_dequeue); 262 263 s64 virtio_transport_stream_has_data(struct vsock_sock *vsk) 264 { 265 struct virtio_vsock_sock *vvs = vsk->trans; 266 s64 bytes; 267 268 spin_lock_bh(&vvs->rx_lock); 269 bytes = vvs->rx_bytes; 270 spin_unlock_bh(&vvs->rx_lock); 271 272 return bytes; 273 } 274 EXPORT_SYMBOL_GPL(virtio_transport_stream_has_data); 275 276 static s64 virtio_transport_has_space(struct vsock_sock *vsk) 277 { 278 struct virtio_vsock_sock *vvs = vsk->trans; 279 s64 bytes; 280 281 bytes = vvs->peer_buf_alloc - (vvs->tx_cnt - vvs->peer_fwd_cnt); 282 if (bytes < 0) 283 bytes = 0; 284 285 return bytes; 286 } 287 288 s64 virtio_transport_stream_has_space(struct vsock_sock *vsk) 289 { 290 struct virtio_vsock_sock *vvs = vsk->trans; 291 s64 bytes; 292 293 spin_lock_bh(&vvs->tx_lock); 294 bytes = virtio_transport_has_space(vsk); 295 spin_unlock_bh(&vvs->tx_lock); 296 297 return bytes; 298 } 299 EXPORT_SYMBOL_GPL(virtio_transport_stream_has_space); 300 301 int virtio_transport_do_socket_init(struct vsock_sock *vsk, 302 struct vsock_sock *psk) 303 { 304 struct virtio_vsock_sock *vvs; 305 306 vvs = kzalloc(sizeof(*vvs), GFP_KERNEL); 307 if (!vvs) 308 return -ENOMEM; 309 310 vsk->trans = vvs; 311 vvs->vsk = vsk; 312 if (psk) { 313 struct virtio_vsock_sock *ptrans = psk->trans; 314 315 vvs->buf_size = ptrans->buf_size; 316 vvs->buf_size_min = ptrans->buf_size_min; 317 vvs->buf_size_max = ptrans->buf_size_max; 318 vvs->peer_buf_alloc = ptrans->peer_buf_alloc; 319 } else { 320 vvs->buf_size = VIRTIO_VSOCK_DEFAULT_BUF_SIZE; 321 vvs->buf_size_min = VIRTIO_VSOCK_DEFAULT_MIN_BUF_SIZE; 322 vvs->buf_size_max = VIRTIO_VSOCK_DEFAULT_MAX_BUF_SIZE; 323 } 324 325 vvs->buf_alloc = vvs->buf_size; 326 327 spin_lock_init(&vvs->rx_lock); 328 spin_lock_init(&vvs->tx_lock); 329 INIT_LIST_HEAD(&vvs->rx_queue); 330 331 return 0; 332 } 333 EXPORT_SYMBOL_GPL(virtio_transport_do_socket_init); 334 335 u64 virtio_transport_get_buffer_size(struct vsock_sock *vsk) 336 { 337 struct virtio_vsock_sock *vvs = vsk->trans; 338 339 return vvs->buf_size; 340 } 341 EXPORT_SYMBOL_GPL(virtio_transport_get_buffer_size); 342 343 u64 virtio_transport_get_min_buffer_size(struct vsock_sock *vsk) 344 { 345 struct virtio_vsock_sock *vvs = vsk->trans; 346 347 return vvs->buf_size_min; 348 } 349 EXPORT_SYMBOL_GPL(virtio_transport_get_min_buffer_size); 350 351 u64 virtio_transport_get_max_buffer_size(struct vsock_sock *vsk) 352 { 353 struct virtio_vsock_sock *vvs = vsk->trans; 354 355 return vvs->buf_size_max; 356 } 357 EXPORT_SYMBOL_GPL(virtio_transport_get_max_buffer_size); 358 359 void virtio_transport_set_buffer_size(struct vsock_sock *vsk, u64 val) 360 { 361 struct virtio_vsock_sock *vvs = vsk->trans; 362 363 if (val > VIRTIO_VSOCK_MAX_BUF_SIZE) 364 val = VIRTIO_VSOCK_MAX_BUF_SIZE; 365 if (val < vvs->buf_size_min) 366 vvs->buf_size_min = val; 367 if (val > vvs->buf_size_max) 368 vvs->buf_size_max = val; 369 vvs->buf_size = val; 370 vvs->buf_alloc = val; 371 } 372 EXPORT_SYMBOL_GPL(virtio_transport_set_buffer_size); 373 374 void virtio_transport_set_min_buffer_size(struct vsock_sock *vsk, u64 val) 375 { 376 struct virtio_vsock_sock *vvs = vsk->trans; 377 378 if (val > VIRTIO_VSOCK_MAX_BUF_SIZE) 379 val = VIRTIO_VSOCK_MAX_BUF_SIZE; 380 if (val > vvs->buf_size) 381 vvs->buf_size = val; 382 vvs->buf_size_min = val; 383 } 384 EXPORT_SYMBOL_GPL(virtio_transport_set_min_buffer_size); 385 386 void virtio_transport_set_max_buffer_size(struct vsock_sock *vsk, u64 val) 387 { 388 struct virtio_vsock_sock *vvs = vsk->trans; 389 390 if (val > VIRTIO_VSOCK_MAX_BUF_SIZE) 391 val = VIRTIO_VSOCK_MAX_BUF_SIZE; 392 if (val < vvs->buf_size) 393 vvs->buf_size = val; 394 vvs->buf_size_max = val; 395 } 396 EXPORT_SYMBOL_GPL(virtio_transport_set_max_buffer_size); 397 398 int 399 virtio_transport_notify_poll_in(struct vsock_sock *vsk, 400 size_t target, 401 bool *data_ready_now) 402 { 403 if (vsock_stream_has_data(vsk)) 404 *data_ready_now = true; 405 else 406 *data_ready_now = false; 407 408 return 0; 409 } 410 EXPORT_SYMBOL_GPL(virtio_transport_notify_poll_in); 411 412 int 413 virtio_transport_notify_poll_out(struct vsock_sock *vsk, 414 size_t target, 415 bool *space_avail_now) 416 { 417 s64 free_space; 418 419 free_space = vsock_stream_has_space(vsk); 420 if (free_space > 0) 421 *space_avail_now = true; 422 else if (free_space == 0) 423 *space_avail_now = false; 424 425 return 0; 426 } 427 EXPORT_SYMBOL_GPL(virtio_transport_notify_poll_out); 428 429 int virtio_transport_notify_recv_init(struct vsock_sock *vsk, 430 size_t target, struct vsock_transport_recv_notify_data *data) 431 { 432 return 0; 433 } 434 EXPORT_SYMBOL_GPL(virtio_transport_notify_recv_init); 435 436 int virtio_transport_notify_recv_pre_block(struct vsock_sock *vsk, 437 size_t target, struct vsock_transport_recv_notify_data *data) 438 { 439 return 0; 440 } 441 EXPORT_SYMBOL_GPL(virtio_transport_notify_recv_pre_block); 442 443 int virtio_transport_notify_recv_pre_dequeue(struct vsock_sock *vsk, 444 size_t target, struct vsock_transport_recv_notify_data *data) 445 { 446 return 0; 447 } 448 EXPORT_SYMBOL_GPL(virtio_transport_notify_recv_pre_dequeue); 449 450 int virtio_transport_notify_recv_post_dequeue(struct vsock_sock *vsk, 451 size_t target, ssize_t copied, bool data_read, 452 struct vsock_transport_recv_notify_data *data) 453 { 454 return 0; 455 } 456 EXPORT_SYMBOL_GPL(virtio_transport_notify_recv_post_dequeue); 457 458 int virtio_transport_notify_send_init(struct vsock_sock *vsk, 459 struct vsock_transport_send_notify_data *data) 460 { 461 return 0; 462 } 463 EXPORT_SYMBOL_GPL(virtio_transport_notify_send_init); 464 465 int virtio_transport_notify_send_pre_block(struct vsock_sock *vsk, 466 struct vsock_transport_send_notify_data *data) 467 { 468 return 0; 469 } 470 EXPORT_SYMBOL_GPL(virtio_transport_notify_send_pre_block); 471 472 int virtio_transport_notify_send_pre_enqueue(struct vsock_sock *vsk, 473 struct vsock_transport_send_notify_data *data) 474 { 475 return 0; 476 } 477 EXPORT_SYMBOL_GPL(virtio_transport_notify_send_pre_enqueue); 478 479 int virtio_transport_notify_send_post_enqueue(struct vsock_sock *vsk, 480 ssize_t written, struct vsock_transport_send_notify_data *data) 481 { 482 return 0; 483 } 484 EXPORT_SYMBOL_GPL(virtio_transport_notify_send_post_enqueue); 485 486 u64 virtio_transport_stream_rcvhiwat(struct vsock_sock *vsk) 487 { 488 struct virtio_vsock_sock *vvs = vsk->trans; 489 490 return vvs->buf_size; 491 } 492 EXPORT_SYMBOL_GPL(virtio_transport_stream_rcvhiwat); 493 494 bool virtio_transport_stream_is_active(struct vsock_sock *vsk) 495 { 496 return true; 497 } 498 EXPORT_SYMBOL_GPL(virtio_transport_stream_is_active); 499 500 bool virtio_transport_stream_allow(u32 cid, u32 port) 501 { 502 return true; 503 } 504 EXPORT_SYMBOL_GPL(virtio_transport_stream_allow); 505 506 int virtio_transport_dgram_bind(struct vsock_sock *vsk, 507 struct sockaddr_vm *addr) 508 { 509 return -EOPNOTSUPP; 510 } 511 EXPORT_SYMBOL_GPL(virtio_transport_dgram_bind); 512 513 bool virtio_transport_dgram_allow(u32 cid, u32 port) 514 { 515 return false; 516 } 517 EXPORT_SYMBOL_GPL(virtio_transport_dgram_allow); 518 519 int virtio_transport_connect(struct vsock_sock *vsk) 520 { 521 struct virtio_vsock_pkt_info info = { 522 .op = VIRTIO_VSOCK_OP_REQUEST, 523 .type = VIRTIO_VSOCK_TYPE_STREAM, 524 .vsk = vsk, 525 }; 526 527 return virtio_transport_send_pkt_info(vsk, &info); 528 } 529 EXPORT_SYMBOL_GPL(virtio_transport_connect); 530 531 int virtio_transport_shutdown(struct vsock_sock *vsk, int mode) 532 { 533 struct virtio_vsock_pkt_info info = { 534 .op = VIRTIO_VSOCK_OP_SHUTDOWN, 535 .type = VIRTIO_VSOCK_TYPE_STREAM, 536 .flags = (mode & RCV_SHUTDOWN ? 537 VIRTIO_VSOCK_SHUTDOWN_RCV : 0) | 538 (mode & SEND_SHUTDOWN ? 539 VIRTIO_VSOCK_SHUTDOWN_SEND : 0), 540 .vsk = vsk, 541 }; 542 543 return virtio_transport_send_pkt_info(vsk, &info); 544 } 545 EXPORT_SYMBOL_GPL(virtio_transport_shutdown); 546 547 int 548 virtio_transport_dgram_enqueue(struct vsock_sock *vsk, 549 struct sockaddr_vm *remote_addr, 550 struct msghdr *msg, 551 size_t dgram_len) 552 { 553 return -EOPNOTSUPP; 554 } 555 EXPORT_SYMBOL_GPL(virtio_transport_dgram_enqueue); 556 557 ssize_t 558 virtio_transport_stream_enqueue(struct vsock_sock *vsk, 559 struct msghdr *msg, 560 size_t len) 561 { 562 struct virtio_vsock_pkt_info info = { 563 .op = VIRTIO_VSOCK_OP_RW, 564 .type = VIRTIO_VSOCK_TYPE_STREAM, 565 .msg = msg, 566 .pkt_len = len, 567 .vsk = vsk, 568 }; 569 570 return virtio_transport_send_pkt_info(vsk, &info); 571 } 572 EXPORT_SYMBOL_GPL(virtio_transport_stream_enqueue); 573 574 void virtio_transport_destruct(struct vsock_sock *vsk) 575 { 576 struct virtio_vsock_sock *vvs = vsk->trans; 577 578 kfree(vvs); 579 } 580 EXPORT_SYMBOL_GPL(virtio_transport_destruct); 581 582 static int virtio_transport_reset(struct vsock_sock *vsk, 583 struct virtio_vsock_pkt *pkt) 584 { 585 struct virtio_vsock_pkt_info info = { 586 .op = VIRTIO_VSOCK_OP_RST, 587 .type = VIRTIO_VSOCK_TYPE_STREAM, 588 .reply = !!pkt, 589 .vsk = vsk, 590 }; 591 592 /* Send RST only if the original pkt is not a RST pkt */ 593 if (pkt && le16_to_cpu(pkt->hdr.op) == VIRTIO_VSOCK_OP_RST) 594 return 0; 595 596 return virtio_transport_send_pkt_info(vsk, &info); 597 } 598 599 /* Normally packets are associated with a socket. There may be no socket if an 600 * attempt was made to connect to a socket that does not exist. 601 */ 602 static int virtio_transport_reset_no_sock(struct virtio_vsock_pkt *pkt) 603 { 604 struct virtio_vsock_pkt_info info = { 605 .op = VIRTIO_VSOCK_OP_RST, 606 .type = le16_to_cpu(pkt->hdr.type), 607 .reply = true, 608 }; 609 610 /* Send RST only if the original pkt is not a RST pkt */ 611 if (le16_to_cpu(pkt->hdr.op) == VIRTIO_VSOCK_OP_RST) 612 return 0; 613 614 pkt = virtio_transport_alloc_pkt(&info, 0, 615 le64_to_cpu(pkt->hdr.dst_cid), 616 le32_to_cpu(pkt->hdr.dst_port), 617 le64_to_cpu(pkt->hdr.src_cid), 618 le32_to_cpu(pkt->hdr.src_port)); 619 if (!pkt) 620 return -ENOMEM; 621 622 return virtio_transport_get_ops()->send_pkt(pkt); 623 } 624 625 static void virtio_transport_wait_close(struct sock *sk, long timeout) 626 { 627 if (timeout) { 628 DEFINE_WAIT_FUNC(wait, woken_wake_function); 629 630 add_wait_queue(sk_sleep(sk), &wait); 631 632 do { 633 if (sk_wait_event(sk, &timeout, 634 sock_flag(sk, SOCK_DONE), &wait)) 635 break; 636 } while (!signal_pending(current) && timeout); 637 638 remove_wait_queue(sk_sleep(sk), &wait); 639 } 640 } 641 642 static void virtio_transport_do_close(struct vsock_sock *vsk, 643 bool cancel_timeout) 644 { 645 struct sock *sk = sk_vsock(vsk); 646 647 sock_set_flag(sk, SOCK_DONE); 648 vsk->peer_shutdown = SHUTDOWN_MASK; 649 if (vsock_stream_has_data(vsk) <= 0) 650 sk->sk_state = SS_DISCONNECTING; 651 sk->sk_state_change(sk); 652 653 if (vsk->close_work_scheduled && 654 (!cancel_timeout || cancel_delayed_work(&vsk->close_work))) { 655 vsk->close_work_scheduled = false; 656 657 vsock_remove_sock(vsk); 658 659 /* Release refcnt obtained when we scheduled the timeout */ 660 sock_put(sk); 661 } 662 } 663 664 static void virtio_transport_close_timeout(struct work_struct *work) 665 { 666 struct vsock_sock *vsk = 667 container_of(work, struct vsock_sock, close_work.work); 668 struct sock *sk = sk_vsock(vsk); 669 670 sock_hold(sk); 671 lock_sock(sk); 672 673 if (!sock_flag(sk, SOCK_DONE)) { 674 (void)virtio_transport_reset(vsk, NULL); 675 676 virtio_transport_do_close(vsk, false); 677 } 678 679 vsk->close_work_scheduled = false; 680 681 release_sock(sk); 682 sock_put(sk); 683 } 684 685 /* User context, vsk->sk is locked */ 686 static bool virtio_transport_close(struct vsock_sock *vsk) 687 { 688 struct sock *sk = &vsk->sk; 689 690 if (!(sk->sk_state == SS_CONNECTED || 691 sk->sk_state == SS_DISCONNECTING)) 692 return true; 693 694 /* Already received SHUTDOWN from peer, reply with RST */ 695 if ((vsk->peer_shutdown & SHUTDOWN_MASK) == SHUTDOWN_MASK) { 696 (void)virtio_transport_reset(vsk, NULL); 697 return true; 698 } 699 700 if ((sk->sk_shutdown & SHUTDOWN_MASK) != SHUTDOWN_MASK) 701 (void)virtio_transport_shutdown(vsk, SHUTDOWN_MASK); 702 703 if (sock_flag(sk, SOCK_LINGER) && !(current->flags & PF_EXITING)) 704 virtio_transport_wait_close(sk, sk->sk_lingertime); 705 706 if (sock_flag(sk, SOCK_DONE)) { 707 return true; 708 } 709 710 sock_hold(sk); 711 INIT_DELAYED_WORK(&vsk->close_work, 712 virtio_transport_close_timeout); 713 vsk->close_work_scheduled = true; 714 schedule_delayed_work(&vsk->close_work, VSOCK_CLOSE_TIMEOUT); 715 return false; 716 } 717 718 void virtio_transport_release(struct vsock_sock *vsk) 719 { 720 struct sock *sk = &vsk->sk; 721 bool remove_sock = true; 722 723 lock_sock(sk); 724 if (sk->sk_type == SOCK_STREAM) 725 remove_sock = virtio_transport_close(vsk); 726 release_sock(sk); 727 728 if (remove_sock) 729 vsock_remove_sock(vsk); 730 } 731 EXPORT_SYMBOL_GPL(virtio_transport_release); 732 733 static int 734 virtio_transport_recv_connecting(struct sock *sk, 735 struct virtio_vsock_pkt *pkt) 736 { 737 struct vsock_sock *vsk = vsock_sk(sk); 738 int err; 739 int skerr; 740 741 switch (le16_to_cpu(pkt->hdr.op)) { 742 case VIRTIO_VSOCK_OP_RESPONSE: 743 sk->sk_state = SS_CONNECTED; 744 sk->sk_socket->state = SS_CONNECTED; 745 vsock_insert_connected(vsk); 746 sk->sk_state_change(sk); 747 break; 748 case VIRTIO_VSOCK_OP_INVALID: 749 break; 750 case VIRTIO_VSOCK_OP_RST: 751 skerr = ECONNRESET; 752 err = 0; 753 goto destroy; 754 default: 755 skerr = EPROTO; 756 err = -EINVAL; 757 goto destroy; 758 } 759 return 0; 760 761 destroy: 762 virtio_transport_reset(vsk, pkt); 763 sk->sk_state = SS_UNCONNECTED; 764 sk->sk_err = skerr; 765 sk->sk_error_report(sk); 766 return err; 767 } 768 769 static int 770 virtio_transport_recv_connected(struct sock *sk, 771 struct virtio_vsock_pkt *pkt) 772 { 773 struct vsock_sock *vsk = vsock_sk(sk); 774 struct virtio_vsock_sock *vvs = vsk->trans; 775 int err = 0; 776 777 switch (le16_to_cpu(pkt->hdr.op)) { 778 case VIRTIO_VSOCK_OP_RW: 779 pkt->len = le32_to_cpu(pkt->hdr.len); 780 pkt->off = 0; 781 782 spin_lock_bh(&vvs->rx_lock); 783 virtio_transport_inc_rx_pkt(vvs, pkt); 784 list_add_tail(&pkt->list, &vvs->rx_queue); 785 spin_unlock_bh(&vvs->rx_lock); 786 787 sk->sk_data_ready(sk); 788 return err; 789 case VIRTIO_VSOCK_OP_CREDIT_UPDATE: 790 sk->sk_write_space(sk); 791 break; 792 case VIRTIO_VSOCK_OP_SHUTDOWN: 793 if (le32_to_cpu(pkt->hdr.flags) & VIRTIO_VSOCK_SHUTDOWN_RCV) 794 vsk->peer_shutdown |= RCV_SHUTDOWN; 795 if (le32_to_cpu(pkt->hdr.flags) & VIRTIO_VSOCK_SHUTDOWN_SEND) 796 vsk->peer_shutdown |= SEND_SHUTDOWN; 797 if (vsk->peer_shutdown == SHUTDOWN_MASK && 798 vsock_stream_has_data(vsk) <= 0) 799 sk->sk_state = SS_DISCONNECTING; 800 if (le32_to_cpu(pkt->hdr.flags)) 801 sk->sk_state_change(sk); 802 break; 803 case VIRTIO_VSOCK_OP_RST: 804 virtio_transport_do_close(vsk, true); 805 break; 806 default: 807 err = -EINVAL; 808 break; 809 } 810 811 virtio_transport_free_pkt(pkt); 812 return err; 813 } 814 815 static void 816 virtio_transport_recv_disconnecting(struct sock *sk, 817 struct virtio_vsock_pkt *pkt) 818 { 819 struct vsock_sock *vsk = vsock_sk(sk); 820 821 if (le16_to_cpu(pkt->hdr.op) == VIRTIO_VSOCK_OP_RST) 822 virtio_transport_do_close(vsk, true); 823 } 824 825 static int 826 virtio_transport_send_response(struct vsock_sock *vsk, 827 struct virtio_vsock_pkt *pkt) 828 { 829 struct virtio_vsock_pkt_info info = { 830 .op = VIRTIO_VSOCK_OP_RESPONSE, 831 .type = VIRTIO_VSOCK_TYPE_STREAM, 832 .remote_cid = le64_to_cpu(pkt->hdr.src_cid), 833 .remote_port = le32_to_cpu(pkt->hdr.src_port), 834 .reply = true, 835 .vsk = vsk, 836 }; 837 838 return virtio_transport_send_pkt_info(vsk, &info); 839 } 840 841 /* Handle server socket */ 842 static int 843 virtio_transport_recv_listen(struct sock *sk, struct virtio_vsock_pkt *pkt) 844 { 845 struct vsock_sock *vsk = vsock_sk(sk); 846 struct vsock_sock *vchild; 847 struct sock *child; 848 849 if (le16_to_cpu(pkt->hdr.op) != VIRTIO_VSOCK_OP_REQUEST) { 850 virtio_transport_reset(vsk, pkt); 851 return -EINVAL; 852 } 853 854 if (sk_acceptq_is_full(sk)) { 855 virtio_transport_reset(vsk, pkt); 856 return -ENOMEM; 857 } 858 859 child = __vsock_create(sock_net(sk), NULL, sk, GFP_KERNEL, 860 sk->sk_type, 0); 861 if (!child) { 862 virtio_transport_reset(vsk, pkt); 863 return -ENOMEM; 864 } 865 866 sk->sk_ack_backlog++; 867 868 lock_sock_nested(child, SINGLE_DEPTH_NESTING); 869 870 child->sk_state = SS_CONNECTED; 871 872 vchild = vsock_sk(child); 873 vsock_addr_init(&vchild->local_addr, le64_to_cpu(pkt->hdr.dst_cid), 874 le32_to_cpu(pkt->hdr.dst_port)); 875 vsock_addr_init(&vchild->remote_addr, le64_to_cpu(pkt->hdr.src_cid), 876 le32_to_cpu(pkt->hdr.src_port)); 877 878 vsock_insert_connected(vchild); 879 vsock_enqueue_accept(sk, child); 880 virtio_transport_send_response(vchild, pkt); 881 882 release_sock(child); 883 884 sk->sk_data_ready(sk); 885 return 0; 886 } 887 888 static bool virtio_transport_space_update(struct sock *sk, 889 struct virtio_vsock_pkt *pkt) 890 { 891 struct vsock_sock *vsk = vsock_sk(sk); 892 struct virtio_vsock_sock *vvs = vsk->trans; 893 bool space_available; 894 895 /* buf_alloc and fwd_cnt is always included in the hdr */ 896 spin_lock_bh(&vvs->tx_lock); 897 vvs->peer_buf_alloc = le32_to_cpu(pkt->hdr.buf_alloc); 898 vvs->peer_fwd_cnt = le32_to_cpu(pkt->hdr.fwd_cnt); 899 space_available = virtio_transport_has_space(vsk); 900 spin_unlock_bh(&vvs->tx_lock); 901 return space_available; 902 } 903 904 /* We are under the virtio-vsock's vsock->rx_lock or vhost-vsock's vq->mutex 905 * lock. 906 */ 907 void virtio_transport_recv_pkt(struct virtio_vsock_pkt *pkt) 908 { 909 struct sockaddr_vm src, dst; 910 struct vsock_sock *vsk; 911 struct sock *sk; 912 bool space_available; 913 914 vsock_addr_init(&src, le64_to_cpu(pkt->hdr.src_cid), 915 le32_to_cpu(pkt->hdr.src_port)); 916 vsock_addr_init(&dst, le64_to_cpu(pkt->hdr.dst_cid), 917 le32_to_cpu(pkt->hdr.dst_port)); 918 919 trace_virtio_transport_recv_pkt(src.svm_cid, src.svm_port, 920 dst.svm_cid, dst.svm_port, 921 le32_to_cpu(pkt->hdr.len), 922 le16_to_cpu(pkt->hdr.type), 923 le16_to_cpu(pkt->hdr.op), 924 le32_to_cpu(pkt->hdr.flags), 925 le32_to_cpu(pkt->hdr.buf_alloc), 926 le32_to_cpu(pkt->hdr.fwd_cnt)); 927 928 if (le16_to_cpu(pkt->hdr.type) != VIRTIO_VSOCK_TYPE_STREAM) { 929 (void)virtio_transport_reset_no_sock(pkt); 930 goto free_pkt; 931 } 932 933 /* The socket must be in connected or bound table 934 * otherwise send reset back 935 */ 936 sk = vsock_find_connected_socket(&src, &dst); 937 if (!sk) { 938 sk = vsock_find_bound_socket(&dst); 939 if (!sk) { 940 (void)virtio_transport_reset_no_sock(pkt); 941 goto free_pkt; 942 } 943 } 944 945 vsk = vsock_sk(sk); 946 947 space_available = virtio_transport_space_update(sk, pkt); 948 949 lock_sock(sk); 950 951 /* Update CID in case it has changed after a transport reset event */ 952 vsk->local_addr.svm_cid = dst.svm_cid; 953 954 if (space_available) 955 sk->sk_write_space(sk); 956 957 switch (sk->sk_state) { 958 case VSOCK_SS_LISTEN: 959 virtio_transport_recv_listen(sk, pkt); 960 virtio_transport_free_pkt(pkt); 961 break; 962 case SS_CONNECTING: 963 virtio_transport_recv_connecting(sk, pkt); 964 virtio_transport_free_pkt(pkt); 965 break; 966 case SS_CONNECTED: 967 virtio_transport_recv_connected(sk, pkt); 968 break; 969 case SS_DISCONNECTING: 970 virtio_transport_recv_disconnecting(sk, pkt); 971 virtio_transport_free_pkt(pkt); 972 break; 973 default: 974 virtio_transport_free_pkt(pkt); 975 break; 976 } 977 release_sock(sk); 978 979 /* Release refcnt obtained when we fetched this socket out of the 980 * bound or connected list. 981 */ 982 sock_put(sk); 983 return; 984 985 free_pkt: 986 virtio_transport_free_pkt(pkt); 987 } 988 EXPORT_SYMBOL_GPL(virtio_transport_recv_pkt); 989 990 void virtio_transport_free_pkt(struct virtio_vsock_pkt *pkt) 991 { 992 kfree(pkt->buf); 993 kfree(pkt); 994 } 995 EXPORT_SYMBOL_GPL(virtio_transport_free_pkt); 996 997 MODULE_LICENSE("GPL v2"); 998 MODULE_AUTHOR("Asias He"); 999 MODULE_DESCRIPTION("common code for virtio vsock"); 1000